# 2-Stage Robust Training: Field Variation Hardening

This notebook implements a **2-Pass Training Workflow** designed to improve robustness against field variation (lighting, motion blur, orientation) which is currently causing a gap between validation accuracy and production performance.

## The Strategy

1.  **Robust Architectures**: Uses **ConvNeXt V2** (Masked Autoencoder pre-training) or **DINOv2** (Self-supervised) for superior feature robustness.
2.  **Field Simulator Augmentation**: Aggressive Motion Blur, Color Jitter, and Affine transforms to simulate production conditions.
3.  **Hard Example Mining (2-Pass)**:
    *   **Stage 1**: Train a standard Binary Classifier (Clean vs Pit).
    *   **Mining**: Run inference on the training set to find "ambiguous" or "confidently wrong" examples.
    *   **Stage 2**: Relabel these as `maybe` and fine-tune a 3-class model (Clean/Maybe/Pit).

## Prerequisites
- Google Colab Pro (GPU required)
- Data in Google Drive: `cherry_classification/data/train` (clean/pit folders)

## 1. Environment Setup

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

# Install dependencies
!pip install -q timm albumentations matplotlib scikit-learn

import os
import shutil
import random
import numpy as np
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, WeightedRandomSampler
import timm
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.notebook import tqdm
from pathlib import Path

# Set Random Seed
SEED = 42
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 2. Configuration

In [None]:
# === USER CONFIGURATION ===

# Model Selection
# Options: "convnextv2_tiny", "resnet50", "dinov2_vits14"
MODEL_ARCH = "convnextv2_tiny" 

# Training Params
BATCH_SIZE = 32
EPOCHS_STAGE1 = 15
EPOCHS_STAGE2 = 15
LR = 1e-4
WD = 1e-4

# Mining Thresholds
# If probability of pit is between these values, mark as 'maybe'
MAYBE_MIN = 0.35 
MAYBE_MAX = 0.65

# Robustness Params
AUGMENT_LEVEL = "heavy" # "standard" or "heavy"

# Paths
DRIVE_ROOT = Path("/content/drive/MyDrive/cherry_experiments/2stage_robust")
DATA_SOURCE = Path("/content/cherry_classification/data")
DATA_STAGE1 = Path("/content/data_stage1")
DATA_STAGE2 = Path("/content/data_stage2")

# Setup output directory
DRIVE_ROOT.mkdir(parents=True, exist_ok=True)
print(f"Results will be saved to: {DRIVE_ROOT}")

## 3. Data Preparation & Augmentation (Field Simulator)

In [None]:
# Clone Data if needed
if not DATA_SOURCE.exists():
    print("Cloning dataset...")
    !git clone --depth 1 https://github.com/weshavener/cherry_classification.git /content/cherry_classification

# Prepare Stage 1 Data (Copy to local VM for speed)
if DATA_STAGE1.exists(): shutil.rmtree(DATA_STAGE1)
shutil.copytree(DATA_SOURCE / "train", DATA_STAGE1 / "train")
shutil.copytree(DATA_SOURCE / "val", DATA_STAGE1 / "val")
print("Stage 1 Data Ready.")

# --- Augmentations ---
def get_transforms(mode="train", level="standard"):
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    
    # Note: Production model inputs are often unnormalized (0-255).
    # However, pre-trained models (timm) usually expect normalization.
    # We will use normalization here for stability, but remember to normalize in inference.
    
    if mode == "val":
        return transforms.Compose([
            transforms.Resize((128, 128)),
            transforms.ToTensor(),
            normalize
        ])
        
    # Train Augmentations
    aug_list = [
        transforms.Resize((128, 128)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
    ]
    
    if level == "heavy":
        aug_list.extend([
            transforms.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.9, 1.1)),
            transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.05), # Strobe simulation
            transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)), # Motion blur simulation
        ])
    
    aug_list.append(transforms.ToTensor())
    aug_list.append(normalize)
    
    return transforms.Compose(aug_list)

# --- Data Loaders ---
def create_loader(data_dir, batch_size, mode="train", weighted=True):
    dataset = datasets.ImageFolder(
        data_dir, 
        transform=get_transforms(mode, level=AUGMENT_LEVEL)
    )
    
    sampler = None
    if weighted and mode == "train":
        targets = dataset.targets
        class_counts = np.bincount(targets)
        class_weights = 1. / class_counts
        sample_weights = class_weights[targets]
        sampler = WeightedRandomSampler(sample_weights, len(sample_weights))
        shuffle = False
    else:
        shuffle = (mode == "train")
        
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, num_workers=2)

train_loader_s1 = create_loader(DATA_STAGE1 / "train", BATCH_SIZE, "train")
val_loader_s1 = create_loader(DATA_STAGE1 / "val", BATCH_SIZE, "val", weighted=False)

print(f"Classes Stage 1: {train_loader_s1.dataset.classes}")

## 4. Model Definition

In [None]:
def create_model(arch, num_classes):
    print(f"Creating {arch} for {num_classes} classes...")
    
    if arch == "convnextv2_tiny":
        model = timm.create_model('convnextv2_tiny.fcmae_ft_in1k', pretrained=True, num_classes=num_classes)
    elif arch == "resnet50":
        model = timm.create_model('resnet50', pretrained=True, num_classes=num_classes)
    elif arch == "dinov2_vits14":
        model = timm.create_model('vit_small_patch14_dinov2.lvd142m', pretrained=True, num_classes=num_classes)
    else:
        raise ValueError(f"Unknown architecture: {arch}")
        
    return model.to(device)

## 5. Training Logic

In [None]:
def train_model(model, train_loader, val_loader, epochs, save_path):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WD)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
    best_acc = 0.0
    
    for epoch in range(epochs):
        # Train
        model.train()
        train_loss = 0
        for imgs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=False):
            imgs, labels = imgs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            
        scheduler.step()
        
        # Validate
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for imgs, labels in val_loader:
                imgs, labels = imgs.to(device), labels.to(device)
                outputs = model(imgs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        acc = correct / total
        print(f"Epoch {epoch+1}: Train Loss = {train_loss/len(train_loader):.4f}, Val Acc = {acc:.2%}")
        
        if acc > best_acc:
            best_acc = acc
            torch.save(model.state_dict(), save_path)
            print(f"--> New Best Model Saved: {save_path}")
            
    print(f"Training Complete. Best Accuracy: {best_acc:.2%}")
    
    # Load best weights
    model.load_state_dict(torch.load(save_path))
    return model

## 6. Stage 1: Binary Training

In [None]:
print("=== STARTING STAGE 1 (BINARY) ===")
model_s1 = create_model(MODEL_ARCH, num_classes=2)
save_path_s1 = DRIVE_ROOT / "stage1_best.pt"

model_s1 = train_model(
    model_s1, 
    train_loader_s1, 
    val_loader_s1, 
    epochs=EPOCHS_STAGE1, 
    save_path=save_path_s1
)

## 7. Hard Example Mining (The "Maybe" Creator)

We now run the Stage 1 model over **both** the training and validation sets.
Images that are misclassified OR have low confidence ($0.35 < p < 0.65$) are moved to the `maybe` class.

In [None]:
print("=== MINING HARD EXAMPLES ===")

def mine_data(source_dir, dest_dir, mode="train"):
    print(f"Mining {mode} set: {source_dir} -> {dest_dir}")
    
    # Create inference dataset (No Augmentation)
    dataset_mine = datasets.ImageFolder(
        source_dir, 
        transform=get_transforms("val")
    )
    loader_mine = DataLoader(dataset_mine, batch_size=1, shuffle=False)
    
    model_s1.eval()
    moves = 0
    total = 0
    
    # Prepare destination folder for 'maybe'
    maybe_dest = dest_dir / "maybe"
    maybe_dest.mkdir(parents=True, exist_ok=True)
    
    # Prepare other class folders in dest
    for cls in dataset_mine.classes:
        (dest_dir / cls).mkdir(parents=True, exist_ok=True)

    with torch.no_grad():
        for i, (img, label) in enumerate(tqdm(loader_mine, desc=f"Mining {mode}")):
            img = img.to(device)
            output = model_s1(img)
            probs = torch.softmax(output, dim=1)
            pred = torch.argmax(probs, dim=1).item()
            conf = probs.max().item()
            pit_prob = probs[0][1].item() # Probability of being a pit
            
            is_wrong = (pred != label.item())
            is_ambiguous = (MAYBE_MIN < pit_prob < MAYBE_MAX)
            
            # Get source path
            src_path, _ = dataset_mine.samples[i]
            filename = Path(src_path).name
            src_cls = dataset_mine.classes[label.item()]
            
            if is_wrong or is_ambiguous:
                # Move to 'maybe'
                dest_path = maybe_dest / filename
                moves += 1
            else:
                # Move to original class folder
                dest_path = dest_dir / src_cls / filename
            
            # Copy file (safer than move, preserves original just in case)
            shutil.copy2(src_path, dest_path)
            total += 1
            
    print(f"  -> Moved {moves}/{total} ({moves/total:.1%}) images to 'maybe' class.")
    
    # Safety check: if 'maybe' is empty, delete it to prevent ImageFolder crash
    if len(list(maybe_dest.glob("*"))) == 0:
        print(f"  WARNING: No 'maybe' samples found in {mode} set! Removing empty folder.")
        maybe_dest.rmdir()

# 1. Setup Stage 2 Directory Structure (Clean Slate)
if DATA_STAGE2.exists(): shutil.rmtree(DATA_STAGE2)
DATA_STAGE2.mkdir(parents=True)

# 2. Run Mining on TRAIN and VAL
mine_data(DATA_STAGE1 / "train", DATA_STAGE2 / "train", mode="train")
mine_data(DATA_STAGE1 / "val", DATA_STAGE2 / "val", mode="val")

print("Mining Complete.")

## 8. Stage 2: 3-Class Fine-Tuning

In [None]:
print("=== STARTING STAGE 2 (3-CLASS) ===")

# Create Loaders for Stage 2
# This will now safely load 'maybe' class if it exists (or just clean/pit if empty)
train_loader_s2 = create_loader(DATA_STAGE2 / "train", BATCH_SIZE, "train")
val_loader_s2 = create_loader(DATA_STAGE2 / "val", BATCH_SIZE, "val", weighted=False)

print(f"Classes Stage 2 (Train): {train_loader_s2.dataset.classes}")
print(f"Classes Stage 2 (Val):   {val_loader_s2.dataset.classes}")

# Modify Model Head
num_classes_s2 = len(train_loader_s2.dataset.classes)
print(f"Adapting model head for {num_classes_s2} classes...")

num_features = model_s1.get_classifier().in_features
model_s1.reset_classifier(num_classes=num_classes_s2)
model_s2 = model_s1.to(device)

save_path_s2 = DRIVE_ROOT / "stage2_best_3class.pt"

# Train
model_s2 = train_model(
    model_s2, 
    train_loader_s2, 
    val_loader_s2, 
    epochs=EPOCHS_STAGE2, 
    save_path=save_path_s2
)

## 9. Evaluation & Download

In [None]:
# Final Evaluation
print("Evaluating Final Model on Validation Set...")
model_s2.eval()
y_true = []
y_pred = []

with torch.no_grad():
    for imgs, labels in val_loader_s2:
        imgs = imgs.to(device)
        outputs = model_s2(imgs)
        _, predicted = torch.max(outputs.data, 1)
        y_true.extend(labels.cpu().numpy())
        y_pred.extend(predicted.cpu().numpy())

print(classification_report(y_true, y_pred, target_names=val_loader_s2.dataset.classes))

print("\nDone! Model saved to Google Drive.")