In [1]:
import kagglehub
import os
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.model_selection import GroupShuffleSplit
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np
import seaborn as sns

# --- 1. CONFIG ---
BATCH_SIZE = 32
LEARNING_RATE = 1e-4
EPOCHS = 30
IMG_SIZE = 64
device = "cuda" if torch.cuda.is_available() else "cpu"

# --- 2. DATA LOADING ---
print("Downloading Dataset...")
DATASET_PATH = kagglehub.dataset_download("prasadvpatil/mrl-dataset/versions/3")
TRAIN_DIR = os.path.join(DATASET_PATH, "train")

classes = sorted(os.listdir(TRAIN_DIR)) # ["Closed_Eyes", "Open_Eyes"]
all_paths, all_labels, all_subjects = [], [], []

for label_idx, label in enumerate(classes):
    class_dir = os.path.join(TRAIN_DIR, label)
    for img_file in os.listdir(class_dir):
        if img_file.lower().endswith((".png", ".jpg", ".jpeg")):
            all_paths.append(os.path.join(class_dir, img_file))
            all_labels.append(label_idx)
            # Extract Subject ID (e.g. s0001 from s0001_00234...)
            all_subjects.append(os.path.basename(img_file).split('_')[0])

df = pd.DataFrame({'path': all_paths, 'label': all_labels, 'subject': all_subjects})

# --- 3. SPLITTING (The Correct Way) ---
# Step A: Split off 10% for the TEST VAULT (Never touch during training)
splitter_test = GroupShuffleSplit(test_size=0.1, n_splits=1, random_state=42)
train_val_idx, test_idx = next(splitter_test.split(df, groups=df['subject']))

train_val_df = df.iloc[train_val_idx]
test_df = df.iloc[test_idx]

# Step B: Split the remaining 90% into TRAIN (80%) and VAL (10%)
# 0.11 of the remaining 90% is approx 10% of the total
splitter_val = GroupShuffleSplit(test_size=0.11, n_splits=1, random_state=42)
train_idx, val_idx = next(splitter_val.split(train_val_df, groups=train_val_df['subject']))

train_df = train_val_df.iloc[train_idx]
val_df = train_val_df.iloc[val_idx]

print(f"Train Subjects: {train_df['subject'].nunique()} ({len(train_df)} images)")
print(f"Val Subjects:   {val_df['subject'].nunique()} ({len(val_df)} images)")
print(f"Test Subjects:  {test_df['subject'].nunique()} ({len(test_df)} images) <- HIDDEN VAULT")

# --- 4. DATASET CLASS ---
class EyeDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.df = dataframe
        self.transform = transform

    def __len__(self): return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(row['path']).convert("L") # Grayscale
        if self.transform: img = self.transform(img)
        return img, row['label']

# Transforms
train_tfm = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

eval_tfm = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Loaders
train_ds = EyeDataset(train_df, train_tfm)
val_ds = EyeDataset(val_df, eval_tfm)
test_ds = EyeDataset(test_df, eval_tfm)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE)

# --- 5. MODEL (Custom CNN) ---
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            # Block 1
            nn.Conv2d(1, 32, 3, padding=1),
            nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2),
            # Block 2
            nn.Conv2d(32, 64, 3, padding=1),
            nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2),
            # Block 3
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(), nn.MaxPool2d(2),
            # Head
            nn.Flatten(),
            nn.Dropout(0.5),
            nn.Linear(128 * (IMG_SIZE//8) * (IMG_SIZE//8), 128),
            nn.ReLU(),
            nn.Linear(128, 2)
        )
    def forward(self, x): return self.net(x)

model = CNN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True)

# --- 6. TRAINING ---
best_acc = 0.0
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    for imgs, lbls in train_loader:
        imgs, lbls = imgs.to(device), lbls.to(device)
        optimizer.zero_grad()
        out = model(imgs)
        loss = criterion(out, lbls)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    # Validation
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for imgs, lbls in val_loader:
            imgs, lbls = imgs.to(device), lbls.to(device)
            out = model(imgs)
            _, pred = torch.max(out, 1)
            correct += (pred == lbls).sum().item()
            total += lbls.size(0)
    
    val_acc = correct / total
    print(f"Epoch {epoch+1}: Loss {total_loss/len(train_loader):.4f} | Val Acc: {val_acc:.4f}")
    
    scheduler.step(total_loss)
    
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), "best_eye_model.pth")

print("Training Done.")

# --- 7. FINAL TEST (The Truth) ---
print("\n--- FINAL TEST SET EVALUATION ---")
model.load_state_dict(torch.load("best_eye_model.pth"))
model.eval()

y_true, y_pred = [], []
with torch.no_grad():
    for imgs, lbls in test_loader:
        imgs = imgs.to(device)
        out = model(imgs)
        _, pred = torch.max(out, 1)
        y_true.extend(lbls.numpy())
        y_pred.extend(pred.cpu().numpy())

print(classification_report(y_true, y_pred, target_names=classes))

# Save for Export
torch.save(model.state_dict(), "final_model_weights.pth")

  from .autonotebook import tqdm as notebook_tqdm


: 