In [7]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn
import torch
from kerosene.utils.tensors import to_onehot
from torch.utils.data import DataLoader

from deepNormalize.inputs.datasets import SingleImageDataset
from deepNormalize.models.dcgan3d import DCGAN
from deepNormalize.models.unet3d import Unet
from deepNormalize.utils.image_slicer import ImageReconstructor


In [4]:
PATCH = 0
SLICE = 1

PATCH_SIZE = (1, 32, 32, 32)
STEP = (1, 32, 32, 32)

BATCH_SIZE = 10

DATA_SET_ROOT_PATH_MR_BRAINS = "/mnt/md0/Data/MRBrainS_scaled/DataNii/TrainingData/"
DATA_SET_ROOT_PATH_ISEG = "/mnt/md0/Data/iSEG_scaled/Training/"

MR_BRAIN_SUBJECT = str(2)
ISEG_SUBJECT = str(9)

In [5]:
mrbrains_dataset = SingleImageDataset(root_path=DATA_SET_ROOT_PATH_MR_BRAINS, subject=MR_BRAIN_SUBJECT,
                                      patch_size=PATCH_SIZE, step=STEP)
iseg_dataset = SingleImageDataset(DATA_SET_ROOT_PATH_ISEG, subject=ISEG_SUBJECT, patch_size=PATCH_SIZE, step=STEP)

mrbrain_data_loader = DataLoader(mrbrains_dataset, batch_size=BATCH_SIZE, num_workers=0, drop_last=False,
                                 shuffle=False,
                                 pin_memory=False)
iseg_data_loader = DataLoader(iseg_dataset, batch_size=BATCH_SIZE, num_workers=0, drop_last=False, shuffle=False,
                              pin_memory=False)

In [6]:
generator = Unet(1, 1, True, True)
discriminator = DCGAN(1, 3)

checkpoint = torch.load("/mnt/md0/models/Generator/data_augmentation/Generator.tar")
generator.load_state_dict(checkpoint["model_state_dict"])
generator.cuda()
checkpoint = torch.load("/mnt/md0/models/Discriminatordata_agumentation/Discriminator.tar")
discriminator.load_state_dict(checkpoint["model_state_dict"])
discriminator.cuda()

reconstructor = ImageReconstructor()

DCGAN(
  (_layer_1): Sequential(
    (0): ReplicationPad3d((1, 1, 1, 1, 1, 1))
    (1): Conv3d(1, 16, kernel_size=(3, 3, 3), stride=(2, 2, 2))
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
    (3): Dropout3d(p=0.25, inplace=False)
  )
  (_layer_2): Sequential(
    (0): ReplicationPad3d((1, 1, 1, 1, 1, 1))
    (1): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(2, 2, 2))
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
    (3): Dropout3d(p=0.25, inplace=False)
    (4): BatchNorm3d(32, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
  )
  (_layer_3): Sequential(
    (0): ReplicationPad3d((1, 1, 1, 1, 1, 1))
    (1): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(2, 2, 2))
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
    (3): Dropout3d(p=0.25, inplace=False)
    (4): BatchNorm3d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
  )
  (_layer_4): Sequential(
    (0): ReplicationPad3d((1, 1, 1, 1, 1, 1))
    (1): Conv3d(64, 128, kernel_size=(3,