In [None]:
import requests
import os

# The direct download link for the dataset
url = "https://prod-dcd-datasets-cache-zipfiles.s3.eu-west-1.amazonaws.com/bwh3zbpkpv-1.zip"

# Define the local filename to save the downloaded file
# You can change this if you want a different name for the zip file
local_filename = "dataset.zip"

# --- Option 1: Using requests library (Recommended) ---
print(f"Attempting to download {url} using requests...")
try:
    with requests.get(url, stream=True) as r:
        r.raise_for_status()  # Raise an HTTPError for bad responses (4xx or 5xx)
        with open(local_filename, 'wb') as f:
            for chunk in r.iter_content(chunk_size=8192):
                f.write(chunk)
    print(f"Successfully downloaded {local_filename}")

except requests.exceptions.RequestException as e:
    print(f"Error during download: {e}")
    print("Please check the URL and your internet connection.")

# --- Unzipping the downloaded file ---
print(f"Attempting to unzip {local_filename}...")
try:
    import zipfile
    with zipfile.ZipFile(local_filename, 'r') as zip_ref:
        zip_ref.extractall("./") # Extract to the current directory
    print("Successfully unzipped the dataset.")

    # Optional: List the contents of the current directory to see the extracted files
    print("\nContents of the current directory after extraction:")
    for item in os.listdir("./"):
        print(item)

except FileNotFoundError:
    print(f"Error: {local_filename} not found. Download might have failed.")
except zipfile.BadZipFile:
    print(f"Error: {local_filename} is not a valid zip file or is corrupted.")
except Exception as e:
    print(f"An unexpected error occurred during unzipping: {e}")

# Optional: Clean up the downloaded zip file after extraction
# if os.path.exists(local_filename):
#     os.remove(local_filename)
#     print(f"Cleaned up {local_filename}")

In [None]:
import os

# Define the filename of the zip file to delete
zip_filename = "dataset.zip"

print(f"Attempting to delete {zip_filename}...")

try:
    if os.path.exists(zip_filename):
        os.remove(zip_filename)
        print(f"Successfully deleted {zip_filename}.")
    else:
        print(f"Warning: {zip_filename} not found, nothing to delete.")
except OSError as e:
    print(f"Error deleting {zip_filename}: {e}")
    print("Please check if the file is in use or if you have the necessary permissions.")

In [None]:
"""
AgriSentry AI Hybrid Training Script - IMPROVED VERSION
Fixes overfitting issues and adds proper early stopping
Optimized for H100 GPU with torch.compile, AMP, and increased data throughput.
"""

import warnings
warnings.filterwarnings('ignore')

import os
import sys
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
from pathlib import Path
import re
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

# PyTorch imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import colorama
from colorama import Fore, Style

print("✅ All imports successful!")

# =============================================================================
# DATASET AND PERFORMANCE CONFIGURATION
# =============================================================================

# Dataset configuration based on CCMT dataset structure
CROPS = ['Cashew', 'Cassava', 'Maize', 'Tomato']
DISEASES = {
    'Cashew': ['anthracnose', 'gumosis', 'healthy', 'leaf miner', 'red rust'],
    'Cassava': ['bacterial blight', 'brown spot', 'green mite', 'healthy', 'mosaic'],
    'Maize': ['fall armyworm', 'grasshoper', 'healthy', 'leaf beetle', 'leaf blight', 'leaf spot', 'streak virus'],
    'Tomato': ['healthy', 'leaf blight', 'leaf curl', 'septoria leaf spot', 'verticulium wilt']
}

# 🚀 PERFORMANCE-TUNED PARAMETERS
BATCH_SIZE = 512
NUM_WORKERS = 8
IMG_SIZE = (224, 224)

print(f"📊 Dataset Configuration:")
print(f"   Crops: {len(CROPS)}")
print(f"   Total Disease Classes: {sum(len(diseases) for diseases in DISEASES.values())}")
print(f"   Image Size: {IMG_SIZE}")
print(f"   Batch Size: {BATCH_SIZE}")
print(f"   Num Workers: {NUM_WORKERS}")

# =============================================================================
# UTILITY FUNCTIONS
# =============================================================================

def clean_disease_name(disease_folder_name):
    """Remove numbers from disease folder names"""
    cleaned = re.sub(r'\d+$', '', disease_folder_name)
    return cleaned.strip()

def get_image_files(directory):
    """Get all image files from a directory, handling multiple extensions."""
    extensions = ["*.jpg", "*.JPG", "*.jpeg", "*.JPEG"]
    files = []
    for ext in extensions:
        files.extend(list(directory.glob(ext)))
    return files

# =============================================================================
# PYTORCH DATASET CLASS
# =============================================================================

class AgriSentryDataset(Dataset):
    """Custom PyTorch Dataset for AgriSentry AI"""
    def __init__(self, filepaths, labels, transform=None):
        self.filepaths = filepaths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.filepaths)

    def __getitem__(self, idx):
        filepath = self.filepaths[idx]
        label = self.labels[idx]
        try:
            image = Image.open(filepath).convert('RGB')
        except Exception as e:
            print(f"Warning: Could not load image {filepath}. Skipping. Error: {e}")
            return self.__getitem__((idx + 1) % len(self))

        if self.transform:
            image = self.transform(image)
        return image, label

# =============================================================================
# DEVICE UTILITIES
# =============================================================================

def get_default_device():
    """Pick GPU if available, else CPU"""
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list, tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

class DeviceDataLoader():
    """Wrap a dataloader to move data to a device"""
    def __init__(self, dataloader, device):
        self.dataloader = dataloader
        self.device = device

    def __iter__(self):
        """Yield a batch of data after moving it to the device"""
        for b in self.dataloader:
            yield to_device(b, self.device)

    def __len__(self):
        """Number of batches"""
        return len(self.dataloader)

# =============================================================================
# IMPROVED CNN MODEL ARCHITECTURE WITH DROPOUT
# =============================================================================

def accuracy(outputs, labels):
    """Calculate accuracy"""
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

class ImageClassificationBase(nn.Module):
    def training_step(self, batch):
        images, labels = batch
        out = self(images)
        loss = F.cross_entropy(out, labels)
        return loss

    def validation_step(self, batch):
        images, labels = batch
        out = self(images)
        loss = F.cross_entropy(out, labels)
        acc = accuracy(out, labels)
        return {'val_loss': loss.detach(), 'val_acc': acc}

    def validation_epoch_end(self, outputs):
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()
        batch_accs = [x['val_acc'] for x in outputs]
        epoch_acc = torch.stack(batch_accs).mean()
        return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}

    def epoch_end(self, epoch, result):
        print("Epoch [{}], train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
            epoch, result['train_loss'], result['val_loss'], result['val_acc']))

def ConvBlock(in_channels, out_channels, pool=False, dropout=0.0):
    """Convolution block with BatchNormalization and optional dropout"""
    layers = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
              nn.BatchNorm2d(out_channels),
              nn.ReLU(inplace=True)]
    
    if dropout > 0:
        layers.append(nn.Dropout2d(dropout))
    
    if pool:
        layers.append(nn.MaxPool2d(4))
    
    return nn.Sequential(*layers)

class CNN_NeuralNet(ImageClassificationBase):
    """IMPROVED CNN Architecture for AgriSentry AI with Dropout Regularization"""
    def __init__(self, in_channels, num_diseases, dropout_rate=0.5):
        super().__init__()
        
        # Convolutional layers with dropout
        self.conv1 = ConvBlock(in_channels, 64, dropout=0.1)
        self.conv2 = ConvBlock(64, 128, pool=True, dropout=0.2)
        self.res1 = nn.Sequential(
            ConvBlock(128, 128, dropout=0.2), 
            ConvBlock(128, 128, dropout=0.2)
        )
        
        self.conv3 = ConvBlock(128, 256, pool=True, dropout=0.3)
        self.conv4 = ConvBlock(256, 512, pool=True, dropout=0.3)
        self.res2 = nn.Sequential(
            ConvBlock(512, 512, dropout=0.3), 
            ConvBlock(512, 512, dropout=0.3)
        )
        
        # Improved classifier with dropout
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Dropout(dropout_rate),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(dropout_rate * 0.6),  # Slightly less dropout in middle
            nn.Linear(256, num_diseases)
        )

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.res1(out) + out
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.res2(out) + out
        out = self.classifier(out)
        return out

# =============================================================================
# DATASET PROCESSOR
# =============================================================================

class CCMTDatasetProcessor:
    """Processes the CCMT dataset for AgriSentry AI model training."""
    def __init__(self, dataset_path):
        self.dataset_path = Path(dataset_path)
        if not self.dataset_path.exists():
            raise FileNotFoundError(f"Dataset directory not found at: {self.dataset_path}")
        self.crops = CROPS
        self.diseases = DISEASES
        self.img_size = IMG_SIZE
        self.batch_size = BATCH_SIZE
        self.class_names = self._create_class_mappings()
        self.num_classes = len(self.class_names)
        print(f"✅ Dataset processor initialized for {self.num_classes} classes.")

    def _create_class_mappings(self):
        """Create unified class mappings across all crops"""
        return sorted([f"{crop}_{disease}" for crop in self.crops for disease in self.diseases[crop]])

    def create_data_loaders(self, device):
        """Create PyTorch DataLoaders for training and validation with improved augmentation"""
        print("🔄 Creating Data Loaders...")
        
        # IMPROVED: More conservative augmentation to prevent overfitting
        train_transform = transforms.Compose([
            transforms.Resize(self.img_size),
            transforms.RandomRotation(10),  # Reduced from 20
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),  # Reduced
            transforms.RandomResizedCrop(self.img_size[0], scale=(0.8, 1.0)),  # Added
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        val_transform = transforms.Compose([
            transforms.Resize(self.img_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        train_filepaths, train_labels, test_filepaths, test_labels = [], [], [], []
        
        for crop in self.crops:
            for subset in ["train_set", "test_set"]:
                subset_path = self.dataset_path / crop / subset
                if not subset_path.exists(): 
                    continue
                    
                for disease_folder in [d for d in subset_path.iterdir() if d.is_dir()]:
                    disease_name = clean_disease_name(disease_folder.name)
                    class_name = f"{crop}_{disease_name}"
                    
                    if class_name not in self.class_names:
                        continue
                        
                    class_idx = self.class_names.index(class_name)
                    
                    for f in get_image_files(disease_folder):
                        if subset == "train_set":
                            train_filepaths.append(str(f))
                            train_labels.append(class_idx)
                        else:
                            test_filepaths.append(str(f))
                            test_labels.append(class_idx)

        print(f"📊 Total images found: {len(train_filepaths)} train, {len(test_filepaths)} test")
        
        train_dataset = AgriSentryDataset(train_filepaths, train_labels, transform=train_transform)
        test_dataset = AgriSentryDataset(test_filepaths, test_labels, transform=val_transform)

        train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True,
                                  num_workers=NUM_WORKERS, pin_memory=True, persistent_workers=True)
        test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False,
                                 num_workers=NUM_WORKERS, pin_memory=True, persistent_workers=True)

        train_loader = DeviceDataLoader(train_loader, device)
        test_loader = DeviceDataLoader(test_loader, device)
        print(f"✅ DataLoaders created.")
        return train_loader, test_loader

# =============================================================================
# IMPROVED TRAINING FUNCTIONS WITH EARLY STOPPING
# =============================================================================

@torch.no_grad()
def evaluate(model, val_loader):
    model.eval()
    outputs = [model.validation_step(batch) for batch in val_loader]
    return model.validation_epoch_end(outputs)

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

def fit_with_early_stopping(epochs, max_lr, model, train_loader, val_loader, 
                           weight_decay=0, grad_clip=None, opt_func=torch.optim.Adam, 
                           patience=7, min_delta=0.001):
    """
    IMPROVED: Training function with proper early stopping and better learning rate scheduling
    """
    torch.cuda.empty_cache()
    history = []
    best_val_loss = float('inf')
    patience_counter = 0
    best_val_acc = 0.0
    
    # Use fused optimizer for CUDA & initialize GradScaler for AMP
    optimizer = opt_func(model.parameters(), max_lr, weight_decay=weight_decay, fused=True)
    
    # IMPROVED: Use ReduceLROnPlateau instead of OneCycleLR for better stability
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=3, min_lr=1e-6
    )
    
    scaler = torch.cuda.amp.GradScaler()
    
    print(f"🚀 Starting training with patience={patience}, min_delta={min_delta}")
    print(f"📊 Initial learning rate: {max_lr}")

    for epoch in range(epochs):
        # Training phase
        model.train()
        train_losses = []
        
        for batch in train_loader:
            with torch.cuda.amp.autocast():
                loss = model.training_step(batch)
            train_losses.append(loss)
            
            # Scale loss, step, and update
            scaler.scale(loss).backward()
            if grad_clip:
                scaler.unscale_(optimizer)
                nn.utils.clip_grad_value_(model.parameters(), grad_clip)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)
        
        # Validation phase
        result = evaluate(model, val_loader)
        result['train_loss'] = torch.stack(train_losses).mean().item()
        result['lr'] = get_lr(optimizer)
        
        # Learning rate scheduling
        scheduler.step(result['val_loss'])
        
        # Display progress
        model.epoch_end(epoch, result)
        history.append(result)
        
        # Early stopping logic
        improved = False
        if result['val_loss'] < (best_val_loss - min_delta):
            best_val_loss = result['val_loss']
            best_val_acc = result['val_acc']
            patience_counter = 0
            improved = True
            
            # Save best model
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': result['val_loss'],
                'val_acc': result['val_acc'],
            }, 'best_model.pth')
            print(f"   ✅ New best model saved! Val Loss: {best_val_loss:.4f}, Val Acc: {best_val_acc:.4f}")
        else:
            patience_counter += 1
            
        print(f"   📊 Patience: {patience_counter}/{patience}, Best Val Acc: {best_val_acc:.4f}")
        
        # Early stopping check
        if patience_counter >= patience:
            print(f"\n🛑 Early stopping triggered after {epoch + 1} epochs")
            print(f"   Best validation loss: {best_val_loss:.4f}")
            print(f"   Best validation accuracy: {best_val_acc:.4f}")
            break
            
        # Stop if learning rate becomes too small
        if get_lr(optimizer) < 1e-6:
            print(f"\n🛑 Learning rate too small, stopping training")
            break

    return history

def plot_training_history(history, output_dir):
    """Enhanced plotting function with more details"""
    if not history:
        print("No history to plot")
        return
        
    epochs = range(len(history))
    val_acc = [x['val_acc'] for x in history]
    val_loss = [x['val_loss'] for x in history]
    train_loss = [x.get('train_loss', 0) for x in history]
    learning_rates = [x.get('lr', 0) for x in history]
    
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
    
    # Loss plot
    ax1.plot(epochs, train_loss, 'r--', label='Training Loss', linewidth=2)
    ax1.plot(epochs, val_loss, 'g-', label='Validation Loss', linewidth=2)
    ax1.set_title('Training and Validation Loss', fontsize=14)
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Accuracy plot
    ax2.plot(epochs, val_acc, 'b-', label='Validation Accuracy', linewidth=2)
    ax2.set_title('Validation Accuracy', fontsize=14)
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    # Learning rate plot
    ax3.plot(epochs, learning_rates, 'orange', label='Learning Rate', linewidth=2)
    ax3.set_title('Learning Rate Schedule', fontsize=14)
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('Learning Rate')
    ax3.set_yscale('log')
    ax3.legend()
    ax3.grid(True, alpha=0.3)
    
    # Loss difference plot
    loss_diff = [val - train for val, train in zip(val_loss, train_loss)]
    ax4.plot(epochs, loss_diff, 'purple', label='Val Loss - Train Loss', linewidth=2)
    ax4.axhline(y=0, color='black', linestyle='--', alpha=0.5)
    ax4.set_title('Overfitting Monitor (Val - Train Loss)', fontsize=14)
    ax4.set_xlabel('Epoch')
    ax4.set_ylabel('Loss Difference')
    ax4.legend()
    ax4.grid(True, alpha=0.3)
    
    plt.tight_layout()
    save_path = Path(output_dir) / 'training_history_improved.png'
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()
    print(f"✅ Training history saved to: {save_path}")

# =============================================================================
# MAIN TRAINING PIPELINE
# =============================================================================

def main():
    """Main training pipeline with improvements"""
    print("🚀 AgriSentry AI IMPROVED Training Pipeline")
    print("=" * 60)
    
    # --- Configuration ---
    DATASET_PATH = "Dataset for Crop Pest and Disease Detection/CCMT Dataset-Augmented"
    OUTPUT_DIR = "agrisentry_pytorch_output_improved"
    Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
    
    device = get_default_device()
    print(f"🖥️  Using device: {device}")
    
    # --- Data Processing ---
    processor = CCMTDatasetProcessor(DATASET_PATH)
    train_loader, test_loader = processor.create_data_loaders(device)
    
    # --- Model Building and Compilation ---
    print("\n🏗️  Building and Compiling Model...")
    model = to_device(CNN_NeuralNet(3, processor.num_classes, dropout_rate=0.5), device)
    print(f"   Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Compile the model for speedups
    print("   Compiling model with torch.compile()...")
    model = torch.compile(model)
    print("✅ Model compiled!")
    
    # --- IMPROVED Training Parameters ---
    print("\n🎯 Training Configuration:")
    num_epochs = 100
    max_lr = 0.003  # REDUCED from 0.01
    grad_clip = 0.1
    weight_decay = 1e-3  # INCREASED from 1e-4
    patience = 10  # INCREASED for better training
    min_delta = 0.001  # Minimum improvement threshold
    
    print(f"   Max epochs: {num_epochs}")
    print(f"   Learning rate: {max_lr}")
    print(f"   Weight decay: {weight_decay}")
    print(f"   Patience: {patience}")
    print(f"   Min delta: {min_delta}")
    
    # --- Training ---
    print("\n🚀 Starting Training...")
    history = fit_with_early_stopping(
        num_epochs, max_lr, model, train_loader, test_loader,
        grad_clip=grad_clip, weight_decay=weight_decay, 
        opt_func=torch.optim.Adam, patience=patience, min_delta=min_delta
    )
    
    # --- Load Best Model ---
    print("\n📥 Loading best model...")
    if Path('best_model.pth').exists():
        checkpoint = torch.load('best_model.pth')
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"✅ Best model loaded from epoch {checkpoint['epoch']}")
        print(f"   Final validation accuracy: {checkpoint['val_acc']:.4f}")
        print(f"   Final validation loss: {checkpoint['val_loss']:.4f}")
    
    # --- Saving Final Model ---
    print("\n💾 Saving Final Model...")
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    model_path = Path(OUTPUT_DIR) / f"agrisentry_model_improved_{timestamp}.pth"
    torch.save({
        'model_state_dict': model.state_dict(),
        'class_names': processor.class_names,
        'num_classes': processor.num_classes,
        'history': history,
        'hyperparameters': {
            'max_lr': max_lr,
            'weight_decay': weight_decay,
            'patience': patience,
            'batch_size': BATCH_SIZE,
            'dropout_rate': 0.5
        }
    }, model_path)
    print(f"✅ Final model saved to: {model_path}")
    
    # --- Plotting Results ---
    print("\n📊 Plotting Results...")
    plot_training_history(history, OUTPUT_DIR)
    
    # --- Final Summary ---
    if history:
        best_epoch = max(range(len(history)), key=lambda i: history[i]['val_acc'])
        best_result = history[best_epoch]
        print(f"\n🎉 Training Complete!")
        print(f"   Total epochs: {len(history)}")
        print(f"   Best epoch: {best_epoch}")
        print(f"   Best validation accuracy: {best_result['val_acc']:.4f}")
        print(f"   Best validation loss: {best_result['val_loss']:.4f}")
        print(f"   Final learning rate: {history[-1].get('lr', 'N/A')}")
    
    return model, history, processor

if __name__ == "__main__":
    model, history, processor = main()