In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
from PIL import Image
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from torch.optim import Adam
import pandas as pd
import matplotlib.patches as patches

from GlaucomaDataset import GlaucomaDatasetBoundingBoxes
from unet import UNet

In [2]:
origa_path = os.path.join("..", '..', "data", "ORIGA")
images_path = os.path.join(origa_path, "Images_Square")
masks_path = os.path.join(origa_path, "Masks_Square")

img_filenames = sorted(os.listdir(images_path))
mask_filenames = sorted(os.listdir(masks_path))

In [3]:
def update_image_path(path):
    split_path = path.split("/")
    return split_path[-1]

In [4]:
bb_df = pd.read_csv("../../data/ORIGA/bounding_boxes.csv")
bb_df['image_path'] = bb_df['image_path'].apply(update_image_path)
bb_df[['x1', 'y1', 'x2', 'y2']] //= 2

In [5]:
# Split into train, validation, and test sets (70, 15, 15)
train_imgs, temp_imgs, train_masks, temp_masks = train_test_split(
    img_filenames, mask_filenames, test_size=0.3, random_state=42)

val_imgs, test_imgs, val_masks, test_masks = train_test_split(
    temp_imgs, temp_masks, test_size=0.5, random_state=42)

In [6]:
# Load data
batch_size = 8
n_workers = 4

train_set = GlaucomaDatasetBoundingBoxes(images_path, masks_path, train_imgs, train_masks, bb_df)
val_set = GlaucomaDatasetBoundingBoxes(images_path, masks_path, val_imgs, val_masks, bb_df)
test_set = GlaucomaDatasetBoundingBoxes(images_path, masks_path, test_imgs, test_masks, bb_df)

train_loader = DataLoader(train_set, batch_size=batch_size, num_workers=n_workers, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size, num_workers=n_workers, shuffle=True)
test_loader = DataLoader(test_set, batch_size=batch_size, num_workers=n_workers, shuffle=True)

In [7]:
for image, image_names, masks, mask_names in test_loader:
    break

In [8]:
image.shape

torch.Size([8, 4, 256, 256])

In [None]:
def display_image_with_bbox(img_array, bbox, color='red', linewidth=2):
    """
    Display an image with a bounding box overlay.
    
    Parameters:
    -----------
    img_array : numpy.ndarray
        Image as a numpy array.
    bbox : list or tuple
        Bounding box coordinates in the format [x1, y1, x2, y2]
        where (x1, y1) is the top-left corner and (x2, y2) is the bottom-right corner.
    color : str, default='red'
        Color of the bounding box.
    linewidth : int, default=2
        Width of the bounding box border.
    """
    # Create figure and axes
    fig, ax = plt.subplots(1)
    
    # Display the image
    ax.imshow(img_array)
    
    # Extract coordinates
    x1, y1, x2, y2 = bbox
    
    # Calculate width and height for the rectangle
    width = x2 - x1
    height = y2 - y1
    
    # Create a Rectangle patch
    rect = patches.Rectangle((x1, y1), width, height, linewidth=linewidth, 
                             edgecolor=color, facecolor='none')
    
    # Add the rectangle to the plot
    ax.add_patch(rect)
    
    # Remove axis ticks and labels
    ax.set_xticks([])
    ax.set_yticks([])
    
    # Show the plot
    plt.tight_layout()
    plt.show()

## Define Metrics

In [None]:
def dice_coefficient(targets, preds, smooth=1e-6):
    # preds = (preds > 0.5).float 
    intersection = torch.sum(preds * targets, dim=(2,3))
    # want close to 1 (identical)
    dice = (2. * intersection + smooth) / (torch.sum(preds, dim=(2,3)) + torch.sum(targets, dim=(2,3)) + smooth)
    return dice.mean()

## Train and Test Loops

In [None]:
def trainloop(dataloader, model, loss_func, optimizer):
    num_batches = len(dataloader)
    train_loss, dice = 0., 0. 
    
    for image, image_name, mask, mask_name in dataloader:
        image, mask = image.to(device), mask.to(device)
        
        optimizer.zero_grad()
        pred = model(image)
        loss = loss_func(pred, mask)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        dice += dice_coefficient(pred, mask).item()
    train_loss /= num_batches
    dice /= num_batches
        
    return train_loss, dice

def testloop(dataloader, model, loss_func):
    num_batches = len(dataloader)
    test_loss, dice = 0. , 0.
    
    with torch.no_grad():
        for image, image_name, mask, mask_name in dataloader:
            image, mask = image.to(device), mask.to(device)
            pred = model(image)
            test_loss += loss_func(pred, mask).item()
            dice += dice_coefficient(pred, mask).item()
    test_loss /= num_batches
    dice /= num_batches
    
    return test_loss, dice