# Import Libraries and Packages

In [None]:
!pip install -qqq torchmetrics
!pip install -qqq pytorch-lightning
!pip install -qqq segmentation-models-pytorch

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import random
import glob

import gc
import time


from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import torchmetrics

import time
import warnings
warnings.filterwarnings("ignore")

%matplotlib inline
from IPython.display import Image
from skimage import io

from pprint import pprint

from sklearn.model_selection import train_test_split
import cv2
from sklearn.preprocessing import StandardScaler, normalize
from IPython.display import display

from PIL import Image

import torchvision
from torchvision import transforms

# Data Loading

In [None]:
img_data = pd.read_csv('../input/lgg-mri-segmentation/kaggle_3m/data.csv') # mask data?
img_data.head()

In [None]:
data_path = []
for sub_dir_path in glob.glob("/kaggle/input/lgg-mri-segmentation/kaggle_3m/"+"*"):
    try:
        dir_name = sub_dir_path.split('/')[-1]
        for filename in os.listdir(sub_dir_path):
            mask_path = sub_dir_path + '/' + filename
            data_path.extend([dir_name, mask_path])
    except Exception as e:
        print(e)

filenames = data_path[::2]
masks = data_path[1::2]

In [None]:
df = pd.DataFrame(data={"patient_id": filenames,"img_path": masks})
print(df.shape)
df

In [None]:
original_img = df[~df['img_path'].str.contains("mask")]
mask_img = df[df['img_path'].str.contains("mask")]
original_img, mask_img

In [None]:
imgs = sorted(original_img["img_path"].values, key=lambda x : int(x[89:-4]))
masks = sorted(mask_img["img_path"].values, key=lambda x : int(x[89:-9]))

idx = random.randint(0, len(imgs)-1)
print("Image path:", imgs[idx], "\nMask path:", masks[idx])

In [None]:
mri_df = pd.DataFrame({"patient_id": original_img.patient_id.values,"img_path": imgs,
                           'mask_path':masks})
mri_df

In [None]:
def get_diagnosis(img_path):
    value = np.max(cv2.imread(img_path))
    if value > 0 : 
        return 1
    else:
        return 0

In [None]:
mri_df['mask'] = mri_df['mask_path'].apply(lambda x: get_diagnosis(x))

mri_df['mask_path'] = mri_df['mask_path'].apply(lambda x: str(x))

print(mri_df.shape)
mri_df

In [None]:
mri_df.drop(columns=['patient_id'],inplace=True)

# Data Augmentation

In [None]:
positive_samples = mri_df[mri_df['mask_path'].apply(get_diagnosis) == 1].sample(n=1183, random_state=42)

tumor_count_before = mri_df['mask_path'].apply(get_diagnosis).sum()
non_tumor_count_before = len(mri_df) - tumor_count_before

transformed_data = []

for idx, row in positive_samples.iterrows():
    image_path = row['img_path']
    mask_path = row['mask_path']
    
    image = cv2.imread(image_path)
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    
    new_image_path = f"{idx}.tif"
    new_mask_path = f"{idx}_mask.tif"
    
    scale_factor = random.uniform(0.9, 1.1)
    image = cv2.resize(image, None, fx=scale_factor, fy=scale_factor)
    mask = cv2.resize(mask, None, fx=scale_factor, fy=scale_factor, interpolation=cv2.INTER_NEAREST)
    
    angle = random.uniform(-180, 180)
    rows, cols, _ = image.shape
    M = cv2.getRotationMatrix2D((cols / 2, rows / 2), angle, 1)
    image = cv2.warpAffine(image, M, (cols, rows))
    mask = cv2.warpAffine(mask, M, (cols, rows), flags=cv2.INTER_NEAREST)
    
    if random.random() < 0.5:
        image = cv2.transpose(image)
        mask = cv2.transpose(mask)
        
    cv2.imwrite(new_image_path, image)
    cv2.imwrite(new_mask_path, mask)
    
    transformed_data.append({'img_path': new_image_path, 'mask_path': new_mask_path})

mri_df = pd.concat([mri_df, pd.DataFrame(transformed_data)], ignore_index=True)

tumor_count_after = mri_df['mask_path'].apply(get_diagnosis).sum()
non_tumor_count_after = len(mri_df) - tumor_count_after

print("Total images with tumors before transformations:", tumor_count_before)
print("Total images without tumors before transformations:", non_tumor_count_before)
print("Total images with tumors after transformations:", tumor_count_after)
print("Total images without tumors after transformations:", non_tumor_count_after)

# Data Loading

In [None]:
image_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
])

mask_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    ])

In [None]:
def adjust_data(img, mask):
    img = img / 255.
    mask = mask / 255.
    mask[mask > 0.5] = 1.0
    mask[mask <= 0.5] = 0.0
    
    return (img, mask)

In [None]:
class MyDataset(Dataset):
    def __init__(self, df= mri_df, 
                 adjust_data = adjust_data, 
                 image_transform=image_transform, mask_transform=mask_transform):
        self.df = df
        self.image_transform = image_transform
        self.mask_transform = mask_transform
        self.adjust_data= adjust_data

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        image_path = self.df.loc[idx, 'img_path']
        mask_path = self.df.loc[idx, 'mask_path']

        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path)
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)

        image, mask = self.adjust_data(image, mask)

        if self.image_transform:
            image = self.image_transform(image).float()

        if self.mask_transform:
            mask = self.mask_transform(mask)
        return image, mask

In [None]:
def prepare_loaders(df= mri_df,
                    train_num= int(mri_df.shape[0] * .6), 
                    valid_num= int(mri_df.shape[0] * .8), 
                    bs = 16):
    
    train = df[:train_num].reset_index(drop=True)
    valid = df[train_num : valid_num].reset_index(drop=True)    
    test  = df[valid_num:].reset_index(drop=True)

    train_ds = MyDataset(df = train)
    valid_ds = MyDataset(df = valid)
    test_ds = MyDataset(df = test)

    train_loader = DataLoader(train_ds, batch_size = bs, num_workers = os.cpu_count(), shuffle = True)
    valid_loader = DataLoader(valid_ds, batch_size = bs, num_workers = os.cpu_count(), shuffle = False)
    test_loader = DataLoader(test_ds, batch_size = 4, num_workers = os.cpu_count(), shuffle = True)
    
    print("DataLoader Completed")
    
    return train_loader, valid_loader, test_loader

In [None]:
train_loader, valid_loader, test_loader = prepare_loaders(df= mri_df,
                                                            train_num= int(mri_df.shape[0] * .65), 
                                                            valid_num= int(mri_df.shape[0] * .85), 
                                                            bs = 16)

In [None]:
data = next(iter(train_loader))
data[0].shape, data[1].shape

In [None]:
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
device

# Attention UNet Model

In [None]:
class ConvBlock(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()

        # number of input channels is a number of filters in the previous layer
        # number of output channels is a number of filters in the current layer
        # "same" convolutions
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class UpConv(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(UpConv, self).__init__()

        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.up(x)
        return x


class AttentionBlock(nn.Module):
    """Attention block with learnable parameters"""

    def __init__(self, F_g, F_l, n_coefficients):
        """
        :param F_g: number of feature maps (channels) in previous layer
        :param F_l: number of feature maps in corresponding encoder layer, transferred via skip connection
        :param n_coefficients: number of learnable multi-dimensional attention coefficients
        """
        super(AttentionBlock, self).__init__()

        self.W_gate = nn.Sequential(
            nn.Conv2d(F_g, n_coefficients, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(n_coefficients)
        )

        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, n_coefficients, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(n_coefficients)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(n_coefficients, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, gate, skip_connection):
        """
        :param gate: gating signal from previous layer
        :param skip_connection: activation from corresponding encoder layer
        :return: output activations
        """
        g1 = self.W_gate(gate)
        x1 = self.W_x(skip_connection)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        out = skip_connection * psi
        return out


class AttentionUNet(nn.Module):

    def __init__(self, img_ch=3, output_ch=1):
        super(AttentionUNet, self).__init__()

        self.MaxPool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = ConvBlock(img_ch, 64)
        self.Conv2 = ConvBlock(64, 128)
        self.Conv3 = ConvBlock(128, 256)
        self.Conv4 = ConvBlock(256, 512)
        self.Conv5 = ConvBlock(512, 1024)

        self.Up5 = UpConv(1024, 512)
        self.Att5 = AttentionBlock(F_g=512, F_l=512, n_coefficients=256)
        self.UpConv5 = ConvBlock(1024, 512)

        self.Up4 = UpConv(512, 256)
        self.Att4 = AttentionBlock(F_g=256, F_l=256, n_coefficients=128)
        self.UpConv4 = ConvBlock(512, 256)

        self.Up3 = UpConv(256, 128)
        self.Att3 = AttentionBlock(F_g=128, F_l=128, n_coefficients=64)
        self.UpConv3 = ConvBlock(256, 128)

        self.Up2 = UpConv(128, 64)
        self.Att2 = AttentionBlock(F_g=64, F_l=64, n_coefficients=32)
        self.UpConv2 = ConvBlock(128, 64)

        self.Conv = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        """
        e : encoder layers
        d : decoder layers
        s : skip-connections from encoder layers to decoder layers
        """
        e1 = self.Conv1(x)

        e2 = self.MaxPool(e1)
        e2 = self.Conv2(e2)

        e3 = self.MaxPool(e2)
        e3 = self.Conv3(e3)

        e4 = self.MaxPool(e3)
        e4 = self.Conv4(e4)

        e5 = self.MaxPool(e4)
        e5 = self.Conv5(e5)

        d5 = self.Up5(e5)

        s4 = self.Att5(gate=d5, skip_connection=e4)
        d5 = torch.cat((s4, d5), dim=1) # concatenate attention-weighted skip connection with previous layer output
        d5 = self.UpConv5(d5)

        d4 = self.Up4(d5)
        s3 = self.Att4(gate=d4, skip_connection=e3)
        d4 = torch.cat((s3, d4), dim=1)
        d4 = self.UpConv4(d4)

        d3 = self.Up3(d4)
        s2 = self.Att3(gate=d3, skip_connection=e2)
        d3 = torch.cat((s2, d3), dim=1)
        d3 = self.UpConv3(d3)

        d2 = self.Up2(d3)
        s1 = self.Att2(gate=d2, skip_connection=e1)
        d2 = torch.cat((s1, d2), dim=1)
        d2 = self.UpConv2(d2)

        out = self.Conv(d2)

        return out


# Evaluation Metrics

In [None]:
def accuracy_function(target, predicted):
    true_positives = np.sum(predicted * target)
    true_negatives = np.sum((1 - predicted) * (1 - target))
    total_pixels = len(predicted)
    acc = (true_positives + true_negatives) / total_pixels
    return acc.item()

def precision_function(target, predicted):
    true_positives = np.sum(predicted * target)
    false_positives = np.sum(predicted * (1 - target))
    precision = true_positives / (true_positives + false_positives + 1e-7)  # Add a small epsilon value to avoid division by zero
    return precision.item()

def recall_function(target, predicted):
    true_positives = np.sum(predicted * target)
    false_negatives = np.sum((1 - predicted) * target)
    recall = true_positives / (true_positives + false_negatives + 1e-7)  # Add a small epsilon value to avoid division by zero
    return recall.item()

def dice_coeff_binary(y_true, y_pred):
    """Compute Dice coefficient for binary segmentation."""
    eps = 1e-9
    inter = np.dot(y_pred, y_true)
    union = np.sum(y_pred) + np.sum(y_true)
    return ((2 * inter + eps) / (union + eps)).item()

def specificity_function(y_true, y_pred):
    """Compute specificity for binary classification."""
    true_negatives = np.sum((1 - y_pred) * (1 - y_true))
    actual_negatives = np.sum(1 - y_true)
    return true_negatives / actual_negatives

def intersection_over_union(y_true, y_pred):
    """Compute Intersection over Union for binary segmentation."""
    eps = 1e-9
    intersection = np.sum(y_pred * y_true)
    union = np.sum(y_pred) + np.sum(y_true) - intersection
    iou = (intersection + eps) / (union + eps)
    return iou.item()


# Model Training

In [None]:
model = AttentionUNet(3, 1)
model.cuda()

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

criterion = torch.nn.BCEWithLogitsLoss()

num_epochs = 20
train_losses = []
val_losses = []
train_dices = []
val_dices = []

best_val_loss = float('inf') 
for epoch in range(num_epochs):
    model.train()
    total_accuracy = 0.0
    total_precision = 0.0
    total_recall = 0.0
    total_dice = 0.0
    total_specificity = 0.0
    total_iou = 0.0

    for i, (images, masks) in enumerate(train_loader):
        images = images.cuda()
        masks = masks.cuda()

        optimizer.zero_grad()
        outputs = model(images)

        threshold = 0.5
        outputs_thresholded = (outputs > threshold).float()

        outputs_flat = outputs.view(-1)
        masks_flat = masks.view(-1)

        loss = criterion(outputs_flat, masks_flat)
        print(f"Epoch {epoch+1}/{num_epochs}, Batch {i+1}, Loss: {loss}")

        loss.backward()
        optimizer.step()

        image_index = 0 
        single_mask = masks[image_index].cpu().detach().numpy()
        single_output = outputs_thresholded[image_index].cpu().detach().numpy()


        outputs_thresholded = outputs_thresholded.cpu().detach().numpy().reshape(-1)
        masks = masks.cpu().detach().numpy().reshape(-1)

        accuracy = accuracy_function(masks, outputs_thresholded)
        precision = precision_function(masks, outputs_thresholded)
        recall = recall_function(masks, outputs_thresholded)
        dice_score = dice_coeff_binary(masks, outputs_thresholded)
        specificity = specificity_function(masks, outputs_thresholded)
        iou = intersection_over_union(masks, outputs_thresholded)

        total_accuracy += accuracy
        total_precision += precision
        total_recall += recall
        total_dice += dice_score
        total_specificity += specificity
        total_iou += iou

        print(f"Accuracy: {accuracy}, Precision: {precision}, Recall: {recall}, Dice Score: {dice_score}, Specificity: {specificity}, IOU: {iou}")

    avg_accuracy = total_accuracy / len(train_loader)
    avg_precision = total_precision / len(train_loader)
    avg_recall = total_recall / len(train_loader)
    avg_dice = total_dice / len(train_loader)
    avg_specificity = total_specificity / len(train_loader)
    avg_iou = total_iou / len(train_loader)

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss}, Average Accuracy: {avg_accuracy}, Average Precision: {avg_precision}, Average Recall: {avg_recall}, Average Dice: {avg_dice}, Average Specificity: {avg_specificity}, Average IOU: {avg_iou}")

    train_losses.append(loss)
    train_dices.append(avg_dice)
    
    model.eval()
    total_val_accuracy = 0.0
    total_val_precision = 0.0
    total_val_recall = 0.0
    total_val_dice = 0.0
    total_val_specificity = 0.0
    total_val_loss = 0.0
    total_val_iou = 0.0

    with torch.no_grad():
        for i, (images, masks) in enumerate(valid_loader):
            images = images.cuda()
            masks = masks.cuda()

            outputs = model(images)
            threshold = 0.5
            outputs_thresholded = (outputs > threshold).float()

            outputs_flat = outputs.view(-1)
            masks_flat = masks.view(-1)

            loss = criterion(outputs_flat, masks_flat)
            total_val_loss += loss.item()

            outputs_thresholded = outputs_thresholded.cpu().detach().numpy().reshape(-1)
            masks = masks.cpu().detach().numpy().reshape(-1)

            accuracy = accuracy_function(masks, outputs_thresholded)
            precision = precision_function(masks, outputs_thresholded)
            recall = recall_function(masks, outputs_thresholded)
            dice_score = dice_coeff_binary(masks, outputs_thresholded)
            specificity = specificity_function(masks, outputs_thresholded)
            iou = intersection_over_union(masks, outputs_thresholded)

            total_val_accuracy += accuracy
            total_val_precision += precision
            total_val_recall += recall
            total_val_dice += dice_score
            total_val_specificity += specificity
            total_val_iou += iou

            print(f"Epoch {epoch+1}/{num_epochs}, Validation Batch {i+1}/{len(valid_loader)}, Loss: {loss.item()}")

        avg_val_accuracy = total_val_accuracy / len(valid_loader)
        avg_val_precision = total_val_precision / len(valid_loader)
        avg_val_recall = total_val_recall / len(valid_loader)
        avg_val_dice = total_val_dice / len(valid_loader)
        avg_val_specificity = total_val_specificity / len(valid_loader)
        avg_val_loss = total_val_loss / len(valid_loader)
        avg_val_iou = total_val_iou / len(valid_loader)
        
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), '/kaggle/working/model_best.ckpt')
            print(f"Epoch {epoch+1}/{num_epochs}, Validation Loss improved. Model saved.")

        val_losses.append(avg_val_loss)
        val_dices.append(avg_val_dice)

        print(f"Epoch {epoch+1}/{num_epochs}, Validation Loss: {avg_val_loss}, Average Accuracy: {avg_val_accuracy}, Average Precision: {avg_val_precision}, Average Recall: {avg_val_recall}, Average Dice: {avg_val_dice}, Average Specificity: {avg_val_specificity}, Average IOU: {avg_val_iou}")

# Visualizations

In [None]:
import numpy as np
import matplotlib.pyplot as plt

epochs = range(1, num_epochs + 1)

train_losses_np = np.array([loss.cpu().detach().numpy() for loss in train_losses])
val_losses_np = np.array([loss for loss in val_losses])

plt.figure(figsize=(10, 5))
plt.plot(epochs, train_losses_np, label='Training Loss')
plt.plot(epochs, val_losses_np, label='Validation Loss')
plt.title('Training and Validation Loss Curves')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.xticks(epochs)
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
epochs = range(1, num_epochs + 1)



plt.figure(figsize=(10, 5))
plt.plot(epochs, train_dices, label='Training DICE')
plt.plot(epochs, val_dices, label='Validation DICE')
plt.title('Training and Validation DICE Scores')
plt.xlabel('Epochs')
plt.ylabel('DICE Score')
plt.xticks(epochs)
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
check_path = '/kaggle/working/model_best.ckpt'
check_path

In [None]:
model_state_dict = torch.load(check_path)

model.load_state_dict(model_state_dict)

# Model Testing

In [None]:
    model.eval()
    total_test_accuracy = 0.0
    total_test_precision = 0.0
    total_test_recall = 0.0
    total_test_dice = 0.0
    total_test_specificity = 0.0
    total_test_loss = 0.0
    total_test_iou = 0.0

    with torch.no_grad():
        for i, (images, masks) in enumerate(test_loader):
            images = images.cuda()
            masks = masks.cuda()

            outputs = model(images)
            threshold = 0.5
            outputs_thresholded = (outputs > threshold).float()

            outputs_flat = outputs.view(-1)
            masks_flat = masks.view(-1)

            loss = criterion(outputs_flat, masks_flat)
            total_test_loss += loss.item()

            outputs_thresholded = outputs_thresholded.cpu().detach().numpy().reshape(-1)
            masks = masks.cpu().detach().numpy().reshape(-1)

            accuracy = accuracy_function(masks, outputs_thresholded)
            precision = precision_function(masks, outputs_thresholded)
            recall = recall_function(masks, outputs_thresholded)
            dice_score = dice_coeff_binary(masks, outputs_thresholded)
            specificity = specificity_function(masks, outputs_thresholded)
            iou = intersection_over_union(masks, outputs_thresholded)

            total_test_accuracy += accuracy
            total_test_precision += precision
            total_test_recall += recall
            total_test_dice += dice_score
            total_test_specificity += specificity
            total_iou += iou

        avg_test_accuracy = total_test_accuracy / len(test_loader)
        avg_test_precision = total_test_precision / len(test_loader)
        avg_test_recall = total_test_recall / len(test_loader)
        avg_test_dice = total_test_dice / len(test_loader)
        avg_test_specificity = total_test_specificity / len(test_loader)
        avg_test_loss = total_test_loss / len(test_loader)
        avg_test_iou = total_iou / len(test_loader)
        
        print(f"Test Loss: {avg_test_loss}, Average Accuracy: {avg_test_accuracy}, Average Precision: {avg_test_precision}, Average Recall: {avg_test_recall}, Average Dice: {avg_test_dice}, Average Specificity: {avg_test_specificity}, Average IOU: {avg_test_iou}")

# Sample Segmentations

In [None]:
import matplotlib.pyplot as plt

model.eval()

with torch.no_grad():
    for i, (images, masks) in enumerate(test_loader):
        images = images.cuda()
        masks = masks.cuda()

        outputs = model(images)
        threshold = 0.5
        outputs_thresholded = (outputs > threshold).float()

        for j in range(min(images.shape[0], 50)):  
            plt.figure(figsize=(15, 5))

            plt.subplot(1, 3, 1)
            plt.imshow(images[j].cpu().permute(1, 2, 0))
            plt.title("Image")
            plt.axis("off")

            plt.subplot(1, 3, 2)
            plt.imshow(masks[j].cpu().squeeze(), cmap='gray')
            plt.title("Ground truth")
            plt.axis("off")

            plt.subplot(1, 3, 3)
            plt.imshow(outputs_thresholded[j].cpu().squeeze(), cmap='gray')
            plt.title("Predicted mask")
            plt.axis("off")

            plt.show()

        if i * test_loader.batch_size + images.shape[0] >= 50:
            break