In [None]:
from utils.crop import postprocess_mask

import os
from PIL import Image
from pathlib import Path
import numpy as np
import pandas as pd
import cv2
import random
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as TF
import torch.nn.functional as F
import torch.optim.lr_scheduler as lr_scheduler
# from torchmetrics import Dice

# !pip install git+https://github.com/qubvel/segmentation_models.pytorch
import segmentation_models_pytorch as smp

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Paths
ROOT_DIR = Path(r'E:\Prut\cxr\data\segmentation\data\Lung Segmentation')
TRAIN_IMG_DIR = ROOT_DIR / 'CXR_png'
TRAIN_MASKS_DIR = ROOT_DIR / 'masks'

# Hyperparameters
SEED = 1234
IMG_SIZE = 224
TEST_SIZE = 0.2
BATCH_SIZE = 32

CONTRAST_FACTOR = 1.8

EPOCHS = 15
LEARNING_RATE = 1e-2
PATIENCE = 2
MIN_DELTA = 1e-3
GAMMA = 0.5



for dir in [TRAIN_IMG_DIR, TRAIN_MASKS_DIR]:
    print('Length of', dir.stem, len(os.listdir(dir)))

TRAIN_FILE_NAMES = sorted(set(os.listdir(TRAIN_IMG_DIR)) & set(os.listdir(TRAIN_MASKS_DIR)))
print('Files to be used:', len(TRAIN_FILE_NAMES))

In [None]:
path = TRAIN_FILE_NAMES[6]
img = cv2.resize(cv2.imread(str(TRAIN_IMG_DIR / path)), (IMG_SIZE, IMG_SIZE)) # [:,:,0]
mask = cv2.resize(cv2.imread(str(TRAIN_MASKS_DIR / path)), (IMG_SIZE, IMG_SIZE)) # [:,:,0]
added = cv2.addWeighted(img, 0.7, mask, 0.3, 0)
stacked = np.hstack((img, mask, added))
Image.fromarray(stacked)

In [None]:
random.seed(SEED)
test = random.sample(TRAIN_FILE_NAMES, round(len(TRAIN_FILE_NAMES) * TEST_SIZE))
train = sorted(set(TRAIN_FILE_NAMES) - set(test))
assert set(train) & set(test) == set()
print('Length of whole dataset: {}'.format(len(TRAIN_FILE_NAMES)))
print('Length of train dataset: {}'.format(len(train)))
print('Length of test dataset: {}'.format(len(test)))

# X_train = [TRAIN_IMG_DIR / stem for stem in train]
# y_train = [TRAIN_MASKS_DIR / stem for stem in train]
# X_test = [TRAIN_IMG_DIR / stem for stem in test]
# y_test = [TRAIN_MASKS_DIR / stem for stem in test]

In [None]:
class ContrastTransform:
    def __init__(self, factor):
        self.factor = factor

    def __call__(self, x):
        return TF.adjust_contrast(x, self.factor)

class SegDataset(Dataset):
    def __init__(self, paths, visualise=False):
        self.paths = paths
        self.transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=3),
            transforms.Resize((IMG_SIZE, IMG_SIZE)),
            transforms.ToTensor(),
            transforms.ColorJitter(0.1, 0, 0.1, 0.1), # Added augmentation 18-08-23
            ContrastTransform(CONTRAST_FACTOR),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
            transforms.Grayscale(num_output_channels=1),
            # # No idea why I needed the next 2 lines, but otherwise the range wouldn't be [0,1] but something like [-2.5,1.5]
            transforms.ToPILImage(),
            transforms.ToTensor(), # transforms.Lambda(lambda x: x/255.)
            
        ])
        self.mask_transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=3),
            transforms.Resize((IMG_SIZE, IMG_SIZE)),
            transforms.ToTensor(),
            transforms.Grayscale(num_output_channels=1)
        ])
        self.visualise = visualise
        self.pil = transforms.ToPILImage()
        self.to_tensor = transforms.ToTensor()

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(str(TRAIN_IMG_DIR / self.paths[idx]))
        mask = Image.open(str(TRAIN_MASKS_DIR / self.paths[idx]))

        if self.transform:
            img = self.transform(img)

            mask = self.mask_transform(mask)
            mask = torch.where(mask != 0, torch.tensor(1), mask)
            mask = mask.to(torch.float32)

        if self.visualise:
            img = self.pil(img)
            mask = self.pil(mask)
        
        return img, mask

In [None]:
visualise_dataset = SegDataset(train, visualise=True)
visualise_dataset[0][0]

In [None]:
visualise_dataset[0][1]

In [None]:
train = train
train_dataset = SegDataset(train)
test_dataset = SegDataset(test)

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
assert train_dataset[0][0].min().item() >= 0.
assert train_dataset[0][0].max().item() <= 1.
assert train_dataset[0][0].size(0) == 1
assert train_dataset[0][0].shape == train_dataset[0][1].shape == torch.Size([1, 224, 224])

### U-Net

### Try if the shape works on a single image

In [None]:
# !pip install git+https://github.com/qubvel/segmentation_models.pytorch
# import segmentation_models_pytorch as smp
model = smp.Unet("resnet34", encoder_weights="imagenet", in_channels=1, classes=2) # encoder_depth, aux_params

model.to(device)

# Shape checking
input = torch.randn((32, 1, 224, 224)).to(device)
print('Model accepts input of shape:', input.shape)
output = model(input)
print('Model generates output of shape:', output.shape)

# Try on our dataloaders
X, y = next(iter(train_dataloader))
X, y = X.to(device), y.to(device)
print('Size of real input:', X.shape)

pred = model(X)
print('Size of prediction:', pred.shape)

# Show result
pil = transforms.ToPILImage()
im1 = pil(X[0])
im2 = pil(y[0])
im3 = pil(pred[0][0])
im4 = pil(pred[0][1])
im = Image.new('RGB', (im1.width + im2.width + im3.width + im4.width, im1.height))
im.paste(im1, (0,0))
im.paste(im2, (im1.width, 0))
im.paste(im3, (im1.width * 2, 0))
im.paste(im4, (im1.width * 3, 0))
im

I searched 'unet resnet outputs 2 pictures which one do i use' on chatgpt:
```
Segmentation Map: The first output picture represents the segmentation map, which is a pixel-wise prediction map that assigns a class label to each pixel in the input image. This map indicates which class or category each pixel belongs to. The pixel values in this map are usually integers representing class labels, or in some cases, the probabilities of each class.

Auxiliary Output: The second output picture is often an auxiliary output that serves as an intermediate representation. This output is not always present in all U-Net with ResNet implementations, and its purpose can vary based on the specific architecture or task. Sometimes, this auxiliary output is used for regularization or for providing additional supervision during training.
```
It says use segmentation map

In [None]:
print(pred.shape)
print(pred[:,0,:,:].unsqueeze(1).shape)

## Training the model

In [None]:
# print('Accuracy:', accuracy_fn(torch.tensor([1,0,0,0,1]), torch.tensor([0,0,1,1,1]))) # 0.4
# print('Loss:', loss_fn(torch.tensor([1,0,0,0,1]), torch.tensor([0,0,1,1,1]))) # 0.6

In [None]:
def dice(preds, target, smooth=1e-6, round=True):
    
    preds = torch.flatten(preds)
    if round:
        preds = torch.round(preds)
    target = torch.flatten(target)
    intersection = (preds * target).sum()
    dice = (2. * intersection + smooth) / (preds.sum() + target.sum() + smooth)
    return dice

class DiceLoss(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, preds, target, smooth=1e-6):

        preds = torch.flatten(preds)
        target = torch.flatten(target)
        
        intersection = (preds * target).sum()
        dice = (2. * intersection + smooth) / (preds.sum() + target.sum() + smooth) 

        loss = 1 - dice

        # assert loss.grad_fn is not None # if train
        return loss
    
class EarlyStopper:
    '''https://stackoverflow.com/questions/71998978/early-stopping-in-pytorch'''
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = np.inf

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False


def train_step(dataloader, model, loss_fn, optimizer, accuracy_fn):
    train_loss = 0
    train_accuracy = 0
    total_samples = 0

    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        preds = model(X)
        preds = preds[:,0,:,:].unsqueeze(1)

        loss = loss_fn(preds, y)

        acc = accuracy_fn(preds, y)

        if batch % 4 == 1:
            print(f'Batch {batch}: Loss {loss:.6f} | Acc {acc * 100:.1f}%')


        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss * X.size(0)
        train_accuracy += acc * X.size(0)
        total_samples += X.size(0)

    train_loss /= total_samples
    train_accuracy /= total_samples

    return model, train_loss, train_accuracy

def test_step(dataloader, model, loss_fn, accuracy_fn):
    test_loss = 0
    test_accuracy = 0
    total_samples = 0

    model.eval()
    with torch.inference_mode():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)

            preds = model(X)
            preds = preds[:,0,:,:].unsqueeze(1)

            loss = loss_fn(preds, y)
            acc = accuracy_fn(preds, y)

            test_loss += loss * X.size(0)
            test_accuracy += acc * X.size(0)
            total_samples += X.size(0)

        test_loss /= total_samples
        test_accuracy /= total_samples
    
    print(f'Loss {loss:.6f} | Dice Accuracy {acc * 100:.1f}%')

    return test_loss, test_accuracy


# ---------- START HERE ----------

train_losses = []
train_accuracies =[]
test_losses = []
test_accuracies = []
model_l = []

torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

# Model
model = smp.Unet("resnet34", encoder_weights="imagenet", in_channels=1, classes=2, activation='sigmoid').to(device)
# Unfreeze decoder and segmentation head
for name, param in model.named_parameters():
    if ('segmentation' not in name) & ('decoder' not in name): 
        param.requires_grad = False

# Loss, Optimizer, Accuracy functions
loss_fn = DiceLoss().to(device)
accuracy_fn = dice
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=GAMMA, last_epoch=-1, verbose=True) # LinearLR(optimizer, start_factor=1.0, end_factor=0.3, total_iters=5)
early_stopper = EarlyStopper(patience=PATIENCE, min_delta=MIN_DELTA)

# Training and testing loops
for epoch in tqdm(range(EPOCHS)):
    print(f'Epoch {epoch + 1}')

    # Train and test loops
    model, train_loss, train_accuracy = train_step(dataloader=train_dataloader, model=model, loss_fn=loss_fn, optimizer=optimizer, accuracy_fn=accuracy_fn)
    test_loss, test_accuracy = test_step(dataloader=test_dataloader, model=model, loss_fn=loss_fn, accuracy_fn=accuracy_fn)

    # For overfitting visualisation
    model_l.append(model)
    train_losses.append(train_loss.item())
    train_accuracies.append(train_accuracy.item())
    test_losses.append(test_loss.item())
    test_accuracies.append(test_accuracy.item())

    # Early stopping
    if early_stopper.early_stop(test_loss.item()):
        # Model checkpoint (p tong said use first model before plateau)
        model = model_l[-PATIENCE]
        break

    # Learning rate scheduler (stored in optimizer.param_groups[0]["lr"])
    scheduler.step()

    print('-' * 50)

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))

ax1.plot(train_losses, label='Train Loss', color='blue', linestyle='-', marker='o')
ax1.plot(test_losses, label='Test Loss', color='red', linestyle='-', marker='x')
ax1.set_xlabel('Epochs')
ax1.set_ylabel('Loss')
ax1.set_title('Train and Test Loss Over Epochs')
ax1.legend()

ax2.plot(train_accuracies, label='Train Accuracy', color='green', linestyle='-', marker='o')
ax2.plot(test_accuracies, label='Test Accuracy', color='orange', linestyle='-', marker='x')
ax2.set_xlabel('Epochs')
ax2.set_ylabel('Accuracy')
ax2.set_title('Train and Test Accuracy Over Epochs')
ax2.legend()

plt.tight_layout()
plt.show()

In [None]:
# Visualise after model training
X, y = next(iter(test_dataloader))
X, y = X.to(device), y.to(device)
print('Size of real input:', X.shape)

pred = model(X)
print('Size of prediction:', pred.shape)

# Show result
pil = transforms.ToPILImage()
im1 = pil(X[0])
im2 = pil(y[0].to(torch.float32))
im3 = pil(pred[0][0])
im4 = pil(pred[0][1])
im = Image.new('RGB', (im1.width + im2.width + im3.width + im4.width, im1.height))
im.paste(im1, (0,0))
im.paste(im2, (im1.width, 0))
im.paste(im3, (im1.width * 2, 0))
im.paste(im4, (im1.width * 3, 0))
im

In [None]:
from torchinfo import summary

summary(model, col_names=['num_params', 'trainable'])

In [None]:
torch.save(model, r'E:\Prut\cxr\models\lung_segment_model_180823.pt')

#### Comparison with scores after postprocessing

In [None]:
dice_unet_list = []
dice_postprocessed_list = []
p = transforms.ToTensor()

for stem in tqdm(test):

    img = Image.open(str(TRAIN_IMG_DIR / stem))
    mask_truth = Image.open(str(TRAIN_MASKS_DIR / stem)).resize((224,224))
    mask_unet = postprocess_mask(img, model, return_original=True)
    mask_postprocessed = postprocess_mask(img, model)
    
    dice_unet = dice(p(mask_unet), p(mask_truth))
    dice_postprocessed = dice(p(mask_postprocessed), p(mask_truth))

    dice_unet_list.append(dice_unet)
    dice_postprocessed_list.append(dice_postprocessed)

avg_dice_unet = np.mean(dice_unet_list)
avg_dice_postprocessed = np.mean(dice_postprocessed_list)
increase = avg_dice_postprocessed - avg_dice_unet

print(f'Average Dice score after U-Net: {avg_dice_unet * 100:.3f} %')
print(f'Average Dice score after postprocessing: {avg_dice_postprocessed * 100:.3f} %')
print(f'Increase in Dice score: {increase * 100:.3f} %')

In [None]:
def overlap(preds, target):
    preds = torch.flatten(preds)
    preds = torch.round(preds)
    target = torch.flatten(target)
    intersection = (preds * target).sum()
    overlap = intersection / target.sum()
    return overlap

overlap_unet_list = []
overlap_postprocessed_list = []
p = transforms.ToTensor()

for stem in tqdm(test):

    img = Image.open(str(TRAIN_IMG_DIR / stem))
    mask_truth = Image.open(str(TRAIN_MASKS_DIR / stem)).resize((224,224))
    mask_unet = postprocess_mask(img, model, return_original=True)
    mask_postprocessed = postprocess_mask(img, model)
    
    overlap_unet = overlap(p(mask_unet), p(mask_truth))
    overlap_postprocessed = overlap(p(mask_postprocessed), p(mask_truth))

    overlap_unet_list.append(overlap_unet)
    overlap_postprocessed_list.append(overlap_postprocessed)

avg_overlap_unet = np.mean(overlap_unet_list)
avg_overlap_postprocessed = np.mean(overlap_postprocessed_list)
increase = avg_overlap_postprocessed - avg_overlap_unet

print(f'Average overlap score after U-Net: {avg_overlap_unet * 100:.3f} %')
print(f'Average overlap score after postprocessing: {avg_overlap_postprocessed * 100:.3f} %')
print(f'Increase in overlap score: {increase * 100:.3f} %')

In [None]:
# from utils.crop import CropLungsTransform
# CropLungsTransform(model)(img)