In [1]:
import os
import pycocotools
from pycocotools import mask
import pycocotools.mask as mask_util
import numpy as np
import json
from pycocotools.coco import COCO
from sklearn.model_selection import train_test_split
import random 
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.gridspec as gridspec
import matplotlib.patches as mpatches
import matplotlib as mpl

def np_encoder(object):
    if isinstance(object, np.generic):
        return object.item()

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from matplotlib.colors import ListedColormap
import segmentation_models_pytorch as smp

In [2]:
os.getcwd()

os.chdir('../')

os.getcwd()

'/Users/srikaranreddy/Desktop/Spring Semester/Computer Vision 6.8300/cv-project/gi-tract-image-segmentation'

In [4]:
from src.data import SegmentationDataset
from src.data import DataGenerator

In [3]:
# Dice Loss
class DiceLoss(nn.Module):
    def __init__(self, smooth=1):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, y_pred, y_true):
        y_pred = y_pred.contiguous().view(-1)
        y_true = y_true.contiguous().view(-1)
        intersection = (y_pred * y_true).sum()
        dice = (2. * intersection + self.smooth) / (y_pred.sum() + y_true.sum() + self.smooth)
        return 1 - dice

# Combined BCE and Dice Loss
class BCEDiceLoss(nn.Module):
    def __init__(self, smooth=1):
        super(BCEDiceLoss, self).__init__()
        self.bce = nn.BCEWithLogitsLoss()
        self.dice = DiceLoss(smooth)

    def forward(self, y_pred, y_true):
        bce_loss = self.bce(y_pred, y_true)
        dice_loss = self.dice(y_pred, y_true)
        return 0.5 * bce_loss + 0.5 * dice_loss

def dice_coef_func(y_true, y_pred, smooth=1):
    y_true_f = y_true.contiguous().view(-1)
    y_pred_f = y_pred.contiguous().view(-1)
    intersection = torch.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (torch.sum(y_true_f) + torch.sum(y_pred_f) + smooth)


In [None]:
train_generator_class = DataGenerator(dataset_dir='datasets/train',
                                     subset="train",
                                     classes=CLASSES,
                                     input_image_size=(128,128),
                                     annFile='datasets/coco/train_json.json',
                                     shuffle=True)

val_generator_class = DataGenerator(dataset_dir='datasets/train',
                                     subset="train",
                                     classes=CLASSES,
                                     input_image_size=(128,128),
                                     annFile='datasets/coco/val_json.json',
                                     shuffle=True)

test_generator_class = DataGenerator(dataset_dir='datasets/train',
                                     subset="test",
                                     classes=CLASSES,
                                     input_image_size=(128,128),
                                     annFile='datasets/coco/test_json.json',
                                     shuffle=False)

In [None]:
X, y = train_generator_class.__getitem__(9)

cmap = ListedColormap(['none', 'red'])  # 'none' is transparent, 'red' for the mask

fig, ax = plt.subplots()
# Display the image
ax.imshow((X/255.)[:,:,0], cmap='gray')  # Use gray scale for the background image
# Display the mask
# The mask is added with 'alpha' for transparency so the image can be seen under the mask
ax.imshow(y[:,:,1], cmap=cmap, alpha=0.5)  # Adjust alpha for more or less transparency

plt.show()

In [None]:
train_loader = DataLoader(train_generator_class, batch_size=32, num_workers=0)
val_loader = DataLoader(val_generator_class, batch_size=32, num_workers=0)
test_loader = DataLoader(test_generator_class, batch_size=32, num_workers=0)

In [None]:
print("Batch size:", train_loader.batch_size)
print("Num workers:", train_loader.num_workers)
print("Dataset size:", len(train_loader.dataset))
print("Number of batches:", len(train_loader))

In [None]:
# Define model
model = smp.Unet(
    encoder_name="efficientnet-b7", 
    encoder_weights="imagenet", 
    in_channels=3, 
    classes=3,
    activation='sigmoid'
)

In [None]:
model

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.2, patience=5, verbose=True, min_lr=0.001)
criterion = BCEDiceLoss()

In [None]:
from tqdm import tqdm
train_losses = []
val_losses = []
train_dice_coefs = []
val_dice_coefs = []
epochs = 5

class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


for epoch in range(epochs):
    model.train()
    train_loss = 0
    dice_coef = 0 

    dice_coef_meter = AverageMeter()
    batches = tqdm(enumerate(train_loader), total=len(train_loader))
    batches.set_description("Epoch NA: Loss (NA) Accuracy (NA %)")
    for batch_idx, (data, target) in batches:
        data = data.permute(0, 3, 1, 2)
        target = target.permute(0, 3, 1, 2)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        dice_coef = dice_coef_func(output, target)
        dice_coef_meter.update(dice_coef)
    
    train_losses.append(train_loss / len(train_loader))
    train_dice_coefs.append(dice_coef / len(train_loader))
    batches.set_description(
            "Epoch {:d}: Loss ({:.2e}), Train Accuracy ({:02.0f}%)".format(
                epoch, train_loss, 100.0 * dice_coef_meter.avg
            )
        )

    
    model.eval()
    
    with torch.no_grad():
        val_loss = 0
        for data, target in tqdm(val_loader):
            data = data.permute(0, 3, 1, 2)
            target = target.permute(0, 3, 1, 2)
            
            output = model(data)
            loss = criterion(output, target)
            val_loss += loss.item()
            dice_coef += dice_coef_func(output, target)

    val_losses.append(val_loss / len(val_loader))
    val_dice_coefs.append(dice_coef / len(val_loader))

    print(f"Epoch {epoch}, Val Loss: {val_loss}")
    
    scheduler.step(val_loss)

    # Save model checkpoint
    torch.save(model.state_dict(), f'UNET_model_epoch_{epoch}.pth')
