In [None]:
import nibabel as nib
import numpy as np
import os
from glob import glob
import matplotlib.pyplot as plt
from tqdm import tqdm
import sys
while os.getcwd().split("/")[-1] != "YODA":
    os.chdir("..")

from skimage.filters import gaussian
from skimage.filters import threshold_otsu
from scipy.ndimage import binary_fill_holes
    
from YODA.preprocessing.register_images import conform

Preprocessing (conforming) the pre-registered Gold Atlas dataset

In [None]:
path = "../data/gold_atlas/resampled"
target = "../data/gold_atlas/conformed"

In [None]:
for subj in os.listdir(path):
    os.makedirs(os.path.join(target, subj), exist_ok=True)
    conform(os.path.join(path, subj, "t1.nii.gz"), os.path.join(target, subj, "t1.nii.gz"))
    conform(os.path.join(path, subj, "t2.nii.gz"), os.path.join(target, subj, "t2.nii.gz"))
    img = nib.load(os.path.join(path, subj, "ct.nii.gz"))
    data = img.get_fdata()
    data -= np.percentile(data, 1)
    data /= np.percentile(data, 99.9)
    data = np.clip(data, 0, 1)
    data *= 255.99
    img = nib.Nifti1Image(data.astype(np.uint8), img.affine, img.header)
    nib.save(img, os.path.join(target, subj, "ct.nii.gz"))
    
    # blur and threshold mask
    mask_nii = nib.load(os.path.join(path, subj, "t1.nii.gz"))
    mask = mask_nii.get_fdata()
    mask = gaussian(mask, 5)
    mask = (mask > threshold_otsu(mask)).astype(np.uint8)
    for i in range(mask.shape[2]):
        mask[..., i] = binary_fill_holes(mask[..., i])
    
    for s in ["t1", "t2", "ct"]:
        # crop first and last 5 slices
        img = nib.load(os.path.join(target, subj, f"{s}.nii.gz"))
        data = img.get_fdata() * mask 
        img = nib.Nifti1Image(data, img.affine, img.header)
        nib.save(img, os.path.join(target, subj, f"{s}.nii.gz"))
    
    mask[..., :5] *= 0
    mask[... , -5:] *= 0
    mask_nii = nib.Nifti1Image(mask, mask_nii.affine, mask_nii.header)
    nib.save(mask_nii, os.path.join(target, subj, f"mask.nii.gz")) 

In [None]:
import json
test = ["1_03_P", "1_06_P", "2_04_P", "2_11_P"]

train, val = [], []

for subj in os.listdir(target):
    if subj[0] == "3":
        continue
    d = {
        "t1": os.path.join(subj, "t1.nii.gz"),
        "t2": os.path.join(subj, "t2.nii.gz"),
        "ct": os.path.join(subj, "ct.nii.gz"),
        "mask": os.path.join(subj, "mask.nii.gz"),
        "subject_ID": subj,
    }
    if subj in test:
        val.append(d)
    else:
        train.append(d)
        
files = {"training" : train * 10, "validation" : val * 8}
with open("../data/test_datasets/gold_atlas_train.json", "w") as f:
    json.dump(files, f)