In [1]:
import matplotlib.pyplot as plt
import pandas as pd
import torch
import torchio as tio
import torchvision.transforms as transforms

from data.Dataset import *
from data.transforms.DataAugmentations import *
from models import Evaluator, models
from models.Predictor import Predictor
from utils.LossFunctions import DC_and_CE_loss, GDiceLossV2
from utils.Utils import *

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# There are multiple data folders belong to same dataset. Each of them processed in different ways.
# Therefore, path of the data and its name explicitly are defined.
dataset_path = "../data/processed/rel3_dhcp_anat_pipeline//"
cv_ = "cv3" # 5-fold cross-validation. Folds [cv1-cv5]

# Transformations.
transform_eval = None #transforms.Compose([RandomMotion(), RandomAffine(degrees=[15])])


# Split dataset.
train = MRIDataset(LateWeeks, "train", dataset_path, transform=transform_eval)
val = MRIDataset(LateWeeks, "val", dataset_path, transform=transform_eval)
test = MRIDataset(LateWeeks, "test", dataset_path, transform=transform_eval)

torch.manual_seed(0)
val_loader = torch.utils.data.DataLoader(dataset=val, batch_size=1)
test_loader = torch.utils.data.DataLoader(dataset=test, batch_size=1)

In [3]:
model = models.UNet3D().to(device)
criterion = DC_and_CE_loss({'batch_dice': True, 'smooth': 1e-5, 'do_bg': False, 'square': False}, {})
model.load_state_dict(torch.load("../models/20221228/dHCP/weights/36_model.pth"))

<All keys matched successfully>

In [6]:
for i in range(len(val)):
    sub = val[i]
    predictor = Predictor(model, (128, 128, 128))
    output = predictor.predict(sub)
    output = output.argmax(dim=1)
    save_nii('Predictions/UNet3D/model_40/', 
             f'{sub.sub_id}_pred', 
             output.squeeze(0).numpy().astype(np.float64), sub.mri.affine)

In [None]:
for i in range(len(test)):
    sub = test[i]
    predictor = Predictor(model, (128, 128, 128))
    output = predictor.predict(sub)
    output = output.argmax(dim=1)
    save_nii('Predictions/UNet3D/model_40/', 
             f'{sub.sub_id}_pred', 
             output.squeeze(0).numpy().astype(np.float64), sub.mri.affine)