In [2]:
import os
import math
import copy
import time
import random
import pprint
import tqdm

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
from torchvision import utils
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torch.utils.tensorboard import SummaryWriter
from sklearn.model_selection import KFold

import cv2
import nibabel as nib
import skimage.transform as skTrans
from numpy import logical_and as l_and, logical_not as l_not
from scipy.spatial.distance import directed_hausdorff

%matplotlib inline

In [3]:
from google.colab import drive
drive.mount('/content/drive')#, force_remount=True)

Mounted at /content/drive


In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [5]:
channels = 4
resize_shape = (144,144,144)

# Transformations

In [6]:
class ScaleToFixed(object):

    def __init__(self, new_shape, interpolation=1, channels=4):
        self.shape= new_shape
        self.interpolation = interpolation
        self.channels = channels

    def __call__(self, image):
        # print('first shape', image.shape)
        if image is not None: # (some patients don't have segmentations)
            if self.channels == 1:
                short_shape = (self.shape[1], self.shape[2], self.shape[3])
                image = skTrans.resize(image, short_shape, order=self.interpolation, preserve_range=True)  #
                image = image.reshape(self.shape)
            else:
                image = skTrans.resize(image, self.shape, order=self.interpolation, preserve_range=True)  #

        # print('second shape', image.shape)
        # print()
        return image

class RandomFlip(object):
    """Randomly flips (horizontally as well as vertically) the given PIL.Image with a probability of 0.5
    """
    def __init__(self, prob_flip=0.5):
        self.prob_flip= prob_flip
    def __call__(self, image):

        if random.random() < self.prob_flip:
            flip_type = np.random.randint(0, 3) # flip across any 3D axis
            image = np.flip(image, flip_type)
        return image

class ZeroChannel(object):
    """Randomly sets channel to zero the given PIL.Image with a probability of 0.25
    """
    def __init__(self, prob_zero=0.25, channels=4):
        self.prob_zero= prob_zero
        self.channels = channels
    def __call__(self, image):

        if np.random.random() < self.prob_zero:
            channel_to_zero = np.random.randint(0, self.channels) # flip across any 3D axis
            zeros = np.zeros((image.shape[1], image.shape[2], image.shape[3]))
            image[channel_to_zero, :, :, :] = zeros
        return image

class ZeroSprinkle(object):
    def __init__(self, prob_zero=0.25, prob_true=0.5, channels=4):
        self.prob_zero=prob_zero
        self.prob_true=prob_true
        self.channels=channels
    def __call__(self, image):

        if self.prob_true:
            mask = np.random.rand(image.shape[0], image.shape[1], image.shape[2], image.shape[3])
            mask[mask < self.prob_zero] = 0
            mask[mask > 0] = 1
            image = image*mask

        return image


class MinMaxNormalize(object):
    """Min-Max normalization
    """
    def __call__(self, image):
        def norm(im):
            im = im.astype(np.float32)
            min_v = np.min(im)
            max_v = np.max(im)
            im = (im - min_v)/(max_v - min_v)
            return im
        image = norm(image)
        return image

class ToTensor(object):
    def __init__(self, scale=1):
        self.scale = scale

    def __call__(self, image):
        if image is not None:
            image = image.astype(np.float32)
            image = image.reshape((image.shape[0], int(image.shape[1]/self.scale), int(image.shape[2]/self.scale), int(image.shape[3]/self.scale)))
            image_tensor = torch.from_numpy(image)
            return image_tensor
        else:
            return image


class Compose(object):
    """
    Composes several transforms together.
    """

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image):
        for i, t in enumerate(self.transforms):
            image = t(image)
        return image

In [7]:

# basic data augmentation
prob_voxel_zero = 0 # 0.1
prob_channel_zero = 0 # 0.5
prob_true = 0 # 0.8
randomflip = RandomFlip()

# MRI transformations
train_transformations = Compose([
    MinMaxNormalize(),
    ScaleToFixed((channels, resize_shape[0],resize_shape[1],resize_shape[2]),
                          interpolation=1,
                          channels=channels),
    ZeroSprinkle(prob_zero=prob_voxel_zero, prob_true=prob_true),
    ZeroChannel(prob_zero=prob_channel_zero),
    randomflip,
    ToTensor()
])

# GT segmentation mask transformations

seg_transformations = Compose([
            ScaleToFixed((1, resize_shape[0],resize_shape[1],resize_shape[2]),
                                      interpolation=0,
                                      channels=1),
            randomflip,
            ToTensor(),
        ])

# Dataloader

In [8]:
def get_bb_3D(img, pad=0):
    '''
    This function returns a tumor 3D bounding box using a segmentation mask
    '''
    xs = np.nonzero(np.sum(np.sum(img, axis=1), axis=1))
    ys = np.nonzero(np.sum(np.sum(img, axis=0), axis=1))
    zs = np.nonzero(np.sum(np.sum(img, axis=0), axis=0))
    xmin, xmax = np.min(xs), np.max(xs)
    ymin, ymax = np.min(ys), np.max(ys)
    zmin, zmax = np.min(zs), np.max(zs)
    bbox = (xmin-pad, ymin-pad, zmin-pad, xmax+pad, ymax+pad, zmax+pad)
    return bbox

def min_max(img):
    '''
    Min-max normalization
    '''
    return (img - img.min()) / (img.max() - img.min())

def read_mri(mr_path_dict, pad=0):

    image_shape = nib.load(mr_path_dict['flair']).get_fdata().shape
    bb_seg = get_bb_3D(nib.load(mr_path_dict['flair']).get_fdata())
    (xmin, ymin, zmin, xmax, ymax, zmax) = bb_seg

    xmin = np.max([0, xmin-pad])
    ymin = np.max([0, ymin-pad])
    zmin = np.max([0, zmin-pad])

    xmax = np.min([image_shape[0]-1, xmax+pad])
    ymax = np.min([image_shape[1]-1, ymax+pad])
    zmax = np.min([image_shape[2]-1, zmax+pad])


    img_dict = {}
    for key in ['flair', 't1', 't1ce', 't2', 'seg']:
        img = nib.load(mr_path_dict[key])
        img_data = img.get_fdata()
        img_dict[key] = img_data[xmin:xmax, ymin:ymax, zmin:zmax]

    stacked_img = np.stack([min_max(img_dict['flair']), min_max(img_dict['t1']),min_max(img_dict['t1ce']),min_max(img_dict['t2'])], axis=0)
    return stacked_img, img_dict['seg']


In [9]:
def plot_(image, seg, predicted=False):
    #Overlay with Predicted
    img = image[slice, :, :, :].squeeze()
    img = utils.make_grid(img)
    img = img.detach().cpu().numpy()
    
    print(img.shape)
    
    # plot images
    plt.figure(figsize=(10, 8))
    img_list = [img[i].T for i in range(channels)] # 1 image per channel
    plt.imshow(np.hstack(img_list), cmap='Greys_r')
    
    ## plot segmentation mask ##
    seg_img = torch.tensor(pred[slice].squeeze())
    if not predicted:
        seg_img = torch.tensor(seg_img.numpy()[:, ::-1].copy()) #flip
    seg_img = utils.make_grid(seg_img).detach().cpu().numpy()
    
    print(np.unique(seg_img))

    plt.imshow(np.hstack([seg_img[0].T]), cmap='Greys_r', alpha=0.3)
    plt.show()
    

In [10]:
def calculate_metrics(preds, targets, patient, tta=False):
    """
    Parameters
    ----------
    preds:
        torch tensor of size 1*C*Z*Y*X
    targets:
        torch tensor of same shape
    patient :
        The patient ID
    tta:
        is tta performed for this run
    """
    pp = pprint.PrettyPrinter(indent=4)
    assert preds.shape == targets.shape, "Preds and targets do not have the same size"

    labels = ["ET", "TC", "WT"]

    metrics_list = []

    for i, label in enumerate(labels):
        metrics = dict(
            patient_id=patient,
            label=label,
            tta=tta,
        )

        if np.sum(targets[i]) == 0:
            print(f"{label} not present for {patient}")
            sens = np.nan
            dice = 1 if np.sum(preds[i]) == 0 else 0
            tn = np.sum(l_and(l_not(preds[i]), l_not(targets[i])))
            fp = np.sum(l_and(preds[i], l_not(targets[i])))
            spec = tn / (tn + fp)
            haussdorf_dist = np.nan

        else:
            preds_coords = np.argwhere(preds[i])
            targets_coords = np.argwhere(targets[i])
            haussdorf_dist = directed_hausdorff(preds_coords, targets_coords)[0]

            tp = np.sum(l_and(preds[i], targets[i]))
            tn = np.sum(l_and(l_not(preds[i]), l_not(targets[i])))
            fp = np.sum(l_and(preds[i], l_not(targets[i])))
            fn = np.sum(l_and(l_not(preds[i]), targets[i]))

            sens = tp / (tp + fn)
            spec = tn / (tn + fp)

            dice = 2 * tp / (2 * tp + fp + fn)

        metrics[HAUSSDORF] = haussdorf_dist
        metrics[DICE] = dice
        metrics[SENS] = sens
        metrics[SPEC] = spec
        pp.pprint(metrics)
        metrics_list.append(metrics)

    return metrics_list


HAUSSDORF = "haussdorf"
DICE = "dice"
SENS = "sens"
SPEC = "spec"
METRICS = [HAUSSDORF, DICE, SENS, SPEC]


In [11]:
class GeneralDataset(Dataset):

    def __init__(self,
                metadata_df,
                root_dir,
                transform=None,
                seg_transform=None, ###
                dataformat=None, # indicates what shape (or content) should be returned (2D or 3D, etc.)
                returndims=None, # what size/shape 3D volumes should be returned as.
                visualize=False,
                modality=None,
                pad=2,
                device='cpu'):
        """
        Args:
            metadata_df (string): Path to the csv file w/ patient IDs
            root_dir (string): Directory for MR images
            transform (callable, optional)
        """
        self.device=device
        self.metadata_df = metadata_df
        self.root_dir = root_dir
        self.transform = transform
        self.seg_transform = seg_transform
        self.returndims=returndims
        self.modality = modality
        self.pad = pad


    def __len__(self):
        return len(self.metadata_df)

    def __getitem__(self, idx):
        #print(type(idx), idx)
        if torch.is_tensor(idx):
            idx = idx.tolist()

        BraTS20ID = self.metadata_df.iloc[idx].BraTS_2020_subject_ID

        # make dictonary of paths to MRI volumnes (modalities) and segmenation masks
        mr_path_dict = {}
        sequence_type = ['seg', 't1', 't1ce', 'flair', 't2']
        for seq in sequence_type:
            mr_path_dict[seq] = os.path.join(self.root_dir, BraTS20ID, BraTS20ID + '_'+seq+'.nii.gz')

        image, seg_image = read_mri(mr_path_dict=mr_path_dict, pad=self.pad)
        
        if seg_image is not None:
            seg_image[seg_image == 4] = 3

        if self.transform:
            image = self.transform(image)
        if self.seg_transform:
            seg_image = self.seg_transform(seg_image)
        else:
            print('no transform')
        # print(image.shape)
        return (image, seg_image), BraTS20ID

In [12]:
# Set random seed for reproduciablity
torch.manual_seed(42)
random.seed(42)


In [13]:

class DiceLoss(nn.Module):
    def __init__(self, epsilon=1e-5):
        super(DiceLoss, self).__init__()
        # smooth factor
        self.epsilon = epsilon

    def forward(self, targets, logits):
        batch_size = targets.size(0)
        # log_prob = torch.sigmoid(logits)
        logits = logits.view(batch_size, -1).type(torch.FloatTensor)
        targets = targets.view(batch_size, -1).type(torch.FloatTensor)
        intersection = (logits * targets).sum(-1)
        dice_score = 2. * intersection / ((logits + targets).sum(-1) + self.epsilon)
        # dice_score = 1 - dice_score.sum() / batch_size
        return torch.mean(1. - dice_score)

In [14]:

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, k_size=3, stride=1, padding=1):
        super(ConvBlock, self).__init__()
        self.conv3d = nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=k_size,
                                stride=stride, padding=padding)
        self.batch_norm = nn.BatchNorm3d(num_features=out_channels)

    def forward(self, x):
        x = self.batch_norm(self.conv3d(x))
        # x = self.conv3d(x)
        x = F.elu(x)
        return x


class EncoderBlock(nn.Module):
    def __init__(self, in_channels, init_features, model_depth=4, pool_size=2):
        super(EncoderBlock, self).__init__()
        self.root_feat_maps = init_features
        self.num_conv_blocks = 2
        self.module_dict = nn.ModuleDict()
        for depth in range(model_depth):
            feat_map_channels = 2 ** (depth + 1) * self.root_feat_maps
            for i in range(self.num_conv_blocks):
                self.conv_block = ConvBlock(in_channels=in_channels, out_channels=feat_map_channels)
                self.module_dict["conv_{}_{}".format(depth, i)] = self.conv_block
                in_channels, feat_map_channels = feat_map_channels, feat_map_channels * 2
            if depth == model_depth - 1:
                break
            else:
                self.pooling = nn.MaxPool3d(kernel_size=pool_size, stride=2, padding=0)
                self.module_dict["max_pooling_{}".format(depth)] = self.pooling

    def forward(self, x):
        down_sampling_features = []
        for k, op in self.module_dict.items():
            if k.startswith("conv"):
                x = op(x)
                #print(k, x.shape)
                if k.endswith("1"):
                    down_sampling_features.append(x)
            elif k.startswith("max_pooling"):
                x = op(x)
                #print(k, x.shape)

        return x, down_sampling_features


class ConvTranspose(nn.Module):
    def __init__(self, in_channels, out_channels, k_size=3, stride=2, padding=1, output_padding=1):
        super(ConvTranspose, self).__init__()
        self.conv3d_transpose = nn.ConvTranspose3d(in_channels=in_channels,
                                                   out_channels=out_channels,
                                                   kernel_size=k_size,
                                                   stride=stride,
                                                   padding=padding,
                                                   output_padding=output_padding)

    def forward(self, x):
        return self.conv3d_transpose(x)


class DecoderBlock(nn.Module):
    def __init__(self, out_channels, init_features, model_depth=4):
        super(DecoderBlock, self).__init__()
        self.num_conv_blocks = 2
        self.num_feat_maps = init_features
        # user nn.ModuleDict() to store ops
        self.module_dict = nn.ModuleDict()

        for depth in range(model_depth - 2, -1, -1):
            # print(depth)
            feat_map_channels = 2 ** (depth + 1) * self.num_feat_maps
            # print(feat_map_channels * 4)
            self.deconv = ConvTranspose(in_channels=feat_map_channels * 4, out_channels=feat_map_channels * 4)
            self.module_dict["deconv_{}".format(depth)] = self.deconv
            for i in range(self.num_conv_blocks):
                if i == 0:
                    self.conv = ConvBlock(in_channels=feat_map_channels * 6, out_channels=feat_map_channels * 2)
                    self.module_dict["conv_{}_{}".format(depth, i)] = self.conv
                else:
                    self.conv = ConvBlock(in_channels=feat_map_channels * 2, out_channels=feat_map_channels * 2)
                    self.module_dict["conv_{}_{}".format(depth, i)] = self.conv
            if depth == 0:
                self.final_conv = ConvBlock(in_channels=feat_map_channels * 2, out_channels=out_channels)
                self.module_dict["final_conv"] = self.final_conv

    def forward(self, x, down_sampling_features):
        """
        :param x: inputs
        :param down_sampling_features: feature maps from encoder path
        :return: output
        """
        for k, op in self.module_dict.items():
            if k.startswith("deconv"):
                x = op(x)
                #print(k, x.shape)
                x = torch.cat((down_sampling_features[int(k[-1])], x), dim=1)
            elif k.startswith("conv"):
                x = op(x)
                #print(k, x.shape)
            else:
                x = op(x)
                #print(k, x.shape)
        return x


In [26]:
class UnetModel(nn.Module):

    def __init__(self, in_channels, out_channels, init_features, model_depth=4, final_activation="sigmoid"):
        super(UnetModel, self).__init__()
        self.encoder = EncoderBlock(in_channels=in_channels,
                                    init_features=init_features,
                                    model_depth=model_depth)
        self.decoder = DecoderBlock(out_channels=out_channels,
                                    init_features=init_features,
                                    model_depth=model_depth)
        if final_activation == "sigmoid":
            self.sigmoid = nn.Sigmoid()
        else:
            self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x, downsampling_features = self.encoder(x)
        x = self.decoder(x, downsampling_features)
        x = self.sigmoid(x)
        # print("Final output shape: ", x.shape)
        return x



In [27]:
def kFoldRun(k_folds, num_epochs, train_batch_size, train_data, validation_data, network, criterion, optim, use_cuda=True):
    torch.manual_seed(42)
    
    if use_cuda:
        network = network.cuda()

    loss_function = criterion

    dataset = ConcatDataset([train_data, validation_data])

    # Define the K-fold Cross Validator
    kfold = KFold(n_splits=k_folds, shuffle=True)

    # Start print
    print('--------------------------------')

    # K-fold Cross Validation model evaluation
    for fold, (train_ids, test_ids) in enumerate(kfold.split(dataset)):
        # Print
        print(f'FOLD {fold}')
        print('--------------------------------')
    
        # Sample elements randomly from a given list of ids, no replacement.
        train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
        test_subsampler = torch.utils.data.SubsetRandomSampler(test_ids)

        # Define data loaders for training and testing data in this fold
        dataloader_train = DataLoader(dataset, batch_size=train_batch_size,sampler=train_subsampler,num_workers=0)
        dataloader_valid = DataLoader(dataset, batch_size=train_batch_size,sampler=test_subsampler, num_workers=0)

        # Initialize optimizer
        optimizer = torch.optim.Adam(network.parameters(), lr=1e-4)

        # Run the training loop for defined number of epochs
        for epoch in range(0, num_epochs):
            # Print epoch
            print(f'Starting epoch {epoch+1}')
            start_time = time.time()

            # Set current loss value
            current_loss = 0.0

            # Iterate over the DataLoader for training data
            if dataloader_train is None or optimizer is None:
                print('None')
                break  # NotImplementedError
            for i, data in enumerate(tqdm.tqdm(dataloader_train)):
                # Get inputs
                (inputs, targets), ID = data
                #inputs = torch.squeeze(torch.permute(image, (0, 4, 1, 2, 3))) 
                #label = torch.squeeze(torch.permute(seg_image, (0, 4, 1, 2, 3))) 
                if use_cuda:
                    inputs, targets = inputs.cuda(), targets.cuda() # add this line
                # Zero the gradients
                optimizer.zero_grad()
                # Perform forward pass
                outputs = network(inputs)
                print(outputs.shape, targets.shape)
                # Compute loss
                loss = loss_function(outputs, targets.squeeze(1).long())
                print('Loss:', loss.item())
                # Perform backward pass
                loss.backward()

                # Perform optimization
                optimizer.step()

                # Print statistics
                current_loss += loss.item()
                if i % 500 == 499:
                    print('Loss after mini-batch %5d: %.3f' %
                          (i + 1, current_loss / 500))
                    current_loss = 0.0
            end_time = time.time()
            print(f"Epoch Time: {end_time - start_time}")
    # Process is complete.
    print('Training process has finished. Saving trained model.')
    
    # Saving the model
    save_path = f'./model-fold-{fold}.pth'
    torch.save(network.state_dict(), 'drive/MyDrive/Colab Notebooks/')

    # Print about testing
    print('Starting testing')
    # Evaluationfor this fold
    correct, total = 0, 0
    with torch.no_grad():
        # Iterate over the test data and generate predictions
        for i, data in enumerate(dataloader_valid, 0):
            # Get inputs
            inputs, targets = data
            # inputs = torch.squeeze(torch.permute(image, (0, 4, 1, 2, 3))) # 
            # label = torch.squeeze(torch.permute(seg_image, (0, 4, 1, 2, 3))) #, 0) 
            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda() # add this line

            # Generate outputs
            outputs = network(inputs)

            # Set total and correct
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()

        # Print accuracy
        print('Accuracy for fold %d: %d %%' % (fold, 100.0 * correct / total))
        print('--------------------------------')
        results[fold] = 100.0 * (correct / total)
    
    # Print fold results
    print(f'K-FOLD CROSS VALIDATION RESULTS FOR {k_folds} FOLDS')
    print('--------------------------------')
    sum = 0.0
    for key, value in results.items():
        print(f'Fold {key}: {value} %')
        sum += value
    print(f'Average: {sum/len(results.items())} %')

    return results

In [39]:
def train_epochs(network, use_cuda, dataloader_train, loss_function, optimizer, num_epochs):
    ################
    history = {'train_loss': [], 'train_acc':[]}
    for epoch in range(num_epochs):
        print(f'Starting Train epoch: {epoch+1}')
        train_loss = 0.0
        train_correct = 0

        for i, data in enumerate(tqdm.tqdm(dataloader_train)):
            (inputs, targets), ID = data
            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda()
            optimizer.zero_grad()
            outputs = network(inputs)
            loss = loss_function(outputs, targets.squeeze(1).long())
            print('Train Loss:', loss.item())

            train_loss += loss.item() * outputs.size(0) #multiplying by batchsize

            _, predictions = torch.max(outputs.data, 1) #change
            train_correct += (predictions == targets).sum().item()

            loss.backward()
            optimizer.step()
            # Print statistics
           
        history['train_loss'].append(train_loss / len(dataloader_train.sampler))
        history['train_acc'].append(train_correct / len(dataloader_train.sampler))

        print(f"Epoch loss: {history['train_loss'][-1]}")

    return history['train_loss'][-1], history['train_acc'][-1], history


def valid_epoch(network, use_cuda, dataloader_valid, loss_function):
    valid_loss = 0.0
    correct, total = 0, 0
    with torch.no_grad():
        # Iterate over the test data and generate predictions
        for i, data in enumerate(dataloader_valid, 0):
            inputs, targets = data
            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda() 
            outputs = network(inputs)
            loss = loss_function(outputs, targets.squeeze(1).long())
            print('Valid Loss:', loss.item())
            #################
            valid_loss += loss.item() * inputs.size(0)

            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == targets).sum().item()
            #################
        # Print accuracy
        valid_loss /= len(dataloader_valid.sampler) 
        valid_acc = 100.0 * (correct / len(dataloader_valid.sampler))
        print(f" Fold Accuracy: {valid_acc}")

        if valid < best:
            best = validation_loss
            model_dict = model.state_dict()
            save_checkpoint(
                dict(
                    epoch=epoch, arch=args.arch,
                    state_dict=model_dict,
                    optimizer=optimizer.state_dict(),
                    scheduler=scheduler.state_dict(),
                ),
                save_folder=args.save_folder, )

    return valid_loss, valid_acc




In [70]:

def save_model(model, optimizer, fold, epoch, loss):
    # Saving the model
    save_path = f'model-fold-{fold}.pth'

    checkpoint = {'epoch': epoch,
                  'model_state_dict': model.state_dict(),
                  'optimizer_state_dict': optimizer.state_dict(),
                  'loss': loss,
                  }
    torch.save(checkpoint, save_path)


In [71]:
def train_test_epochs(model, loss_function, optimizer, dataloader_train, dataloader_valid, fold, num_epochs, use_cuda):
    train_history = {'train_loss': [], 'train_acc':[]}
    valid_history = {'valid_loss': [], 'valid_acc':[]}
    best = math.inf

    for epoch in range(num_epochs):
        print(f'Starting Train epoch: {epoch+1}')

        train_loss = 0.0
        train_correct = 0

        for i, data in enumerate(tqdm.tqdm(dataloader_train)):
            (inputs, targets), ID = data
            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda()
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, targets.squeeze(1).long())
            print('Train Loss:', loss.item())

            train_loss += loss.item() * outputs.size(0) #multiplying by batchsize

            _, predictions = torch.max(outputs.data, 1) #change
            train_correct += (predictions == targets).sum().item()

            loss.backward()
            optimizer.step()
           
        train_history['train_loss'].append(train_loss / len(dataloader_train.sampler))
        train_history['train_acc'].append(train_correct / len(dataloader_train.sampler))

        print(f"Train Epoch loss: {train_history['train_loss'][-1]}")

        valid_loss = 0.0
        correct, total = 0, 0

        with torch.no_grad():
            # Iterate over the test data and generate predictions
            for i, data in enumerate(tqdm.tqdm(dataloader_valid)):
                (inputs, targets), ID = data
                if use_cuda:
                    inputs, targets = inputs.cuda(), targets.cuda() 
                outputs = model(inputs)
                loss = loss_function(outputs, targets.squeeze(1).long())
                print('Valid Loss:', loss.item())
                #################
                valid_loss += loss.item() * inputs.size(0)
                _, predicted = torch.max(outputs.data, 1)
                correct += (predicted == targets).sum().item()
                #################
            # Print accuracy
            valid_loss /= len(dataloader_valid.sampler) 
            valid_acc = 100.0 * (correct / len(dataloader_valid.sampler))
            print(f" Fold Accuracy: {valid_acc}")

        valid_history['valid_loss'].append(valid_loss)
        valid_history['valid_acc'].append(valid_acc)

        print(f"Val Epoch loss: {valid_history['valid_loss'][-1]}")

        # saving best model for this fold
        if valid_loss < best:
            best = valid_loss
            save_model(model, optimizer, fold, epoch, loss)

    return train_history['train_loss'][-1], train_history['train_acc'][-1], train_history, 
    valid_history['valid_loss'][-1], valid_history['valid_acc'][-1], valid_history

In [65]:

kfold = KFold(n_splits=2, shuffle=True)
dataset = ConcatDataset([transformed_dataset_train, transformed_dataset_valid])

train_ids, test_ids =  next(iter(kfold.split(dataset)))
print(test_ids)
print('--------------------------------')
# Sample elements randomly from a given list of ids, no replacement.
train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
test_subsampler = torch.utils.data.SubsetRandomSampler(test_ids)

# Define data loaders for training and testing data in this fold
dataloader_train = DataLoader(dataset, batch_size=2, sampler=train_subsampler, num_workers=0)
dataloader_valid = DataLoader(dataset, batch_size=2, sampler=test_subsampler, num_workers=0)




[ 1  2  5  8 10 11 14 16 18 19 20 23 25 26 27 28 33 35 36 39]
--------------------------------


In [66]:
with torch.no_grad():
    # Iterate over the test data and generate predictions
    (h, k), ID = next(iter(dataloader_train))
    print(type(h), type(k))
      

<class 'torch.Tensor'> <class 'torch.Tensor'>


In [72]:
def kFoldRunAll(model, criterion, optim, train_data, validation_data, k_folds, num_epochs, train_batch_size, use_cuda):
    torch.manual_seed(42)
    if use_cuda:
        model = model.cuda()

    loss_function = criterion
    dataset = ConcatDataset([train_data, validation_data])
    # Define the K-fold Cross Validator
    kfold = KFold(n_splits=k_folds, shuffle=True)

    fold_train_history = {}
    fold_valid_history = {}
    fold_train_and_valid_acc = {}
    fold_train_and_valid_loss = {}

    print('--------------------------------')
    # K-fold Cross Validation model evaluation
    for fold, (train_ids, test_ids) in enumerate(kfold.split(dataset)):
        print(f'FOLD {fold}')
        print('--------------------------------')
        # Sample elements randomly from a given list of ids, no replacement.
        train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
        test_subsampler = torch.utils.data.SubsetRandomSampler(test_ids)

        # Define data loaders for training and testing data in this fold
        dataloader_train = DataLoader(dataset, batch_size=train_batch_size, sampler=train_subsampler, num_workers=0)
        dataloader_valid = DataLoader(dataset, batch_size=train_batch_size, sampler=test_subsampler, num_workers=0)

        # Initialize optimizer
        optimizer = optim

        # Run the training, testing and saving loop for defined number of epochs
        start_time = time.time()

        t_loss, t_acc, t_history, v_loss, v_acc, v_history = train_test_epochs(model, loss_function, optimizer,
                                                                               dataloader_train, dataloader_valid,
                                                                               fold, num_epochs, use_cuda)

        end_time = time.time()
        print(f"Epoch Time: {end_time - start_time}")

        #Saving loss results 
        fold_train_and_valid_loss[str(fold)] = [t_loss, v_loss]
        fold_train_and_valid_acc[str(fold)] = [t_acc, v_acc]
        fold_train_history[str(fold)] = t_history
        fold_valid_history[str(fold)] = v_history

        # Print accuracy
        print(f'Accuracy for fold {fold}: {v_acc}')
        print(f'Loss for fold {fold}: {v_loss}')
        print('--------------------------------')  

    return fold_train_history, fold_valid_history, fold_train_and_valid_loss, fold_train_and_valid_acc
  


In [73]:


def plot_metric(train, valid, metric_name):
    # Plot losses
    plt.figure(figsize=(10,8))
    plt.semilogy(train, label='Train')
    plt.semilogy(valid, label='Valid')
    plt.xlabel('Epoch')
    plt.ylabel(metric_name)
    plt.legend()
    plt.title(f'Model {metric_name} Plot')
    plt.show()
    plt.clf()

def plot_result(num_epochs, fold_train_history, fold_valid_history):
    final_fold = {'train_loss':[],'valid_loss':[],'train_acc':[],'valid_acc':[]}

    for epoch in range(num_epochs):                                      
        final_fold['train_loss'].append(np.mean([fold_train_history[str(fold)]['train_loss'][epoch] for fold in range(kfolds)]))
        final_fold['valid_loss'].append(np.mean([fold_valid_history[str(fold)]['valid_loss'][epoch]for fold in range(kfolds)]))
        final_fold['train_acc'].append(np.mean([fold_train_history[str(fold)]['train_acc'][epoch]for fold in range(kfolds)]))
        final_fold['valid_acc'].append(np.mean([fold_valid_history[str(fold)]['valid_acc'][epoch]for fold in range(kfolds)]))

    plot_metric(final_fold['train_loss'], final_fold['valid_loss'], 'Loss')
    plot_metric(final_fold['train_acc'], final_fold['valid_acc'], 'Accuracy')


In [74]:

def train_main(train, valid, in_channels, out_channels, init_features, learning_rate, k_folds, no_epochs, train_batch_size):
    """
    Train module
    :param data_folder: data folder
    :param in_channels: the input channel of input images
    :param out_channels: the final output channel
    :param learning_rate: set learning rate for training
    :param no_epochs: number of epochs to train model
    :return: None
    """
    model = UnetModel(in_channels=in_channels, out_channels=out_channels,
                      init_features=init_features)
    optim = torch.optim.Adam(params=model.parameters(), lr=learning_rate)
    #criterion = DiceLoss()
    criterion = CrossEntropyLoss()
    use_cuda = torch.cuda.is_available()
    t_history, v_history, tv_loss, tv_acc = kFoldRunAll(model, criterion, optim,
                                                        train, valid,
                                                        k_folds, no_epochs,
                                                        train_batch_size,
                                                        use_cuda)
    
    plot_result(no_epochs, t_history, v_history)

# Data

In [75]:

image_dir = 'drive/MyDrive/Colab Notebooks/MICCAI_BraTS2020_TrainingData/'
naming = pd.read_csv(f'drive/MyDrive/Colab Notebooks/MICCAI_BraTS2020_TrainingData/name_mapping.csv')

data_df = pd.DataFrame(naming['BraTS_2020_subject_ID'])
n_p = 20 # n_patients_to_train_with
train_df = data_df[:n_p]
valid_df = data_df[n_p:n_p*2]

In [76]:
transformed_dataset_train = GeneralDataset(metadata_df=train_df, 
                                           root_dir=image_dir,
                                           transform=train_transformations,
                                           seg_transform=seg_transformations,
                                           returndims=resize_shape)

transformed_dataset_valid = GeneralDataset(metadata_df=valid_df, 
                                           root_dir=image_dir,
                                           transform=train_transformations,
                                           seg_transform=seg_transformations,
                                           returndims=resize_shape)

In [77]:
channels = 4
resize_shape = (144,144,144)

In [78]:
train_main(transformed_dataset_train, transformed_dataset_valid, in_channels=4,
           out_channels=4, init_features=4, learning_rate=1e-4, k_folds=2,
           no_epochs=3, train_batch_size=2)

--------------------------------
FOLD 0
--------------------------------
Starting Train epoch: 1


 10%|█         | 1/10 [00:11<01:47, 11.92s/it]

Train Loss: 1.3912442922592163


 20%|██        | 2/10 [00:23<01:33, 11.73s/it]

Train Loss: 1.3883370161056519


 30%|███       | 3/10 [00:35<01:21, 11.70s/it]

Train Loss: 1.3818974494934082


 40%|████      | 4/10 [00:46<01:09, 11.62s/it]

Train Loss: 1.3788005113601685


 50%|█████     | 5/10 [00:58<00:58, 11.63s/it]

Train Loss: 1.3725768327713013


 60%|██████    | 6/10 [01:09<00:46, 11.54s/it]

Train Loss: 1.3752583265304565


 70%|███████   | 7/10 [01:21<00:34, 11.53s/it]

Train Loss: 1.370758056640625


 80%|████████  | 8/10 [01:32<00:23, 11.53s/it]

Train Loss: 1.3752810955047607


 90%|█████████ | 9/10 [01:44<00:11, 11.55s/it]

Train Loss: 1.3709559440612793


100%|██████████| 10/10 [01:55<00:00, 11.59s/it]


Train Loss: 1.3721274137496948
Train Epoch loss: 1.3777236938476562


 10%|█         | 1/10 [00:11<01:42, 11.44s/it]

Valid Loss: 1.3647671937942505


 20%|██        | 2/10 [00:23<01:32, 11.55s/it]

Valid Loss: 1.368791103363037


 30%|███       | 3/10 [00:34<01:20, 11.52s/it]

Valid Loss: 1.361297845840454


 40%|████      | 4/10 [00:45<01:08, 11.45s/it]

Valid Loss: 1.3694045543670654


 50%|█████     | 5/10 [00:57<00:57, 11.53s/it]

Valid Loss: 1.3709602355957031


 60%|██████    | 6/10 [01:09<00:46, 11.66s/it]

Valid Loss: 1.3612005710601807


 70%|███████   | 7/10 [01:21<00:35, 11.68s/it]

Valid Loss: 1.3604273796081543


 80%|████████  | 8/10 [01:32<00:23, 11.67s/it]

Valid Loss: 1.372279405593872


 90%|█████████ | 9/10 [01:44<00:11, 11.62s/it]

Valid Loss: 1.3640974760055542


100%|██████████| 10/10 [01:56<00:00, 11.62s/it]


Valid Loss: 1.3690590858459473
 Fold Accuracy: 282018200.0
Val Epoch loss: 1.3662284851074218
Starting Train epoch: 2


 10%|█         | 1/10 [00:11<01:45, 11.74s/it]

Train Loss: 1.3652409315109253


 20%|██        | 2/10 [00:23<01:34, 11.86s/it]

Train Loss: 1.367870807647705


 30%|███       | 3/10 [00:35<01:23, 11.86s/it]

Train Loss: 1.3649587631225586


 40%|████      | 4/10 [00:47<01:11, 11.85s/it]

Train Loss: 1.3621537685394287


 50%|█████     | 5/10 [00:58<00:58, 11.75s/it]

Train Loss: 1.355141520500183


 60%|██████    | 6/10 [01:10<00:46, 11.75s/it]

Train Loss: 1.36250638961792


 70%|███████   | 7/10 [01:22<00:35, 11.74s/it]

Train Loss: 1.3549476861953735


 80%|████████  | 8/10 [01:34<00:23, 11.71s/it]

Train Loss: 1.3590283393859863


 90%|█████████ | 9/10 [01:45<00:11, 11.64s/it]

Train Loss: 1.3567270040512085


100%|██████████| 10/10 [01:57<00:00, 11.73s/it]


Train Loss: 1.3590023517608643
Train Epoch loss: 1.3607577562332154


 10%|█         | 1/10 [00:11<01:44, 11.57s/it]

Valid Loss: 1.3537724018096924


 20%|██        | 2/10 [00:23<01:33, 11.66s/it]

Valid Loss: 1.3566924333572388


 30%|███       | 3/10 [00:34<01:21, 11.61s/it]

Valid Loss: 1.3565257787704468


 40%|████      | 4/10 [00:46<01:09, 11.62s/it]

Valid Loss: 1.3544209003448486


 50%|█████     | 5/10 [00:57<00:57, 11.58s/it]

Valid Loss: 1.3520203828811646


 60%|██████    | 6/10 [01:09<00:46, 11.60s/it]

Valid Loss: 1.358034610748291


 70%|███████   | 7/10 [01:21<00:34, 11.58s/it]

Valid Loss: 1.3539916276931763


 80%|████████  | 8/10 [01:32<00:22, 11.48s/it]

Valid Loss: 1.3550121784210205


 90%|█████████ | 9/10 [01:43<00:11, 11.47s/it]

Valid Loss: 1.3458833694458008


100%|██████████| 10/10 [01:55<00:00, 11.57s/it]


Valid Loss: 1.3518940210342407
 Fold Accuracy: 292705225.0
Val Epoch loss: 1.3538247704505921
Starting Train epoch: 3


 10%|█         | 1/10 [00:11<01:46, 11.82s/it]

Train Loss: 1.3526220321655273


 20%|██        | 2/10 [00:23<01:33, 11.68s/it]

Train Loss: 1.3503224849700928


 30%|███       | 3/10 [00:35<01:21, 11.69s/it]

Train Loss: 1.3529770374298096


 40%|████      | 4/10 [00:46<01:10, 11.70s/it]

Train Loss: 1.3478285074234009


 50%|█████     | 5/10 [00:58<00:58, 11.69s/it]

Train Loss: 1.345108985900879


 60%|██████    | 6/10 [01:09<00:46, 11.62s/it]

Train Loss: 1.345410943031311


 70%|███████   | 7/10 [01:21<00:35, 11.68s/it]

Train Loss: 1.3491337299346924


 80%|████████  | 8/10 [01:33<00:23, 11.62s/it]

Train Loss: 1.3509191274642944


 90%|█████████ | 9/10 [01:44<00:11, 11.64s/it]

Train Loss: 1.357401967048645


100%|██████████| 10/10 [01:56<00:00, 11.65s/it]


Train Loss: 1.3469284772872925
Train Epoch loss: 1.3498653292655944


 10%|█         | 1/10 [00:11<01:45, 11.72s/it]

Valid Loss: 1.3445147275924683


 20%|██        | 2/10 [00:23<01:33, 11.75s/it]

Valid Loss: 1.3524267673492432


 30%|███       | 3/10 [00:34<01:21, 11.62s/it]

Valid Loss: 1.344847321510315


 40%|████      | 4/10 [00:46<01:09, 11.65s/it]

Valid Loss: 1.346929907798767


 50%|█████     | 5/10 [00:58<00:57, 11.59s/it]

Valid Loss: 1.3496114015579224


 60%|██████    | 6/10 [01:09<00:46, 11.60s/it]

Valid Loss: 1.3422654867172241


 70%|███████   | 7/10 [01:21<00:34, 11.62s/it]

Valid Loss: 1.3443429470062256


 80%|████████  | 8/10 [01:32<00:23, 11.50s/it]

Valid Loss: 1.3480671644210815


 90%|█████████ | 9/10 [01:44<00:11, 11.46s/it]

Valid Loss: 1.3449431657791138


100%|██████████| 10/10 [01:55<00:00, 11.56s/it]

Valid Loss: 1.34572434425354
 Fold Accuracy: 330987855.0
Val Epoch loss: 1.34636732339859





ValueError: ignored

In [None]:
model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
                        in_channels=4, out_channels=4, init_features=4,
                        pretrained=False)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)

In [None]:
### Training loop here
use_cuda = torch.cuda.is_available()
if use_cuda:
    model = model.cuda()


num_epochs = 10

for epoch in range(num_epochs):
    losses = []
    if dataloader_train is None or optimizer is None:
        break  # NotImplementedError
    for data in tqdm.tqdm(dataloader_train):
        (image, seg_image), bratsID = data
        
        p_image = torch.squeeze(torch.permute(image, (0, 4, 1, 2, 3))) 
        p_seg_image = torch.squeeze(torch.permute(seg_image, (0, 4, 1, 2, 3))) 

        if use_cuda:
            p_image, p_seg_image = p_image.cuda(), p_seg_image.cuda() 
        pred = model(p_image.float())
        loss = criterion(pred, p_seg_image.long())
        
        print(loss.item())
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    losses.append(loss.item())

    print("Epoch:", epoch, "Mean Loss:", np.mean(losses))

###End

In [None]:
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            }, 'drive/MyDrive/Colab Notebooks/resumable_third_10_50e_mriseg.pt')

In [None]:
model = torch.load('second_100_mriseg.pt', map_location=torch.device('cpu'))
model.eval()

In [None]:

pred = model(p_image.float())
#pred = model(p_image[:, 0, :, :].unsqueeze(1).float()) # only for 0(flair) modality model

In [None]:
slice = 72

In [None]:
plt.imshow(cv2.resize(pred[slice, 2, :, :].detach().cpu().numpy(), (160, 160), interpolation = cv2.INTER_NEAREST))


In [None]:
torch.argmax(pred[0, 1:, 0, 8])

In [None]:
pred_seg = torch.sigmoid(pred)
segs = pred_seg.permute(1, 0, 2, 3).detach().cpu().numpy() > 0.52   #4, 144, 144, 144 now i.e c, d, h, w

In [None]:
pred[0, :, 0, 8]

In [None]:
pred_seg[0, :, 0, 8]

In [None]:
#On predictions- get Regions
et = segs[1]
net = np.logical_and(segs[2], np.logical_not(et))
ed = np.logical_and(segs[3], np.logical_not(segs[2]))

labelmap = np.zeros(segs[1].shape)
labelmap[et] = 3
labelmap[net] = 1
labelmap[ed] = 2

print(f"voxel values : {np.unique(labelmap)}")

In [None]:
#visualize any 2d slice

plt.imshow(labelmap[slice])
print(f"pixel values: {np.unique(labelmap[slice])}")

In [None]:
ref_seg = p_seg_image.squeeze()
refmap_et, refmap_tc, refmap_wt = [np.zeros_like(ref_seg) for i in range(3)]
refmap_et = ref_seg == 3
refmap_tc = np.logical_or(refmap_et, ref_seg == 1)
refmap_wt = np.logical_or(refmap_tc, ref_seg == 2)
refmap = np.stack([refmap_et, refmap_tc, refmap_wt])

In [None]:
patient_metric_list = calculate_metrics(segs[1:], refmap, 'test')

In [None]:
#Overlay with Ground seg



img = p_image[slice, :, :, :].squeeze()
img = utils.make_grid(img)
img = img.detach().cpu().numpy()
print(img.shape)
# plot images
plt.figure(figsize=(10, 8))
img_list = [img[i].T for i in range(channels)] # 1 image per channel
plt.imshow(np.hstack(img_list), cmap='Greys_r')


## plot segmentation mask ##
seg_img = p_seg_image[slice, :, :, :].squeeze()
print(seg_img.shape)
seg_img = torch.tensor(seg_img.numpy()[:, ::-1].copy()) #flip
seg_img = utils.make_grid(seg_img).detach().cpu().numpy()

print(np.unique(p_seg_image))
#seg_img = seg_img > 1
plt.imshow(np.hstack([seg_img[0].T]), cmap='Greys_r', alpha=0.5)
plt.show()

In [None]:
#Overlay with Predicted
img = p_image[slice, :, :, :].squeeze()
img = utils.make_grid(img)
img = img.detach().cpu().numpy()
print(img.shape)
# plot images
plt.figure(figsize=(10, 8))
img_list = [img[i].T for i in range(channels)] # 1 image per channel
plt.imshow(np.hstack(img_list), cmap='Greys_r')
## plot segmentation mask ##
seg_img = torch.tensor(labelmap[slice, :, :].squeeze())
seg_img = utils.make_grid(seg_img).detach().cpu().numpy()
print(np.unique(seg_img))

#plt.figure(figsize=(4, 4))
plt.imshow(np.hstack([seg_img[0].T]), cmap='Greys_r', alpha=0.3)
plt.show()