# Choroid Plexus Segmentation Training via Auto3DSeg

In [1]:
import os
import json
import nibabel as nib
import nibabel as nibabel
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import glob
import random
from monai.apps.auto3dseg import AutoRunner
from monai.config import print_config
from mri_preproc.paths import hemond_data, init_paths
import importlib
from dataclasses import asdict

print_config()

ModuleNotFoundError: No module named 'mri_preproc'

In [None]:
init_paths.main()
from mri_data.file_manager import DataSet, scan_3Tpioneer_bids
from monai_training import training, preprocess
importlib.reload(hemond_data)

## Prep the database

Get the data and labels organized

In [32]:
#! delete this
# dataset = hemond_data.get_pituitary_3Tpioneer_bids("/mnt/h/3Tpioneer_bids", suppress_output=True)
# omit_subs = ['ms1196']
# dataset = [data for data in dataset if data.label is not None and data.subid not in omit_subs]

In [None]:
work_dir = "/home/srs-9/Projects/ms_mri/training_work_dirs/pituitary1"
# work_dir = "/home/hemondlab/Dev/ms_mri/training_work_dirs/cp_work_dir6"
if not os.path.isdir(work_dir):
    os.makedirs(work_dir)

dataroot_dir = "/mnt/h"
if not os.path.isdir(dataroot_dir):
    os.makedirs(dataroot_dir)

In [None]:
dataroot = Path("/mnt/h/3Tpioneer_bids")
dataset = scan_3Tpioneer_bids(dataroot, image="t1.nii.gz", label="pituitary.nii.gz")
omit_subs = [('1196', '20161004')]
bad_scan = dataset.find_scan(subid=omit_subs[0][0], sesid=omit_subs[0][1])[0]
dataset.remove_scan(bad_scan)

In [None]:
# can't remember what fraction I had used
fraction_ts = 0.2
dataset = training.assign_conditions(dataset, fraction_ts)
dataset[1]
preprocess.save_dataset(dataset, work_dir / "training-dataset.json")

In [34]:
training_data = []
test_data = []

for scan in dataset:
    if scan.cond == 'tr':
        training_data.append({"image": scan.image, "label": scan.label})
    elif scan.cond == 'ts':
        test_data.append(scan.image)

## Review

In [35]:
def display_slices(scan):
    img1 = nibabel.load(scan.image)
    img2 = nibabel.load(scan.label)

    data1 = img1.get_fdata()[:,:,:]
    data2 = img2.get_fdata()

    slice_sums = np.sum(data2, axis=(0, 1))
    
    print(slice_sums)

    max_slice_index = np.argmax(slice_sums)
    print(f"Max slice: {max_slice_index}")

    slice1 = data1[:, :, max_slice_index]
    slice2 = data2[:, :, max_slice_index]    

    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(slice1, cmap='gray')
    plt.title(f"Image 1 - Slice {max_slice_index}")
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(slice2, cmap='gray')
    plt.title(f"Image 2 - Slice {max_slice_index}")
    plt.axis('off')
    plt.show()


In [36]:
img2 = nibabel.load(dataset[0].label)
data2 = img2.get_fdata()

In [None]:
display_slices(dataset[20])

In the original code, they include labels in the test data as well. Also they have a function that checks that there is nonzero number of voxels in the label

In [None]:
#? I don't know why they put labels for the test data. the brats segmentation code didn't.
# train_data = [{'image': path + '/flair.nii.gz', 'label': path + '/flair_chp_mask_qced.nii.gz'} for path in train_exams]
# test_data = [{'image': path + '/flair.nii.gz', 'label': path + '/flair_chp_mask_qced.nii.gz'} for path in test_exams]

train_data = []
test_data = []
for scan in dataset:
    if scan.cond == 'tr' and scan.has_label:
        train_data.append({"image": str(scan.image), "label": str(scan.label)})
    elif scan.cond == 'ts' and scan.has_label():
        test_data.append({"image": str(scan.image), "label": str(scan.label)})


print(f"Train num total: {len(train_data)}")
print(f"Test num: {len(test_data)}")

In [44]:
n_folds = 5
datalist = {
    "testing": test_data,
    "training": [{"fold": i % n_folds, "image": c["image"], "label": c["label"]} for i,c in enumerate(train_data)]
}

In [45]:
# sub_datalist = dict({'training':[], 'testing':[]})
# sub_datalist["training"] = datalist["training"][:100]
# sub_datalist["testing"] = datalist["testing"][:29]
# datalist = sub_datalist

In [None]:
len(datalist['training'])

In [48]:
datalist_file = os.path.join(work_dir, "datalist.json")
with open(datalist_file, "w") as f:
    json.dump(datalist, f)

In [None]:
runner = AutoRunner(
    work_dir=work_dir,
    algos=["swinunetr"],
    input={
        "modality": "MRI",
        "datalist": datalist_file,
        "dataroot": dataroot_dir,
    },
)

In [None]:
max_epochs = 100

train_param = {
    "num_epochs_per_validation": 1,
    #"num_images_per_batch": 2,
    "num_epochs": max_epochs,
    "num_warmup_epochs": 1,
}
runner.set_training_params(train_param)

In [None]:


runner.run()

In [None]:
datalist['training'][48]

In [29]:
scan_path = "/mnt/t/Data/3Tpioneer_bids/sub-ms1001/ses-20170215/proc/lesion_index.t3m20-mni_reg.nii.gz"
img = nib.load(scan_path)
