# Cherry Classifier Training on Google Colab

This notebook runs the ResNet50 cherry pit classifier training on Google Colab Pro with GPU acceleration.

## Prerequisites
- Google Colab Pro subscription (for GPU access)
- Google Drive (for saving model checkpoints)

## Steps
1. **Setup**: Install dependencies and clone repos
2. **Configure**: Mount Google Drive for checkpoint saving
3. **Train**: Run training script
4. **Monitor**: View training progress
5. **Download**: Save trained model to Drive

## Step 1: Check GPU Availability

In [None]:
# Check if GPU is available
!nvidia-smi

import torch
print(f"\nPyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU device: {torch.cuda.get_device_name(0)}")

## Step 2: Mount Google Drive

In [None]:
# Mount Google Drive for saving checkpoints
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

# Create output directory in Drive
import os
drive_output_dir = '/content/drive/MyDrive/cherry_training/outputs'
os.makedirs(drive_output_dir, exist_ok=True)
print(f"\nOutput directory: {drive_output_dir}")

## Step 3: Install Dependencies

In [None]:
# Install required packages
!pip install -q pyyaml scikit-learn matplotlib

print("\nDependencies installed successfully!")

## Step 4: Clone Repositories

In [None]:
# Clone training code repository
!git clone https://github.com/usefulmove/cherries.git

# Clone dataset repository (shallow clone to save time)
!git clone --depth 1 https://github.com/weshavener/cherry_classification.git

print("\nRepositories cloned successfully!")

## Step 5: Verify Dataset

In [None]:
# Check dataset structure and count images
import os
from pathlib import Path

data_root = Path('/content/cherry_classification/data')

train_clean = data_root / 'train' / 'cherry_clean'
train_pit = data_root / 'train' / 'cherry_pit'
val_clean = data_root / 'val' / 'cherry_clean'
val_pit = data_root / 'val' / 'cherry_pit'

print("Dataset Summary:")
print("=" * 50)
print(f"Training clean images: {len(list(train_clean.glob('*')))}")
print(f"Training pit images: {len(list(train_pit.glob('*')))}")
print(f"Validation clean images: {len(list(val_clean.glob('*')))}")
print(f"Validation pit images: {len(list(val_pit.glob('*')))}")
print("=" * 50)

## Step 6: Run Training

**Note**: This will take 1-2 hours to complete 30 epochs on Colab Pro GPU.

In [None]:
# Change to training directory
%cd /content/cherries

# Run training script
!python training/scripts/train.py \
    --config training/configs/resnet50_baseline.yaml \
    --data-root /content/cherry_classification/data \
    --output-dir {drive_output_dir}/resnet50_baseline

## Step 7: Plot Training Curves

In [None]:
# Generate training curves
!python training/scripts/plot_metrics.py \
    {drive_output_dir}/resnet50_baseline/metrics.json \
    --output {drive_output_dir}/resnet50_baseline/training_curves.png

# Display the plot
from IPython.display import Image, display
display(Image(f'{drive_output_dir}/resnet50_baseline/training_curves.png'))

## Step 8: View Training Summary

In [None]:
# Read and display final metrics
import json

metrics_file = f'{drive_output_dir}/resnet50_baseline/metrics.json'

# Read all metrics
val_metrics = []
with open(metrics_file, 'r') as f:
    for line in f:
        data = json.loads(line)
        if data.get('phase') == 'val':
            val_metrics.append(data)

# Print summary
if val_metrics:
    final = val_metrics[-1]
    best_acc = max([m['val_accuracy'] for m in val_metrics])
    best_epoch = [m for m in val_metrics if m['val_accuracy'] == best_acc][0]['epoch']
    
    print("\n" + "="*50)
    print("TRAINING COMPLETE!")
    print("="*50)
    print(f"Final Validation Accuracy: {final['val_accuracy']:.4f}")
    print(f"Final Validation Loss: {final['val_loss']:.4f}")
    print(f"\nBest Validation Accuracy: {best_acc:.4f} (Epoch {best_epoch})")
    print("\nPer-Class Metrics:")
    for class_name, metrics in final['val_per_class_metrics'].items():
        print(f"  {class_name}:")
        print(f"    Precision: {metrics['precision']:.4f}")
        print(f"    Recall: {metrics['recall']:.4f}")
        print(f"    F1: {metrics['f1']:.4f}")
    print("="*50)
else:
    print("No validation metrics found!")

## Step 9: List Output Files

In [None]:
# List all files in output directory
!ls -lh {drive_output_dir}/resnet50_baseline/

print("\n" + "="*50)
print("Output files saved to Google Drive:")
print(f"{drive_output_dir}/resnet50_baseline/")
print("\nKey files:")
print("  - model_best.pt: Best model (highest val accuracy)")
print("  - model_final.pt: Final model (last epoch)")
print("  - metrics.json: Training metrics log")
print("  - training_curves.png: Training visualization")
print("  - checkpoint_epoch_*.pt: Periodic checkpoints")
print("="*50)

## Optional: Download Model to Local Machine

The models are already saved to your Google Drive, but you can also download them directly from Colab:

In [None]:
# Download best model
from google.colab import files

# Uncomment to download:
# files.download(f'{drive_output_dir}/resnet50_baseline/model_best.pt')
# files.download(f'{drive_output_dir}/resnet50_baseline/training_curves.png')

print("To download files, uncomment the lines above and run this cell.")