In [None]:
import os
import time
import copy
import torch
import imageio
import numpy as np
import pandas as pd
import nibabel as nib
from glob import glob
from sklearn.metrics import *
from torch import FloatTensor
from tqdm.notebook import tqdm

In [None]:
class BraTS(torch.utils.data.Dataset):

    def __init__(self, path: str) -> None:
        super().__init__()
        self.path = path
        self.dirs = os.listdir(path)


    def __len__(self) -> int:
        return len(self.dirs)


    def __getitem__(self, idx: int) -> tuple[FloatTensor, FloatTensor]:
        # determine the MRI paths
        mri_path = os.path.join(self.path, self.dirs[idx])
        t1_path  = glob(mri_path + '/*t1.*')[0]
        t1c_path = glob(mri_path + '/*t1c*')[0]
        t2_path  = glob(mri_path + '/*t2*')[0]
        t2f_path = glob(mri_path + '/*fl*')[0]
        seg_path = glob(mri_path + '/*seg*')[0]

        # read the MRI
        t1_img  = nib.load(t1_path ).get_fdata()
        t1c_img = nib.load(t1c_path).get_fdata()
        t2_img  = nib.load(t2_path ).get_fdata()
        t2f_img = nib.load(t2f_path).get_fdata()
        seg_img = nib.load(seg_path).get_fdata()

        # compose a 4 channel image
        mri_4x = np.stack((t1_img, t1c_img, t2_img, t2f_img))

        # transform to tensors
        mri_tensor = torch.tensor(mri_4x, dtype=torch.float32).swapaxes(1,3)
        seg_tensor = torch.tensor(seg_img, dtype=torch.float32)

        # rescale to interval [0, 1]
        mri_tensor = self.normalize(mri_tensor)
        return mri_tensor, seg_tensor


    def normalize(self, data: FloatTensor) -> FloatTensor:
        # normalize a tensor to interval [0, 1]
        data_min, data_max = torch.min(data), torch.max(data)
        return (data - data_min) / (data_max - data_min)

In [None]:
brats = BraTS('./data')
brats[0][0].shape

In [None]:
def cnnLayer3d(in_filters, out_filters, kernel_size=3, leak_rate=0.01):
    padding = kernel_size//2
    return torch.nn.Sequential(
        torch.nn.Conv3d(in_filters, out_filters, kernel_size, padding=padding), 
        torch.nn.BatchNorm3d(out_filters),
        torch.nn.LeakyReLU(leak_rate)
    )

In [None]:
class UNetBlock3d(torch.nn.Module):
    
    def __init__(self, in_channels, mid_channels, out_channels=None, layers=1, sub_network=None, filter_size=3):
        super().__init__()

        # preparing the layers used to process the input
        in_layers = [cnnLayer3d(in_channels, mid_channels, filter_size)]
        
        # double the number of inputs to the output if there's a subnetwork
        inputs_to_outputs = 1 if sub_network is None else 2

        # preparing the layers used to process the final output, which has extra input channels from any sub-network
        out_layers = [cnnLayer3d(mid_channels*inputs_to_outputs, mid_channels, filter_size)]
        
        # make the additional hidden layers used for the input and output
        for _ in range(layers-1):
            in_layers.append(cnnLayer3d(mid_channels, mid_channels, filter_size))
            out_layers.append(cnnLayer3d(mid_channels, mid_channels, filter_size))
        
        # use 1x1 Convolutions to ensure a specific output size
        if out_channels is not None:
            out_layers.append(torch.nn.Conv3d(mid_channels, out_channels, 1, padding=0))

        # define the three sub-networks:

        #1) in_model performs the intial rounds of convolution
        self.in_model = torch.nn.Sequential(*in_layers)

        #2) our subnetwork works on the max-pooled result. We will add the pooling and up-scaling directly into the sub-model
        if sub_network is not None:
            self.bottleneck = torch.nn.Sequential(
                torch.nn.MaxPool3d(2),
                sub_network,
                torch.nn.ConvTranspose3d(mid_channels, mid_channels, filter_size, padding=filter_size//2, output_padding=1, stride=2)
            )
        else:
            self.bottleneck = None
        
        #3) the output model that processes the concatenated result, or just the output from in_model if no sub-network was given
        self.out_model = torch.nn.Sequential(*out_layers)
    
    
    def forward(self, x):

        # compute the convolutions at full scale
        full_scale_result = self.in_model(x) # shape (B, C, D, W, H)

        # check if there's a bottleneck to apply
        if self.bottleneck is not None:
            # initial shape (B, C, D, W, H)
            bottle_result = self.bottleneck(full_scale_result)
            # final shape (B, 2*C, D, W, H)
            full_scale_result = torch.cat([full_scale_result, bottle_result], dim=1)
    
        # compute the output on the concatenated result
        return self.out_model(full_scale_result)


In [None]:
unet3d_model = torch.nn.Sequential(
    UNetBlock3d(4, 32, layers=2,
    sub_network= UNetBlock3d(32, 64, out_channels=32, layers=2,
        sub_network=UNetBlock3d(64, 128, out_channels=64, layers=2)
        ),
    ),
    torch.nn.Conv3d(32, 4, (3,3), padding=1), # shape (B, 4, D, W, H)
)

In [None]:
def train_evaluate(
    model,              # model's instance
    optimizer,          # optimizer instance
    criterion,          # loss function
    dataloaders,        # dictionary of train and (optional) valid dataloaders
    epochs=10,          # number of epochs
    # lr=3e-4,          # learning rate (default: Karpathy’s constant)
    device='auto',      # device for computations
    score_funcs={},     # score functions to use from sklearn
    checkpoint='',      # name of the file for the checkpoint
    save_best=False
    ):

    # place the model on the selected device
    if device == 'auto':
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    # create a dictionary dataloader if a single dataloader is passed
    if not isinstance(dataloaders, dict):
        dataloaders = {'train': dataloaders}
    
    # set up a dictionary of items to track
    track_list = ['epoch', 'time']
    # check number of phases to track
    if len(dataloaders) > 1:
        phases = ['train', 'valid']
    else:
        phases = ['train']
    # add loss tracking for each phase
    track_list.extend([phase + '_loss' for phase in phases])
    # add score functions tracking for each phase
    if score_funcs:
        key_list = [key for key in score_funcs.keys()]
        track_list.extend(
            [phase + '_' + key for key in key_list for phase in phases]
            )
    # instantiate the dictionary log to track
    train_log = {x: [] for x in track_list}
    # print(train_log)

    # initialize time for logging
    train_time = float()
    best_loss = float()

    # iterate through epochs
    for epoch in tqdm(range(1, epochs+1)):

        # set up
        start_time = time.time()
        train_epoch_loss, valid_epoch_loss = [], []

        # iterate through phases:
        for phase in phases:

            # set up
            y_true, y_pred = [], []

            # set model to train or eval mode
            if phase == 'train':
                model.train()
            if phase == 'valid':
                model.eval()

            # iterate over batches
            for inputs, labels in tqdm(dataloaders[phase], leave=False):

                # place the data on the selected device
                inputs, labels = inputs.to(device), labels.to(device)

                # enable autograd differentiation only for training
                with torch.set_grad_enabled(phase == 'train'):

                    # forward propagation
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)

                    # backward propagation
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                        optimizer.zero_grad()

                    # logging the loss
                    if phase == 'train':
                        train_epoch_loss.append(loss.item())
                    if phase == 'valid':
                        valid_epoch_loss.append(loss.item())

                # score functions calculation
                if len(score_funcs) > 0 and isinstance(labels, torch.Tensor):
                    # moving labels & outputs to CPU arrays
                    labels = labels.detach().cpu().numpy()
                    outputs = outputs.detach().cpu().numpy()
                    # save the labels & outputs for later
                    y_true.extend(labels.tolist())
                    y_pred.extend(outputs.tolist())
            
            y_true, y_pred = np.asarray(y_true), np.asarray(y_pred)
            if len(y_pred.shape) == 2 and y_pred.shape[1] > 1: #We have a classification problem, convert to labels
                y_pred = np.argmax(y_pred, axis=1)
            #Else, we assume we are working on a regression problem
        
            # calculate and logging the score functions
            for key, score_func in score_funcs.items():
                key = phase + '_' + key
                try:
                    train_log[key].append(score_func(y_true, y_pred))
                except:
                    train_log[key].append(float("NaN"))

        # stop timer and check time
        end_time = time.time()
        epoch_time = end_time - start_time
        train_time += epoch_time

        # logging the epochs, time, losses
        for key in train_log.keys():
            if key == 'epoch':
                train_log[key].append(epoch)
            elif key == 'time':
                train_log[key].append(round(train_time, 2))
            elif key == 'train_loss':
                train_log[key].append(np.mean(train_epoch_loss))
            elif key == 'valid_loss':
                train_log['valid_loss'].append(
                    np.mean(valid_epoch_loss) if not valid_epoch_loss else 'nan'
                    )
    
        # save a checkpoint
        if checkpoint:
            torch.save({
                'epoch': epoch,
                'model_state': model.state_dict(),
                'optimizer_state': optimizer.state_dict(),
                'log' : train_log
                },
                checkpoint
            )
        
        # deep copy the model
        if save_best and phase == 'valid' and train_log['valid_loss'][-1] < best_loss:
            best_loss = train_log['valid_loss'][-1]
            best_model = copy.deepcopy(model.state_dict())

    # return a dataframe object with all the logging information
    if save_best:
        return (
            pd.DataFrame.from_dict(train_log).set_index('epoch'),
            model.load_state_dict(best_model)
        )
    else:
        return pd.DataFrame.from_dict(train_log).set_index('epoch')

In [None]:
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score

optimizer = torch.optim.AdamW(unet3d_model.parameters())
criterion = torch.nn.CrossEntropyLoss()

train_data, test_data = torch.utils.data.random_split(brats, (100, len(brats)-100))
train_loader = torch.utils.data.DataLoader(train_data, batch_size=16, shuffle=True)

train_results = train_evaluate(
    model=unet3d_model,
    optimizer=optimizer,
    criterion=criterion,
    dataloaders=train_loader,
    epochs=1,
    score_funcs={'Ac':accuracy_score, 'F1': f1_score},
)