# Simple Baseline: Prove It Works First

**Goal**: Build a minimal CNN that shows positive learning in just 5 epochs.

## Problem Identified

From feature exploration:
- Simple linear regression on color: **R¬≤ = 0.20**
- Complex CNN (25M params): **R¬≤ = -1.25** ‚ùå

**Root causes:**
1. ColorJitter destroying color signal (saturation=0.3 too aggressive)
2. Model too complex (25M params, 285 training samples ‚Üí severe overfitting)
3. Training too long without validation (40 epochs wasted)

## This Notebook's Approach

**Fixes applied:**
- ‚ùå **NO ColorJitter** - preserve color information
- üîß **ResNet18** instead of ResNet50 (11M ‚Üí simpler)
- üîß **Simple FC head** - one layer (512 ‚Üí 256 ‚Üí 5)
- ‚è±Ô∏è **5 epochs only** - fast validation (~5-7 minutes)
- üìâ **Lower LR** - 1e-4 instead of 3e-4

**Success criteria:**
- Epoch 1: R¬≤ > -1.0 (better than before)
- Epoch 3: R¬≤ > 0.0 (beat mean prediction)
- Epoch 5: R¬≤ > 0.20 (beat linear regression)

If this works ‚Üí Scale up to 20-30 epochs with early stopping.

---
## Setup

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image

from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_absolute_error
from tqdm.auto import tqdm

sns.set_style('whitegrid')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set seeds
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

print("‚úì Imports complete")

In [None]:
# Load data
train_enriched = pd.read_csv('competition/train_enriched.csv')
train_enriched['Sampling_Date'] = pd.to_datetime(train_enriched['Sampling_Date'])
train_enriched['full_image_path'] = train_enriched['image_path'].apply(lambda x: f'competition/{x}')

# Target columns and weights
target_cols = ['Dry_Green_g', 'Dry_Dead_g', 'Dry_Clover_g', 'GDM_g', 'Dry_Total_g']
competition_weights = [0.1, 0.1, 0.1, 0.2, 0.5]

# Train/val split
train_data, val_data = train_test_split(train_enriched, test_size=0.2, random_state=42)

print(f"Total samples: {len(train_enriched)}")
print(f"Training samples: {len(train_data)}")
print(f"Validation samples: {len(val_data)}")
print(f"\nTargets: {target_cols}")
print(f"Competition weights: {competition_weights}")

---
## Dataset - NO ColorJitter!

In [None]:
class SimpleDataset(Dataset):
    """Simple image dataset WITHOUT ColorJitter."""
    
    def __init__(self, dataframe, image_size=224, augment=False):
        self.df = dataframe.reset_index(drop=True)
        self.image_size = image_size
        
        if augment:
            print("Augmentation (NO ColorJitter):")
            print("  - RandomHorizontalFlip")
            print("  - RandomVerticalFlip")
            print("  - RandomRotation(10 degrees)")
            print("  - Standard normalization")
            
            self.transform = transforms.Compose([
                transforms.Resize((image_size, image_size)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
                transforms.RandomRotation(10),
                # NO ColorJitter - preserve color information!
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        else:
            self.transform = transforms.Compose([
                transforms.Resize((image_size, image_size)),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        # Load image
        img = Image.open(row['full_image_path']).convert('RGB')
        img = self.transform(img)
        
        # Targets
        targets = torch.tensor(
            row[target_cols].values.astype('float32'),
            dtype=torch.float32
        )
        
        return {'image': img, 'targets': targets}

# Create datasets
batch_size = 16

train_dataset = SimpleDataset(train_data, augment=True)
val_dataset = SimpleDataset(val_data, augment=False)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

print(f"\n‚úì Datasets created")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")

---
## Simple Model - ResNet18

In [None]:
class SimpleModel(nn.Module):
    """Simplified CNN: ResNet18 + single hidden layer."""
    
    def __init__(self, num_outputs=5):
        super().__init__()
        
        # ResNet18 backbone (lighter than ResNet50)
        self.resnet = models.resnet18(pretrained=True)
        num_features = self.resnet.fc.in_features  # 512 for ResNet18
        
        # Simple FC head - just one hidden layer
        self.resnet.fc = nn.Sequential(
            nn.Linear(num_features, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(0.2),  # Less dropout than before
            nn.Linear(256, num_outputs)
        )
    
    def forward(self, x):
        return self.resnet(x)

# Create model
model = SimpleModel(num_outputs=5).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print("‚úì Model architecture:")
print(f"  Backbone: ResNet18 (pre-trained ImageNet)")
print(f"  FC head: 512 ‚Üí 256 ‚Üí 5")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"\n  Compare to previous: 25M+ params")
print(f"  Reduction: {100 * (1 - trainable_params/25e6):.1f}% fewer parameters")

---
## Loss Function & Training Setup

In [None]:
class CompetitionLoss(nn.Module):
    """MSE loss weighted by competition metric."""
    def __init__(self):
        super().__init__()
        self.weights = torch.tensor([0.1, 0.1, 0.1, 0.2, 0.5]).to(device)
    
    def forward(self, pred, target):
        mse = F.mse_loss(pred, target, reduction='none')
        weighted_mse = (mse * self.weights).mean()
        return weighted_mse

criterion = CompetitionLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)  # Lower LR
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)

print("‚úì Training setup:")
print(f"  Loss: Competition-weighted MSE")
print(f"  Optimizer: AdamW")
print(f"  Learning rate: 1e-4 (lower than before)")
print(f"  Weight decay: 1e-4")
print(f"  Scheduler: ReduceLROnPlateau (patience=2)")

---
## Training Loop - 5 Epochs Only

In [None]:
def train_epoch(model, train_loader, criterion, optimizer):
    """Train for one epoch."""
    model.train()
    train_loss = 0
    
    for batch in train_loader:
        images = batch['image'].to(device)
        targets = batch['targets'].to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item() * images.size(0)
    
    return train_loss / len(train_loader.dataset)

def validate(model, val_loader, criterion):
    """Validate and calculate R¬≤ scores."""
    model.eval()
    val_loss = 0
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        for batch in val_loader:
            images = batch['image'].to(device)
            targets = batch['targets'].to(device)
            
            outputs = model(images)
            loss = criterion(outputs, targets)
            val_loss += loss.item() * images.size(0)
            
            all_preds.append(outputs.cpu().numpy())
            all_targets.append(targets.cpu().numpy())
    
    val_loss /= len(val_loader.dataset)
    
    # Calculate R¬≤ for each target
    all_preds = np.vstack(all_preds)
    all_targets = np.vstack(all_targets)
    
    r2_scores = []
    competition_score = 0
    
    for i in range(5):
        r2 = r2_score(all_targets[:, i], all_preds[:, i])
        r2_scores.append(r2)
        competition_score += competition_weights[i] * r2
    
    return val_loss, competition_score, r2_scores

print("‚úì Training functions defined")

In [None]:
# Training loop
num_epochs = 5
history = {
    'train_loss': [],
    'val_loss': [],
    'val_r2': [],
    'epoch': []
}

best_r2 = -float('inf')

print("="*80)
print("TRAINING SIMPLE BASELINE - 5 EPOCHS")
print("="*80)
print("\nSuccess criteria:")
print("  Epoch 1: R¬≤ > -1.0 (better than previous -2.0)")
print("  Epoch 3: R¬≤ > 0.0 (beat mean prediction)")
print("  Epoch 5: R¬≤ > 0.20 (beat linear regression)")
print("\nTraining...\n")

for epoch in range(num_epochs):
    # Train
    train_loss = train_epoch(model, train_loader, criterion, optimizer)
    
    # Validate
    val_loss, val_r2, r2_scores = validate(model, val_loader, criterion)
    
    # Update scheduler
    scheduler.step(val_loss)
    
    # Store history
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['val_r2'].append(val_r2)
    history['epoch'].append(epoch + 1)
    
    # Print progress
    print(f"Epoch {epoch+1}/{num_epochs}:")
    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Val Loss:   {val_loss:.4f}")
    print(f"  Val R¬≤:     {val_r2:+.4f}")
    
    # Check milestones
    if epoch == 0 and val_r2 > -1.0:
        print("  ‚úì Milestone 1: Better than previous baseline!")
    if epoch == 2 and val_r2 > 0.0:
        print("  ‚úì Milestone 2: Beat mean prediction!")
    if epoch == 4 and val_r2 > 0.20:
        print("  ‚úì Milestone 3: Beat linear regression!")
    
    # Save best model
    if val_r2 > best_r2:
        best_r2 = val_r2
        torch.save(model.state_dict(), 'simple_baseline_best.pth')
        print(f"  üíæ New best R¬≤ = {best_r2:+.4f}")
    
    print()

print("="*80)
print(f"TRAINING COMPLETE")
print("="*80)
print(f"\nBest validation R¬≤: {best_r2:+.4f}")
print(f"\nComparison:")
print(f"  Simple linear regression: +0.2048")
print(f"  Previous CNN baseline: -1.2527")
print(f"  This simple CNN: {best_r2:+.4f}")

---
## Detailed Evaluation

In [None]:
# Load best model
model.load_state_dict(torch.load('simple_baseline_best.pth'))
model.eval()

# Full evaluation
all_preds = []
all_targets = []

with torch.no_grad():
    for batch in val_loader:
        images = batch['image'].to(device)
        targets = batch['targets'].to(device)
        
        outputs = model(images)
        all_preds.append(outputs.cpu().numpy())
        all_targets.append(targets.cpu().numpy())

all_preds = np.vstack(all_preds)
all_targets = np.vstack(all_targets)

# Calculate detailed metrics
print("="*80)
print("DETAILED RESULTS")
print("="*80)

competition_score = 0
for i, target in enumerate(target_cols):
    r2 = r2_score(all_targets[:, i], all_preds[:, i])
    mae = mean_absolute_error(all_targets[:, i], all_preds[:, i])
    competition_score += competition_weights[i] * r2
    
    print(f"\n{target}:")
    print(f"  R¬≤ = {r2:+.4f} (weight: {competition_weights[i]})")
    print(f"  MAE = {mae:.2f}g")

print(f"\n{'='*80}")
print(f"Competition Score: {competition_score:+.4f}")
print(f"{'='*80}")

---
## Training Curves

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss curves
ax = axes[0]
ax.plot(history['epoch'], history['train_loss'], 'o-', label='Train Loss', linewidth=2, markersize=8)
ax.plot(history['epoch'], history['val_loss'], 's-', label='Val Loss', linewidth=2, markersize=8)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Loss', fontsize=12)
ax.set_title('Loss Curves (5 Epochs)', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(alpha=0.3)

# R¬≤ curve
ax = axes[1]
ax.plot(history['epoch'], history['val_r2'], 'o-', color='green', linewidth=2, markersize=8, label='Val R¬≤')
ax.axhline(y=0.0, color='gray', linestyle='--', linewidth=2, label='Baseline (predict mean)')
ax.axhline(y=0.2048, color='orange', linestyle='--', linewidth=2, label='Linear regression')
ax.axhline(y=-1.2527, color='red', linestyle='--', linewidth=2, label='Previous CNN')
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('R¬≤ Score', fontsize=12)
ax.set_title('R¬≤ Progress (5 Epochs)', fontsize=14, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(alpha=0.3)

plt.tight_layout()
plt.savefig('simple_baseline_curves.png', dpi=150, bbox_inches='tight')
plt.show()

print("‚úì Training curves saved to: simple_baseline_curves.png")

---
## Summary & Next Steps

In [None]:
print("="*80)
print("SUMMARY")
print("="*80)

print("\nüìä Results Comparison:")
print(f"\n  Metric                  | Score")
print(f"  " + "-"*60)
print(f"  Linear regression       | {0.2048:+.4f}")
print(f"  Previous CNN (40 epochs)| {-1.2527:+.4f}")
print(f"  This simple CNN (5)     | {best_r2:+.4f}")

print("\nüîß What Changed:")
print("  ‚ùå Removed ColorJitter (was destroying color signal)")
print("  üîß Simpler model: ResNet18 vs ResNet50")
print("  üîß Fewer parameters: ~11M vs 25M")
print("  üìâ Lower learning rate: 1e-4 vs 3e-4")
print("  ‚è±Ô∏è  Faster validation: 5 epochs vs 40")

print("\n" + "="*80)
print("NEXT STEPS")
print("="*80)

if best_r2 > 0.20:
    print("\n‚úÖ SUCCESS! Simple CNN beats linear regression!")
    print("\nRecommended next steps:")
    print("  1. Scale up to 20-30 epochs with early stopping")
    print("  2. Try slightly larger model (ResNet34?)")
    print("  3. Experiment with learning rate (1e-4 to 3e-4)")
    print("  4. Consider ensemble predictions")
    print("  5. Generate test set predictions")
    
elif best_r2 > 0.0:
    print("\n‚ö†Ô∏è  PARTIAL SUCCESS: CNN beats mean prediction but not linear model")
    print("\nRecommended next steps:")
    print("  1. Try more epochs (15-20)")
    print("  2. Experiment with learning rate")
    print("  3. Try ResNet34 (slightly larger)")
    print("  4. Add more aggressive geometric augmentation")
    print("  5. Consider feature concatenation (add color features to CNN)")
    
else:
    print("\n‚ùå STILL FAILING: R¬≤ < 0.0")
    print("\nDeeper investigation needed:")
    print("  1. Check data loading: Are images loading correctly?")
    print("  2. Check loss function: Is it computing correctly?")
    print("  3. Check image-label alignment: IDs matching?")
    print("  4. Try even simpler model (linear layer on flattened images)")
    print("  5. Verify ImageNet normalization is appropriate")

print("\n" + "="*80)
print("‚úì Simple baseline complete!")
print("="*80)