# Crop Disease Detection - ResNet50 Training

This notebook implements the training pipeline for crop disease detection using ResNet50 transfer learning.

In [None]:
# Import required libraries
import sys
sys.path.append('../src')

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

from dataset import create_data_loaders
from model import create_model, get_model_summary
from train import Trainer

print("Libraries imported successfully!")

In [None]:
# Configuration
config = {
    'data_dir': '../data',
    'batch_size': 8,  # Small batch size for demo
    'num_epochs': 5,  # Reduced for quick training
    'learning_rate': 1e-4,
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu')
}

print(f"Using device: {config['device']}")
print(f"Configuration: {config}")

In [None]:
# Load dataset
print("Loading dataset...")
train_loader, val_loader, test_loader, class_names = create_data_loaders(
    data_dir=config['data_dir'],
    batch_size=config['batch_size'],
    num_workers=0
)

print(f"Dataset loaded successfully!")
print(f"Number of classes: {len(class_names)}")
print(f"Classes: {class_names}")
print(f"Training samples: {len(train_loader.dataset)}")
print(f"Validation samples: {len(val_loader.dataset)}")
print(f"Test samples: {len(test_loader.dataset)}")

In [None]:
# Create model
print("Creating ResNet50 model...")
model = create_model(num_classes=len(class_names), device=config['device'])
get_model_summary(model)

# Test forward pass
dummy_input = torch.randn(1, 3, 224, 224).to(config['device'])
output = model(dummy_input)
print(f"\nModel test - Input: {dummy_input.shape}, Output: {output.shape}")

In [None]:
# Initialize trainer
trainer = Trainer(model, train_loader, val_loader, class_names, config['device'])

print("Trainer initialized successfully!")
print("Ready to start training...")

In [None]:
# Start training
print("Starting training process...")

trained_model, history = trainer.train(
    num_epochs=config['num_epochs'],
    learning_rate=config['learning_rate'],
    checkpoint_path='../models/crop_disease_resnet50.pth',
    fine_tune_epoch=3  # Start fine-tuning earlier for demo
)

print("\nTraining completed!")

In [None]:
# Plot training results
trainer.plot_training_curves('../outputs/training_curves.png')

# Display training history
print("Training History:")
for epoch in range(len(history['train_loss'])):
    print(f"Epoch {epoch+1}: Train Loss: {history['train_loss'][epoch]:.4f}, "
          f"Train Acc: {history['train_acc'][epoch]:.4f}, "
          f"Val Loss: {history['val_loss'][epoch]:.4f}, "
          f"Val Acc: {history['val_acc'][epoch]:.4f}")

In [None]:
# Evaluate model
from evaluate import evaluate_model

print("Evaluating trained model...")
results = evaluate_model(
    checkpoint_path='../models/crop_disease_resnet50.pth',
    data_dir='../data',
    batch_size=config['batch_size']
)

print("\nModel evaluation completed!")
print(f"Final test accuracy: {results['metrics']['accuracy']:.4f}")

## Training Complete!

The ResNet50 model has been successfully trained for crop disease detection. The model checkpoint has been saved to `../models/crop_disease_resnet50.pth`.

### Next Steps:
1. Implement knowledge base (Step 5)
2. Add Grad-CAM visualization (Step 6)
3. Build FastAPI backend (Step 8)

### Files Generated:
- Model checkpoint: `models/crop_disease_resnet50.pth`
- Training curves: `outputs/training_curves.png`
- Evaluation results: `outputs/results.json`
- Confusion matrix: `outputs/confusion_matrix.png`