# UNet Segmentation of Optic Disc and Optic Cup

In [2]:
import os
import numpy as np
import torch
from PIL import Image
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from torch.optim import Adam

from GlaucomaDataset import GlaucomaDataset
from unet import UNet

In [3]:
origa_path = os.path.join("..", "data", "ORIGA")
images_path = os.path.join(origa_path, "Images_Square")
masks_path = os.path.join(origa_path, "Masks_Square")

img_filenames = sorted(os.listdir(images_path))
mask_filenames = sorted(os.listdir(masks_path))

In [4]:
# Split into train, validation, and test sets (70, 15, 15)
train_imgs, temp_imgs, train_masks, temp_masks = train_test_split(
    img_filenames, mask_filenames, test_size=0.3, random_state=42)

val_imgs, test_imgs, val_masks, test_masks = train_test_split(
    temp_imgs, temp_masks, test_size=0.5, random_state=42)

In [5]:
# Set parameters (learning rate, batch size, epochs, etc.)
lr = 1e-4
batch_size = 8
n_workers = 4
epochs = 40

In [6]:
# Load data
train_set = GlaucomaDataset(images_path, masks_path, train_imgs, train_masks)
val_set = GlaucomaDataset(images_path, masks_path, val_imgs, val_masks)
test_set = GlaucomaDataset(images_path, masks_path, test_imgs, test_masks)

train_loader = DataLoader(train_set, batch_size=batch_size, num_workers=n_workers, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size, num_workers=n_workers, shuffle=True)
test_loader = DataLoader(test_set, batch_size=batch_size, num_workers=n_workers, shuffle=True) # need to load test?

In [7]:
# Define metrics (dice coefficient)
def dice_coefficient(targets, preds, smooth=1e-6):
    # preds = (preds > 0.5).float 
    intersection = torch.sum(preds * targets, dim=(2,3))
    # want close to 1 (identical)
    dice = (2. * intersection + smooth) / (torch.sum(preds, dim=(2,3)) + torch.sum(targets, dim=(2,3)) + smooth)
    return dice.mean()

In [8]:
# Initialize model (device, model, loss, optimizer)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(in_channels=1).to(device)
loss_func = torch.nn.BCELoss(reduction='mean')
optimizer = Adam(model.parameters(), lr=lr)

In [9]:
for images, masks in train_loader:
    print(images.shape)
    print(masks.shape)
    break

torch.Size([8, 1, 256, 256])
torch.Size([8, 2, 256, 256])


In [10]:
len(images)

8

In [None]:
# TODO: Train model 
train_losses = []
dice_scores = []
model.train()

# for epoch in epochs:
train_loss = 0.
dice = 0.

for images, masks in train_loader:
    images, masks = images.to(device), masks.to(device)

    optimizer.zero_grad()
    output = model(images)
    print(output.shape)
    loss = loss_func(output, masks)
    loss.backward()
    optimizer.step()

    train_loss += loss.item()
    dice += dice_coefficient(output, masks).item()

train_loss /= len(train_loader)
dice /= len(train_loader)
print("train loss", train_loss)
print("dice score", dice)

torch.Size([8, 2, 256, 256])
torch.Size([8, 2, 256, 256])
torch.Size([8, 2, 256, 256])
torch.Size([8, 2, 256, 256])
torch.Size([8, 2, 256, 256])
torch.Size([8, 2, 256, 256])
torch.Size([8, 2, 256, 256])
torch.Size([8, 2, 256, 256])
torch.Size([8, 2, 256, 256])
torch.Size([8, 2, 256, 256])
torch.Size([8, 2, 256, 256])
torch.Size([8, 2, 256, 256])
torch.Size([8, 2, 256, 256])
torch.Size([8, 2, 256, 256])
torch.Size([8, 2, 256, 256])
torch.Size([8, 2, 256, 256])
torch.Size([8, 2, 256, 256])
torch.Size([8, 2, 256, 256])
torch.Size([8, 2, 256, 256])
torch.Size([8, 2, 256, 256])
torch.Size([8, 2, 256, 256])
torch.Size([8, 2, 256, 256])
torch.Size([8, 2, 256, 256])
torch.Size([8, 2, 256, 256])
torch.Size([8, 2, 256, 256])
torch.Size([8, 2, 256, 256])
torch.Size([8, 2, 256, 256])
torch.Size([8, 2, 256, 256])
torch.Size([8, 2, 256, 256])
torch.Size([8, 2, 256, 256])
torch.Size([8, 2, 256, 256])
torch.Size([8, 2, 256, 256])
torch.Size([8, 2, 256, 256])
torch.Size([8, 2, 256, 256])
torch.Size([8,

In [None]:
# To consider: adding grayscale / normalization (either for images or batches), adding a remove nerves function
# Implement early stopping?

In [None]:
# TODO: Tune model (cv for lr, epochs, batch size, etc.)