In [None]:
#@title Preperation
from google.colab import drive
drive.mount('/content/drive/')
!rm -rf sample_data
!cp drive/MyDrive/Finals.zip .
!unzip Finals.zip
!mv Finals/data .
!rm Finals.zip
drive.flush_and_unmount()

In [None]:
#@title Install Required Libraries and Import Libraries
!pip install torch torchvision matplotlib tqdm

import os
import torch
from torch import nn
import torchvision
from torchvision.transforms import functional as F
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torchvision.datasets import CocoDetection
import gc
import torch.optim as optim
import torchvision.transforms as transforms
import numpy as np
from PIL import Image
import pandas as pd
import cv2
from torch.utils.data import Dataset, DataLoader
from torchvision.models.detection import maskrcnn_resnet50_fpn, MaskRCNN_ResNet50_FPN_Weights
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
import imageio


In [None]:
#@title Generate Masks
import os
import cv2
import pandas as pd
import numpy as np
from PIL import Image

def create_binary_masks(input_folder):
    image_folder = os.path.join(input_folder, 'images')
    label_folder = os.path.join(input_folder, 'labels')
    mask_folder = os.path.join(input_folder, 'masks')
    image_shape = (380, 540)

    os.makedirs(mask_folder, exist_ok=True)

    def create_binary_mask(center_x, center_y, major_axis, minor_axis, angle, image_shape):
        mask = np.zeros(image_shape, dtype=np.uint8)
        center = (int(center_x), int(center_y))
        axes = (int(major_axis / 2), int(minor_axis / 2))
        cv2.ellipse(mask, center, axes, angle, 0, 360, 255, -1)
        return mask

    image_files = [f for f in os.listdir(image_folder) if f.endswith(('.png', '.jpg'))]

    for image_file in image_files:
        csv_file = os.path.join(label_folder, os.path.splitext(image_file)[0] + '.csv')

        if os.path.exists(csv_file):
            df = pd.read_csv(csv_file)
            mask_image = np.zeros(image_shape, dtype=np.uint8)

            for index, row in df.iterrows():
                center_x = row['Center_X']
                center_y = row['Center_Y']
                major_axis = row['Major_Axis']
                minor_axis = row['Minor_Axis']
                angle = row['Angle']
                mask = create_binary_mask(center_x, center_y, major_axis, minor_axis, angle, image_shape)
                mask_image = cv2.add(mask_image, mask)

            output_path = os.path.join(mask_folder, os.path.splitext(image_file)[0] + '.gif')
            img = Image.fromarray(mask_image)
            img.save(output_path)

    print('Masks created successfully.')

create_binary_masks('data/train')
create_binary_masks('data/val')



In [None]:
#@title Dataset and Data Loader

IMAGE_HEIGHT = 380
IMAGE_WIDTH = 540

class PupilDataset(torch.utils.data.Dataset):
    def __init__(self, root, transforms=None):
        self.root = root
        self.transforms = transforms
        self.imgs = list(sorted(os.listdir(os.path.join(root, "images"))))
        self.masks = list(sorted(os.listdir(os.path.join(root, "masks"))))

    def __getitem__(self, idx):
        img_path = os.path.join(self.root, "images", self.imgs[idx])
        mask_path = os.path.join(self.root, "masks", self.masks[idx])
        img = Image.open(img_path).convert("L")
        mask = Image.open(mask_path)

        img_np = np.array(img)
        mask_np = np.array(mask)

        if self.transforms is not None:
            transformed = self.transforms(image=img_np, mask=mask_np)
            img_np = transformed['image']
            mask_np = transformed['mask']

        img = img_np.squeeze().astype(np.float32) / 255.0
        img = torch.from_numpy(img).unsqueeze(0)

        mask = torch.as_tensor(mask_np, dtype=torch.uint8)

        obj_ids = torch.unique(mask)[1:]
        masks = mask == obj_ids[:, None, None]
        num_objs = len(obj_ids)
        boxes = []
        for i in range(num_objs):
            pos = torch.where(masks[i])
            xmin = torch.min(pos[1])
            xmax = torch.max(pos[1])
            ymin = torch.min(pos[0])
            ymax = torch.max(pos[0])
            if xmin < xmax and ymin < ymax:
                boxes.append([xmin.item(), ymin.item(), xmax.item(), ymax.item()])
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.ones((num_objs,), dtype=torch.int64)
        masks = masks.type(torch.uint8)

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["masks"] = masks
        target["image_id"] = torch.tensor([idx])
        if len(boxes) > 0:
            target["area"] = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        else:
            target["area"] = torch.zeros((0,), dtype=torch.float32)
        target["iscrowd"] = torch.zeros((len(boxes),), dtype=torch.int64)

        if len(boxes) == 0:
            target["boxes"] = torch.zeros((0, 4), dtype=torch.float32)
            target["labels"] = torch.zeros((0,), dtype=torch.int64)
            target["masks"] = torch.zeros((0, img.shape[1], img.shape[2]), dtype=torch.uint8)
            target["area"] = torch.zeros((0,), dtype=torch.float32)
            target["iscrowd"] = torch.zeros((0,), dtype=torch.int64)

        return img, target

    def __len__(self):
        return len(self.imgs)

def get_transform():
    return A.Compose(
        [
            A.Rotate(limit=15, p=0.5),
            A.HorizontalFlip(p=0.25),
            A.VerticalFlip(p=0.25),
            A.RandomResizedCrop(height=IMAGE_HEIGHT, width=IMAGE_WIDTH, scale=(0.9, 1.0), ratio=(0.9, 1.1), p=0.2),
            A.RandomBrightnessContrast(p=0.2),
            A.RandomGamma(p=0.2),
        ],
        additional_targets={'mask': 'image'}
    )

def collate_fn(batch):
    return tuple(zip(*batch))

def get_dataloader(root, batch_size=15, shuffle=False, num_workers=4):
    transforms = get_transform() if not shuffle else None
    dataset = PupilDataset(root, transforms)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, collate_fn=collate_fn)
    return data_loader


In [None]:
#@title Train Model

import os
import torch
import torchvision
from tqdm import tqdm

def get_model_instance_segmentation(num_classes):
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT")
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
    model.best_val_loss = float('inf')
    return model

num_classes = 2
model = get_model_instance_segmentation(num_classes)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

best_model_path = 'best_model.pth'
if os.path.exists(best_model_path):
    print("Loading best model...")
    model = get_model_instance_segmentation(num_classes)
    checkpoint = torch.load(best_model_path, map_location=device)
    if 'model_state_dict' in checkpoint and 'best_val_loss' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
        model.best_val_loss = checkpoint['best_val_loss']
        print("Best model loaded successfully.")
    else:
        print("Checkpoint does not contain expected keys. Initializing new model...")
else:
    print("No existing best model found. Initializing new model...")
    model = get_model_instance_segmentation(num_classes)

model.to(device)

train_data_loader = get_dataloader('data/train')
val_data_loader = get_dataloader('data/val', shuffle=False)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(params, lr=0.000001, weight_decay=0.000001)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

num_epochs = 100
patience = 10
counter = 0
model.train()

for epoch in range(num_epochs):
    epoch_loss = 0
    train_loader = tqdm(train_data_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
    for images, targets in train_loader:
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        optimizer.zero_grad()
        losses.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        epoch_loss += losses.item()
        train_loader.set_postfix(loss=losses.item())
    lr_scheduler.step()
    val_loss = 0
    num_valid_batches = 0
    with torch.no_grad():
        val_loader = tqdm(val_data_loader, desc=f"Validation Epoch {epoch+1}/{num_epochs}")
        for images, targets in val_loader:
            images = list(image.to(device) for image in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            val_loss += losses.item()
            val_loader.set_postfix(loss=losses.item())
            num_valid_batches += 1
    avg_val_loss = val_loss / num_valid_batches if num_valid_batches > 0 else 0
    print(f"Epoch {epoch+1}, Training Loss: {epoch_loss / len(train_data_loader):.4f}, Validation Loss: {avg_val_loss:.4f}")

    if avg_val_loss < model.best_val_loss:
        model.best_val_loss = avg_val_loss
        torch.save({
            'model_state_dict': model.state_dict(),
            'best_val_loss': model.best_val_loss,
        }, best_model_path)
        print(f"Saved best model with validation loss: {model.best_val_loss:.4f}")
        counter = 0
    else:
        counter += 1
        print(f"Early Stopping counter: {counter} out of {patience}")

    if counter >= patience:
        print("Early stopping")
        break

print("Training completed.")

In [None]:
#@title Visualization
def visualize(image, masks):
    plt.figure(figsize=(10, 10))
    plt.imshow(image.permute(1, 2, 0).cpu().numpy())
    for mask in masks:
        plt.imshow(mask.cpu().numpy(), alpha=0.5)
    plt.axis('off')
    plt.show()

masks = prediction['masks'] > 0.5
visualize(image, masks)


In [None]:
#@title Clear GPU Ram

torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
gc.collect()

In [None]:
#@title Get the Segmentation
class PupilDataset(Dataset):
    def __init__(self, image_dir, label_dir, transforms=None):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.jpg')]
        self.transform = transforms

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        image_path = os.path.join(self.image_dir, img_name)
        label_path = os.path.join(self.label_dir, os.path.splitext(img_name)[0] + '.csv')

        image = Image.open(image_path).convert("RGB")
        label_data = pd.read_csv(label_path)

        if self.transform:
            image = self.transform(image)

        return image, label_data

def get_transform():
    return transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor()
    ])

def get_model_instance_segmentation(num_classes):
    weights = MaskRCNN_ResNet50_FPN_Weights.COCO_V1
    model = maskrcnn_resnet50_fpn(weights=weights)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
    return model

def load_model(model_path, num_classes):
    model = get_model_instance_segmentation(num_classes)
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    return model

def draw_red_eyeball(image, mask):
    mask = mask.squeeze().cpu().numpy()
    image_np = image.permute(1, 2, 0).cpu().numpy()
    image_np = (image_np * 255).astype(np.uint8)  # Ensure image is in uint8
    mask = cv2.resize(mask, (image_np.shape[1], image_np.shape[0]))
    mask = (mask > 0.5).astype(np.uint8)

    # Create a red overlay with alpha transparency
    red_overlay = np.zeros_like(image_np, dtype=np.uint8)
    red_overlay[:, :, 0] = 255  # Red channel
    alpha = 0.5  # Transparency factor
    overlayed_image = cv2.addWeighted(image_np, 1 - alpha, red_overlay, alpha, 0)

    # Apply the mask to the overlay
    overlayed_image[mask == 0] = image_np[mask == 0]

    return overlayed_image

def custom_collate_fn(batch):
    images = [item[0] for item in batch]
    labels = [item[1] for item in batch]
    return images, labels

def visualize_images_to_gif(image_dir, label_dir, model_path, output_gif_path):
    transform = get_transform()

    dataset = PupilDataset(image_dir, label_dir, transforms=transform)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=custom_collate_fn)

    model = load_model(model_path, num_classes=2)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    frames = []

    for i, (image, label_data) in enumerate(dataloader):
        image = image[0].to(device)  # Unpack single-element batch
        with torch.no_grad():
            prediction = model([image])

        mask = prediction[0]['masks'][0, 0]
        img_with_eyeball = draw_red_eyeball(image, mask)

        img_with_eyeball = cv2.cvtColor(img_with_eyeball, cv2.COLOR_BGR2RGB)
        frames.append(img_with_eyeball)

    imageio.mimsave(output_gif_path, frames, fps=10)
    print(f"GIF saved to {output_gif_path}")

image_dir = 'data/train/images'
label_dir = 'data/train/labels'
model_path = 'best_model.pth'
output_gif_path = 'output_video.gif'

visualize_images_to_gif(image_dir, label_dir, model_path, output_gif_path)
