In [18]:
#@title Copy Dataset
%%capture

from google.colab import drive
drive.mount('/content/drive')
!cp drive/MyDrive/new_dataset.zip .
drive.flush_and_unmount()
!unzip new_dataset.zip
!mv new_dataset/* .
!rm -r new_dataset* sample_data


In [17]:
#@title Import Libraries

import torch
import torch.nn as nn
import torchvision.transforms.functional as TF
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torchvision
import cv2
import os
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import gc
import torchvision.transforms as transforms
import torch.multiprocessing as mp
import re
from time import time
import multiprocessing


In [13]:
#@title Define Model and Loader

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)


class UNET(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, features=[64, 128, 256, 512]):
        super(UNET, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        for feature in reversed(features):
            self.ups.append(nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2))
            self.ups.append(DoubleConv(feature * 2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1] * 2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx // 2]
            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:])

            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx + 1](concat_skip)

        return self.final_conv(x)


class PupilDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.images[index].replace(".jpg",".gif"))
        image = np.array(Image.open(img_path).convert("L"))
        mask = np.array(Image.open(mask_path).convert("L"),dtype=np.float32)
        mask[mask == 255.0] = 1.0

        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]

        return image, mask


In [16]:
#@title Define Helper Functions

def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)


def load_checkpoint(checkpoint, model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])

def get_loaders(
    train_dir,
    train_maskdir,
    val_dir,
    val_maskdir,
    batch_size,
    train_transform,
    val_transform,
    num_workers=4,
    pin_memory=True,
):
    train_ds = PupilDataset(image_dir=train_dir, mask_dir=train_maskdir, transform=train_transform)

    train_loader = DataLoader(
        train_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=True
    )
    val_ds = PupilDataset(image_dir=val_dir, mask_dir=val_maskdir, transform=val_transform)

    val_loader = DataLoader(
        val_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=False
    )

    return train_loader, val_loader

def check_accuracy(loader,model, device="cuda"):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()

    with torch.no_grad():
        for x,y in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / ((preds + y).sum() + 1e-8)
    print(f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}")
    print(f"Dice score: {dice_score/len(loader)}")
    model.train()
    return dice_score/len(loader)

def save_predictions_as_imgs(
    loader,model,folder="saved_images/",device="cuda"
):
    model.eval()
    for idx, (x,y) in enumerate(loader):
        x = x.to(device=device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
        torchvision.utils.save_image(
            preds, f"{folder}/pred_{idx}.png"
        )
        y = y.float() / y.max()
        torchvision.utils.save_image(y.unsqueeze(1), f"{folder}.png")
    model.train()


In [None]:
#@title Train the Model

LEARNING_RATE = 1e-3
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 40
NUM_EPOCHS = 1000
NUM_WORKERS = 8
IMAGE_HEIGHT = 380
IMAGE_WIDTH = 540
PIN_MEMORY = True
LOAD_MODEL = False
TRAIN_IMG_DIR = "train/images"
TRAIN_MASK_DIR = "train/masks"
VAL_IMG_DIR = "val/images"
VAL_MASK_DIR = "val/masks"

def train_fn(loader, model, optimizer, bce_loss, scaler, epoch):
    model.train()
    total_bce_loss = 0

    progress_bar = tqdm(loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} - Training", leave=True)

    for batch_idx, (data, targets) in enumerate(progress_bar):
        data = data.to(DEVICE)
        targets = targets.float().unsqueeze(1).to(DEVICE)

        with torch.cuda.amp.autocast():
            predictions = model(data)
            bce = bce_loss(predictions, targets)

        optimizer.zero_grad()
        scaler.scale(bce).backward()
        scaler.step(optimizer)
        scaler.update()

        total_bce_loss += bce.item()

        progress_bar.set_postfix({"BCE Loss": f"{bce.item():.4f}"})

    return total_bce_loss / len(loader)

def validate_fn(loader, model, bce_loss, epoch):
    model.eval()
    total_bce_loss = 0

    progress_bar = tqdm(loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} - Validation", leave=True)

    with torch.no_grad():
        for batch_idx, (data, targets) in enumerate(progress_bar):
            data = data.to(DEVICE)
            targets = targets.float().unsqueeze(1).to(DEVICE)

            predictions = model(data)
            bce = bce_loss(predictions, targets)

            total_bce_loss += bce.item()

            progress_bar.set_postfix({"BCE Loss": f"{bce.item():.4f}"})

    return total_bce_loss / len(loader)

def main():
    train_transform = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Rotate(limit=5, p=0.5),
            A.HorizontalFlip(p=0.25),
            A.VerticalFlip(p=0.25),
            A.RandomResizedCrop(height=IMAGE_HEIGHT, width=IMAGE_WIDTH, scale=(0.9, 1.0), ratio=(0.9, 1.1), p=0.2),
            A.RandomBrightnessContrast(p=0.2),
            A.RandomGamma(p=0.2),
            A.Normalize(mean=0.0, std=1.0, max_pixel_value=255.0),
            ToTensorV2(),
        ]
    )

    val_transforms = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Normalize(mean=0.0, std=1.0, max_pixel_value=255.0),
            ToTensorV2(),
        ]
    )

    model = UNET(in_channels=1, out_channels=1).to(DEVICE)
    bce_loss = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

    train_loader, val_loader = get_loaders(
        TRAIN_IMG_DIR, TRAIN_MASK_DIR, VAL_IMG_DIR, VAL_MASK_DIR, BATCH_SIZE, train_transform, val_transforms, NUM_WORKERS, PIN_MEMORY
    )

    best_bce = float('inf')

    if LOAD_MODEL:
        checkpoint = torch.load("my_checkpoint.pth.tar")
        load_checkpoint(checkpoint, model)
        best_bce = checkpoint.get("best_bce", float('inf'))

    scaler = torch.cuda.amp.GradScaler()

    for epoch in range(NUM_EPOCHS):
        train_bce_loss = train_fn(train_loader, model, optimizer, bce_loss, scaler, epoch)

        current_bce = validate_fn(val_loader, model, bce_loss, epoch)

        scheduler.step(current_bce)

        if current_bce < best_bce:
            best_bce = current_bce
            print(f"New best model found! Saving model with validation BCE loss: {best_bce:.4f}")
            checkpoint = {
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "best_bce": best_bce
            }
            save_checkpoint(checkpoint, filename="best_model.pth.tar")

        if epoch % 5 == 0:
            checkpoint = {
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "best_bce": best_bce
            }
            save_checkpoint(checkpoint)

        if not os.path.exists("saved_images"):
            os.makedirs("saved_images")
        save_predictions_as_imgs(val_loader, model, folder="saved_images", device=DEVICE)

if __name__ == "__main__":
    main()


In [None]:
#@title Clear Cache and Display Memory

torch.cuda.empty_cache()
gc.collect()
print(torch.cuda.memory_summary())
print(f"Memory Allocated: {torch.cuda.memory_allocated()} bytes")
print(f"Memory Reserved: {torch.cuda.memory_reserved()} bytes")


In [None]:
#@title Generate the Video

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IMAGE_HEIGHT = 380
IMAGE_WIDTH = 540
BATCH_SIZE = 58
NUM_WORKERS = multiprocessing.cpu_count()

class ImageDataset(Dataset):
    def __init__(self, image_dir, transform):
        self.image_dir = image_dir
        self.transform = transform
        self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.png') or f.endswith('.jpg')]
        self.image_files.sort(key=lambda s: [int(c) if c.isdigit() else c.lower() for c in re.split(r'(\d+)', s)])

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        augmented = self.transform(image=image)
        return augmented['image'], img_path


def load_model(model_path):
    model = UNET(in_channels=1, out_channels=1).to(DEVICE)
    checkpoint = torch.load(model_path, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    model.eval()
    return model

def get_transforms():
    return A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Normalize(mean=0.0, std=1.0, max_pixel_value=255.0),
            ToTensorV2(),
        ]
    )

def generate_masks(model, dataloader):
    all_masks = []
    all_paths = []

    for batch, paths in tqdm(dataloader, desc="Generating masks"):
        batch = batch.to(DEVICE)
        with torch.no_grad():
            predictions = torch.sigmoid(model(batch))
            predictions = (predictions > 0.5).float()

        masks = (predictions.cpu().numpy() * 255).astype(np.uint8)
        all_masks.extend(masks)
        all_paths.extend(paths)

    return all_masks, all_paths

def is_elliptical(mask, threshold=0.85):
    contours, _ = cv2.findContours(mask.squeeze(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if not contours:
        return False

    largest_contour = max(contours, key=cv2.contourArea)

    if len(largest_contour) < 5:
        return False

    ellipse = cv2.fitEllipse(largest_contour)
    ellipse_mask = np.zeros_like(mask.squeeze())
    cv2.ellipse(ellipse_mask, ellipse, 255, -1)

    contour_area = cv2.contourArea(largest_contour)
    ellipse_area = np.sum(ellipse_mask > 0)

    if ellipse_area == 0:
        return False

    similarity = contour_area / ellipse_area
    return similarity > threshold

def process_images(image_paths, masks, centroids, eye_areas, model_start_time):
    results = []
    total_processed = 0
    for idx, (image_path, mask, (cx, cy), eye_area) in enumerate(zip(image_paths, masks, centroids, eye_areas), 1):
        image = cv2.imread(image_path)
        if is_elliptical(mask):
            red_overlay = np.zeros_like(image)
            red_overlay[:,:,2] = 255
            red_mask = cv2.bitwise_and(red_overlay, red_overlay, mask=mask.squeeze())
            result = cv2.addWeighted(image, 1, red_mask, 0.5, 0)
            eye_area = str(int(eye_area))
            cx_display = f"{cx:.2f}"
            cy_display = f"{cy:.2f}"
        else:
            result = image
            eye_area = "Undefined"
            cx_display = "N/A"
            cy_display = "N/A"

        total_processed += 1
        current_time = time() - model_start_time
        fps = total_processed / current_time if current_time > 0 else 0

        cyan_color = (255, 255, 0)
        font = cv2.FONT_HERSHEY_SIMPLEX
        font_scale = 0.6
        thickness = 2
        report_text = f"Area: {eye_area}, CX: {cx_display}, CY: {cy_display}, FPS: {fps:.2f}"
        cv2.putText(result, report_text, (20, 30), font, font_scale, cyan_color, thickness)

        results.append(result)
    return results

def create_gif(image_list, output_path, duration=200):
    images = [Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) for img in image_list]
    images[0].save(output_path, save_all=True, append_images=images[1:], duration=duration, loop=0)

def calculate_centroids(masks):
    centroids = []
    for mask in masks:
        M = cv2.moments(mask.squeeze())
        if M["m00"] != 0:
            cx = M["m10"] / M["m00"]
            cy = M["m01"] / M["m00"]
        else:
            cx, cy = 0, 0
        centroids.append((cx, cy))
    return centroids

def calculate_eye_areas(masks):
    return [np.sum(mask.squeeze() > 0) for mask in masks]

def main():
    model_path = "best_model.pth.tar"
    image_dir = "data/train/images"
    output_gif_path = "output.gif"

    model = load_model(model_path)
    transform = get_transforms()

    dataset = ImageDataset(image_dir, transform)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, pin_memory=True)

    data_load_time = time()
    masks, image_paths = generate_masks(model, dataloader)

    centroids = calculate_centroids(masks)
    eye_areas = calculate_eye_areas(masks)

    model_start_time = time()
    processed_images = process_images(image_paths, masks, centroids, eye_areas, model_start_time)

    create_gif(processed_images, output_gif_path)

    total_time = time() - model_start_time
    total_images = len(masks)
    average_fps = total_images / total_time
    print(f"GIF animation created at {output_gif_path}")
    print(f"Total processing time: {total_time:.2f} seconds")
    print(f"Total images processed: {total_images}")
    print(f"Average FPS: {average_fps:.2f}")

if __name__ == "__main__":
    main()
