# CSIRO Image2Biomass - Model Comparison

This notebook compares different model configurations and architectures to help you choose the best approach.

In [None]:
import json
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (14, 6)

## Load Results from Multiple Runs

In [None]:
# Scan outputs directory for completed runs
outputs_dir = Path.cwd().parent / "outputs" / "baseline"

results = []
for run_dir in outputs_dir.glob("run-*"):
    metrics_file = run_dir / "validation_constraint_metrics.json"
    if metrics_file.exists():
        with open(metrics_file) as f:
            metrics = json.load(f)
        
        # Extract run info from directory name
        run_name = run_dir.name
        
        results.append({
            'run': run_name,
            'rmse_raw': metrics.get('rmse_raw', None),
            'rmse_repaired': metrics.get('rmse_repaired', None),
            'mae_raw': metrics.get('mae_raw', None),
            'mae_repaired': metrics.get('mae_repaired', None),
            'improvement': metrics.get('rmse_raw', 0) - metrics.get('rmse_repaired', 0),
            **{f'rmse_{k}': v for k, v in metrics.get('per_target_rmse', {}).items()}
        })

df_results = pd.DataFrame(results)
print(f"Found {len(df_results)} completed runs")
df_results.head()

## Overall Performance Comparison

In [None]:
if len(df_results) > 0:
    # Sort by RMSE (repaired)
    df_sorted = df_results.sort_values('rmse_repaired')
    
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    # RMSE comparison
    x = range(len(df_sorted))
    axes[0].bar(x, df_sorted['rmse_raw'], alpha=0.5, label='Raw', color='lightcoral')
    axes[0].bar(x, df_sorted['rmse_repaired'], alpha=0.8, label='Repaired', color='steelblue')
    axes[0].set_xlabel('Run')
    axes[0].set_ylabel('RMSE (g)')
    axes[0].set_title('RMSE Comparison Across Runs')
    axes[0].legend()
    axes[0].set_xticks(x)
    axes[0].set_xticklabels(df_sorted['run'], rotation=45, ha='right')
    
    # MAE comparison
    axes[1].bar(x, df_sorted['mae_raw'], alpha=0.5, label='Raw', color='lightcoral')
    axes[1].bar(x, df_sorted['mae_repaired'], alpha=0.8, label='Repaired', color='forestgreen')
    axes[1].set_xlabel('Run')
    axes[1].set_ylabel('MAE (g)')
    axes[1].set_title('MAE Comparison Across Runs')
    axes[1].legend()
    axes[1].set_xticks(x)
    axes[1].set_xticklabels(df_sorted['run'], rotation=45, ha='right')
    
    plt.tight_layout()
    plt.show()
    
    print("\nBest Run (by RMSE repaired):")
    best_run = df_sorted.iloc[0]
    print(f"  Run: {best_run['run']}")
    print(f"  RMSE (repaired): {best_run['rmse_repaired']:.2f} g")
    print(f"  MAE (repaired): {best_run['mae_repaired']:.2f} g")
    print(f"  Improvement from constraint repair: {best_run['improvement']:.2f} g")
else:
    print("No completed runs found. Train a model first!")

## Per-Target Performance

In [None]:
if len(df_results) > 0:
    # Extract per-target RMSE columns
    target_cols = [col for col in df_results.columns if col.startswith('rmse_Dry') or col.startswith('rmse_GDM')]
    
    if target_cols:
        # Create heatmap
        plt.figure(figsize=(12, 8))
        
        # Prepare data for heatmap
        heatmap_data = df_results[['run'] + target_cols].set_index('run')
        heatmap_data.columns = [col.replace('rmse_', '') for col in heatmap_data.columns]
        
        sns.heatmap(heatmap_data.T, annot=True, fmt='.2f', cmap='RdYlGn_r', 
                   cbar_kws={'label': 'RMSE (g)'}, linewidths=0.5)
        plt.title('Per-Target RMSE Across Runs')
        plt.xlabel('Run')
        plt.ylabel('Target Variable')
        plt.tight_layout()
        plt.show()
        
        # Summary statistics
        print("\nPer-Target RMSE Summary:")
        print(heatmap_data.describe())

## Constraint Repair Impact

In [None]:
if len(df_results) > 0 and 'improvement' in df_results.columns:
    plt.figure(figsize=(10, 6))
    
    improvement_pct = (df_results['improvement'] / df_results['rmse_raw'] * 100)
    
    plt.bar(range(len(df_results)), improvement_pct, color='green', alpha=0.7, edgecolor='black')
    plt.xlabel('Run')
    plt.ylabel('Improvement (%)')
    plt.title('RMSE Improvement from Constraint Repair')
    plt.xticks(range(len(df_results)), df_results['run'], rotation=45, ha='right')
    plt.axhline(0, color='red', linestyle='--', linewidth=1)
    plt.grid(axis='y', alpha=0.3)
    plt.tight_layout()
    plt.show()
    
    print(f"\nAverage improvement: {improvement_pct.mean():.2f}%")
    print(f"Max improvement: {improvement_pct.max():.2f}%")
    print(f"Min improvement: {improvement_pct.min():.2f}%")

## Recommendations

Based on the results above, here are configuration recommendations for different scenarios.

In [None]:
print("="*80)
print("HYPERPARAMETER RECOMMENDATIONS")
print("="*80)
print("\n1. QUICK EXPERIMENTATION (5-10 min)")
print("   --max_epochs 10 --batch_size 16 --backbone resnet18")
print("\n2. BALANCED PERFORMANCE (20-30 min)")
print("   --max_epochs 50 --batch_size 8 --backbone efficientnet_b3")
print("\n3. BEST PERFORMANCE (1-2 hours)")
print("   --max_epochs 100 --batch_size 4 --backbone efficientnet_b4")
print("\n4. CROSS-VALIDATION (2-4 hours)")
print("   --max_epochs 50 --num_folds 5")
print("\n" + "="*80)