## Prepare Dataset for Segmentation Training

### Data Sources
- [NLM Tuberculosis Chest X-ray Image Data Sets](https://lhncbc.nlm.nih.gov/publication/pub9931)
- [Shenzhen subset segmentation masks](https://www.kaggle.com/yoctoman/shcxr-lung-mask)
- Additional non-public, manually segmented images were used but are not included in this notebook

In [1]:
from fastai.vision.all import *
import numpy as np
import cv2
import os
import sys
import shutil
from sklearn.model_selection import train_test_split
import re
import pandas as pd

dataset_dir = f"/data/datasets/NLM-shenzhen-montgomery/original"

In [2]:
shutil.rmtree(dataset_dir, ignore_errors=True)
os.makedirs(os.path.join(dataset_dir, "train", "xray"))
os.makedirs(os.path.join(dataset_dir, "train", "mask"))
os.makedirs(os.path.join(dataset_dir, "test", "xray"))
os.makedirs(os.path.join(dataset_dir, "test", "mask"))

def show_img_stats(img):
    print(img.shape)
    print(np.min(img), np.max(img))
    print(np.unique(img))

def save_image(image_path, mask, split_dir, out_name, dilate_kernel=None):
    mask = np.round(mask / 255.).astype(np.uint8)
    if np.max(mask) > 1:
        print(image_path, np.unique(mask))
    os.link(image_path, os.path.join(dataset_dir, split_dir, "xray", out_name))
    cv2.imwrite(os.path.join(dataset_dir, split_dir, "mask", out_name.replace(".jpg", ".png")), mask)

In [3]:
montgomery_dir = "/data/data/NLM-shenzhen-montgomery/NLM-MontgomeryCXRSet/MontgomerySet"
montgomery_files = sorted(get_image_files(os.path.join(montgomery_dir, "CXR_png")))
montgomery_half_masks = [sorted(get_image_files(os.path.join(montgomery_dir, "ManualMask", "leftMask"))),
                    sorted(get_image_files(os.path.join(montgomery_dir, "ManualMask", "rightMask")))]
print("Montgomery:", len(montgomery_files), len(montgomery_half_masks[0]), len(montgomery_half_masks[1]))

# montgomery_train, montgomery_test = train_test_split(montgomery_files, test_size=0.2, random_state=2020)

n_train = len(montgomery_files)
n_test = 0
count_m = {"train": 0, "test": 0}
DILATE_KERNEL = np.ones((7, 7), np.uint8)

for (image_file, lmask_file, rmask_file) in zip(*([montgomery_files] + montgomery_half_masks)):
    bimage = os.path.basename(image_file)
    blmask = os.path.basename(lmask_file)
    brmask = os.path.basename(rmask_file)
    
    if bimage != blmask or blmask != brmask:
        print(blmask, brmask)
        sys.exit(1)
    
    left_mask = cv2.imread(str(lmask_file), cv2.IMREAD_GRAYSCALE)
    right_mask = cv2.imread(str(rmask_file), cv2.IMREAD_GRAYSCALE)
    
    mask = np.maximum(left_mask, right_mask)
    
#     split_dir = "train" if image_file in montgomery_train else "test"
    split_dir = "train"
    count_m[split_dir] += 1
    
    save_image(str(image_file), mask, split_dir, f"montgomery_{bimage}")
    
    sys.stdout.write(f"\r{count_m['train']} / {n_train}, {count_m['test']} / {n_test}")

Montgomery: 138 138 138
138 / 138, 0 / 0

In [4]:
shenzhen_dir = "/data/data/NLM-shenzhen-montgomery/ChinaSet_AllFiles/"
shenzhen_xray_dir = os.path.join(shenzhen_dir, "CXR_png")
shenzhen_files = sorted(get_image_files(shenzhen_xray_dir))
shenzhen_masks = sorted(get_image_files(os.path.join(shenzhen_dir, "shcxr-lung-mask", "mask")))
print("Shenzhen:", len(shenzhen_files), len(shenzhen_masks))
# shenzhen_train, shenzhen_test = train_test_split(shenzhen_masks, test_size=0.2, random_state=2019)

n_train = len(shenzhen_masks)
n_test = 0
count_s = {"train": 0, "test": 0}
DILATE_KERNEL = np.ones((15, 15), np.uint8)

shenzhen_pattern = re.compile(r"(.*)_mask(.*)")

for mask_file in shenzhen_masks:
    bmask = os.path.basename(mask_file)
    m = shenzhen_pattern.match(bmask)
    bimage = "{}{}".format(m.group(1), m.group(2))
    xray_file = os.path.join(shenzhen_xray_dir, bimage)
    
    if not os.path.exists(xray_file):
        print(xray_file)
        sys.exit(1)
    
    mask = cv2.imread(str(mask_file), cv2.IMREAD_GRAYSCALE)
    
#     split_dir = "train" if mask_file in shenzhen_train else "test"
    split_dir = "train"
    count_s[split_dir] += 1
    
    save_image(str(xray_file), mask, split_dir, f"shenzhen_{bimage}")
    
    sys.stdout.write(f"\r{count_s['train']} / {n_train}, {count_s['test']} / {n_test}")

Shenzhen: 662 566
566 / 566, 0 / 0