# Model 4b: Auxiliary Pretrained Baseline

## Overview

**Model 4b** is our current best-performing model with **R²=+0.6852 validation, R²=+0.51 Kaggle**.

This notebook serves as a clean baseline to understand Model 4b's architecture and identify improvements.

---

## Why Model 4b Performs Well

Model 4b uses a **two-phase training approach**:

### Phase 1: Auxiliary Pretraining (15 epochs)
- Train CNN to predict **tabular features** from images:
  - NDVI (vegetation greenness)
  - Height (plant height in cm)
  - Weather (14 features: rainfall, temp, ET0, etc.)
  - State (4 classes: NSW, Tas, Vic, WA)
  - Species (15 classes)

**Why this works:**
- Forces model to learn visual patterns that correlate with environment
- Model achieves **82% state accuracy** - it can "see" location from image!
- Creates rich image representations that encode environmental context

### Phase 2: Biomass Fine-tuning (30 epochs)
- Fine-tune pretrained CNN for biomass prediction
- Uses **differential learning rates**:
  - Backbone (pretrained): 1e-5 (small, preserve learned features)
  - Biomass head (new): 3e-4 (larger, learn biomass patterns)

**At inference:** Only needs images (no tabular features!) ✅ Kaggle-compatible

---

## Architecture Details

```
Input: 224×224 RGB Image
         ↓
   ResNet18 Backbone (pretrained on ImageNet)
         ↓
   512-dim features
         ↓
┌─────────────────────────────┐
│  Phase 1: Auxiliary Heads   │
│  - NDVI head (1 output)     │
│  - Height head (1 output)   │
│  - Weather head (14 outputs)│
│  - State head (4 classes)   │
│  - Species head (15 classes)│
└─────────────────────────────┘
         ↓
┌─────────────────────────────┐
│  Phase 2: Biomass Head      │
│  512 → 256 → ReLU           │
│        → Dropout(0.2)       │
│        → 256 → 5 outputs    │
└─────────────────────────────┘
         ↓
   5 biomass predictions
```

**Total parameters:** ~11.7M
- ResNet18 backbone: ~11.2M
- Auxiliary heads: ~150K
- Biomass head: ~133K

---

## Current Performance

**Validation (72 images):**
- Overall R²: **+0.6852**
- Per-target R²:
  - Dry_Green_g: +0.6903
  - Dry_Dead_g: +0.5243
  - Dry_Clover_g: +0.5017
  - GDM_g: +0.7254
  - Dry_Total_g: +0.7243

**Kaggle Test (unknown size):**
- Overall R²: **+0.51**
- Gap: **-0.175** (concerning - indicates overfitting or distribution shift)

---

## Known Issues & Improvement Opportunities

### Issue 1: Large Validation-Test Gap (-0.175)
**Possible causes:**
1. **Overfitting** (most likely)
   - Trained 30 epochs on small dataset (285 images)
   - Validation R² bounced around (not monotonic)
   - Best epoch was 29/30 (late in training)

2. **Distribution shift**
   - Test set may have different:
     - Seasons, locations, species mix
     - Image quality/lighting conditions
   - Model 1 (simpler) scored worse (0.48), suggesting architecture is good

3. **Validation split not representative**
   - Random 80/20 split (not stratified)
   - Only 72 validation images
   - High variance in scores

**Potential solutions:**
- ✅ Early stopping (stop at epoch ~20-25)
- ✅ Reduce Phase 2 epochs (30 → 20)
- ✅ Stronger regularization (dropout 0.2 → 0.3-0.4)
- ✅ Data augmentation (add more transforms)
- ✅ Learning rate scheduling (reduce LR after plateaus)
- ⚠️ Stratified split (hard with small dataset)
- ⚠️ K-fold CV (too slow for iteration)

### Issue 2: Target Normalization Uses Split Stats
**Current:** Uses training split (285 images) statistics
```python
Dry_Green_g: mean=27.49g, std=26.19g
```

**Should use:** Full dataset (357 images) statistics
```python
Dry_Green_g: mean=26.624722g, std=25.401232g
```

**Impact:** Small but could contribute to test set mismatch

### Issue 3: No Learning Rate Scheduling
**Current:** Fixed LR throughout training
- Phase 2 head: 3e-4 (all epochs)
- Phase 2 backbone: 1e-5 (all epochs)

**Potential improvement:** Reduce LR on plateau
- Could help fine-tune without overfitting
- Standard practice in deep learning

### Issue 4: Limited Data Augmentation
**Current augmentations:**
- Random horizontal flip
- Random vertical flip
- Random rotation (10°)

**Could add:**
- Color jitter (brightness, contrast, saturation)
- Random crop + resize
- Gaussian blur
- Random erasing

**Caution:** Too much augmentation could hurt (Model 2 with ColorJitter scored worse)

---

## Recommended Next Steps

### Priority 1: Reduce Overfitting (Most Important!)
1. **Reduce Phase 2 epochs**: 30 → 20-25
2. **Add learning rate scheduling**: ReduceLROnPlateau
3. **Increase dropout**: 0.2 → 0.3-0.4
4. **Use full dataset normalization stats**

**Expected impact:** Close validation-test gap by ~0.05-0.10

### Priority 2: Better Training Monitoring
1. **Track validation R² every epoch** (not just loss)
2. **Save checkpoints at multiple epochs** (not just best)
3. **Plot learning curves** to identify overfitting visually

### Priority 3: Ensemble Models
1. **Train 3-5 models with different:**
   - Random seeds
   - Train/val splits
   - Augmentation strategies
2. **Average predictions** (simple ensemble)

**Expected impact:** +0.02-0.05 R² improvement

---

## Training Configuration (Current Baseline)

```python
# Phase 1: Auxiliary Pretraining
PHASE1_EPOCHS = 15
PHASE1_LR = 3e-4
PHASE1_WEIGHT_DECAY = 1e-4

# Phase 2: Biomass Fine-tuning
PHASE2_EPOCHS = 30  # ⚠️ Too many? Try 20-25
PHASE2_HEAD_LR = 3e-4
PHASE2_BACKBONE_LR = 1e-5
PHASE2_WEIGHT_DECAY = 1e-4

# Architecture
HIDDEN_DIM = 256
DROPOUT = 0.2  # ⚠️ Too low? Try 0.3-0.4

# Data
BATCH_SIZE = 16
TRAIN_VAL_SPLIT = 0.8  # 285 train, 72 val
IMAGE_SIZE = 224

# Augmentation
AUGMENT_HFLIP = True
AUGMENT_VFLIP = True
AUGMENT_ROTATE = 10  # degrees
```

---


## Code Implementation

Below is the complete implementation of Model 4b for reference.

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models

class AuxiliaryPretrainedModel(nn.Module):
    """Model 4b: Two-phase training with auxiliary tasks.
    
    Phase 1: Train to predict tabular features from images
    Phase 2: Fine-tune for biomass prediction
    
    At inference: Only needs image (learned implicit tabular patterns)
    
    Args:
        num_outputs: Number of biomass targets (5)
        hidden_dim: Hidden layer size in biomass head (256 or 512)
        dropout: Dropout rate in biomass head (0.2-0.4)
        num_states: Number of state classes (4: NSW, Tas, Vic, WA)
        num_species: Number of species classes (15)
    """
    def __init__(self, num_outputs=5, hidden_dim=256, dropout=0.2, 
                 num_states=4, num_species=15):
        super().__init__()
        
        # Shared backbone: ResNet18 (pretrained on ImageNet)
        self.backbone = models.resnet18(pretrained=True)
        # Remove final FC layer, keep global avg pool
        self.backbone = nn.Sequential(*list(self.backbone.children())[:-1])
        # Output: 512-dim features
        
        # Phase 1: Auxiliary heads (predict tabular features from image)
        self.ndvi_head = nn.Linear(512, 1)           # Predict NDVI
        self.height_head = nn.Linear(512, 1)         # Predict height
        self.weather_head = nn.Linear(512, 14)       # Predict 14 weather features
        self.state_head = nn.Linear(512, num_states)     # Predict state (4 classes)
        self.species_head = nn.Linear(512, num_species)  # Predict species (15 classes)
        
        # Phase 2: Biomass prediction head (used after pretraining)
        self.biomass_head = nn.Sequential(
            nn.Linear(512, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_outputs)
        )
    
    def forward(self, x, mode='biomass'):
        """Forward pass.
        
        Args:
            x: Input image tensor [B, 3, 224, 224]
            mode: 'auxiliary' (Phase 1) or 'biomass' (Phase 2)
        
        Returns:
            If mode='auxiliary': dict with keys [ndvi, height, weather, state, species]
            If mode='biomass': tensor [B, 5] with biomass predictions
        """
        # Extract features from image
        features = self.backbone(x)  # [B, 512, 1, 1]
        features = features.flatten(1)  # [B, 512]
        
        if mode == 'auxiliary':
            # Phase 1: Predict tabular features
            return {
                'ndvi': self.ndvi_head(features),
                'height': self.height_head(features),
                'weather': self.weather_head(features),
                'state': self.state_head(features),
                'species': self.species_head(features)
            }
        else:  # mode == 'biomass'
            # Phase 2: Predict biomass
            return self.biomass_head(features)

# Example usage:
# Phase 1: model(images, mode='auxiliary') → dict with tabular predictions
# Phase 2: model(images, mode='biomass') → [B, 5] biomass predictions

## Key Insights from Model 4b

### What Works Well

1. **Auxiliary pretraining is effective**
   - 82% state accuracy shows model learns meaningful visual features
   - Model can "see" location, season, plant type from image alone
   - Validation R²=+0.6852 is very competitive

2. **Differential learning rates help**
   - Low LR (1e-5) for pretrained backbone preserves features
   - Higher LR (3e-4) for new head allows learning biomass patterns

3. **Architecture is appropriate for dataset size**
   - ResNet18 (11.2M params) is reasonable for 285 training images
   - Model 1 (simpler) scored worse, showing complexity is needed

### What Needs Improvement

1. **Overfitting is the main issue**
   - 0.175 gap between validation and test is too large
   - Model 1 (10 epochs) scored 0.48 vs Model 4b (30 epochs) 0.51
   - Need: early stopping, more regularization

2. **Training could be more stable**
   - Validation R² bounces between epochs
   - Need: learning rate scheduling, better monitoring

3. **Small dataset is limiting**
   - 285 training images is very small for deep learning
   - Need: stronger augmentation, ensemble methods

---

## Next Experiment: Model 4b Improved

Based on this analysis, the next notebook will implement:

1. **Reduce Phase 2 epochs**: 30 → 20
2. **Add LR scheduling**: ReduceLROnPlateau (factor=0.5, patience=3)
3. **Increase dropout**: 0.2 → 0.3
4. **Use full dataset normalization stats**
5. **Better monitoring**: Track R² every epoch, save multiple checkpoints

**Expected Kaggle score:** 0.55-0.58 (vs current 0.51)

---