In [1]:
### METHODS TO READ DATA IN ###
import sys
from provided_code.data_loader import DataLoader #provided by the challenge creators
import torch
import torch.utils.data as data
import os
import numpy as np
import random

# Pre-processing method: correctly aligns the images
# and concatenates them into one list
def pre_processing(dict_images):
    
    # Data in OpenKBP dataset is (h, w, -z, c) or (y, x, -z, c)
    # Change to (c, z, x, y)
    OAR_all = dict_images['structure_masks'][:,:,::-1,:].transpose(3,2,1,0)
    CT = dict_images['ct'][:,:,::-1,:].transpose(3,2,1,0)
    dose = dict_images['dose'][:,:,::-1,:].transpose(3,2,1,0)
    possible_dose_mask = dict_images['possible_dose_mask'][:,:,::-1,:].transpose(3,2,1,0)

    list_images = [np.concatenate((OAR_all, CT), axis=0),  # Input
                   dose,  # Label
                   possible_dose_mask]
    return list_images

# Custom dataset: fetches correct patiets per phase
# and pre-processes with above method
class MyDataset(data.Dataset):
    def __init__(self, num_samples_per_epoch, phase):
        # 'train' or 'val'
        self.phase = phase
        self.num_samples_per_epoch = num_samples_per_epoch

        self.list_case_id = {'train': ['provided-data/train-pats/pt_' + str(i) for i in range(1, 201)],
                             'val': ['provided-data/validation-pats/pt_' + str(i) for i in range(201, 241)],
                             'test': ['provided-data/test-pats/pt_' + str(i) for i in range(241, 341)]}[phase]
        random.shuffle(self.list_case_id)
        self.dl = DataLoader(self.list_case_id)
        self.sum_case = len(self.list_case_id)

    def __getitem__(self, index_):
        dict_images = self.dl.load_and_shape_data(self.list_case_id[index_])
        list_images = pre_processing(dict_images)
        for i in range(len(list_images)):
            list_images[i] = torch.from_numpy(list_images[i].copy()).float()
        return list_images

    def __len__(self):
        return self.num_samples_per_epoch

# Method to get loader: builds dataloaders using custom dataset
# num_samples_per_epoch indiciates how many samples are looked at
# per epoch. 
def get_loader(train_bs=1, val_bs=1, train_num_samples_per_epoch=4, val_num_samples_per_epoch=4, num_works=0):
    train_dataset = MyDataset(num_samples_per_epoch=train_num_samples_per_epoch, phase='train')
    val_dataset = MyDataset(num_samples_per_epoch=val_num_samples_per_epoch, phase='val')

    train_loader = data.DataLoader(dataset=train_dataset, batch_size=train_bs, shuffle=True, num_workers=num_works,
                                   pin_memory=False)
    val_loader = data.DataLoader(dataset=val_dataset, batch_size=val_bs, shuffle=False, num_workers=num_works,
                                 pin_memory=False)

    return train_loader, val_loader


In [2]:
### MODEL ###
import torch
import torch.nn as nn
import torch.nn.functional as F

# Basic 3D convolution block: 3x3x3 convolution
# with stride=1, padding=1. Followed by instance
# noramlization and ReLU
class SingleConv(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size, stride, padding):
        super(SingleConv, self).__init__()

        self.single_conv = nn.Sequential(
            nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding, stride=stride, bias=True),
            nn.InstanceNorm3d(out_ch, affine=True),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.single_conv(x)

# Upsampling block: first applies trilinear upsampling,
# then 3x3x3 convolution with stride=1, padding=1. Followed
# by instance noramlization and ReLU
class UpConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(UpConv, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv3d(in_ch, out_ch, kernel_size=3, padding=1, stride=1, bias=True),
            nn.InstanceNorm3d(out_ch, affine=True),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = F.interpolate(x, scale_factor=2, mode='trilinear', align_corners=True)
        x = self.conv(x)
        return x

# Set of downsampling blocks: Uses basic
# convolution block sequence with stride=2
class Encoder(nn.Module):
    def __init__(self, in_ch, list_ch):
        super(Encoder, self).__init__()
        self.encoder_1 = nn.Sequential(
            SingleConv(in_ch, list_ch[1], kernel_size=3, stride=1, padding=1),
            SingleConv(list_ch[1], list_ch[1], kernel_size=3, stride=1, padding=1)
        )
        self.encoder_2 = nn.Sequential(
            SingleConv(list_ch[1], list_ch[2], kernel_size=3, stride=2, padding=1),
            SingleConv(list_ch[2], list_ch[2], kernel_size=3, stride=1, padding=1)
        )
        self.encoder_3 = nn.Sequential(
            SingleConv(list_ch[2], list_ch[3], kernel_size=3, stride=2, padding=1),
            SingleConv(list_ch[3], list_ch[3], kernel_size=3, stride=1, padding=1)
        )
        self.encoder_4 = nn.Sequential(
            SingleConv(list_ch[3], list_ch[4], kernel_size=3, stride=2, padding=1),
            SingleConv(list_ch[4], list_ch[4], kernel_size=3, stride=1, padding=1)
        )
        self.encoder_5 = nn.Sequential(
            SingleConv(list_ch[4], list_ch[5], kernel_size=3, stride=2, padding=1),
            SingleConv(list_ch[5], list_ch[5], kernel_size=3, stride=1, padding=1)
        )

    def forward(self, x):
        out_encoder_1 = self.encoder_1(x)
        out_encoder_2 = self.encoder_2(out_encoder_1)
        out_encoder_3 = self.encoder_3(out_encoder_2)
        out_encoder_4 = self.encoder_4(out_encoder_3)
        out_encoder_5 = self.encoder_5(out_encoder_4)

        return [out_encoder_1, out_encoder_2, out_encoder_3, out_encoder_4, out_encoder_5]

# Set of upsampling blocks: uses upsampling
# blocks and basic blocks as described above
class Decoder(nn.Module):
    def __init__(self, list_ch):
        super(Decoder, self).__init__()

        self.upconv_4 = UpConv(list_ch[5], list_ch[4])
        self.decoder_conv_4 = nn.Sequential(
            SingleConv(2 * list_ch[4], list_ch[4], kernel_size=3, stride=1, padding=1),
            SingleConv(list_ch[4], list_ch[4], kernel_size=3, stride=1, padding=1)
        )
        self.upconv_3 = UpConv(list_ch[4], list_ch[3])
        self.decoder_conv_3 = nn.Sequential(
            SingleConv(2 * list_ch[3], list_ch[3], kernel_size=3, stride=1, padding=1),
            SingleConv(list_ch[3], list_ch[3], kernel_size=3, stride=1, padding=1)
        )
        self.upconv_2 = UpConv(list_ch[3], list_ch[2])
        self.decoder_conv_2 = nn.Sequential(
            SingleConv(2 * list_ch[2], list_ch[2], kernel_size=3, stride=1, padding=1),
            SingleConv(list_ch[2], list_ch[2], kernel_size=3, stride=1, padding=1)
        )
        self.upconv_1 = UpConv(list_ch[2], list_ch[1])
        self.decoder_conv_1 = nn.Sequential(
            SingleConv(2 * list_ch[1], list_ch[1], kernel_size=3, stride=1, padding=1)
        )

    def forward(self, out_encoder):
        out_encoder_1, out_encoder_2, out_encoder_3, out_encoder_4, out_encoder_5 = out_encoder

        out_decoder_4 = self.decoder_conv_4(
            torch.cat((self.upconv_4(out_encoder_5), out_encoder_4), dim=1)
        )
        out_decoder_3 = self.decoder_conv_3(
            torch.cat((self.upconv_3(out_decoder_4), out_encoder_3), dim=1)
        )
        out_decoder_2 = self.decoder_conv_2(
            torch.cat((self.upconv_2(out_decoder_3), out_encoder_2), dim=1)
        )
        out_decoder_1 = self.decoder_conv_1(
            torch.cat((self.upconv_1(out_decoder_2), out_encoder_1), dim=1)
        )

        return out_decoder_1

# Define the overall structure of U-Net: uses
# set of downsampling and upsampling blocks as
# described above
class BaseUNet(nn.Module):
    def __init__(self, in_ch, list_ch):
        super(BaseUNet, self).__init__()
        self.encoder = Encoder(in_ch, list_ch)
        self.decoder = Decoder(list_ch)

        # init
        self.initialize()

    @staticmethod
    def init_conv_IN(modules):
        for m in modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.)
            elif isinstance(m, nn.InstanceNorm3d):
                nn.init.constant_(m.weight, 1.)
                nn.init.constant_(m.bias, 0.)

    def initialize(self):
        self.init_conv_IN(self.encoder.modules)
        self.init_conv_IN(self.decoder.modules)

    def forward(self, x):
        out_encoder = self.encoder(x)
        out_decoder = self.decoder(out_encoder)

        return out_decoder

# Define Model: U-Net, with output layer: a 1x1x1 convolution
# in_ch, out_ch are the input/output channels (11 and 1)
# list_ch are the number of feature maps of blocks in U-Net
class Model(nn.Module):
    def __init__(self, in_ch, out_ch, list_ch):
        super(Model, self).__init__()

        self.net = BaseUNet(in_ch, list_ch)
        self.conv_out = nn.Conv3d(list_ch[1], out_ch, kernel_size=1, padding=0, bias=True)

    def forward(self, x):
        out_net = self.net(x)
        output = self.conv_out(out_net)
        return output


In [3]:
### LOSS FUNCTION ###
import torch.nn as nn

# Custom L1 loss function to zoom in on the voxels
# of possible dose mask only.
class Loss(nn.Module):
    def __init__(self):
        super().__init__()
        self.L1_loss_func = nn.L1Loss(reduction='mean')

    def forward(self, pred, gt):
        gt_dose = gt[0]
        possible_dose_mask = gt[1]

        pred = pred[possible_dose_mask > 0]
        gt_dose = gt_dose[possible_dose_mask > 0]

        L1_loss = self.L1_loss_func(pred, gt_dose)
        return L1_loss

In [4]:
### TRAINING LOOP ###
import time

# Function to train the model: Takes in
# a model, dataloader, optimizer, scheduler, and loss function.
# num_epochs controls max number of epochs the loop is run for
# verbose for logging
def train_model(model, dataloader, optimizer, scheduler, loss_fn, num_epochs = 50, verbose = False):
    loss_dict = {'train':[],'val':[]}
    best_loss = 500
    phases = ['train','val']
    since = time.time()
    for i in range(num_epochs):
        if verbose or (i%10 == 0):
            print('Epoch: {}/{}'.format(i+1, num_epochs))
            print('-'*10)
        for p in phases:
            running_loss = 0
            running_total = 0
            if p == 'train':
                model.train()
            else:
                model.eval()
                
            for data in dataloader[p]:
                optimizer.zero_grad()
                image = data[0].to(device)
                dose = data[1].to(device)
                pdm = data[2].to(device)
                label = [dose,pdm]
                output = model(image)
                loss = loss_fn(output, label)
                num_imgs = image.size()[0]
                running_loss += loss.item()*num_imgs
                running_total += num_imgs
                if p== 'train':
                    loss.backward()
                    optimizer.step()
            epoch_loss = float(running_loss/running_total)
            if verbose or (i%10 == 0):
                print('Phase:{}, epoch loss: {:.4f}'.format(p, epoch_loss))
            
            loss_dict[p].append(epoch_loss)
            if p == 'val':
                if epoch_loss < best_loss:
                    best_loss = epoch_loss
                    best_model_wts = model.state_dict()
            else:
                if scheduler:
                    scheduler.step()
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val loss: {:4f}'.format(best_loss))
    
    model.load_state_dict(best_model_wts)
    
    return model, loss_dict

In [5]:
### Define model, loss, optimizer, scheduler, dataloader for training loop ###
model = Model(in_ch=11, out_ch=1, list_ch=[-1, 16, 32, 64, 128, 256])
device = torch.device('cuda')
model = nn.DataParallel(model)
model.to(device)

loss = Loss()
optimizer = torch.optim.Adam(model.parameters(),lr = 3e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=80000, eta_min=1e-7, last_epoch=-1)
bs = 1
dataloader = {}
dataloader['train'] , dataloader['val'] = get_loader(
        train_bs=bs,
        val_bs=bs,
        train_num_samples_per_epoch=200,  
        val_num_samples_per_epoch=40,
        num_works=0
    )

In [6]:
### Train the model, 50 epochs ###
model, loss_dict = train_model(model, dataloader, optimizer, scheduler, loss, num_epochs=50, verbose = True)

Epoch: 1/50
----------
Phase:train, epoch loss: 21.9143
Phase:val, epoch loss: 21.3063
Epoch: 2/50
----------
Phase:train, epoch loss: 20.9893
Phase:val, epoch loss: 20.7821
Epoch: 3/50
----------
Phase:train, epoch loss: 20.5120
Phase:val, epoch loss: 20.3260
Epoch: 4/50
----------
Phase:train, epoch loss: 19.7598
Phase:val, epoch loss: 19.7626
Epoch: 5/50
----------
Phase:train, epoch loss: 19.3305
Phase:val, epoch loss: 19.3231
Epoch: 6/50
----------
Phase:train, epoch loss: 18.8603
Phase:val, epoch loss: 18.8134
Epoch: 7/50
----------
Phase:train, epoch loss: 18.4156
Phase:val, epoch loss: 18.3723
Epoch: 8/50
----------
Phase:train, epoch loss: 17.7459
Phase:val, epoch loss: 17.8372
Epoch: 9/50
----------
Phase:train, epoch loss: 17.2853
Phase:val, epoch loss: 17.2531
Epoch: 10/50
----------
Phase:train, epoch loss: 16.8403
Phase:val, epoch loss: 16.8154
Epoch: 11/50
----------
Phase:train, epoch loss: 16.1692
Phase:val, epoch loss: 16.3642
Epoch: 12/50
----------
Phase:train, epoc

In [7]:
### Save model locally, name reflects configuration tested ###
torch.save(model,'lr3e-4_model.pkl')

In [8]:
### Also save the losses for ease of access ###
import pickle
f = open('lr3e-4_loss.pkl','wb')
pickle.dump(loss_dict,f)
f.close()

In [9]:
### Get the dose score on test set ###
import torch.utils.data as data
# Define test dataset and dataloader
test_dataset = MyDataset(num_samples_per_epoch=100, phase='test')
test_loader = data.DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=0,pin_memory=False)
# Empty lists to save images for visualization later: requires a lot of memory
mads = []
cts= []
doses = []
preds = []
for data in test_loader:
    with torch.no_grad():
        image = data[0]
        cts.append(image.numpy()[0,-1,:,:,:])
        dose = data[1].numpy()        
        pdm = data[2].numpy()
        doses.append(np.where(pdm>0,dose,0))
        preds.append(np.where(pdm>0,dose,0))
        pred = model(image).cpu().numpy()        
        pred = pred[pdm > 0]
        dose = dose[pdm > 0]
        
        mads.append(np.mean(np.abs(pred-dose)))
        
dose_score = np.mean(mads)

In [10]:
### Print out test dose score ###
print('Dose score is: ' + str(dose_score))

Dose score is: 3.729331


In [11]:
### Print out lowest/highest MAE in test set ###
import numpy as np
print(f'The max MAE in Test Set is: {np.amax(mads)} at index {np.argmax(mads)}.')
print(f'The min MAE in Test Set is: {np.amin(mads)} at index {np.argmin(mads)}.')

The max MAE in Test Set is: 9.856223106384277 at index 82.
The min MAE in Test Set is: 1.981579065322876 at index 14.


In [12]:
### Get visuals for Figure 5 of paper ###
### Draws CT image and dose ground truth or prediction ###
### For highest/lowest MAE index ###
from mayavi import mlab

dose = doses[14][0,0,:,:,:].transpose(2,0,1)
ct = cts[14].transpose(2,0,1)
mlab.pipeline.volume(mlab.pipeline.scalar_field(ct),color=(0,0,1))
mlab.pipeline.volume(mlab.pipeline.scalar_field(dose))


mlab.show()