In [None]:
from glob import glob
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
from tqdm import tqdm


from google.colab import drive
drive.mount('/content/gdrive')
import os

# clean ds
CLEAN_DS_PATH = '/content/gdrive/MyDrive/ProjectDL/BraTS'

# train
CLEAN_TRAIN_PATH = f'{CLEAN_DS_PATH}/train'
CLEAN_TRAIN_IMG_PATH = f'{CLEAN_TRAIN_PATH}/images'
CLEAN_TRAIN_MSK_PATH = f'{CLEAN_TRAIN_PATH}/masks'
print(CLEAN_TRAIN_IMG_PATH)
# val
CLEAN_VAL_PATH = f'{CLEAN_DS_PATH}/val'
CLEAN_VAL_IMG_PATH = f'{CLEAN_VAL_PATH}/images'
CLEAN_VAL_MSK_PATH = f'{CLEAN_VAL_PATH}/masks'


# MAking dataset ready
class SimpleLogger:

    def __init__(self, debug=True):
        self.debug = debug

    def enable_debug(self):
        self.debug = True

    def disable_debug(self):
        self.debug = False

    def log(self, message, condition=True):
        if self.debug and condition:
            print(message)


logger = SimpleLogger(debug=True)

def to_categorical(y, n_classes):
    return np.eye(n_classes, dtype="uint8")[y]


class BraTSDataset(Dataset):
    def log(self, message):
        logger.log(message, condition=self.debug)

    def __init__(self, images_path, masks_path, transform=None, one_hot_target=True, debug=True):
        
        # data_files_images = sorted(os.listdir(images_path))
        # self.images = []
        # for file in data_files_images:
        #   if file.endswith('.npy'):
        #     file_path_images = os.path.join(images_path, file)
        #     image = np.load(file_path_images)
        #     self.images.append(image)

        # data_files_masks = sorted(os.listdir(masks_path))
        # self.masks = []
        # for file in data_files_masks:
        #   if file.endswith('.npy'):
        #     file_path_masks = os.path.join(masks_path, file)
        #     mask = np.load(file_path_masks)
        #     self.masks.append(mask)
              
        self.images = sorted(glob(f"{images_path}/*.npy"))
        self.masks = sorted(glob(f"{masks_path}/*.npy"))
        self.transform = transform
        self.one_hot_target = one_hot_target
        self.debug = debug
        self.log(f"images: {len(self.images)}, masks: {len(self.masks)} ")
        assert len(self.images) == len(self.masks), "images and masks lengths are not the same!"

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # if torch.is_tensor(idx):
        #     idx = idx.tolist()

        image = np.load(self.images[idx])
        mask = np.load(self.masks[idx])
        # resizing image and mask, experimental
        image = image[::2,::2,::2]
        mask = mask[::2,::2,::2]
        if self.one_hot_target:
            mask = to_categorical(mask, 4)
            mask = mask[::, ::, ::, 1::]  # discard background

        image = torch.from_numpy(image).float()  # .double()
        mask = torch.from_numpy(mask)  # .float() #.long()

        return image.permute((3, 0, 1, 2)), mask.permute((3, 0, 1, 2))


def get_dl(dataset, batch_size=4, pm=True, nw=1):
    return DataLoader(dataset, batch_size, shuffle=True, pin_memory=pm, num_workers=nw, )

def get_train_ds():
    return BraTSDataset(CLEAN_TRAIN_IMG_PATH, CLEAN_TRAIN_MSK_PATH)


def get_val_ds():
    return BraTSDataset(CLEAN_VAL_IMG_PATH, CLEAN_VAL_MSK_PATH)

# this is for testing only
# if __name__ == '__main__':
#     train_ds = BraTSDataset(CLEAN_TRAIN_IMG_PATH, CLEAN_TRAIN_MSK_PATH)
#     print(train_ds[0][0].shape)
#     print(train_ds[0][1].shape)
#     dl = get_dl(train_ds, batch_size=1)
#     print("OK")


Mounted at /content/gdrive
/content/gdrive/MyDrive/ProjectDL/BraTS/train/images


In [None]:
# Residual 3DUNet model

import torch
import torch.nn as nn


# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        # the output image will be (n + 2p — f + 1) * (n + 2p — f + 1) where p =1 in this case.
        # Convolutional Layer
        self.conv = nn.Sequential(
          nn.BatchNorm3d(in_channels),
          nn.ReLU(inplace=True),
          nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride),
          nn.BatchNorm3d(out_channels),
          nn.ReLU(inplace=True),
          nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1, stride=1))

        # Identity Mapping
        self.shortcut = nn.Conv3d(in_channels, out_channels, kernel_size=1, padding=0, stride=stride)

    def forward(self, inputs):
        x = self.conv(inputs) 
        s = self.shortcut(inputs)       
        skip = x + s
        return skip

class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.upsample = nn.Upsample(scale_factor=2, mode="trilinear", align_corners=True) #  mode="trilinear"
        self.residual = ResidualBlock(in_channels + out_channels, out_channels)

    def forward(self, inputs, skip):
        x = self.upsample(inputs)
        x = torch.cat([x, skip], axis=1)
        x = self.residual(x)
        return x

class Res3DUNet(nn.Module):
    # the default dataset has 3 channels of data ->  T1CE, T2, FLAIR
    # The output has background, NCR/NET, ED, ET 

    def __init__(self, in_channels=3, out_channels=4):
        super().__init__()

        # Encoder 1 
        self.conv1 = nn.Conv3d(in_channels, 64, kernel_size=3, padding=1)
        self.bn1   = nn.BatchNorm3d(64)
        self.relu  = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv3d(64, 64, kernel_size=3, padding=1)
        #Identity Mapping
        self.conv3 = nn.Conv3d(in_channels, 64, kernel_size=1, padding=0)
        
        # Encoder 2 
        self.r2 = ResidualBlock(64, 128, stride=2)
        # Encoder 3 
        self.r3 = ResidualBlock(128, 256, stride=2)
        # Encoder 4 
        self.r4 = ResidualBlock(256, 512, stride=2)
        # Bridge
        self.r5 = ResidualBlock(512, 1024, stride=2)
        # Decoder 1
        self.d1 = DecoderBlock(1024, 512)
        # Decoder 2
        self.d2 = DecoderBlock(512, 256)
        # Decoder 3
        self.d3 = DecoderBlock(256, 128)
        # Decoder 4
        self.d4 = DecoderBlock(128, 64)

        # Output 
        self.output = nn.Conv3d(64, out_channels, kernel_size=1, padding=0)


    def forward(self, inputs):
        # Encoder 1 
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        s = self.conv3(inputs)
        skip1 = x + s
        # Encoder 2 
        skip2 = self.r2(skip1)
        # Encoder 3 
        skip3 = self.r3(skip2)
        # Encoder 4 
        skip4 = self.r4(skip3)
        # Bridge 
        b = self.r5(skip4)
        # Decoder 1
        d1 = self.d1(b, skip4)
        # Decoder 2
        d2 = self.d2(d1, skip3)
        # Decoder 3
        d3 = self.d3(d2, skip2)
        # Decoder 4
        d4 = self.d4(d3, skip1)
        # output 
        output = self.output(d4)

        return output


# def _test_Res3dUNet():
#     x = torch.randn((1, 3, 128, 128, 128)).to(device)
#     print(x.shape)
#     model = Res3DUNet(in_channels=3).to(device)
#     out = model(x)
#     print(out.shape)

# if __name__ == '__main__':
#     _test_Res3dUNet()



# class DoubleConv3D(nn.Module):
#     def __init__(self, in_channels, out_channels):
#         super().__init__()
#         # 1 + (L - l + 2P)/s
#         self.conv = nn.Sequential(
#             # 1 + out - 3 + 2 = out
#             nn.Conv3d(in_channels, out_channels, 3, stride=1, padding=1, bias=False),
#             nn.BatchNorm3d(out_channels),
#             nn.ReLU(inplace=True),
#             nn.Conv3d(out_channels, out_channels, 3, stride=1, padding=1, bias=False),
#             nn.BatchNorm3d(out_channels),
#             nn.ReLU(inplace=True)
#         )

#     def forward(self, inputs):
#         return self.conv(inputs)


# class Base3DUNet(nn.Module):
#     # the default dataset has 3 channels of data ->  T1CE, T2, FLAIR
#     # The output has background, NCR/NET, ED, ET
#     def __init__(self, in_channels=3, out_channels=4, features=[64, 128, 256, 512]):
#         super().__init__()
#         # 1 + (L - l + 2P)/s
#         # 1 + (L - 2)/2 = L
#         self.pooling = nn.MaxPool3d(kernel_size=2, stride=2)

#         self.downs = nn.ModuleList()
#         self.ups = nn.ModuleList()

#         # Each Layer - number of filters , see UNet architecture
#         input_channels = in_channels

#         for feature in features:
#             self.downs.append(DoubleConv3D(input_channels, feature))
#             input_channels = feature

#         for feature in reversed(features):
#             self.ups.append(nn.ConvTranspose3d(feature * 2, feature, kernel_size=2, stride=2))
#             self.ups.append(DoubleConv3D(feature * 2, feature))

#         self.bottleneck = DoubleConv3D(features[-1], features[-1] * 2)  # this connects downs to ups

#         self.output_conv = nn.Conv3d(features[0], out_channels, kernel_size=1)  # last layer - feature compression

#     def forward(self, inputs):
#         skips = []

#         x = inputs
#         for down in self.downs:
#             x = down(x)
#             skips.append(x)
#             x = self.pooling(x)

#         x = self.bottleneck(x)

#         for idx in range(0, len(self.ups), 2):  # going up 2 steps, as each step has convTranspose and DoubleConv
#             x = self.ups[idx](x)  # up sampling w/ the convTranspose
#             skip_connection = skips.pop()  # give me the last skip I added, to add it first on the ups
#             x = torch.cat((skip_connection, x), dim=1)  # dim 0 is batch, dim 1 is the channels
#             x = self.ups[idx + 1](x)  # double conv

#         return self.output_conv(x)

In [None]:
# Loss Functions
# DSC = 2 * |A intersect B| / (|A| + |B|)
class DiceLoss(nn.Module):
    """Calculate dice loss."""

    def __init__(self, eps: float = 1e-9):
        super(DiceLoss, self).__init__()
        self.eps = eps

    def forward(self,
                logits: torch.Tensor,
                targets: torch.Tensor) -> torch.Tensor:
        num = targets.size(0)
        probability = torch.sigmoid(logits)
        probability = probability.view(num, -1)
        targets = targets.view(num, -1)
        assert (probability.shape == targets.shape)

        intersection = 2.0 * (probability * targets).sum()
        union = probability.sum() + targets.sum()
        dice_score = (intersection + self.eps) / union
        # print("intersection", intersection, union, dice_score)
        return 1.0 - dice_score


class BCEDiceLoss(nn.Module):
    """Compute objective loss: BCE loss + DICE loss."""

    def __init__(self):
        super(BCEDiceLoss, self).__init__()
        self.bce = nn.BCEWithLogitsLoss()
        self.dice = DiceLoss()

    def forward(self,
                logits: torch.Tensor,
                targets: torch.Tensor) -> torch.Tensor:
        assert (logits.shape == targets.shape)
        dice_loss = self.dice(logits, targets)
        bce_loss = self.bce(logits, targets)

        return dice_loss # bce_loss + dice_loss

In [None]:
# All the previous cells should run first
# Training Residual Unet 

# Hyper Parameters
BATCH_SIZE = 4
EPOCHS = 100
LR = 0.0001
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_dl = get_dl(get_train_ds(), BATCH_SIZE,nw=1)

import gc
torch.cuda.empty_cache()
gc.collect()

model = Res3DUNet(3, 3).to(DEVICE)
# model = Base3DUNet(3, 3, features=[64, 128, 256, 512]).to(DEVICE)
# print(model)

print(f"total parameters = {sum(p.numel() for p in model.parameters())}")
print(f"total learnable parameters = {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

# loss functons
opt = torch.optim.Adam(model.parameters(), lr=LR)
loss = BCEDiceLoss()
#loss = DiceLoss()


def train(model, epochs=1, training_loader=None, loss_fn=None, device=None,
          optimizer: torch.optim.Optimizer = None):
    for epoch in range(epochs):
        tq_dl = tqdm(training_loader)
        for idx, (image, mask) in enumerate(tq_dl):
            image, mask = image.to(device), mask.to(device)
            # forward pass
            out = model(image)
            loss = loss_fn(out, mask.float())
            # backward pass
            optimizer.zero_grad()
            loss.backward()

            # optimize
            optimizer.step()

            tq_dl.set_description(f"At epoch [{epoch + 1}/{epochs}]")
            tq_dl.set_postfix(loss=loss.item())  # acc, ...

# do not give in the format - the format will be .pt
def save(model, path):
    torch.save(model.state_dict(), f"{path}.pt")


# training 
EPOCHS = 100
train(model, epochs=EPOCHS, training_loader=train_dl, loss_fn=loss, device=DEVICE, optimizer=opt)

# saving sample
save(model,"/content/gdrive/MyDrive/ProjectDL/BraTS/3d_100e_adam_dice")
print("saved the model...")


images: 221, masks: 221 
total parameters = 95882563
total learnable parameters = 95882563


At epoch [1/100]: 100%|██████████| 56/56 [04:15<00:00,  4.57s/it, loss=0.788]
At epoch [2/100]: 100%|██████████| 56/56 [02:12<00:00,  2.37s/it, loss=0.535]
At epoch [3/100]: 100%|██████████| 56/56 [02:08<00:00,  2.29s/it, loss=0.321]
At epoch [4/100]: 100%|██████████| 56/56 [02:10<00:00,  2.34s/it, loss=0.323]
At epoch [5/100]: 100%|██████████| 56/56 [02:08<00:00,  2.29s/it, loss=0.607]
At epoch [6/100]: 100%|██████████| 56/56 [02:11<00:00,  2.35s/it, loss=0.567]
At epoch [7/100]: 100%|██████████| 56/56 [02:08<00:00,  2.30s/it, loss=0.766]
At epoch [8/100]: 100%|██████████| 56/56 [02:11<00:00,  2.35s/it, loss=0.275]
At epoch [9/100]: 100%|██████████| 56/56 [02:14<00:00,  2.39s/it, loss=0.319]
At epoch [10/100]: 100%|██████████| 56/56 [02:11<00:00,  2.35s/it, loss=0.258]
At epoch [11/100]: 100%|██████████| 56/56 [02:09<00:00,  2.31s/it, loss=0.559]
At epoch [12/100]: 100%|██████████| 56/56 [02:10<00:00,  2.32s/it, loss=0.663]
At epoch [13/100]: 100%|██████████| 56/56 [02:08<00:00,  2.30

In [None]:
# # Validation
# # Run the first two cell first 
# DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# # do not give in the format - the format will be .pt
# def load(model, path, eval=True):
#     model.load_state_dict(torch.load(f"{path}.pt"))
#     if eval:
#         model.eval()


# def check_accuracy(data_loader, model, device="cuda"):
#     num_correct = 0
#     num_pixels = 0
#     dice_score = 0
#     model.eval()

#     with torch.no_grad():
#         for x, y in data_loader:
#             x = x.to(device)
#             y = y.to(device) #.unsqueeze(1)
#             preds = torch.sigmoid(model(x))
#             preds = (preds > 0.5).float()
#             num_correct += (preds == y).sum()
#             num_pixels += torch.numel(preds)
#             dice_score += (2 * (preds * y).sum()) / (
#                     (preds + y).sum() + 1e-8
#             )

#     print(
#         f"Results: {num_correct}/{num_pixels} with accuracy {num_correct / num_pixels * 100:.4f}"
#     )
#     print(f"Dice score: {dice_score / len(data_loader)}")
#     model.train()

# BATCH_SIZE = 4;
# # loading sample
# val_dl = get_dl(get_val_ds(), BATCH_SIZE, nw=1)

# # Loading model
# model = Res3DUNet(3, 3).to(DEVICE)
# load(model,"/content/gdrive/MyDrive/ProjectDL/BraTS/3d_100e_adam_dice")

# check_accuracy(val_dl,model,DEVICE)
# #BaseUNET
# # 50 DICE Results: 115540617/116391936 with accuracy 99.2686 Dice score: 0.7475918531417847
# # 100 DICE Results: Results: 115510440/116391936 with accuracy 99.2427 Dice score: 0.7560997009277344
# # 100 BCE-DICE Results: 115626234/116391936 with accuracy 99.3421 Dice score: 0.7765376567840576

# #ResUNET
# # 100 BCE-DICE Results: Results: 115570441/116391936 with accuracy 99.2942 Dice score: 0.7470014691352844
# # 100 DICE Results: 115284011/116391936 with accuracy 99.0481 Dice score: 0.6380645036697388