In [2]:
# ==== imports ====
import os
import json
import numpy as np
import pandas as pd
from pathlib import Path
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm.notebook import tqdm
from transformers import AutoModel

# ============================
# 0) CONFIG
# ============================
LOCAL_DIR = '/scratch.global/kanth042/CV_project'
csv_path = f'{LOCAL_DIR}/labels_PICD.csv'
img_dir = f'{LOCAL_DIR}/PICD'

# --- safer HF auth (optional) ---
# export HUGGING_FACE_HUB_TOKEN in your env instead of hardcoding
HF_TOKEN = os.environ.get('HUGGING_FACE_HUB_TOKEN', None)

# ============================
# 1) CATEGORY REDUCTION (24 -> 10)
# ============================
# Original category_id in CSV is 1..24
GROUPS = {
    # Rule of Thirds
    1: 0, 2: 0,
    # Centered
    3: 1, 4: 1,
    # Diagonal
    5: 2, 6: 2,
    # Horizontal
    7: 3, 8: 3, 9: 3, 10: 3,
    # Vertical
    11: 4, 12: 4, 13: 4, 14: 4,
    # Triangle
    15: 5, 16: 5,
    # Curves
    17: 6, 18: 6, 19: 6,
    # Radial / Perspective
    20: 7, 21: 7,
    # Dense / Pattern
    22: 8, 23: 8,
    # Scatter
    24: 9,
}

GROUP_NAMES = [
    "RuleOfThirds", "Centered", "Diagonal", "Horizontal",
    "Vertical", "Triangle", "Curves", "RadialPerspective",
    "DensePattern", "Scatter"
]
NUM_GROUPS = len(GROUP_NAMES)

# ============================
# 2) DATA PREP + REMAP
# ============================
print('\n' + '='*70)
print('üìä DATA PREPARATION (with category reduction)')
print('='*70)

df = pd.read_csv(csv_path)
df['category_id'] = df['category_id'].astype(str)

# keep single-label rows
df_single = df[~df['category_id'].str.contains(',')].copy()
df_single['category_id'] = df_single['category_id'].astype(int)

# remap 24->10
df_single = df_single[df_single['category_id'].between(1,24)]
df_single['group_id'] = df_single['category_id'].map(GROUPS).astype(int)

# verify images exist
print('\nüîç Verifying images exist on disk...')
img_dir_path = Path(img_dir)
exists_mask = []
for _, row in tqdm(df_single.iterrows(), total=len(df_single), desc='Checking files'):
    p = img_dir_path / row['folder_name'] / row['img_id']
    exists_mask.append(p.exists())
df_existing = df_single[exists_mask].copy()

print(f'‚úì Images on disk: {len(df_existing):,}')
print(f'‚úó Missing images: {len(df_single)-len(df_existing):,}')

# show group distribution
print('\nüìä Group Distribution (after reduction):')
grp_counts = df_existing['group_id'].value_counts().sort_index()
for gid, cnt in grp_counts.items():
    print(f'  {gid:02d} {GROUP_NAMES[gid]:>18}: {cnt:5d}')

# ============================
# 3) TRAIN/VAL SPLIT (stratify by group)
# ============================
print('\n' + '='*70)
print('‚úÇÔ∏è  TRAIN/VAL SPLIT')
print('='*70)

train_df, val_df = train_test_split(
    df_existing,
    test_size=0.2,
    random_state=42,
    stratify=df_existing['group_id']
)
train_df.to_csv(f'{LOCAL_DIR}/train_split_reduced.csv', index=False)
val_df.to_csv(f'{LOCAL_DIR}/val_split_reduced.csv', index=False)
print(f'Train: {len(train_df):,}  Val: {len(val_df):,}')

# ============================
# 4) DATASET + TRANSFORMS
# ============================
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

class PICDDatasetReduced(Dataset):
    def __init__(self, df, img_root, transform=None):
        self.df = df.reset_index(drop=True)
        self.root = Path(img_root)
        self.transform = transform

    def __len__(self): return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = self.root / row['folder_name'] / row['img_id']
        if not img_path.exists():
            raise FileNotFoundError(f"Missing image: {img_path}")
        img = Image.open(img_path).convert('RGB')
        if self.transform: img = self.transform(img)
        label = int(row['group_id'])  # 0..9
        return img, label

train_dataset = PICDDatasetReduced(train_df, img_dir, transform)
val_dataset   = PICDDatasetReduced(val_df,   img_dir, transform)

batch_size = 256
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                          num_workers=8, pin_memory=True, prefetch_factor=2)
val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False,
                          num_workers=8, pin_memory=True, prefetch_factor=2)

print('\nDataloaders ready.')
print(f'Batch size: {batch_size} | Train batches: {len(train_loader)} | Val batches: {len(val_loader)}')

# ============================
# 5) MODEL (frozen DINOv3 + 10-way head)
# ============================
class DINOv3Classifier(nn.Module):
    def __init__(self, num_classes=NUM_GROUPS, model_name='facebook/dinov3-vitb16-pretrain-lvd1689m', token=None):
        super().__init__()
        print(f'Loading DINOv3: {model_name}')
        extra = {}
        if token: extra['token'] = token
        self.dinov3 = AutoModel.from_pretrained(model_name, **extra)
        self.dinov3.eval()
        for p in self.dinov3.parameters():
            p.requires_grad = False
        hidden = self.dinov3.config.hidden_size
        self.classifier = nn.Linear(hidden, num_classes)
        print(f'Backbone dim: {hidden}  -> classes: {num_classes}')

    def forward(self, x):
        with torch.no_grad():
            out = self.dinov3(x)
            feats = out.pooler_output
        return self.classifier(feats)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'\nüîß Device: {device}')
model = DINOv3Classifier(token=HF_TOKEN).to(device)

trainable_params = sum(p.numel() for p in model.classifier.parameters())
frozen_params = sum(p.numel() for p in model.dinov3.parameters())
print(f'Frozen: {frozen_params:,} | Trainable: {trainable_params:,}')

# ============================
# 6) TRAINING
# ============================
optimizer = torch.optim.Adam(model.classifier.parameters(), lr=1e-3, weight_decay=1e-4)
max_epochs = 20
scheduler = CosineAnnealingLR(optimizer, T_max=max_epochs, eta_min=1e-6)
criterion = nn.CrossEntropyLoss()

best_val_acc, patience, patience_ctr = 0.0, 5, 0
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': [], 'lr': []}

print('\n' + '='*70); print('üöÄ TRAINING START'); print('='*70)
for epoch in range(max_epochs):
    # --- train ---
    model.train()
    model.dinov3.eval()  # keep frozen
    tr_loss = 0.0; tr_correct = 0; tr_total = 0
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{max_epochs} [Train]', leave=False)
    for images, labels in pbar:
        images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)
        optimizer.zero_grad()
        logits = model(images)
        loss = criterion(logits, labels)
        loss.backward(); optimizer.step()
        tr_loss += loss.item()
        preds = logits.argmax(1)
        tr_correct += (preds == labels).sum().item()
        tr_total += labels.size(0)
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{100*tr_correct/tr_total:.2f}%'})
    tr_acc = 100.0 * tr_correct / tr_total
    tr_loss = tr_loss / len(train_loader)

    # --- val ---
    model.eval()
    va_loss = 0.0; va_correct = 0; va_total = 0
    with torch.no_grad():
        pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{max_epochs} [Val]', leave=False)
        for images, labels in pbar:
            images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            logits = model(images)
            loss = criterion(logits, labels)
            va_loss += loss.item()
            preds = logits.argmax(1)
            va_correct += (preds == labels).sum().item()
            va_total += labels.size(0)
            pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{100*va_correct/va_total:.2f}%'})
    va_acc = 100.0 * va_correct / va_total
    va_loss = va_loss / len(val_loader)
    lr = optimizer.param_groups[0]['lr']

    history['train_loss'].append(tr_loss); history['train_acc'].append(tr_acc)
    history['val_loss'].append(va_loss);   history['val_acc'].append(va_acc)
    history['lr'].append(lr)

    print(f'\nEpoch {epoch+1}/{max_epochs} | Train {tr_loss:.4f}/{tr_acc:.2f}% | Val {va_loss:.4f}/{va_acc:.2f}% | lr={lr:.6f}')

    if va_acc > best_val_acc:
        best_val_acc = va_acc; patience_ctr = 0
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.classifier.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'train_acc': tr_acc, 'val_acc': va_acc,
            'train_loss': tr_loss, 'val_loss': va_loss,
            'history': history,
            'group_names': GROUP_NAMES,
        }, f'{LOCAL_DIR}/best_dinov3_classifier_reduced.pth')
        print(f'  ‚úì Best model saved (Val Acc {va_acc:.2f}%)')
    else:
        patience_ctr += 1
        print(f'  Patience {patience_ctr}/{patience}')
        if patience_ctr >= patience:
            print('‚ö†Ô∏è  Early stopping triggered'); break

    scheduler.step()

# ============================
# 7) SAVE HISTORY + QUICK STATS
# ============================
with open(f'{LOCAL_DIR}/training_history_reduced.json', 'w') as f:
    json.dump(history, f, indent=2)
print('\nüéâ Done. Best Val Acc (10 classes):', f'{best_val_acc:.2f}%')



üìä DATA PREPARATION (with category reduction)

üîç Verifying images exist on disk...


Checking files:   0%|          | 0/48594 [00:00<?, ?it/s]

‚úì Images on disk: 42,945
‚úó Missing images: 5,649

üìä Group Distribution (after reduction):
  00       RuleOfThirds:  3305
  01           Centered:  4801
  02           Diagonal:  3694
  03         Horizontal:  8023
  04           Vertical:  5078
  05           Triangle:  3534
  06             Curves:  4720
  07  RadialPerspective:  3973
  08       DensePattern:  3949
  09            Scatter:  1868

‚úÇÔ∏è  TRAIN/VAL SPLIT
Train: 34,356  Val: 8,589

Dataloaders ready.
Batch size: 256 | Train batches: 135 | Val batches: 34

üîß Device: cuda
Loading DINOv3: facebook/dinov3-vitb16-pretrain-lvd1689m
Backbone dim: 768  -> classes: 10
Frozen: 85,660,416 | Trainable: 7,690

üöÄ TRAINING START


Epoch 1/20 [Train]:   0%|          | 0/135 [00:00<?, ?it/s]



Epoch 1/20 [Val]:   0%|          | 0/34 [00:00<?, ?it/s]


Epoch 1/20 | Train 1.3238/61.43% | Val 0.9288/72.28% | lr=0.001000
  ‚úì Best model saved (Val Acc 72.28%)


Epoch 2/20 [Train]:   0%|          | 0/135 [00:00<?, ?it/s]



Epoch 2/20 [Val]:   0%|          | 0/34 [00:00<?, ?it/s]


Epoch 2/20 | Train 0.8292/74.94% | Val 0.7767/75.53% | lr=0.000994
  ‚úì Best model saved (Val Acc 75.53%)


Epoch 3/20 [Train]:   0%|          | 0/135 [00:00<?, ?it/s]



Epoch 3/20 [Val]:   0%|          | 0/34 [00:00<?, ?it/s]


Epoch 3/20 | Train 0.7209/77.46% | Val 0.7202/76.54% | lr=0.000976
  ‚úì Best model saved (Val Acc 76.54%)


Epoch 4/20 [Train]:   0%|          | 0/135 [00:00<?, ?it/s]



Epoch 4/20 [Val]:   0%|          | 0/34 [00:00<?, ?it/s]


Epoch 4/20 | Train 0.6690/78.61% | Val 0.6923/77.05% | lr=0.000946
  ‚úì Best model saved (Val Acc 77.05%)


Epoch 5/20 [Train]:   0%|          | 0/135 [00:00<?, ?it/s]



Epoch 5/20 [Val]:   0%|          | 0/34 [00:00<?, ?it/s]


Epoch 5/20 | Train 0.6375/79.47% | Val 0.6758/77.35% | lr=0.000905
  ‚úì Best model saved (Val Acc 77.35%)


Epoch 6/20 [Train]:   0%|          | 0/135 [00:00<?, ?it/s]



Epoch 6/20 [Val]:   0%|          | 0/34 [00:00<?, ?it/s]


Epoch 6/20 | Train 0.6154/79.87% | Val 0.6660/77.67% | lr=0.000854
  ‚úì Best model saved (Val Acc 77.67%)


Epoch 7/20 [Train]:   0%|          | 0/135 [00:00<?, ?it/s]



Epoch 7/20 [Val]:   0%|          | 0/34 [00:00<?, ?it/s]


Epoch 7/20 | Train 0.5995/80.30% | Val 0.6587/77.94% | lr=0.000794
  ‚úì Best model saved (Val Acc 77.94%)


Epoch 8/20 [Train]:   0%|          | 0/135 [00:00<?, ?it/s]



Epoch 8/20 [Val]:   0%|          | 0/34 [00:00<?, ?it/s]


Epoch 8/20 | Train 0.5874/80.68% | Val 0.6552/77.82% | lr=0.000727
  Patience 1/5


Epoch 9/20 [Train]:   0%|          | 0/135 [00:00<?, ?it/s]



Epoch 9/20 [Val]:   0%|          | 0/34 [00:00<?, ?it/s]


Epoch 9/20 | Train 0.5780/80.97% | Val 0.6512/77.98% | lr=0.000655
  ‚úì Best model saved (Val Acc 77.98%)


Epoch 10/20 [Train]:   0%|          | 0/135 [00:00<?, ?it/s]



Epoch 10/20 [Val]:   0%|          | 0/34 [00:00<?, ?it/s]


Epoch 10/20 | Train 0.5699/81.26% | Val 0.6495/78.00% | lr=0.000579
  ‚úì Best model saved (Val Acc 78.00%)


Epoch 11/20 [Train]:   0%|          | 0/135 [00:00<?, ?it/s]



Epoch 11/20 [Val]:   0%|          | 0/34 [00:00<?, ?it/s]


Epoch 11/20 | Train 0.5635/81.39% | Val 0.6477/78.02% | lr=0.000501
  ‚úì Best model saved (Val Acc 78.02%)


Epoch 12/20 [Train]:   0%|          | 0/135 [00:00<?, ?it/s]



Epoch 12/20 [Val]:   0%|          | 0/34 [00:00<?, ?it/s]


Epoch 12/20 | Train 0.5573/81.54% | Val 0.6465/78.03% | lr=0.000422
  ‚úì Best model saved (Val Acc 78.03%)


Epoch 13/20 [Train]:   0%|          | 0/135 [00:00<?, ?it/s]



Epoch 13/20 [Val]:   0%|          | 0/34 [00:00<?, ?it/s]


Epoch 13/20 | Train 0.5548/81.65% | Val 0.6457/78.02% | lr=0.000346
  Patience 1/5


Epoch 14/20 [Train]:   0%|          | 0/135 [00:00<?, ?it/s]



Epoch 14/20 [Val]:   0%|          | 0/34 [00:00<?, ?it/s]


Epoch 14/20 | Train 0.5487/81.79% | Val 0.6452/78.01% | lr=0.000274
  Patience 2/5


Epoch 15/20 [Train]:   0%|          | 0/135 [00:00<?, ?it/s]



Epoch 15/20 [Val]:   0%|          | 0/34 [00:00<?, ?it/s]


Epoch 15/20 | Train 0.5457/81.90% | Val 0.6447/78.11% | lr=0.000207
  ‚úì Best model saved (Val Acc 78.11%)


Epoch 16/20 [Train]:   0%|          | 0/135 [00:00<?, ?it/s]



Epoch 16/20 [Val]:   0%|          | 0/34 [00:00<?, ?it/s]


Epoch 16/20 | Train 0.5441/81.98% | Val 0.6444/78.08% | lr=0.000147
  Patience 1/5


Epoch 17/20 [Train]:   0%|          | 0/135 [00:00<?, ?it/s]



Epoch 17/20 [Val]:   0%|          | 0/34 [00:00<?, ?it/s]


Epoch 17/20 | Train 0.5432/82.02% | Val 0.6443/78.10% | lr=0.000096
  Patience 2/5


Epoch 18/20 [Train]:   0%|          | 0/135 [00:00<?, ?it/s]



Epoch 18/20 [Val]:   0%|          | 0/34 [00:00<?, ?it/s]


Epoch 18/20 | Train 0.5421/82.04% | Val 0.6441/78.06% | lr=0.000055
  Patience 3/5


Epoch 19/20 [Train]:   0%|          | 0/135 [00:00<?, ?it/s]



Epoch 19/20 [Val]:   0%|          | 0/34 [00:00<?, ?it/s]


Epoch 19/20 | Train 0.5422/82.08% | Val 0.6441/78.08% | lr=0.000025
  Patience 4/5


Epoch 20/20 [Train]:   0%|          | 0/135 [00:00<?, ?it/s]



Epoch 20/20 [Val]:   0%|          | 0/34 [00:00<?, ?it/s]


Epoch 20/20 | Train 0.5406/82.11% | Val 0.6441/78.08% | lr=0.000007
  Patience 5/5
‚ö†Ô∏è  Early stopping triggered

üéâ Done. Best Val Acc (10 classes): 78.11%
