In [1]:
import os
from pathlib import Path

from monai_training import preprocess, training
from mri_data import file_manager as fm

In [2]:
project_home = Path("/home/srs-9/Projects/ms_mri")
curr_dir = project_home / "analysis" / "choroid_pineal_pituitary_crosstrain"
drive_root = fm.get_drive_root()
dataroot = drive_root / "3Tpioneer_bids"

In [3]:
import copy

dataset1, _ = preprocess.load_dataset(curr_dir / "dataset1.json")
dataset1_tests = [scan for scan in dataset1 if scan.cond == "ts"]
for i, scan in enumerate(dataset1_tests):
    scan2 = copy.copy(scan)
    scan2.cond = "tr"
    dataset1_tests[i] = scan2
dataset1_trains = [scan for scan in dataset1 if scan.cond == "tr"]

dataset2_part = fm.DataSet.from_scans(dataset1_trains)
dataset2_part = training.assign_conditions(dataset2_part, n_ts=10)
dataset2 = fm.DataSet.from_scans([scan for scan in dataset2_part] + dataset1_tests)
dataset2_tests = [scan for scan in dataset2 if scan.cond == "ts"]
for i, scan in enumerate(dataset2_tests):
    scan2 = copy.copy(scan)
    scan2.cond = "tr"
    dataset2_tests[i] = scan2
dataset2_part_trains = [scan for scan in dataset2_part if scan.cond == "tr"]

dataset3_part = fm.DataSet.from_scans(dataset2_part_trains)
dataset3_part = training.assign_conditions(dataset3_part, n_ts=10)
dataset3 = fm.DataSet.from_scans([scan for scan in dataset3_part] + dataset1_tests + dataset2_tests)
dataset3_tests = [scan for scan in dataset3 if scan.cond == "ts"]
for i, scan in enumerate(dataset3_tests):
    scan2 = copy.copy(scan)
    scan2.cond = "tr"
    dataset3_tests[i] = scan2
dataset3_part_trains = [scan for scan in dataset3_part if scan.cond == "tr"]

dataset4_tests = []
for scan in dataset3_part_trains:
    scan2 = copy.copy(scan)
    scan2.cond = "ts"
    dataset4_tests.append(scan2)

dataset4 = fm.DataSet.from_scans(dataset4_tests + dataset3_tests + dataset2_tests + dataset1_tests)

In [6]:
dataset1_tests = [scan for scan in dataset1 if scan.cond == "ts"]
dataset2_tests = [scan for scan in dataset2 if scan.cond == "ts"]
dataset3_tests = [scan for scan in dataset3 if scan.cond == "ts"]
dataset4_tests = [scan for scan in dataset4 if scan.cond == "ts"]

all_tests = set(dataset1_tests + dataset2_tests + dataset3_tests + dataset4_tests)
assert len(all_tests) == 40

In [7]:
dataset1.sort(key=lambda s: s.cond)
dataset2.sort(key=lambda s: s.cond)
dataset3.sort(key=lambda s: s.cond)
dataset4.sort(key=lambda s: s.cond)

In [8]:
preprocess.save_dataset(dataset1, curr_dir / "dataset1.json")
preprocess.save_dataset(dataset2, curr_dir / "dataset2.json")
preprocess.save_dataset(dataset3, curr_dir / "dataset3.json")
preprocess.save_dataset(dataset4, curr_dir / "dataset4.json")