In [4]:
import os
import numpy as np
import nibabel as nib
from tqdm import tqdm
from torch.utils.data import Dataset
import random
import torch
import torchio as tio
import matplotlib.pyplot as plt
from dataset_tio import Normalize, ImagePair, calculate_overlap, data_split
import torchvision
from models.generator import GeneratorRRDB
from models.generator_small import GeneratorRRDBSmall
from models.generator_ds import GeneratorRRDBDeepSupervision

from models.discriminator import Discriminator
from models.feature_extractor import FeatureExtractor

from trainer_org import LitTrainer as LitTrainer_org
from trainer_gan import LitTrainer as LitTrainer_gan
import pytorch_lightning as pl

from torchvision.utils import save_image
from torchsummary import summary
import time
from utils import save_subject
from torch.utils.tensorboard import SummaryWriter


print(os.getcwd())

/mnt/beta/djboonstoppel/Code


In [2]:
%load_ext autoreload
%autoreload 2

In [13]:
generator = GeneratorRRDB(channels=1, filters=64, num_res_blocks=1)
generator_small = GeneratorRRDBSmall(channels=1, filters=64, num_res_blocks=1)
generator_ds = GeneratorRRDBDeepSupervision(channels=1, filters=64, num_res_blocks=1)


In [6]:
root_dir = '/mnt/beta/djboonstoppel/Code'
patients_frac = .5
std = 0.3548
patch_size = 64
patch_overlap = .5
batch_size = 16

data_path = os.path.join(root_dir, 'data')
train_subjects = data_split('training', patients_frac=patients_frac, root_dir=data_path)

training_transform = tio.Compose([
    Normalize(std=std),
    # tio.RandomNoise(p=0.5),
    tio.RandomFlip(axes=(0, 1)),
])

training_set = tio.SubjectsDataset(
    train_subjects, transform=training_transform)

overlap, nr_patches = calculate_overlap(train_subjects[0]['LR'],
                                        (patch_size, patch_size),
                                        (patch_overlap, patch_overlap)
                                        )

sampler = tio.data.GridSampler(patch_size=(patch_size, patch_size, 1),
                                    patch_overlap=overlap,
                                    # padding_mode=0,
                                    )

training_queue = tio.Queue(
    subjects_dataset=training_set,
    max_length=nr_patches * 10,
    samples_per_volume=nr_patches,
    sampler=sampler,
    num_workers=4,
    shuffle_subjects=True,
    shuffle_patches=True,
)
training_loader = torch.utils.data.DataLoader(
    training_queue,
    batch_size=batch_size,
    num_workers=0,
)

batch = next(iter(training_loader))

Loading training set...


In [17]:
batch_LR = batch['LR'][tio.DATA].squeeze(4)
writer = SummaryWriter('log/architectures/generator')
writer.add_graph(generator.cuda(), batch_LR.cuda())

In [19]:
summary(generator.cpu(), input_size=(1, 64, 64), batch_size=16, device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [16, 64, 64, 64]             640
            Conv2d-2           [16, 64, 64, 64]          36,928
         LeakyReLU-3           [16, 64, 64, 64]               0
            Conv2d-4           [16, 64, 64, 64]          73,792
         LeakyReLU-5           [16, 64, 64, 64]               0
            Conv2d-6           [16, 64, 64, 64]         110,656
         LeakyReLU-7           [16, 64, 64, 64]               0
            Conv2d-8           [16, 64, 64, 64]         147,520
         LeakyReLU-9           [16, 64, 64, 64]               0
           Conv2d-10           [16, 64, 64, 64]         184,384
DenseResidualBlock-11           [16, 64, 64, 64]               0
           Conv2d-12           [16, 64, 64, 64]          36,928
        LeakyReLU-13           [16, 64, 64, 64]               0
           Conv2d-14           [16, 64