In [None]:
import os
import json
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.utils.data import Dataset, DataLoader
import timm
import numpy as np
from torchvision import transforms
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau
from PIL import Image
import random
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    roc_auc_score
)
import time
from tqdm import tqdm
import logging
from pathlib import Path
from datetime import datetime
import gc
from collections import Counter
import matplotlib.pyplot as plt
import seaborn as sns
from torch.amp import autocast, GradScaler

try:
    import psutil
    PSUTIL_AVAILABLE = True
except ImportError:
    PSUTIL_AVAILABLE = False

# Logging Configuration
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler()
    ]
)
logger = logging.getLogger('DLTraining')

# Global Config Variables
BASE_PATH = Path(r"D:\iate_project")
DATA_DIR = BASE_PATH / "original_dataset"
OUTPUT_DIR = BASE_PATH / "output" / "dl_training"

BATCH_SIZE = 32
NUM_WORKERS = 0  # Windows compatibility
NUM_EPOCHS = 30
NUM_FOLDS = 5

USE_AMP = True
PIN_MEMORY = True
DROP_RATE = 0.2  # Reduced further for lightweight models

EARLY_STOPPING_PATIENCE = 5
GRAD_CLIP_VALUE = 1.0

# Updated model definitions - Focusing on lightweight, high-performance models
MODELS = {
    # Very lightweight models
    'mobilenetv3_small_100': {
        'pretrained': True,
        'lr': 5e-5,
        'description': 'Ultra lightweight MobileNetV3 Small variant'
    },
    'efficientnet_lite0': {
        'pretrained': True,
        'lr': 5e-5,
        'description': 'Lightweight EfficientNet variant optimized for mobile'
    },
    'mobilevit_xxs': {
        'pretrained': True,
        'lr': 5e-5,
        'description': 'Extra small MobileViT - combines mobility and transformer benefits'
    },
    'shufflenet_v2_x1_0': {
        'pretrained': True,
        'lr': 5e-5,
        'description': 'ShuffleNetV2 with 1.0x output channels - very efficient'
    },
    # Slightly larger but still efficient models
    'efficientnet_b0': {
        'pretrained': True,
        'lr': 5e-5,
        'description': 'Smallest EfficientNet variant with good accuracy-size tradeoff'
    },
    'mobilevit_xs': {
        'pretrained': True,
        'lr': 5e-5,
        'description': 'Extra small MobileViT - larger than xxs but still efficient'
    }
}

# Device Setup
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
    logger.info(f"CUDA is available! Using {torch.cuda.get_device_name(0)}")
    logger.info(f"CUDA Version: {torch.version.cuda}")
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
else:
    DEVICE = torch.device("cpu")
    logger.info("CUDA is not available. Using CPU.")

# Set random seed
def set_random_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    cudnn.deterministic = True
    cudnn.benchmark = True

set_random_seed(42)

# Create output directories if necessary
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR / 'logs', exist_ok=True)
os.makedirs(OUTPUT_DIR / 'models', exist_ok=True)
os.makedirs(OUTPUT_DIR / 'results', exist_ok=True)
os.makedirs(OUTPUT_DIR / 'visualizations', exist_ok=True)

# Enhanced CoffeeDataset class to handle the specified directory structure
class CoffeeDataset(Dataset):
    """Dataset class for loading coffee bean images with enhanced augmentation strategies."""
    def __init__(self, root_dir, split='train', test_ratio=0.15, val_ratio=0.15):
        self.root_dir = root_dir
        self.split = split
        self.logger = logging.getLogger('CoffeeDataset')
        self.samples = []
        self.max_retries = 3

        # Load all samples from the specified directory structure
        normal_path = os.path.join(self.root_dir, "Normal")
        defect_path = os.path.join(self.root_dir, "Defect")

        # Handle case sensitivity
        if not os.path.exists(normal_path):
            normal_path = os.path.join(self.root_dir, "normal")
        if not os.path.exists(defect_path):
            defect_path = os.path.join(self.root_dir, "defect")

        self.logger.info(f"Loading samples from Normal: {normal_path}")
        self.logger.info(f"Loading samples from Defect: {defect_path}")

        # Process methods (Dry, Honey, Wet)
        all_filepaths = []
        all_labels = []

        # Process Normal samples
        for method in ['Dry', 'Honey', 'Wet', 'dry', 'honey', 'wet']:
            method_path = os.path.join(normal_path, method)
            if not os.path.exists(method_path):
                continue

            for roast in ['Dark', 'Light', 'Medium', 'dark', 'light', 'medium']:
                roast_path = os.path.join(method_path, roast)
                if not os.path.exists(roast_path):
                    continue

                for img_name in os.listdir(roast_path):
                    if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                        all_filepaths.append(os.path.join(roast_path, img_name))
                        all_labels.append(0)  # Normal

        # Process Defect samples
        for method in ['Dry', 'Honey', 'Wet', 'dry', 'honey', 'wet']:
            method_path = os.path.join(defect_path, method)
            if not os.path.exists(method_path):
                continue

            for roast in ['Dark', 'Light', 'Medium', 'dark', 'light', 'medium']:
                roast_path = os.path.join(method_path, roast)
                if not os.path.exists(roast_path):
                    continue

                for img_name in os.listdir(roast_path):
                    if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                        all_filepaths.append(os.path.join(roast_path, img_name))
                        all_labels.append(1)  # Defect

        # Split the dataset into train, validation, and test sets
        unique_paths = list(set(all_filepaths))  # Ensure no duplicates
        random.shuffle(unique_paths)

        test_size = int(len(unique_paths) * test_ratio)
        val_size = int(len(unique_paths) * val_ratio)

        test_paths = set(unique_paths[:test_size])
        val_paths = set(unique_paths[test_size:test_size + val_size])
        train_paths = set(unique_paths[test_size + val_size:])

        # Assign samples based on split
        if split == 'train':
            target_paths = train_paths
        elif split == 'val':
            target_paths = val_paths
        else:  # 'test'
            target_paths = test_paths

        # Filter samples for the current split
        for filepath, label in zip(all_filepaths, all_labels):
            if filepath in target_paths:
                self.samples.append({
                    'filepath': filepath,
                    'label': label
                })

        self.logger.info(f"Loaded {len(self.samples)} samples for split: {self.split}")

        # Analyze class distribution
        labels = [sample['label'] for sample in self.samples]
        class_counts = Counter(labels)
        total = len(labels)

        self.logger.info(f"\n{self.split} set class distribution:")
        for label, count in class_counts.items():
            percentage = (count / total) * 100
            self.logger.info(f"Class {label}: {count} samples ({percentage:.2f}%)")

        # Enhanced but lightweight data augmentation for training (optimized for speed)
        if self.split == 'train':
            self.transform = transforms.Compose([
                transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomVerticalFlip(p=0.3),
                transforms.RandomRotation(
                    degrees=30,
                    fill=1.0,
                    interpolation=transforms.InterpolationMode.BILINEAR
                ),
                transforms.ColorJitter(
                    brightness=0.2, contrast=0.2,
                    saturation=0.2, hue=0.1
                ),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
        else:
            # Validation and test transforms - keep simple for efficiency
            self.transform = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        image_path = sample['filepath']
        label = sample['label']

        for attempt in range(self.max_retries):
            try:
                image = Image.open(image_path).convert('RGB')
                return self.transform(image), label
            except Exception as e:
                self.logger.warning(
                    f"Failed to load/transform image {image_path} on attempt {attempt+1}: {str(e)}"
                )
                idx = (idx + 1) % len(self)

        self.logger.error(
            f"Failed to load image {image_path} after {self.max_retries} attempts. Returning default sample."
        )
        return torch.zeros(3, 224, 224), -1

# Create Datasets
logger.info("Creating train/val/test datasets...")

# Create separate instances for train, val, test
train_dataset = CoffeeDataset(str(DATA_DIR), split='train')
val_dataset = CoffeeDataset(str(DATA_DIR), split='val')
test_dataset = CoffeeDataset(str(DATA_DIR), split='test')

# Create Data Loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    drop_last=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY
)

logger.info("Dataset sizes:")
logger.info(f"Train: {len(train_dataset)} samples")
logger.info(f"Validation: {len(val_dataset)} samples")
logger.info(f"Test: {len(test_dataset)} samples")

# Early stopping parameters
early_stopping_patience = EARLY_STOPPING_PATIENCE
early_stopping_min_delta = 0.001

# Metric calculation function
def calculate_metrics_inline(targets, predictions, probabilities=None):
    """Calculate metrics such as accuracy, precision, recall, f1-score, specificity, confusion matrix."""
    try:
        if len(targets) != len(predictions):
            raise ValueError("Length of targets and predictions must match")
        if probabilities is not None and len(targets) != len(probabilities):
            raise ValueError("Length of targets and probabilities must match")

        metrics = {
            'accuracy': accuracy_score(targets, predictions),
            'precision': precision_score(targets, predictions, average='binary', zero_division=0),
            'recall': recall_score(targets, predictions, average='binary', zero_division=0),
            'f1_score': f1_score(targets, predictions, average='binary', zero_division=0)
        }

        if probabilities is not None:
            try:
                metrics['roc_auc'] = roc_auc_score(targets, probabilities)
            except ValueError as e:
                logger.warning(f"ROC-AUC calculation failed: {str(e)}")
                metrics['roc_auc'] = 0.0

        cm = confusion_matrix(targets, predictions)
        if cm.shape == (2, 2):
            tn, fp, fn, tp = cm.ravel()
            metrics.update({
                'specificity': tn / (tn + fp) if (tn + fp) > 0 else 0,
                'confusion_matrix': cm.tolist()
            })

        return metrics
    except Exception as e:
        logger.error(f"Error in metrics calculation: {str(e)}")
        return {
            'accuracy': 0.0,
            'precision': 0.0,
            'recall': 0.0,
            'f1_score': 0.0,
            'confusion_matrix': [[0, 0], [0, 0]],
            'specificity': 0.0,
            'error': str(e)
        }

# Enhanced model metrics calculation
def compute_model_metrics_inline(model):
    """Calculate model size, params count, and complexity metrics."""
    # Calculate parameter count
    param_count = sum(p.numel() for p in model.parameters())

    # Calculate model size in MB
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
    size_mb = (param_size + buffer_size) / 1024**2

    return {
        'param_count': param_count,
        'size_mb': size_mb
    }

# Plot confusion matrix
def plot_confusion_matrix_inline(cm, output_dir, model_name):
    """Plot confusion matrix."""
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['Normal', 'Defect'], yticklabels=['Normal', 'Defect'])
    plt.title(f'Confusion Matrix - {model_name}')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    os.makedirs(output_dir, exist_ok=True)
    plt.savefig(os.path.join(output_dir, f'{model_name}_confusion_matrix.png'))
    plt.close()

# Plot training curves
def plot_training_curves_inline(history, output_dir, model_name):
    """Plot training loss, accuracy, LR schedule, etc."""
    sns.set_theme(style='whitegrid')
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle(f'Training Curves - {model_name}', fontsize=16)

    axes[0, 0].plot(history['train_loss'], label='Training Loss')
    axes[0, 0].plot(history['val_loss'], label='Validation Loss')
    axes[0, 0].set_title('Loss Curves')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()

    axes[0, 1].plot(history['train_acc'], label='Training Accuracy')
    axes[0, 1].plot(history['val_acc'], label='Validation Accuracy')
    axes[0, 1].set_title('Accuracy Curves')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy')
    axes[0, 1].legend()

    axes[1, 0].plot(history['lr_history'], label='Learning Rate')
    axes[1, 0].set_title('Learning Rate Schedule')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Learning Rate')
    axes[1, 0].set_yscale('log')

    axes[1, 1].plot(history['epoch_times'], label='Training Time')
    axes[1, 1].set_title('Training Time per Epoch')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Time (seconds)')

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    os.makedirs(output_dir, exist_ok=True)
    plt.savefig(os.path.join(output_dir, f'{model_name}_training_curves.png'))
    plt.close()

# Enhanced model evaluation function with detailed timing
def evaluate_model_inline(model, test_loader, criterion):
    model.eval()
    running_loss = 0.0
    predictions = []
    targets = []
    probabilities = []
    inference_times = []
    batch_sizes = []

    with torch.no_grad():
        for images, labels in test_loader:
            batch_sizes.append(images.shape[0])
            images = images.to(DEVICE, non_blocking=True)
            labels = labels.to(DEVICE, non_blocking=True)

            # Warm-up run to eliminate first-batch overhead
            if len(inference_times) == 0:
                _ = model(images)
                torch.cuda.synchronize()

            # Measure inference time with synchronization for accuracy
            torch.cuda.synchronize()
            start_time = time.time()
            outputs = model(images)
            torch.cuda.synchronize()
            inference_time = time.time() - start_time
            inference_times.append(inference_time)

            loss = criterion(outputs, labels)
            probs = torch.softmax(outputs, dim=1)[:, 1]

            running_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)

            predictions.extend(preds.cpu().numpy())
            targets.extend(labels.cpu().numpy())
            probabilities.extend(probs.cpu().numpy())

    # Compute metrics
    test_loss = running_loss / len(test_loader.dataset)
    metrics = calculate_metrics_inline(targets, predictions, probabilities)
    metrics['loss'] = test_loss

    # More accurate inference time calculations
    total_samples = sum(batch_sizes)
    total_inference_time = sum(inference_times)

    # Skip the first batch timing (warm-up)
    if len(inference_times) > 1:
        avg_inference_time = sum(inference_times[1:]) / (len(inference_times) - 1)
        inference_time_per_image = sum(inference_times[1:]) / sum(batch_sizes[1:])
    else:
        avg_inference_time = total_inference_time
        inference_time_per_image = total_inference_time / total_samples

    metrics['avg_inference_time'] = avg_inference_time  # Average time per batch
    metrics['inference_time_per_image'] = inference_time_per_image  # Average time per image
    metrics['total_inference_time'] = total_inference_time
    metrics['images_per_second'] = 1.0 / inference_time_per_image if inference_time_per_image > 0 else 0

    return metrics, targets, predictions, probabilities

# Start model evaluation and selection
logger.info("Initializing models for evaluation...")

models_info = {}

# Initialize and evaluate each model
for model_name, model_cfg in MODELS.items():
    try:
        start_time = time.time()

        # Create the model
        model = timm.create_model(
            model_name,
            pretrained=model_cfg['pretrained'],
            num_classes=2,
            drop_rate=DROP_RATE
        ).to(DEVICE)

        # Calculate model metrics
        model_metrics = compute_model_metrics_inline(model)

        # Store model information
        models_info[model_name] = {
            'model': model,
            'lr': model_cfg['lr'],
            'size_mb': model_metrics['size_mb'],
            'param_count': model_metrics['param_count'],
            'description': model_cfg.get('description', 'No description provided')
        }

        logger.info(f"Initialized model: {model_name}")
        logger.info(f"  - Parameters: {model_metrics['param_count']:,}")
        logger.info(f"  - Size: {model_metrics['size_mb']:.2f} MB")
        logger.info(f"  - Description: {model_cfg.get('description', 'Not provided')}")

    except Exception as e:
        logger.error(f"Failed to initialize model {model_name}: {str(e)}")

# Single fold validation to quickly assess each model
logger.info("\nPerforming quick validation of all models...")

validation_results = {}

for model_name, m_info in models_info.items():
    logger.info(f"\nValidating model: {model_name}")

    # Set up optimizer and criterion
    model = m_info['model']
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=m_info['lr'],
        weight_decay=0.01,
        betas=(0.9, 0.999)
    )

    scheduler = CosineAnnealingLR(
        optimizer,
        T_max=10,  # Shorter for quick validation
        eta_min=1e-6
    )

    criterion = nn.CrossEntropyLoss()
    scaler = GradScaler(enabled=USE_AMP)

    # Train for just a few epochs to get a sense of model performance
    train_time_start = time.time()

    for epoch in range(5):  # Train for just 5 epochs for quick assessment
        # Train step
        model.train()
        running_loss = 0.0

        for images, labels in train_loader:
            images = images.to(DEVICE, non_blocking=True)
            labels = labels.to(DEVICE, non_blocking=True)
            optimizer.zero_grad(set_to_none=True)

            with autocast(device_type='cuda', enabled=USE_AMP):
                outputs = model(images)
                loss = criterion(outputs, labels)

            if USE_AMP:
                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_VALUE)
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_VALUE)
                optimizer.step()

            running_loss += loss.item() * images.size(0)

        train_loss = running_loss / len(train_loader.dataset)
        scheduler.step()

        # Log progress for each epoch
        logger.info(f"  Epoch {epoch+1}/5 - Train Loss: {train_loss:.4f}")

    train_time = time.time() - train_time_start

    # Evaluate on validation set
    logger.info(f"Evaluating {model_name} on validation set...")
    val_metrics, val_targets, val_preds, val_probs = evaluate_model_inline(
        model, val_loader, criterion
    )

    # Evaluate on test set
    logger.info(f"Evaluating {model_name} on test set...")
    test_metrics, test_targets, test_preds, test_probs = evaluate_model_inline(
        model, test_loader, criterion
    )

    # Store results
    validation_results[model_name] = {
        'val_metrics': val_metrics,
        'test_metrics': test_metrics,
        'train_time_total': train_time,
        'train_time_per_epoch': train_time / 5,
        'size_mb': m_info['size_mb'],
        'param_count': m_info['param_count']
    }

    # Log model performance
    logger.info(f"\nPerformance Summary for {model_name}:")
    logger.info(f"  Model Size: {m_info['size_mb']:.2f} MB")
    logger.info(f"  Parameters: {m_info['param_count']:,}")
    logger.info(f"  Training Time: {train_time:.2f}s ({train_time/5:.2f}s per epoch)")
    logger.info(f"  Inference Speed: {test_metrics['images_per_second']:.2f} images/second")
    logger.info(f"  Validation Accuracy: {val_metrics['accuracy']:.4f}")
    logger.info(f"  Validation F1: {val_metrics['f1_score']:.4f}")
    logger.info(f"  Test Accuracy: {test_metrics['accuracy']:.4f}")
    logger.info(f"  Test F1: {test_metrics['f1_score']:.4f}")

    # Save model state for top models
    model_save_path = os.path.join(str(OUTPUT_DIR), 'models', f"{model_name}_quick_val.pth")
    torch.save(model.state_dict(), model_save_path)

# Choose best models based on validation results
logger.info("\nSelecting best models for full training...")

# Define a balanced score that considers accuracy, speed, and model size
def calculate_model_efficiency_score(result):
    accuracy = result['test_metrics']['accuracy']
    f1_score = result['test_metrics']['f1_score']
    inference_speed = result['test_metrics']['images_per_second']
    model_size = result['size_mb']

    # Balanced score formula - higher is better
    # 50% for performance (F1 + accuracy), 30% for inference speed, 20% for model size efficiency
    performance_score = 0.3 * accuracy + 0.2 * f1_score
    speed_score = 0.3 * min(1.0, inference_speed / 100)  # Normalize speed, cap at 100 img/sec
    size_score = 0.2 * (1.0 - min(1.0, model_size / 100))  # Smaller size is better, cap at 100MB

    return performance_score + speed_score + size_score

# Calculate efficiency scores
model_scores = {}
for model_name, result in validation_results.items():
    score = calculate_model_efficiency_score(result)
    model_scores[model_name] = score

# Sort models by score
sorted_models = sorted(model_scores.items(), key=lambda x: x[1], reverse=True)

# Select top 3 models for full training
top_models = [name for name, _ in sorted_models[:3]]

logger.info("Selected models for full training:")
for i, model_name in enumerate(top_models):
    result = validation_results[model_name]
    logger.info(f"{i+1}. {model_name}")
    logger.info(f"   - Score: {model_scores[model_name]:.4f}")
    logger.info(f"   - Size: {result['size_mb']:.2f} MB")
    logger.info(f"   - F1 Score: {result['test_metrics']['f1_score']:.4f}")
    logger.info(f"   - Inference: {result['test_metrics']['images_per_second']:.2f} img/sec")

# Full training with K-Fold cross-validation for selected models
logger.info("\nStarting full K-Fold cross-validation for selected models...")

# Training function for full training
def train_model_with_kfold(model_name, model_info):
    logger.info(f"\nStarting full training for model: {model_name}")
    start_time = time.time()
    current_fold_metrics = []
    best_global_state = None
    best_global_loss = float('inf')
    best_global_fold = None

    # Prepare K-Fold splits
    labels_train_dataset = [sample['label'] for sample in train_dataset.samples]
    kf = StratifiedKFold(n_splits=NUM_FOLDS, shuffle=True, random_state=42)

    for fold, (train_indices, val_indices) in enumerate(kf.split(train_dataset.samples, labels_train_dataset), 1):
        logger.info(f"\nStarting fold {fold}/{NUM_FOLDS}")

        # Create subset for train/val
        train_subset = torch.utils.data.Subset(train_dataset, train_indices)
        val_subset = torch.utils.data.Subset(train_dataset, val_indices)

        # Re-init model for each fold
        model = timm.create_model(
            model_name,
            pretrained=True,
            num_classes=2,
            drop_rate=DROP_RATE
        ).to(DEVICE)

        # Set up optimizer
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=model_info['lr'],
            weight_decay=0.01,
            betas=(0.9, 0.999)
        )

        # Set up learning rate scheduler
        scheduler = CosineAnnealingLR(
            optimizer,
            T_max=NUM_EPOCHS,
            eta_min=1e-6
        )

        criterion = nn.CrossEntropyLoss()
        scaler = GradScaler(enabled=USE_AMP)

        # Create DataLoader for subset
        fold_train_loader = DataLoader(
            train_subset,
            batch_size=BATCH_SIZE,
            shuffle=True,
            num_workers=NUM_WORKERS,
            pin_memory=PIN_MEMORY,
            drop_last=True
        )

        fold_val_loader = DataLoader(
            val_subset,
            batch_size=BATCH_SIZE,
            shuffle=False,
            num_workers=NUM_WORKERS,
            pin_memory=PIN_MEMORY
        )

        # Prepare training history
        training_history = {
            'train_loss': [], 'val_loss': [],
            'train_acc': [], 'val_acc': [],
            'train_metrics': [], 'val_metrics': [],
            'lr_history': [], 'epoch_times': []
        }

        best_val_loss = float('inf')
        best_model_state = None

        # Early stopping state
        es_best_loss = None
        es_counter = 0
        es_early_stop = False

        # Start epoch loop
        for epoch in range(NUM_EPOCHS):
            epoch_start_time = time.time()

            # -------------------- TRAIN EPOCH --------------------
            model.train()
            running_loss = 0.0
            predictions = []
            targets_list = []
            probabilities_list = []

            for images, labels in fold_train_loader:
                images = images.to(DEVICE, non_blocking=True)
                labels = labels.to(DEVICE, non_blocking=True)
                optimizer.zero_grad(set_to_none=True)

                with autocast(device_type='cuda', enabled=USE_AMP):
                    outputs = model(images)
                    loss = criterion(outputs, labels)
                    probs = torch.softmax(outputs, dim=1)[:, 1]

                if USE_AMP:
                    scaler.scale(loss).backward()
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_VALUE)
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_VALUE)
                    optimizer.step()

                running_loss += loss.item() * images.size(0)
                _, preds = torch.max(outputs, 1)
                predictions.extend(preds.detach().cpu().numpy())
                probabilities_list.extend(probs.detach().cpu().numpy())
                targets_list.extend(labels.cpu().numpy())

            epoch_train_loss = running_loss / len(fold_train_loader.dataset)
            train_metrics_dict = calculate_metrics_inline(targets_list, predictions, probabilities_list)
            train_metrics_dict['loss'] = epoch_train_loss

            # -------------------- VALIDATE EPOCH --------------------
            model.eval()
            val_running_loss = 0.0
            val_predictions = []
            val_targets_list = []
            val_probabilities_list = []

            with torch.no_grad():
                for images_val, labels_val in fold_val_loader:
                    images_val = images_val.to(DEVICE, non_blocking=True)
                    labels_val = labels_val.to(DEVICE, non_blocking=True)
                    outputs_val = model(images_val)
                    loss_val = criterion(outputs_val, labels_val)
                    probs_val = torch.softmax(outputs_val, dim=1)[:, 1]

                    val_running_loss += loss_val.item() * images_val.size(0)
                    _, preds_val = torch.max(outputs_val, 1)
                    val_predictions.extend(preds_val.detach().cpu().numpy())
                    val_probabilities_list.extend(probs_val.detach().cpu().numpy())
                    val_targets_list.extend(labels_val.cpu().numpy())

            val_loss = val_running_loss / len(fold_val_loader.dataset)
            val_metrics_dict = calculate_metrics_inline(val_targets_list, val_predictions, val_probabilities_list)
            val_metrics_dict['loss'] = val_loss

            # Logging training/val info
            current_lr = optimizer.param_groups[0]['lr']
            logger.info(
                f"Epoch {epoch+1}/{NUM_EPOCHS} | "
                f"Train Loss: {train_metrics_dict['loss']:.4f} | "
                f"Val Loss: {val_metrics_dict['loss']:.4f} | "
                f"Val F1: {val_metrics_dict['f1_score']:.4f} | "
                f"Val Acc: {val_metrics_dict['accuracy']:.4f}"
            )

            # Step the scheduler
            scheduler.step()

            # Update training history
            training_history['train_loss'].append(train_metrics_dict['loss'])
            training_history['val_loss'].append(val_metrics_dict['loss'])
            training_history['train_acc'].append(train_metrics_dict['accuracy'])
            training_history['val_acc'].append(val_metrics_dict['accuracy'])
            training_history['train_metrics'].append(train_metrics_dict)
            training_history['val_metrics'].append(val_metrics_dict)
            training_history['lr_history'].append(current_lr)
            training_history['epoch_times'].append(time.time() - epoch_start_time)

            # Save best model for this fold
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_model_state = model.state_dict().copy()

            # Inline early stopping
            if es_best_loss is None:
                es_best_loss = val_loss
            else:
                if val_loss > es_best_loss - early_stopping_min_delta:
                    es_counter += 1
                    if es_counter >= early_stopping_patience:
                        es_early_stop = True
                else:
                    es_best_loss = val_loss
                    es_counter = 0

            if es_early_stop:
                logger.info("Early stopping triggered.")
                break

        # After finishing epochs for this fold
        if best_model_state is not None:
            model.load_state_dict(best_model_state)

        # Store fold results
        fold_results = {
            'model_state': best_model_state,
            'metrics': training_history,
            'best_val_loss': best_val_loss
        }

        # Summarize fold results
        final_train_metrics = fold_results['metrics']['train_metrics'][-1]
        final_val_metrics = fold_results['metrics']['val_metrics'][-1]

        summary_str = (
            f"\nFold {fold} Summary - {model_name}:\n"
            f"Validation Metrics:\n"
            f"- Accuracy: {final_val_metrics['accuracy']:.4f}\n"
            f"- F1 Score: {final_val_metrics['f1_score']:.4f}\n"
            f"- Precision: {final_val_metrics['precision']:.4f}\n"
            f"- Recall: {final_val_metrics['recall']:.4f}\n"
        )
        logger.info(summary_str)

        current_fold_metrics.append(fold_results['metrics'])

        # Check if this fold is better globally
        if best_val_loss < best_global_loss:
            best_global_loss = best_val_loss
            best_global_state = best_model_state
            best_global_fold = fold

        # Clean up
        del train_subset, val_subset, fold_train_loader, fold_val_loader, optimizer, criterion, model
        torch.cuda.empty_cache()
        gc.collect()

    # After all folds, recreate the model and load best weights
    best_model = timm.create_model(
        model_name,
        pretrained=False,
        num_classes=2,
        drop_rate=DROP_RATE
    ).to(DEVICE)

    if best_global_state is not None:
        best_model.load_state_dict(best_global_state)
        logger.info(f"Loaded best fold state (fold {best_global_fold}) for {model_name}.")

    training_time = time.time() - start_time

    return {
        'model': best_model,
        'all_fold_metrics': current_fold_metrics,
        'training_time': training_time,
        'best_fold': best_global_fold
    }

# Full training for top models
full_training_results = {}
for model_name in top_models:
    result = train_model_with_kfold(model_name, models_info[model_name])
    full_training_results[model_name] = result

# Final evaluation on test set
logger.info("\nEvaluating fully trained models on test set...")

final_test_results = {}
for model_name, training_result in full_training_results.items():
    logger.info(f"Evaluating {model_name} on test set...")
    model = training_result['model']
    criterion_test = nn.CrossEntropyLoss()

    # Evaluate the model
    metrics, targets_test, predictions_test, probabilities_test = evaluate_model_inline(
        model, test_loader, criterion_test
    )

    # Log detailed metrics
    logger.info(f"Test Results for {model_name}:")
    logger.info(f"  Accuracy: {metrics['accuracy']:.4f}")
    logger.info(f"  F1 Score: {metrics['f1_score']:.4f}")
    logger.info(f"  Precision: {metrics['precision']:.4f}")
    logger.info(f"  Recall: {metrics['recall']:.4f}")
    logger.info(f"  ROC-AUC: {metrics.get('roc_auc', 'N/A')}")
    logger.info(f"  Inference Speed: {metrics['images_per_second']:.2f} images/second")

    final_test_results[model_name] = {
        'metrics': metrics,
        'targets': targets_test,
        'predictions': predictions_test,
        'probabilities': probabilities_test
    }

    # Save test results
    results_dir = os.path.join(str(OUTPUT_DIR), 'results')
    os.makedirs(results_dir, exist_ok=True)
    test_results_path = os.path.join(results_dir, f"{model_name}_final_test_results.json")

    test_data_dict = {
        "targets": [int(t) for t in targets_test],
        "predictions": [int(p) for p in predictions_test],
        "probabilities": [float(prob) for prob in probabilities_test],
        "metrics": {k: float(v) if isinstance(v, (int, float)) else v
                    for k, v in metrics.items() if k != 'confusion_matrix'}
    }

    with open(test_results_path, 'w') as f_out:
        json.dump(test_data_dict, f_out, indent=4)

    # Confusion matrix plot
    if 'confusion_matrix' in metrics:
        cm_array = np.array(metrics['confusion_matrix'])
        plot_confusion_matrix_inline(
            cm_array,
            os.path.join(str(OUTPUT_DIR), 'visualizations'),
            f"{model_name}_final_test"
        )

    # Save the final model
    model_save_path = os.path.join(str(OUTPUT_DIR), 'models', f"{model_name}_final_model.pth")
    torch.save(model.state_dict(), model_save_path)
    logger.info(f"Saved final model to {model_save_path}")

# Select best final model
logger.info("\nSelecting best final model...")

best_final_model = None
best_final_score = -float('inf')

for model_name, result in final_test_results.items():
    metrics = result['metrics']

    # Calculate balanced score for final selection
    # Balance accuracy, F1 score, and efficiency
    accuracy = metrics['accuracy']
    f1_score = metrics['f1_score']
    inference_speed = metrics['images_per_second']
    model_size = models_info[model_name]['size_mb']

    # Score formula: 40% F1, 30% accuracy, 20% speed, 10% size
    final_score = (0.4 * f1_score +
                   0.3 * accuracy +
                   0.2 * min(1.0, inference_speed / 100) +
                   0.1 * (1.0 - min(1.0, model_size / 100)))

    logger.info(f"{model_name} final score: {final_score:.4f}")

    if final_score > best_final_score:
        best_final_score = final_score
        best_final_model = model_name

# Save comprehensive results
final_comparison = {
    'models': {},
    'best_model': best_final_model,
    'training_time': {model: result['training_time'] for model, result in full_training_results.items()}
}

for model_name in top_models:
    test_metrics = final_test_results[model_name]['metrics']
    model_size = models_info[model_name]['size_mb']
    param_count = models_info[model_name]['param_count']

    final_comparison['models'][model_name] = {
        'accuracy': test_metrics['accuracy'],
        'f1_score': test_metrics['f1_score'],
        'precision': test_metrics['precision'],
        'recall': test_metrics['recall'],
        'roc_auc': test_metrics.get('roc_auc', 0),
        'size_mb': model_size,
        'param_count': param_count,
        'inference_time_ms': test_metrics['inference_time_per_image'] * 1000,
        'images_per_second': test_metrics['images_per_second']
    }

# Save final comparison
comparison_path = os.path.join(str(OUTPUT_DIR), 'results', 'final_model_comparison.json')
with open(comparison_path, 'w') as f:
    json.dump(final_comparison, f, indent=4)

# Final summary
logger.info("\n=== FINAL RESULTS ===")
logger.info(f"Best model: {best_final_model}")
logger.info(f"  - Size: {models_info[best_final_model]['size_mb']:.2f} MB")
logger.info(f"  - Parameters: {models_info[best_final_model]['param_count']:,}")
logger.info(f"  - Accuracy: {final_test_results[best_final_model]['metrics']['accuracy']:.4f}")
logger.info(f"  - F1 Score: {final_test_results[best_final_model]['metrics']['f1_score']:.4f}")
logger.info(f"  - Inference: {final_test_results[best_final_model]['metrics']['images_per_second']:.2f} img/sec")

# Create summary report file
with open(os.path.join(str(OUTPUT_DIR), 'best_model_summary.txt'), 'w') as f:
    f.write("=== COFFEE BEAN DEFECT DETECTION - BEST MODEL SUMMARY ===\n\n")
    f.write(f"Best Model: {best_final_model}\n")
    f.write(f"Model Size: {models_info[best_final_model]['size_mb']:.2f} MB\n")
    f.write(f"Parameters: {models_info[best_final_model]['param_count']:,}\n")
    f.write(f"Accuracy: {final_test_results[best_final_model]['metrics']['accuracy']:.4f}\n")
    f.write(f"F1 Score: {final_test_results[best_final_model]['metrics']['f1_score']:.4f}\n")
    f.write(f"Precision: {final_test_results[best_final_model]['metrics']['precision']:.4f}\n")
    f.write(f"Recall: {final_test_results[best_final_model]['metrics']['recall']:.4f}\n")
    f.write(f"Inference Time: {final_test_results[best_final_model]['metrics']['inference_time_per_image']*1000:.2f} ms/image\n")
    f.write(f"Processing Speed: {final_test_results[best_final_model]['metrics']['images_per_second']:.2f} images/second\n\n")
    f.write("=== Model Comparison ===\n")

    for model_name in top_models:
        metrics = final_test_results[model_name]['metrics']
        f.write(f"{model_name}:\n")
        f.write(f"  - Size: {models_info[model_name]['size_mb']:.2f} MB\n")
        f.write(f"  - F1 Score: {metrics['f1_score']:.4f}\n")
        f.write(f"  - Speed: {metrics['images_per_second']:.2f} img/sec\n\n")

logger.info("Training pipeline completed successfully!")

# Final cleanup
if torch.cuda.is_available():
    torch.cuda.empty_cache()
gc.collect()
logger.info("Cleanup done.")