In [None]:
# Dataset

from torch.utils.data import Dataset
import os
import torch
import numpy as np


class RadDataset(Dataset):
    def __init__(self, mode):
        super().__init__()

        self.samples = []

        if mode in ['train', 'validation']:
            files = sorted(os.listdir(f'data/{mode}'))
            patient_ids = list(set([i.split('_')[1] for i in files]))

            for i in patient_ids:
                self.samples.append((os.path.join('data', mode, 'sample_' + i, 'ct.npy'),
                                     os.path.join('data', mode, 'sample_' + i, 'dose.npy'),
                                     os.path.join('data', mode, 'sample_' + i, 'possible_dose_mask.npy'),
                                     os.path.join('data', mode, 'sample_' + i, 'structure_masks.npy')))

        elif mode == 'test':
            files = sorted(os.listdir('data/test_nodose'))
            patient_ids = list(set([i.split('_')[1] for i in files]))

            for i in patient_ids:
                self.samples.append((os.path.join(f'data', 'test_nodose', 'sample_' + i, 'ct.npy'),
                                     '',
                                     os.path.join(f'data', 'test_nodose', 'sample_' + i, 'possible_dose_mask.npy'),
                                     os.path.join(f'data', 'test_nodose', 'sample_' + i, 'structure_masks.npy')))

    def __getitem__(self, item):
        ct, dose, possible_dose_mask, structure_masks = self.samples[item]

        ct = torch.from_numpy(np.load(ct))[None, :, :]
        dose = torch.from_numpy(np.load(dose))[None, :, :]
        possible_dose_mask = torch.from_numpy(np.load(possible_dose_mask))[None, :, :]
        structure_masks = torch.from_numpy(np.load(structure_masks))[None, :, :]

        return {'ct': ct, 'dose': dose, 'possible_dose_mask': possible_dose_mask, 'structure_masks': structure_masks}

    def __len__(self):
        return len(self.samples)

In [None]:
# Model
import torch
from torch import nn
import torch.nn.functional as F


class UNetDown(nn.Module):
    def __init__(self, in_size, out_size):
        super(UNetDown, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_size, out_size, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(out_size),
            nn.ReLU()
          )

    def forward(self, x):
        return self.model(x)


class UNetUp(nn.Module):
    def __init__(self, in_size, out_size):
        super(UNetUp, self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(in_size, out_size, kernel_size=4,
                               stride=2, padding=1),
            nn.InstanceNorm2d(out_size),
            nn.ReLU()
        )

    def forward(self, x, skip_input=None):
        if skip_input is not None:
            x = torch.cat((x, skip_input), 1)  # add the skip connection
        x = self.model(x)
        return x


class FinalLayer(nn.Module):
    def __init__(self, in_size, out_size):
        super(FinalLayer, self).__init__()
        self.model = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_size, out_size, kernel_size=3, padding=1),
            nn.Tanh(),
        )

    def forward(self, x, skip_input=None):
        if skip_input is not None:
            x = torch.cat((x, skip_input), 1)  # add the skip connection
        x = self.model(x)
        return x


class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super(UNet, self).__init__()

        self.down1 = UNetDown(in_channels, 64)
        self.down2 = UNetDown(64, 128)
        self.down3 = UNetDown(128, 256)
        self.down4 = UNetDown(256, 512)
        self.down5 = UNetDown(512, 512)

        self.up1 = UNetUp(512, 512)
        self.up2 = UNetUp(1024, 256)
        self.up3 = UNetUp(512, 128)
        self.up4 = UNetUp(256, 64)

        self.final = FinalLayer(128, out_channels)

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)

        u1 = self.up1(d5)
        u2 = self.up2(u1, d4)
        u3 = self.up3(u2, d3)
        u4 = self.up4(u3, d2)

        return self.final(u4, d1)

In [None]:
# Utils
import torch
import os
import time
import datetime
import sys
from tqdm.notebook import tqdm


def train(model, train_loader, val_loader, num_epoch=10, lr=0.0001):
    """Train a generator on its own.

    Args:
        train_loader: (DataLoader) a DataLoader wrapping the training dataset
        test_loader: (DataLoader) a DataLoader wrapping the test dataset
        num_epoch: (int) number of epochs performed during training
        lr: (float) learning rate of the discriminator and generator Adam optimizers

    Returns:
        generator: (nn.Module) the trained generator
    """

    cuda = True if torch.cuda.is_available() else False
    print(f"Using cuda device: {cuda}")  # check if GPU is used

    # Tensor type (put everything on GPU if possible)
    Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

    # Loss function
    criterion = torch.nn.L1Loss()

    if cuda:
        model = model.cuda()
        criterion.cuda()

    # Optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    # ----------
    #  Training
    # ----------
    prev_time = time.time()

    for epoch in range(num_epoch):
        for i, batch in tqdm(enumerate(train_loader)):

            # Inputs T1-w and T2-w
            ct = batch["ct"].type(Tensor)
            dose = batch["dose"].type(Tensor)

            # Remove stored gradients
            optimizer.zero_grad()

            # Generate output
            y_pred = model(ct)

            # Compute the corresponding loss
            loss = criterion(y_pred, dose)

            # Compute the gradient and perform one optimization step
            loss.backward()
            optimizer.step()

    return

In [None]:
# Training
from torch.utils.data import DataLoader

lr = 0.001
batch_size = 16
num_epoch = 15

train_loader = DataLoader(RadDataset('train'),
                            batch_size=batch_size,
                            shuffle=True)
val_loader = DataLoader(RadDataset('validation'),
                        batch_size=batch_size,
                        shuffle=False)
test_loader = DataLoader(RadDataset('test'),
                            batch_size=batch_size,
                            shuffle=False)

model = UNet()

train(model, train_loader, val_loader, num_epoch=num_epoch, lr=lr)

In [None]:
print(loss)