In [None]:
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import StepLR
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset

import os
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

from datasets.unet_dataset import UNetDataset
from models.unext import UNext

from scripts.train import train
from scripts.validate import validate

from utils.image import torch_to_numpy

First, let's load the datasets :

In [None]:
# data paths for train and validation
val_dir = "challenge_data/validation/validation"
train_dir = "challenge_data/train/train"

# data augmentation
val_transform = transforms.Compose([transforms.ToTensor()])
train_transform = transforms.Compose([transforms.ToTensor(),
                        transforms.RandomAffine(degrees=(0, 0), translate=(0., 0.), scale=(0.8, 1.2))])

# train and validation datasets
val_dataset = UNetDataset(val_dir, transforms=val_transform)
train_dataset = UNetDataset(train_dir, transforms=train_transform)

Some plots of the dataset : 

- Inputs : 

In [None]:
# pick an element of the dataset
elem = next(iter(train_dataset))
    
# show CT scans, possible dose mask and 10 organ masks
for el in elem['input'][0]:
    plt.figure()
    plt.imshow(torch_to_numpy(el.unsqueeze(0).unsqueeze(0).detach().cpu()[0]))
    plt.show()

Labels : 

In [None]:
# pick label corresponding to the inputs above
label=elem["label"]
plt.figure()
plt.imshow(torch_to_numpy(label.detach().cpu()))
plt.show()

Let's build our model : 

In [None]:
# the model takes an input with 12 channels : CT scan, possible dose mask, 10 organ masks
model = UNext(n_channels=12)
print(model)

Train the model : 

In [None]:
train(model.cuda(), train_dataset, val_dataset, batch_size=32)

Load our best model : 

In [None]:
model = UNext(n_channels=12).cuda().eval()
#path leading to the model
# PATH = ""
# model.load_state_dict(torch.load(PATH))

Validate the model : 

In [None]:
# validate(model.cuda(), val_dataset)

Sample example : 

In [None]:
# val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=True)
# # pick one example of the validation set
# val_sample = next(iter(val_loader))
# val_elem = [val_sample["input"][0].cuda(), val_sample["input"][1].cuda()]
# m = nn.ZeroPad2d((24,24, 9, 9))
# out = m(model(val_elem))
# out = torch.multiply(out, val_sample["possible_dose_mask"].cuda().unsqueeze(1))

Left : Prediction

Right : Groundtruth

In [None]:
# plt.figure()
# plt.subplot(1,2,1)
# plt.imshow(torch_to_numpy(out[0].cpu()))
# plt.subplot(1,2,2)
# plt.imshow(torch_to_numpy(m(val_sample["label"])[0].cpu()))
# plt.legend()
# plt.show()