In [3]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colors
import logging
import sys
from brats_dataset import BratsDataset
import brats_dataset as bd
import visualisation
import inference
import data_processing
import elastic_transform
import pipeline
import image_pipeline
import logging_config
import pickle
import serialization
import importlib

device = "cuda" if torch.cuda.is_available() else "cpu"

output_dir = "/Users/sw/work/msc_ai_diss/output"

result_path = logging_config.get_result_path(output_dir=output_dir, suffix="_val")
# result_path = "/Users/sw/work/msc_ai_diss/output/2023_09_14__00_22_01_val"

logging_config.configure_logging(result_path)
logging.info(f"Result path: {result_path}")

ser = serialization.Serialization(result_path)


RESULT PATH: /Users/sw/work/msc_ai_diss/output/2023_09_15__12_52_22_val
2023-09-15 14:52:22,336 [INFO] Result path: /Users/sw/work/msc_ai_diss/output/2023_09_15__12_52_22_val


In [4]:
# alpha = 10000
# sigma = 40

# alpha = 1000
# sigma = 20

# alpha = 400
# sigma = 10

alpha = 200
sigma = 6

logging.info(f"alpha = {alpha}, sigma = {sigma}")

BRATS_PATH = "/Users/sw/work/msc_ai_diss/BraTS/validation_10"
brats = BratsDataset(BRATS_PATH)
logging.info(f"Samples found: {len(brats)}")

checkpoint_path = "/Users/sw/msc_ai_diss/output/_vast/2023_08_29__21_48_13/checkpoint_epoch1000.pth"
inf = inference.Inference(checkpoint_path, device=device)

transformation_count = 10
demo_sample_id = 1552

image_pipelines = []
for sample_idx in range(len(brats)):
    mri_4d, segmentation_4d = brats[sample_idx]
    sample_id = brats.get_sample_id(sample_idx)
    logging.info(f"Processing sample {sample_idx + 1} of {len(brats)} (id: {sample_id})")
    keep_pipelines = sample_id == demo_sample_id
    ip = image_pipeline.ImagePipeline(mri_4d, segmentation_4d, inf,
                                      transformation_count=transformation_count, device=device,
                                      alpha=alpha, sigma=sigma, restore_order=3,
                                      keep_pipelines=keep_pipelines)
    image_pipelines.append(ip)

    logging.info(f"Dumping ImagePipeline")
    ser.dump_to_file(ip, f"ip_{sample_idx}.pkl")

logging.info(f"Dumping all pipelines")
ser.dump_to_file(image_pipelines, "all_ips.pkl")


2023-09-15 14:52:24,494 [INFO] alpha = 200, sigma = 6
2023-09-15 14:52:24,496 [INFO] Samples found: 10
2023-09-15 14:52:24,497 [INFO] Loading model from [/Users/sw/msc_ai_diss/output/_vast/2023_08_29__21_48_13/checkpoint_epoch1000.pth]
2023-09-15 14:52:25,017 [INFO] Loaded image: /Users/sw/work/msc_ai_diss/BraTS/validation_10/BraTS2021_00547/BraTS2021_00547_flair.nii.gz (took 0.0s)
2023-09-15 14:52:25,049 [INFO] Loaded image: /Users/sw/work/msc_ai_diss/BraTS/validation_10/BraTS2021_00547/BraTS2021_00547_t1.nii.gz (took 0.0s)
2023-09-15 14:52:25,081 [INFO] Loaded image: /Users/sw/work/msc_ai_diss/BraTS/validation_10/BraTS2021_00547/BraTS2021_00547_t1ce.nii.gz (took 0.0s)
2023-09-15 14:52:25,111 [INFO] Loaded image: /Users/sw/work/msc_ai_diss/BraTS/validation_10/BraTS2021_00547/BraTS2021_00547_t2.nii.gz (took 0.0s)
2023-09-15 14:52:25,133 [INFO] MRI image shape: (4, 240, 240, 144)
2023-09-15 14:52:25,144 [INFO] Loaded image: /Users/sw/work/msc_ai_diss/BraTS/validation_10/BraTS2021_00547/

In [4]:
all_ips = ser.read_from_file("all_ips.pkl")

2023-09-12 02:08:48,380 [INFO] Reading /Users/sw/work/msc_ai_diss/output/2023_09_12__00_02_55_val/all_ips.pkl
