# UNet Segmentation of Optic Disc and Optic Cup

In [38]:
import os
import numpy as np
import torch
from PIL import Image
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split

In [6]:
origa_path = os.path.join("..", "data", "ORIGA")
images_path = os.path.join(origa_path, "Images_Square")
masks_path = os.path.join(origa_path, "Masks_Square")

In [41]:
# Pre-process fundus images: convert to RGB array, convert to tensors, resize using bilinear interpolation
img_filenames = sorted(os.listdir(images_path))
images = []

for img in img_filenames:
    img_name = os.path.join(images_path, img)
    img = np.array(Image.open(img_name).convert('RGB'))
    img = transforms.functional.to_tensor(img)
    img = transforms.functional.resize(img, size=(256,256), interpolation=Image.BILINEAR)
    images.append(img)


In [42]:
# Pre-process mask images: convert to array, create binary masks, convert to tensors, and resize using nearest neighbor
mask_filenames = sorted(os.listdir(masks_path))
masks = []

for mask in mask_filenames:
    mask_name = os.path.join(masks_path, mask)
    mask = np.array(Image.open(mask_name, mode='r'))

    # Create binary masks for optic disc and optic cup classes
    od = (mask==1.).astype(np.float32)
    oc = (mask==2.).astype(np.float32)

    # Convert to tensor and add batch dimension
    od = torch.from_numpy(od[None,:,:]) # (1, Height, Width)
    oc = torch.from_numpy(oc[None,:,:])

    # Resize using nearest neighbor interpolation
    od = transforms.functional.resize(od, size=(256,256), interpolation=Image.NEAREST)
    oc = transforms.functional.resize(oc, size=(256,256), interpolation=Image.NEAREST)
    masks.append(torch.cat([od, oc], dim=0))

In [43]:
# Split into train, validation, and test sets (70, 15, 15)
train_imgs, temp_imgs, train_masks, temp_masks = train_test_split(
    images, masks, test_size=0.3, random_state=42)
val_imgs, test_imgs, val_masks, test_masks = train_test_split(
    temp_imgs, temp_masks, test_size=0.5, random_state=42
)

In [9]:
# TODO: Set parameters (learning rate, batch size, epochs, etc.)

In [10]:
# TODO: Load data (with DataLoader?) and initialize model (device, model, loss, optimizer)

In [11]:
# TODO: Train model 

In [None]:
# To consider: adding grayscale / normalization (either for images or batches), adding a remove nerves function