In [None]:
import os
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torch.cuda.amp import autocast, GradScaler
from sklearn.metrics import confusion_matrix, classification_report
from tqdm import tqdm
from einops import rearrange, repeat
import albumentations as A
from albumentations.pytorch import ToTensorV2

# --- CONFIGURATION (Safe & Fast) ---
ROOT_DIR = '/kaggle/input/retinal-fundus-dr/retinal-dr-n/retinal-dr-n'
IMG_SIZE = 224   # Reduced to standard size to prevent OOM
BATCH_SIZE = 32  # Standard batch size
EPOCHS = 45      # Sufficient for convergence
LR = 1e-4        # Standard learning rate
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"ðŸš€ Starting Research Run on {DEVICE}...")

# --- 1. ROBUST DATASET (Handles your messy files) ---
class RetinalDataset(Dataset):
    def __init__(self, root_dir, split='train', transform=None):
        self.root_dir = os.path.join(root_dir, split)
        self.transform = transform
        self.samples = []
        self.class_map = {"Normal Fundus": 0, "Mild DR": 1, "Moderate DR": 2, "Severe DR": 3, "Proliferate DR": 4}
        
        for cls_name, label in self.class_map.items():
            folder = os.path.join(self.root_dir, cls_name)
            if not os.path.exists(folder): continue
            for f in os.listdir(folder):
                if f.lower().endswith(('.jpg', '.jpeg', '.png')):
                    self.samples.append((os.path.join(folder, f), label))

    def __len__(self): return len(self.samples)
    
    def __getitem__(self, idx):
        path, label = self.samples[idx]
        img = cv2.imread(path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        if self.transform:
            img = self.transform(image=img)['image']
        return img, label

# Augmentation
transform = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Normalize(),
    ToTensorV2()
])

# Loaders with Imbalance Fix
train_ds = RetinalDataset(ROOT_DIR, 'train', transform)
val_ds = RetinalDataset(ROOT_DIR, 'val', transform)

# Weighted Sampler
targets = [s[1] for s in train_ds.samples]
counts = np.bincount(targets)
weights = 1. / counts
sample_weights = [weights[t] for t in targets]
sampler = WeightedRandomSampler(sample_weights, len(sample_weights))

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, sampler=sampler, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f"âœ… Data Loaded: {len(train_ds)} Training, {len(val_ds)} Validation")

# --- 2. MODEL: Simplified MambaVision (Pure PyTorch) ---
class MambaBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.in_proj = nn.Linear(dim, 2 * dim)
        self.conv1d = nn.Conv1d(dim, dim, 3, padding=1, groups=dim) # Standard Conv
        self.out_proj = nn.Linear(dim, dim)
        self.act = nn.SiLU()
    
    def forward(self, x):
        # x: [B, L, D]
        B, L, D = x.shape
        x_and_z = self.in_proj(x)
        x_branch, z_branch = x_and_z.chunk(2, dim=-1)
        
        # Conv Branch
        x_branch = x_branch.transpose(1, 2)
        x_branch = self.act(self.conv1d(x_branch))
        x_branch = x_branch.transpose(1, 2)
        
        # Simplified Gating (No complex SSM to avoid CUDA errors)
        # This approximates the mixing behavior for visual tasks
        y = x_branch * F.sigmoid(z_branch)
        return self.out_proj(y)

class MambaVisionSimple(nn.Module):
    def __init__(self, num_classes=5):
        super().__init__()
        # 1. Stem (CNN for features)
        self.stem = nn.Sequential(
            nn.Conv2d(3, 64, 7, stride=2, padding=3), nn.BatchNorm2d(64), nn.ReLU(),
            nn.MaxPool2d(3, stride=2, padding=1),
            nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.BatchNorm2d(128), nn.ReLU()
        )
        
        # 2. Body (Mamba Blocks)
        self.blocks = nn.ModuleList([MambaBlock(128) for _ in range(6)])
        
        # 3. Head
        self.norm = nn.LayerNorm(128)
        self.head = nn.Linear(128, num_classes)
        
    def forward(self, x):
        # Extract features (CNN)
        x = self.stem(x) # [B, 128, H/8, W/8]
        
        # Flatten for Mamba
        B, C, H, W = x.shape
        x = x.flatten(2).transpose(1, 2) # [B, L, C]
        
        # Process (Mamba)
        for block in self.blocks:
            x = x + block(x)
            
        # Classify
        x = self.norm(x.mean(dim=1))
        return self.head(x)

model = MambaVisionSimple(num_classes=5).to(DEVICE)
optimizer = optim.AdamW(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()
scaler = GradScaler() # Mixed Precision

# --- 3. TRAINING LOOP (With Graphs) ---
history = {'loss': [], 'acc': [], 'val_loss': [], 'val_acc': []}
best_acc = 0

for epoch in range(EPOCHS):
    model.train()
    total_loss, correct, total = 0, 0, 0
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    for img, label in pbar:
        img, label = img.to(DEVICE), label.to(DEVICE)
        
        optimizer.zero_grad()
        with autocast(): # Faster & Less Memory
            out = model(img)
            loss = criterion(out, label)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item()
        preds = out.argmax(dim=1)
        correct += (preds == label).sum().item()
        total += label.size(0)
        pbar.set_postfix(loss=loss.item())

    history['loss'].append(total_loss / len(train_loader))
    history['acc'].append(correct / total)
    
    # Validation
    model.eval()
    val_correct, val_total, val_loss = 0, 0, 0
    with torch.no_grad():
        for img, label in val_loader:
            img, label = img.to(DEVICE), label.to(DEVICE)
            out = model(img)
            val_loss += criterion(out, label).item()
            val_correct += (out.argmax(1) == label).sum().item()
            val_total += label.size(0)
            
    val_acc = val_correct / val_total
    history['val_loss'].append(val_loss / len(val_loader))
    history['val_acc'].append(val_acc)
    
    print(f"Validation Acc: {val_acc:.4f}")
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), "best_model.pth")

# --- 4. RESEARCH OUTPUTS ---
# Plot Curves
plt.figure(figsize=(10, 5))
plt.subplot(1,2,1); plt.plot(history['loss'], label='Train'); plt.plot(history['val_loss'], label='Val'); plt.title("Loss"); plt.legend()
plt.subplot(1,2,2); plt.plot(history['acc'], label='Train'); plt.plot(history['val_acc'], label='Val'); plt.title("Accuracy"); plt.legend()
plt.savefig("training_graphs.png")
plt.show()

# Confusion Matrix
model.load_state_dict(torch.load("best_model.pth"))
model.eval()
all_preds, all_labels = [], []
with torch.no_grad():
    for img, label in val_loader:
        img, label = img.to(DEVICE), label.to(DEVICE)
        all_preds.extend(model(img).argmax(1).cpu().numpy())
        all_labels.extend(label.cpu().numpy())

cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(8,6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=train_ds.class_map.keys(), yticklabels=train_ds.class_map.keys())
plt.title("Confusion Matrix")
plt.savefig("confusion_matrix.png")
plt.show()

print("\nFinal Report:\n", classification_report(all_labels, all_preds, target_names=train_ds.class_map.keys()))