In [1]:
from glob import glob
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
from tqdm import tqdm
import pandas as pd
import torch.optim as optim


from google.colab import drive
drive.mount('/content/gdrive')
import os

# clean ds
CLEAN_DS_PATH = '/content/gdrive/MyDrive/ProjectDL/BraTS/SecondPart'

# train
CLEAN_TRAIN_PATH = f'{CLEAN_DS_PATH}/train'
CLEAN_TRAIN_IMG_PATH = f'{CLEAN_TRAIN_PATH}/images'
CLEAN_TRAIN_MSK_PATH = f'{CLEAN_TRAIN_PATH}/masks'
print(CLEAN_TRAIN_IMG_PATH)
# val
CLEAN_VAL_PATH = f'{CLEAN_DS_PATH}/val'
CLEAN_VAL_IMG_PATH = f'{CLEAN_VAL_PATH}/images'
CLEAN_VAL_MSK_PATH = f'{CLEAN_VAL_PATH}/masks'


# MAking dataset ready
class SimpleLogger:

    def __init__(self, debug=True):
        self.debug = debug

    def enable_debug(self):
        self.debug = True

    def disable_debug(self):
        self.debug = False

    def log(self, message, condition=True):
        if self.debug and condition:
            print(message)


logger = SimpleLogger(debug=True)

def to_categorical(y, n_classes):
    return np.eye(n_classes, dtype="uint8")[y]


class BraTSDataset(Dataset):
    def log(self, message):
        logger.log(message, condition=self.debug)

    def __init__(self, images_path, masks_path, transform=None, one_hot_target=True, debug=True):
              
        self.images = sorted(glob(f"{images_path}/*.npy"))
        self.masks = sorted(glob(f"{masks_path}/*.npy"))
        self.transform = transform
        self.one_hot_target = one_hot_target
        self.debug = debug
        self.log(f"images: {len(self.images)}, masks: {len(self.masks)} ")
        assert len(self.images) == len(self.masks), "images and masks lengths are not the same!"

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # if torch.is_tensor(idx):
        #     idx = idx.tolist()

        image = np.load(self.images[idx])
        mask = np.load(self.masks[idx])
        # resizing image and mask, experimental
        image = image[::2,::2,::2]
        mask = mask[::2,::2,::2]
        if self.one_hot_target:
            mask = to_categorical(mask, 4)
            mask = mask[::, ::, ::, 1::]  # discard background

        image = torch.from_numpy(image).float()  # .double()
        mask = torch.from_numpy(mask)  # .float() #.long()

        return image.permute((3, 0, 1, 2)), mask.permute((3, 0, 1, 2))


def get_dl(dataset, batch_size=4, pm=True, nw=1):
    return DataLoader(dataset, batch_size, shuffle=False, pin_memory=pm, num_workers=nw, ) # change

def get_train_ds():
    return BraTSDataset(CLEAN_TRAIN_IMG_PATH, CLEAN_TRAIN_MSK_PATH)


def get_val_ds():
    return BraTSDataset(CLEAN_VAL_IMG_PATH, CLEAN_VAL_MSK_PATH)

# this is for testing only
# if __name__ == '__main__':
#     train_ds = BraTSDataset(CLEAN_TRAIN_IMG_PATH, CLEAN_TRAIN_MSK_PATH)
#     print(train_ds[0][0].shape)
#     print(train_ds[0][1].shape)
#     dl = get_dl(train_ds, batch_size=1)
#     print("OK")


Mounted at /content/gdrive
/content/gdrive/MyDrive/ProjectDL/BraTS/SecondPart/train/images


In [2]:
# Residual 3DUNet model

import torch
import torch.nn as nn


# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        # the output image will be (n + 2p — f + 1) * (n + 2p — f + 1) where p =1 in this case.
        # Convolutional Layer
        self.conv = nn.Sequential(
          nn.BatchNorm3d(in_channels),
          nn.ReLU(inplace=True),
          nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride),
          nn.BatchNorm3d(out_channels),
          nn.ReLU(inplace=True),
          nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1, stride=1))

        # Identity Mapping
        self.shortcut = nn.Conv3d(in_channels, out_channels, kernel_size=1, padding=0, stride=stride)

    def forward(self, inputs):
        x = self.conv(inputs) 
        s = self.shortcut(inputs)       
        skip = x + s
        return skip

class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.upsample = nn.Upsample(scale_factor=2, mode="trilinear", align_corners=True) #  mode="trilinear"
        self.residual = ResidualBlock(in_channels + out_channels, out_channels)

    def forward(self, inputs, skip):
        x = self.upsample(inputs)
        x = torch.cat([x, skip], axis=1)
        x = self.residual(x)
        return x

class Res3DUNet(nn.Module):
    # the default dataset has 3 channels of data ->  T1CE, T2, FLAIR
    # The output has background, NCR/NET, ED, ET 

    def __init__(self, in_channels=3, out_channels=4):
        super().__init__()

        # Encoder 1 
        self.conv1 = nn.Conv3d(in_channels, 64, kernel_size=3, padding=1)
        self.bn1   = nn.BatchNorm3d(64)
        self.relu  = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv3d(64, 64, kernel_size=3, padding=1)
        #Identity Mapping
        self.conv3 = nn.Conv3d(in_channels, 64, kernel_size=1, padding=0)
        
        # Encoder 2 
        self.r2 = ResidualBlock(64, 128, stride=2)
        # Encoder 2 
        self.r3 = ResidualBlock(128, 256, stride=2)
        # Bridge
        self.r4 = ResidualBlock(256, 512, stride=2)
        # Decoder 1
        self.d1 = DecoderBlock(512, 256)
        # Decoder 2
        self.d2 = DecoderBlock(256, 128)
        # Decoder 3
        self.d3 = DecoderBlock(128, 64)

        # Output 
        self.output = nn.Conv3d(64, out_channels, kernel_size=1, padding=0)


    def forward(self, inputs):
        # Encoder 1 
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        s = self.conv3(inputs)
        skip1 = x + s
        # Encoder 2 
        skip2 = self.r2(skip1)
        # Encoder 3 
        skip3 = self.r3(skip2)
        # Bridge 
        b = self.r4(skip3)
        # Decoder 1
        d1 = self.d1(b, skip3)
        # Decoder 1
        d2 = self.d2(d1, skip2)
        # Decoder 1
        d3 = self.d3(d2, skip1)
        # output 
        output = self.output(d3)

        return output

In [7]:
# Extract training and validation masks for the second part of the project
# For training and validation of the second part, I need masks. Because I want to extract the volume of the tumour and use it for the training and validation part.
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
import torch.nn.functional as Fun
import pandas as pd

def load_dataset(data_loader, model, device="cuda"):
    preds_masks = torch.empty((0, 3, 64, 64, 64)).to(device)
    y_masks = torch.empty((0, 3, 64, 64, 64)).to(device)
    model.eval()

    with torch.no_grad():
        for x, y in data_loader:
            x = x.to(device)
            y = y.to(device) #.unsqueeze(1)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            preds_masks = torch.cat((preds_masks, preds), dim=0)
            y_masks = torch.cat((y_masks, y), dim=0)
            # print(preds.shape)
            # print(preds_masks.shape)
    model.train()
    return preds_masks, y_masks


# Loading train masks (ground truth) for training the second model
# masks_paths = sorted(glob(f"{CLEAN_TRAIN_MSK_PATH}/*.npy"))
# train_masks_y = np.array([np.load(path) for path in masks_paths])
# print(train_masks_y.shape)
# After loading the fisrt model, we ge the predicted masks as a validation set for testing the second mdel 
# masks_list = []
# for file in masks_files:
#     mask = np.load(file)
#     masks_list.append(mask)
# masks_array = np.array(masks_list)
# print(masks_array.shape)

# test_masks_preds = load_dataset(val_dl,model,DEVICE)[0]
# train_masks_y = load_dataset(train_dl,model,DEVICE)[1]
# print(test_masks_preds.shape)


# For the training part, I load all the masks in the database (actual masks)
print("Size of training images:")
BATCH_SIZE = 4;
train_dl = get_dl(get_train_ds(), BATCH_SIZE, nw=1)
train_masks_y = torch.empty((0, 3, 64, 64, 64)).to(device)
for x, y in train_dl:
    x = x.to(device)
    y = y.to(device) #.unsqueeze(1)
    train_masks_y = torch.cat((train_masks_y, y), dim=0)
#print(train_masks_y.shape)


# For the validation part, first, I give all the validation images to the 3DUNET and get the masks and then use these masks to compute the tumour volume for the survival days prediction
# 1) Loading test dataset
print("Size of testing images:")
val_dl = get_dl(get_val_ds(), BATCH_SIZE, nw=1)

# 1) Loading 3DUNET model
model = Res3DUNet(3, 3).to(DEVICE)
path = "/content/gdrive/MyDrive/ProjectDL/BraTS/3d_100e_adam_bce-dice"
model.load_state_dict(torch.load(f"{path}.pt"))
model.eval()

# 3) Geting output of the model
test_masks_preds = load_dataset(val_dl,model,DEVICE)[0]
#print(test_masks_preds.shape)





# Calculate the ratios of the volume of each tumor sub-region to the size of the whole brain. 
# Train ratios
num_pixels = np.zeros((train_masks_y.shape[0], 3))
num_nonzero = np.zeros((train_masks_y.shape[0], 3))
ratios = np.zeros((train_masks_y.shape[0], 3))
# Loop over all train masks
for i in range(train_masks_y.shape[0]):
    for j in range(train_masks_y.shape[1]):
        # Count the number of elements in the current mask
        num_pixels[i,j] += torch.numel(train_masks_y[i, j, :, :, :])
        #print(num_pixels[i,j])

        # Count the number of elements with value 3 in the current image
        num_nonzero[i,j] += (train_masks_y[i, j, :, :, :] != 0).sum() # torch.count_nonzero(torch.numel(train_masks_y[i, j, :, :, :]) 
        #print(num_nonzero[i,j])
    
        # Calculate the ratio of elements 
        ratios[i, j] = num_nonzero[i,j] / num_pixels[i,j]


train_age = pd.read_csv('/content/gdrive/MyDrive/ProjectDL/BraTS/SecondPart/train/survival_age.csv', header=None)
train_dataset_sd = np.concatenate(((ratios), (train_age)), axis=1)
print("Size of the training matrix after adding age feature. (3 tomour volume and one age features")
print(train_dataset_sd.shape)

train_survival_days = pd.read_csv('/content/gdrive/MyDrive/ProjectDL/BraTS/SecondPart/train/survival_days_category.csv', header=None)
# train_survival_days_onehot = Fun.one_hot(torch.from_numpy(train_survival_days.values), num_classes = 3)
# train_label_sd = train_survival_days_onehot
# train_label_sd = torch.squeeze(train_label_sd)
train_label_sd = torch.squeeze(torch.from_numpy(train_survival_days.values))

train_data_sd = []
for i in range(len(train_dataset_sd)):
   train_data_sd.append([train_dataset_sd[i], train_label_sd[i]])


# Test ratios
num_pixels = np.zeros((test_masks_preds.shape[0], 3))
num_nonzero = np.zeros((test_masks_preds.shape[0], 3))
ratios = np.zeros((test_masks_preds.shape[0], 3))
# Loop over all train masks
for i in range(test_masks_preds.shape[0]):
    for j in range(test_masks_preds.shape[1]):
        # Count the number of elements in the current mask
        num_pixels[i,j] += torch.numel(test_masks_preds[i, j, :, :, :])
        #print(num_pixels[i,j])

        # Count the number of elements with value 3 in the current image
        num_nonzero[i,j] += (test_masks_preds[i, j, :, :, :] != 0).sum() # torch.count_nonzero(torch.numel(train_masks_y[i, j, :, :, :]) 
        #print(num_nonzero[i,j])
    
        # Calculate the ratio of elements 
        ratios[i, j] = num_nonzero[i,j] / num_pixels[i,j]


test_age = pd.read_csv('/content/gdrive/MyDrive/ProjectDL/BraTS/SecondPart/val/survival_age.csv', header=None)
test_dataset_sd = np.concatenate(((ratios), (test_age)), axis=1)
print("Size of the validation matrix after adding age feature. (3 tomour volume and one age features")
print(test_dataset_sd.shape)

test_survival_days = pd.read_csv('/content/gdrive/MyDrive/ProjectDL/BraTS/SecondPart/val/survival_days_category.csv', header=None)
# test_survival_days_onehot = Fun.one_hot(torch.from_numpy(test_survival_days.values), num_classes = 3)
# test_label_sd = test_survival_days_onehot
# test_label_sd = torch.squeeze(test_label_sd)
test_label_sd = torch.squeeze(torch.from_numpy(test_survival_days.values))

test_data_sd = []
for i in range(len(test_dataset_sd)):
   test_data_sd.append([test_dataset_sd[i], test_label_sd[i]])

Size of training images:
images: 71, masks: 71 
torch.Size([71, 3, 64, 64, 64])
Size of testing images:
images: 46, masks: 46 
Size of the training matrix after adding age feature. (3 tomour volume and one age features
(71, 4)
Size of the validation matrix after adding age feature. (3 tomour volume and one age features
(46, 4)


In [8]:
# Survival Prediction MOdel

import torch
import torch.nn as nn


# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class SurvivalPRED(nn.Module):
    def __init__(self, in_channels=4, out_channels=3):
        super().__init__()
        self.fc1 = nn.Linear(in_channels, 64)
        self.bn1 = nn.BatchNorm1d(64)
        self.fc2 = nn.Linear(64, 64)
        self.bn2 = nn.BatchNorm1d(64)
        self.fc3 = nn.Linear(64, out_channels)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=0.5)

    def forward(self, x):
        x = self.fc1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc3(x)

        return x



In [12]:
# Train Survival Days Model
# Hyper Parameters
BATCH_SIZE = 8
EPOCHS = 500
LR = 0.0001
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_dl = get_dl(train_data_sd, BATCH_SIZE,nw=1)

import gc
torch.cuda.empty_cache()
gc.collect()

model = SurvivalPRED(4, 3).to(DEVICE)

# Define the loss function and optimizer
loss = nn.CrossEntropyLoss()
#opt = optim.SGD(model.parameters(), lr=LR, momentum=0.9) # Accuracy: 39% 
#opt = optim.Adagrad(model.parameters(), lr=LR) # Accuracy: 19%
#opt = optim.RMSprop(model.parameters(), lr=LR) # Accuracy: 48%
opt = optim.Adam(model.parameters(), lr=LR) # Accuracy: 48%


def train(model, epochs=1, training_loader=None, loss_fn=None, device=None,
          optimizer: torch.optim.Optimizer = None):
    for epoch in range(epochs):
        tq_dl = tqdm(training_loader)
        for idx, (image, mask) in enumerate(tq_dl):
            image, mask = image.to(device), mask.to(device)
            # forward pass
            out = model(image.to(torch.float32))
            loss = loss_fn(out, mask)
            # backward pass
            optimizer.zero_grad()
            loss.backward()

            # optimize
            optimizer.step()

            tq_dl.set_description(f"At epoch [{epoch + 1}/{epochs}]")
            tq_dl.set_postfix(loss=loss.item())  # acc, ...

# do not give in the format - the format will be .pt
def save(model, path):
    torch.save(model.state_dict(), f"{path}.pt")


# training 
train(model, epochs=EPOCHS, training_loader=train_dl, loss_fn=loss, device=DEVICE, optimizer=opt)

# # saving sample
# save(model,"/content/gdrive/MyDrive/ProjectDL/BraTS/3d_100e_adam_dice")
# print("saved the model...")

# validation
test_dl = get_dl(test_data_sd, BATCH_SIZE, False, nw=1)
with torch.no_grad():
    correct = 0
    total = 0
    for images, masks in test_dl:
        images = images.to(device)
        masks = masks.to(device)
        outputs = model(images.to(torch.float32))
        #print(outputs.data)
        _, predicted = torch.max(outputs.data, 1)
        total += masks.size(0)
        # print(predicted.shape)
        # print(masks)
        correct += (predicted == masks).sum().item()
        del images, masks, outputs


    print('Accuracy of the network on the test: {} %'.format(100 * correct / total)) 

At epoch [1/500]: 100%|██████████| 9/9 [00:00<00:00, 40.71it/s, loss=1.53]
At epoch [2/500]: 100%|██████████| 9/9 [00:00<00:00, 48.67it/s, loss=1.43]
At epoch [3/500]: 100%|██████████| 9/9 [00:00<00:00, 49.10it/s, loss=1.49]
At epoch [4/500]: 100%|██████████| 9/9 [00:00<00:00, 48.11it/s, loss=1.28]
At epoch [5/500]: 100%|██████████| 9/9 [00:00<00:00, 48.44it/s, loss=1.49]
At epoch [6/500]: 100%|██████████| 9/9 [00:00<00:00, 48.74it/s, loss=1.31]
At epoch [7/500]: 100%|██████████| 9/9 [00:00<00:00, 45.51it/s, loss=1.07]
At epoch [8/500]: 100%|██████████| 9/9 [00:00<00:00, 47.04it/s, loss=1.34]
At epoch [9/500]: 100%|██████████| 9/9 [00:00<00:00, 49.27it/s, loss=1.5]
At epoch [10/500]: 100%|██████████| 9/9 [00:00<00:00, 46.54it/s, loss=1.12]
At epoch [11/500]: 100%|██████████| 9/9 [00:00<00:00, 49.27it/s, loss=1.38]
At epoch [12/500]: 100%|██████████| 9/9 [00:00<00:00, 48.13it/s, loss=1.16]
At epoch [13/500]: 100%|██████████| 9/9 [00:00<00:00, 44.32it/s, loss=1.21]
At epoch [14/500]: 100

Accuracy of the network on the test: 47.82608695652174 %
