# Imports

In [None]:
%%capture
!pip install segmentation-models-pytorch
!pip install torchinfo

In [None]:
# Data handling
import pandas as pd
import numpy as np

# Data visualization
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import cv2

# Torch
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import segmentation_models_pytorch as smp
from torchinfo import summary

# os
import os

# Path
from pathlib import Path

# tqdm
from tqdm.auto import tqdm

from glob import iglob, glob
from itertools import chain

# warnings
import warnings
warnings.filterwarnings("ignore")

import random as rnd

import shutil

In [None]:
BATCH_SIZE = 64
NUM_WORKERS = os.cpu_count()

# Num of samples, that will be used for train out model
K_SAMPLES_TRAIN = 5

NUM_CLASSES = 2

# CUDA
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Dataloaders

## Kaggle API Setup

In [None]:
!pip install -q kaggle

In [None]:
from google.colab import files
files.upload()

In [None]:
!mkdir -p ~/.kaggle
!mv kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

## Download GlaS

In [None]:
!kaggle datasets download -d sani84/glasmiccai2015-gland-segmentation -p /content/glas

In [None]:
!unzip /content/glas/glasmiccai2015-gland-segmentation.zip -d /content/glas/

In [None]:
!rm /content/glas/glasmiccai2015-gland-segmentation.zip

## Defining train for meta-test

In [None]:
!rm /content/glas/Warwick_QU_Dataset/Grade.csv

In [None]:
def get_glas_train():
    rnd.seed(42)
    global K_SAMPLES_TRAIN

    files = glob("/content/glas/Warwick_QU_Dataset/*.bmp")
    train_labels = rnd.sample(glob("/content/glas/Warwick_QU_Dataset/*anno.bmp"), k=K_SAMPLES_TRAIN)
    train_files = [p for p in glob("/content/glas/Warwick_QU_Dataset/*.bmp") if p.replace(".bmp", "_anno.bmp") in train_labels]

    return train_files + train_labels

In [None]:
!mkdir /content/glas/Warwick_QU_Dataset/train
!mkdir /content/glas/Warwick_QU_Dataset/test

In [None]:
train_files = get_glas_train()

for f in iglob("/content/glas/Warwick_QU_Dataset/*.bmp"):
    if f in train_files:
        shutil.move(f, f.replace("/content/glas/Warwick_QU_Dataset", "/content/glas/Warwick_QU_Dataset/train"))
    else:
        shutil.move(f, f.replace("/content/glas/Warwick_QU_Dataset", "/content/glas/Warwick_QU_Dataset/test"))

# Data prepare

In [None]:
!rm -rf /content/glas/Warwick_QU_Dataset/train/images
!rm -rf /content/glas/Warwick_QU_Dataset/train/labels

!rm -rf /content/glas/Warwick_QU_Dataset/test/images
!rm -rf /content/glas/Warwick_QU_Dataset/test/labels

In [None]:
!mkdir -p /content/glas/Warwick_QU_Dataset/train/images
!mkdir -p /content/glas/Warwick_QU_Dataset/train/labels

!mkdir -p /content/glas/Warwick_QU_Dataset/test/images
!mkdir -p /content/glas/Warwick_QU_Dataset/test/labels

In [None]:
for p in Path("/content/glas/Warwick_QU_Dataset/train/").glob("*anno.bmp"):
    shutil.move(p, "/content/glas/Warwick_QU_Dataset/train/labels/")

for p in Path("/content/glas/Warwick_QU_Dataset/train/").glob("*.bmp"):
    shutil.move(p, "/content/glas/Warwick_QU_Dataset/train/images")

for p in Path("/content/glas/Warwick_QU_Dataset/test/").glob("*anno.bmp"):
    shutil.move(p, "/content/glas/Warwick_QU_Dataset/test/labels/")

for p in Path("/content/glas/Warwick_QU_Dataset/test/").glob("*.bmp"):
    shutil.move(p, "/content/glas/Warwick_QU_Dataset/test/images")

In [None]:
for p in Path("/content/glas/Warwick_QU_Dataset/train/labels/").glob("*.bmp"):
    os.rename(p, p.parent / (p.name.replace("_anno", "")))

for p in Path("/content/glas/Warwick_QU_Dataset/test/labels/").glob("*.bmp"):
    os.rename(p, p.parent / (p.name.replace("_anno", "")))

In [None]:
for p in Path("/content/glas/Warwick_QU_Dataset/train/labels/").glob("*.bmp"):
    img = cv2.imread(str(p), cv2.IMREAD_GRAYSCALE)
    bin_mask = np.zeros(img.shape)
    bin_mask[img != 0] = 1
    cv2.imwrite(str(p), bin_mask)

In [None]:
for p in Path("/content/glas/Warwick_QU_Dataset/test/labels/").glob("*.bmp"):
    img = cv2.imread(str(p), cv2.IMREAD_GRAYSCALE)
    bin_mask = np.zeros(img.shape)
    bin_mask[img != 0] = 1
    cv2.imwrite(str(p), bin_mask)

# Utils

## Data

In [None]:
def image_mask_path(image_path: str, mask_path: str):
    IMAGE_PATH = Path(image_path)
    IMAGE_PATH_LIST = sorted(list(IMAGE_PATH.glob("*.bmp")))

    MASK_PATH = Path(mask_path)
    MASK_PATH_LIST = sorted(list(MASK_PATH.glob("*.bmp")))

    return IMAGE_PATH_LIST, MASK_PATH_LIST

In [None]:
def count_unique(mask_path_list):
    VALUES_UNIQUE_TRAIN = []

    for i in mask_path_list:
        sample = cv2.imread(str(i), cv2.IMREAD_GRAYSCALE)
        uniques = np.unique(sample)
        VALUES_UNIQUE_TRAIN.append(uniques)

    FINAL_VALUES_UNIQUE_TRAIN = np.concatenate(VALUES_UNIQUE_TRAIN)

    return np.unique(FINAL_VALUES_UNIQUE_TRAIN)

In [None]:
class CustomImageMaskDataset(Dataset):
    def __init__(self, data:pd.DataFrame, image_transforms, mask_transforms):
        self.data = data
        self.image_transforms = image_transforms
        self.mask_transforms = mask_transforms

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image_path = self.data.iloc[idx, 0]
        image = Image.open(image_path).convert("RGB")

        state = torch.get_rng_state()
        image = self.image_transforms(image)

        mask_path = self.data.iloc[idx, 1]
        mask = Image.open(mask_path)

        torch.set_rng_state(state)
        mask = self.mask_transforms(mask)

        return image, mask

In [None]:
class CustomTestDataset(Dataset):
    def __init__(self, data:pd.DataFrame, image_transforms):
        self.data = data
        self.image_transforms = image_transforms

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image_path = self.data.iloc[idx, 0]
        image = Image.open(image_path).convert("RGB")
        image = self.image_transforms(image)

        return image

## Train

In [None]:
def train_step(model:torch.nn.Module, dataloader:torch.utils.data.DataLoader,
               loss_fn:torch.nn.Module, optimizer:torch.optim.Optimizer):

    model.train()

    train_loss = 0.
    train_dice = 0.

    for batch, (X,y) in enumerate(dataloader):
        X = X.to(device = DEVICE, dtype = torch.float32)
        y = y.to(device = DEVICE, dtype = torch.long)
        optimizer.zero_grad()
        logit_mask = model(X)
        loss = loss_fn(logit_mask, y.squeeze())
        train_loss += loss.item()

        loss.backward()
        optimizer.step()

        prob_mask = logit_mask.softmax(dim = 1)
        pred_mask = prob_mask.argmax(dim = 1)

        tp,fp,fn,tn = smp.metrics.get_stats(output = pred_mask.detach().cpu().long(),
                                            target = y.squeeze().cpu().long(),
                                            mode = "multiclass",
                                            num_classes = 21)

        train_dice += smp.metrics.f1_score(tp, fp, fn, tn, reduction = "micro").numpy()

    train_loss = train_loss / len(dataloader)
    train_dice = train_dice / len(dataloader)

    return train_loss, train_dice

In [None]:
def train(model:torch.nn.Module, train_dataloader:torch.utils.data.DataLoader,
          loss_fn:torch.nn.Module,
          optimizer:torch.optim.Optimizer, epochs:int = 10):

    results = {'train_loss':[], 'train_dice':[]}

    for epoch in tqdm(range(epochs)):
        train_loss, train_dice = train_step(model = model,
                                           dataloader = train_dataloader,
                                           loss_fn = loss_fn,
                                           optimizer = optimizer)

        print(f'Epoch: {epoch + 1} | ',
              f'Train Loss: {train_loss:.4f} | ',
              f'Train Dice: {train_dice:.4f}')

        results['train_loss'].append(train_loss)
        results['train_dice'].append(train_dice)

    return results

## Prediction

In [None]:
def predictions_mask(model, test_dataloader: torch.utils.data.DataLoader):
    model.eval()

    y_pred_mask = []

    with torch.inference_mode():
        for batch,X in enumerate(test_dataloader):
            X = X.to(device = DEVICE, dtype = torch.float32)
            mask_logit = model(X)
            mask_prob = mask_logit.softmax(dim = 1)
            mask_pred = mask_prob.argmax(dim = 1)
            y_pred_mask.append(mask_pred.detach().cpu())

    y_pred_mask = torch.cat(y_pred_mask)

    return y_pred_mask

# Augmentations

In [None]:
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]

image_transforms = transforms.Compose([
                                       transforms.RandomHorizontalFlip(),
                                       transforms.RandomVerticalFlip(),
                                       transforms.RandomResizedCrop((224, 224), antialias=True),
                                       transforms.ToTensor(),
                                       transforms.Normalize(mean = MEAN, std = STD),
                                       ])

mask_transforms = transforms.Compose([
                                      transforms.RandomHorizontalFlip(),
                                      transforms.RandomVerticalFlip(),
                                      transforms.RandomResizedCrop((224, 224), antialias=True),
                                      transforms.PILToTensor(),
                                    ])

image_transforms_test = transforms.Compose([
                                       transforms.Resize((224, 224), antialias=True),
                                       transforms.ToTensor(),
                                       transforms.Normalize(mean = MEAN, std = STD),
                                       ])

mask_transforms_test = transforms.Compose([
                                      transforms.Resize((224, 224), antialias=True),
                                      transforms.PILToTensor(),
                                    ])

# Data load

In [None]:
image_path_train = "/content/glas/Warwick_QU_Dataset/train/images"
mask_path_train = "/content/glas/Warwick_QU_Dataset/train/labels"

IMAGE_PATH_LIST_TRAIN, MASK_PATH_LIST_TRAIN = image_mask_path(image_path_train,
                                                              mask_path_train)

print(f'Total Images Train: {len(IMAGE_PATH_LIST_TRAIN)}')
print(f'Total Masks Train: {len(MASK_PATH_LIST_TRAIN)}')

In [None]:
print("Unique values Train:")
print(count_unique(MASK_PATH_LIST_TRAIN))

# Preprocessing

In [None]:
data_train = pd.DataFrame({'Image':IMAGE_PATH_LIST_TRAIN, 'Mask': MASK_PATH_LIST_TRAIN})

In [None]:
train_dataset = CustomImageMaskDataset(data_train, image_transforms, mask_transforms)

In [None]:
train_dataloader = DataLoader(dataset = train_dataset, batch_size = BATCH_SIZE,
                              shuffle = True, num_workers = NUM_WORKERS)

In [None]:
# We visualize the dimensions of a batch.
batch_images, batch_masks = next(iter(train_dataloader))

batch_images.shape, batch_masks.shape

# Model

## Load from checkpoint

In [None]:
from google.colab import drive
drive.mount("/content/drive")

In [None]:
model = torch.load("/content/drive/MyDrive/model.pt", map_location=torch.device('cpu'))

## Freeze layers

In [None]:
model.segmentation_head[0].out_channels = NUM_CLASSES

In [None]:
for param in model.encoder.parameters():
    param.requires_grad = False

In [None]:
# We view our model again to check if the encoder layers freeze.
summary(model = model,
        input_size = [64, 3, 224, 224],
        col_width = 15,
        col_names = ['input_size', 'output_size', 'num_params', 'trainable'],
        row_settings = ['var_names'])

## Train

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = 0.001, weight_decay = 0.0001)

In [None]:
# Training!!!

SEED = 42
EPOCHS = 10
torch.cuda.manual_seed(SEED)
torch.manual_seed(SEED)

RESULTS = train(model.to(device = DEVICE),
                train_dataloader,
                loss_fn,
                optimizer,
                EPOCHS)

## Save

In [None]:
!mkdir /content/checkpoints

In [None]:
torch.save(model.state_dict(), "/content/checkpoints/model.pth")

# Evaluation

In [None]:
image_path_val = "/content/glas/Warwick_QU_Dataset/test/images"
mask_path_val = "/content/glas/Warwick_QU_Dataset/test/labels"

IMAGE_PATH_LIST_VAL, MASK_PATH_LIST_VAL = image_mask_path(image_path_val,
                                                          mask_path_val)

print(f'Total Images Val: {len(IMAGE_PATH_LIST_VAL)}')
print(f'Total Masks Val: {len(MASK_PATH_LIST_VAL)}')

In [None]:
data_val = pd.DataFrame({'Image':IMAGE_PATH_LIST_VAL,
                         'Mask':MASK_PATH_LIST_VAL})
val_dataset = CustomImageMaskDataset(data_val, image_transforms_test,
                                     mask_transforms_test)
val_dataloader = DataLoader(dataset = val_dataset, batch_size = BATCH_SIZE,
                            shuffle = True, num_workers = NUM_WORKERS)

In [None]:
# Num of batches, which will be used for evaluation
BATCH_TO_TEST = 2

In [None]:
test_dice = 0.

with torch.inference_mode():
    for batch, (X, y) in tqdm(enumerate(val_dataloader), total=BATCH_TO_TEST):
        if batch >= BATCH_TO_TEST:
            break

        X = X.to(device = DEVICE, dtype = torch.float32)
        y = y.to(device = DEVICE, dtype = torch.long)

        logit_mask = model(X)

        prob_mask = logit_mask.softmax(dim = 1)
        pred_mask = prob_mask.argmax(dim = 1)

        tp, fp, fn, tn = smp.metrics.get_stats(output = pred_mask.detach().cpu().long(),
                                                target = y.squeeze().cpu().long(),
                                                mode = "multiclass",
                                                num_classes = 3)

        test_dice += smp.metrics.f1_score(tp, fp, fn, tn, reduction = "micro").numpy()

test_dice /= BATCH_TO_TEST

In [None]:
test_dice