In [1]:
import os
import json
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.utils.data import Dataset, DataLoader
import timm
import numpy as np
from torchvision import transforms
from torch.optim.lr_scheduler import CosineAnnealingLR
from PIL import Image
import random
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    roc_auc_score
)
import time
from pathlib import Path
import gc
from collections import Counter
import matplotlib.pyplot as plt
import seaborn as sns
from torch.amp import autocast, GradScaler

# Advanced augmentation imports
from torchvision.transforms import RandAugment

try:
    import psutil
    PSUTIL_AVAILABLE = True
except ImportError:
    PSUTIL_AVAILABLE = False

# Global Config Variables
BASE_PATH = Path(r"D:\iate_project\data")
raw_dataset_path = BASE_PATH / "raw"
output_path = BASE_PATH / "results"

# Create output directories if necessary
os.makedirs(output_path, exist_ok=True)
os.makedirs(output_path / 'models', exist_ok=True)
os.makedirs(output_path / 'results', exist_ok=True)
os.makedirs(output_path / 'visualizations', exist_ok=True)

# Training hyperparameters
BATCH_SIZE = 32
NUM_WORKERS = 0  # Windows compatibility
NUM_EPOCHS = 10
NUM_FOLDS = 5
GRAD_CLIP_VALUE = 1.0

USE_AMP = True  # Automatic Mixed Precision for faster training
PIN_MEMORY = True
DROP_RATE = 0.2  # Dropout rate for regularization

EARLY_STOPPING_PATIENCE = 3  # Stop training if no improvement after this many epochs
EARLY_STOPPING_MIN_DELTA = 0.0005  # Minimum change to qualify as improvement

# Enable advanced augmentation techniques - these provide more robust training
USE_ADVANCED_AUGMENTATION = True
USE_MIXUP = True  # MixUp augmentation technique
MIXUP_ALPHA = 0.2  # Alpha parameter for MixUp

"""
Model selection rationale:
These lightweight models are specifically chosen for practical deployment in coffee industry settings:
1. Edge device compatibility: Models need to run on low-power hardware in production environments
2. Real-time inference: Coffee bean sorting requires fast processing (>10 beans/second)
3. Memory efficiency: Limited RAM available on edge devices used in coffee processing facilities
4. Energy consumption: Processing facilities may have power constraints
5. Thermal considerations: Coffee processing environment may already be hot

Effectiveness in image recognition tasks was balanced against these practical constraints.
"""
MODELS = {
    # Very lightweight models for edge deployment
    'mobilenetv3_small_100': {
        'pretrained': True,
        'lr': 5e-5,
    },
    'efficientnet_lite0': {
        'pretrained': True,
        'lr': 5e-5,
    },
    'mobilevit_xxs': {
        'pretrained': True,
        'lr': 5e-5,
    },
    'regnetx_002': {
        'pretrained': True,
        'lr': 5e-5,
    },
    # Slightly larger but still efficient models
    'efficientnet_b0': {
        'pretrained': True,
        'lr': 5e-5,
    },
    'mobilevit_xs': {
        'pretrained': True,
        'lr': 5e-5,
    }
}

# Device Setup
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
    print(f"CUDA is available! Using {torch.cuda.get_device_name(0)}")
    print(f"CUDA Version: {torch.version.cuda}")
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
else:
    DEVICE = torch.device("cpu")
    print("CUDA is not available. Using CPU.")

# Set random seed for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)
cudnn.deterministic = True
cudnn.benchmark = True

"""
MixUp Augmentation Implementation
This technique creates new training samples by linearly combining pairs of images and their labels,
which has been shown to improve model generalization and robustness.
"""
def mixup_data(x, y, alpha=0.2):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(x.device)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    '''Calculates the mixed loss'''
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

# Dataset class with advanced augmentation strategies
class CoffeeDataset(Dataset):
    """
    Dataset class for loading coffee bean images with enhanced augmentation strategies.

    Coffee bean images are collected with varying conditions including:
    - Processing methods (dry, honey, wet)
    - Roast levels (dark, light, medium)
    - Bean condition (normal, defect)

    Advanced augmentation helps the model generalize across these variations.
    """
    def __init__(self, root_dir, split='train', test_ratio=0.15, val_ratio=0.15):
        self.root_dir = root_dir
        self.split = split
        self.samples = []
        self.max_retries = 3

        # Load all samples from the specified directory structure
        normal_path = os.path.join(self.root_dir, "Normal")
        defect_path = os.path.join(self.root_dir, "Defect")

        # Handle case sensitivity
        if not os.path.exists(normal_path):
            normal_path = os.path.join(self.root_dir, "normal")
        if not os.path.exists(defect_path):
            defect_path = os.path.join(self.root_dir, "defect")

        print(f"Loading samples from Normal: {normal_path}")
        print(f"Loading samples from Defect: {defect_path}")

        # Process methods (Dry, Honey, Wet)
        all_filepaths = []
        all_labels = []

        # Process Normal samples
        for method in ['Dry', 'Honey', 'Wet', 'dry', 'honey', 'wet']:
            method_path = os.path.join(normal_path, method)
            if not os.path.exists(method_path):
                continue

            for roast in ['Dark', 'Light', 'Medium', 'dark', 'light', 'medium']:
                roast_path = os.path.join(method_path, roast)
                if not os.path.exists(roast_path):
                    continue

                for img_name in os.listdir(roast_path):
                    if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                        all_filepaths.append(os.path.join(roast_path, img_name))
                        all_labels.append(0)  # Normal

        # Process Defect samples
        for method in ['Dry', 'Honey', 'Wet', 'dry', 'honey', 'wet']:
            method_path = os.path.join(defect_path, method)
            if not os.path.exists(method_path):
                continue

            for roast in ['Dark', 'Light', 'Medium', 'dark', 'light', 'medium']:
                roast_path = os.path.join(method_path, roast)
                if not os.path.exists(roast_path):
                    continue

                for img_name in os.listdir(roast_path):
                    if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                        all_filepaths.append(os.path.join(roast_path, img_name))
                        all_labels.append(1)  # Defect

        # Split the dataset into train, validation, and test sets
        unique_paths = list(set(all_filepaths))  # Ensure no duplicates
        random.shuffle(unique_paths)

        test_size = int(len(unique_paths) * test_ratio)
        val_size = int(len(unique_paths) * val_ratio)

        test_paths = set(unique_paths[:test_size])
        val_paths = set(unique_paths[test_size:test_size + val_size])
        train_paths = set(unique_paths[test_size + val_size:])

        # Assign samples based on split
        if split == 'train':
            target_paths = train_paths
        elif split == 'val':
            target_paths = val_paths
        else:  # 'test'
            target_paths = test_paths

        # Filter samples for the current split
        for filepath, label in zip(all_filepaths, all_labels):
            if filepath in target_paths:
                self.samples.append({
                    'filepath': filepath,
                    'label': label
                })

        print(f"Loaded {len(self.samples)} samples for split: {self.split}")

        # Analyze class distribution
        labels = [sample['label'] for sample in self.samples]
        class_counts = Counter(labels)
        total = len(labels)

        print(f"\n{self.split} set class distribution:")
        for label, count in class_counts.items():
            percentage = (count / total) * 100
            print(f"Class {label}: {count} samples ({percentage:.2f}%)")

        # Configure transformations based on dataset split
        if self.split == 'train':
            if USE_ADVANCED_AUGMENTATION:
                # Advanced augmentation pipeline for more robust training
                self.transform = transforms.Compose([
                    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.RandomVerticalFlip(p=0.3),
                    transforms.RandomRotation(
                        degrees=30,
                        fill=1.0,
                        interpolation=transforms.InterpolationMode.BILINEAR
                    ),
                    # Add RandAugment for automated augmentation policy
                    RandAugment(num_ops=2, magnitude=9),
                    # Color jitter specifically tuned for coffee beans
                    transforms.ColorJitter(
                        brightness=0.3, contrast=0.3,
                        saturation=0.3, hue=0.1
                    ),
                    # Add random perspective for 3D-like variations
                    transforms.RandomPerspective(distortion_scale=0.2, p=0.5),
                    # Add occasional grayscale conversion to focus on texture
                    transforms.RandomGrayscale(p=0.1),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                ])
            else:
                # Standard augmentation pipeline
                self.transform = transforms.Compose([
                    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.RandomVerticalFlip(p=0.3),
                    transforms.RandomRotation(
                        degrees=30,
                        fill=1.0,
                        interpolation=transforms.InterpolationMode.BILINEAR
                    ),
                    transforms.ColorJitter(
                        brightness=0.2, contrast=0.2,
                        saturation=0.2, hue=0.1
                    ),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                ])
        else:
            # Validation and test transforms - keep simple for efficiency
            self.transform = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        image_path = sample['filepath']
        label = sample['label']

        for attempt in range(self.max_retries):
            try:
                image = Image.open(image_path).convert('RGB')
                return self.transform(image), label
            except Exception as e:
                print(f"Failed to load/transform image {image_path} on attempt {attempt+1}: {str(e)}")
                idx = (idx + 1) % len(self)

        print(f"Failed to load image {image_path} after {self.max_retries} attempts. Returning default sample.")
        return torch.zeros(3, 224, 224), -1

# Create Datasets
print("Creating train/val/test datasets...")

# Create separate instances for train, val, test
train_dataset = CoffeeDataset(str(raw_dataset_path), split='train')
val_dataset = CoffeeDataset(str(raw_dataset_path), split='val')
test_dataset = CoffeeDataset(str(raw_dataset_path), split='test')

# Create Data Loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    drop_last=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY
)

print("Dataset sizes:")
print(f"Train: {len(train_dataset)} samples")
print(f"Validation: {len(val_dataset)} samples")
print(f"Test: {len(test_dataset)} samples")

# Initialize model tracking variables
print("Initializing models for evaluation...")
models_info = {}
all_training_results = {}
all_test_results = {}

"""
K-fold Cross-Validation Strategy:

This implementation uses K-fold CV with a specific purpose different from typical hyperparameter tuning:
1. For each model architecture, we train multiple versions with different data splits
2. We select the best performing model checkpoint based on validation performance
3. This approach helps overcome potential bias from any single train/val split
4. The final selected model is evaluated on a completely separate test set

Why this approach versus traditional CV for hyperparameter tuning:
- In low-data regimes and with class imbalance, performance can vary significantly based on the specific data split
- This helps us identify the most robust model for real-world deployment
- The test set remains untouched during this process, maintaining proper evaluation protocol
"""

# Iterate through all models to train and evaluate
for model_name, model_cfg in MODELS.items():
    print(f"\n{'='*50}")
    print(f"Starting process for model: {model_name}")

    try:
        # Create the model
        model = timm.create_model(
            model_name,
            pretrained=model_cfg['pretrained'],
            num_classes=2,
            drop_rate=DROP_RATE
        ).to(DEVICE)

        # Calculate model parameters and size
        param_count = sum(p.numel() for p in model.parameters())

        # Calculate model size in MB
        param_size = 0
        for param in model.parameters():
            param_size += param.nelement() * param.element_size()
        buffer_size = 0
        for buffer in model.buffers():
            buffer_size += buffer.nelement() * buffer.element_size()
        size_mb = (param_size + buffer_size) / 1024**2

        # Store model information
        models_info[model_name] = {
            'model': model,
            'lr': model_cfg['lr'],
            'size_mb': size_mb,
            'param_count': param_count
        }

        print(f"Initialized model: {model_name}")
        print(f"  - Parameters: {param_count:,}")
        print(f"  - Size: {size_mb:.2f} MB")

        # Prepare for K-Fold cross-validation
        print(f"\nStarting K-Fold cross-validation for model: {model_name}")
        start_time = time.time()
        current_fold_metrics = []
        best_global_state = None
        best_global_loss = float('inf')
        best_global_fold = None

        # Get labels for stratified sampling
        labels_train_dataset = [sample['label'] for sample in train_dataset.samples]
        kf = StratifiedKFold(n_splits=NUM_FOLDS, shuffle=True, random_state=42)

        # Loop through each fold
        for fold, (train_indices, val_indices) in enumerate(kf.split(train_dataset.samples, labels_train_dataset), 1):
            print(f"\nStarting fold {fold}/{NUM_FOLDS}")

            # Create subset for train/val
            train_subset = torch.utils.data.Subset(train_dataset, train_indices)
            val_subset = torch.utils.data.Subset(train_dataset, val_indices)

            # Re-init model for each fold
            fold_model = timm.create_model(
                model_name,
                pretrained=True,
                num_classes=2,
                drop_rate=DROP_RATE
            ).to(DEVICE)

            # Set up optimizer
            optimizer = torch.optim.AdamW(
                fold_model.parameters(),
                lr=model_cfg['lr'],
                weight_decay=0.01,
                betas=(0.9, 0.999)
            )

            # Set up learning rate scheduler
            scheduler = CosineAnnealingLR(
                optimizer,
                T_max=NUM_EPOCHS,
                eta_min=1e-6
            )

            criterion = nn.CrossEntropyLoss()
            scaler = GradScaler(enabled=USE_AMP)

            # Create DataLoader for subset
            fold_train_loader = DataLoader(
                train_subset,
                batch_size=BATCH_SIZE,
                shuffle=True,
                num_workers=NUM_WORKERS,
                pin_memory=PIN_MEMORY,
                drop_last=True
            )

            fold_val_loader = DataLoader(
                val_subset,
                batch_size=BATCH_SIZE,
                shuffle=False,
                num_workers=NUM_WORKERS,
                pin_memory=PIN_MEMORY
            )

            # Prepare training history
            training_history = {
                'train_loss': [], 'val_loss': [],
                'train_acc': [], 'val_acc': [],
                'train_metrics': [], 'val_metrics': [],
                'lr_history': [], 'epoch_times': []
            }

            best_val_loss = float('inf')
            best_model_state = None

            # Early stopping state
            es_best_loss = None
            es_counter = 0
            es_early_stop = False

            # Start epoch loop
            for epoch in range(NUM_EPOCHS):
                epoch_start_time = time.time()

                # -------------------- TRAIN EPOCH --------------------
                fold_model.train()
                running_loss = 0.0
                predictions = []
                targets_list = []
                probabilities_list = []

                for images, labels in fold_train_loader:
                    images = images.to(DEVICE, non_blocking=True)
                    labels = labels.to(DEVICE, non_blocking=True)
                    optimizer.zero_grad(set_to_none=True)

                    # Apply MixUp augmentation if enabled
                    if USE_MIXUP and np.random.random() < 0.5:
                        mixed_images, labels_a, labels_b, lam = mixup_data(images, labels, MIXUP_ALPHA)

                        with autocast(device_type='cuda', enabled=USE_AMP):
                            outputs = fold_model(mixed_images)
                            loss = mixup_criterion(criterion, outputs, labels_a, labels_b, lam)
                            probs = torch.softmax(outputs, dim=1)[:, 1]
                    else:
                        with autocast(device_type='cuda', enabled=USE_AMP):
                            outputs = fold_model(images)
                            loss = criterion(outputs, labels)
                            probs = torch.softmax(outputs, dim=1)[:, 1]

                    if USE_AMP:
                        scaler.scale(loss).backward()
                        scaler.unscale_(optimizer)
                        torch.nn.utils.clip_grad_norm_(fold_model.parameters(), GRAD_CLIP_VALUE)
                        scaler.step(optimizer)
                        scaler.update()
                    else:
                        loss.backward()
                        torch.nn.utils.clip_grad_norm_(fold_model.parameters(), GRAD_CLIP_VALUE)
                        optimizer.step()

                    running_loss += loss.item() * images.size(0)
                    _, preds = torch.max(outputs, 1)
                    predictions.extend(preds.detach().cpu().numpy())
                    probabilities_list.extend(probs.detach().cpu().numpy())
                    targets_list.extend(labels.cpu().numpy())

                epoch_train_loss = running_loss / len(fold_train_loader.dataset)

                # Calculate training metrics
                train_metrics = {}
                train_metrics['loss'] = epoch_train_loss

                # Calculate accuracy
                train_metrics['accuracy'] = accuracy_score(targets_list, predictions)
                train_metrics['precision'] = precision_score(targets_list, predictions, average='binary', zero_division=0)
                train_metrics['recall'] = recall_score(targets_list, predictions, average='binary', zero_division=0)
                train_metrics['f1_score'] = f1_score(targets_list, predictions, average='binary', zero_division=0)

                # Try to calculate ROC-AUC if possible
                try:
                    train_metrics['roc_auc'] = roc_auc_score(targets_list, probabilities_list)
                except ValueError as e:
                    print(f"ROC-AUC calculation failed: {str(e)}")
                    train_metrics['roc_auc'] = 0.0

                # -------------------- VALIDATE EPOCH --------------------
                fold_model.eval()
                val_running_loss = 0.0
                val_predictions = []
                val_targets_list = []
                val_probabilities_list = []

                with torch.no_grad():
                    for images_val, labels_val in fold_val_loader:
                        images_val = images_val.to(DEVICE, non_blocking=True)
                        labels_val = labels_val.to(DEVICE, non_blocking=True)
                        outputs_val = fold_model(images_val)
                        loss_val = criterion(outputs_val, labels_val)
                        probs_val = torch.softmax(outputs_val, dim=1)[:, 1]

                        val_running_loss += loss_val.item() * images_val.size(0)
                        _, preds_val = torch.max(outputs_val, 1)
                        val_predictions.extend(preds_val.detach().cpu().numpy())
                        val_probabilities_list.extend(probs_val.detach().cpu().numpy())
                        val_targets_list.extend(labels_val.cpu().numpy())

                val_loss = val_running_loss / len(fold_val_loader.dataset)

                # Calculate validation metrics
                val_metrics = {}
                val_metrics['loss'] = val_loss
                val_metrics['accuracy'] = accuracy_score(val_targets_list, val_predictions)
                val_metrics['precision'] = precision_score(val_targets_list, val_predictions, average='binary', zero_division=0)
                val_metrics['recall'] = recall_score(val_targets_list, val_predictions, average='binary', zero_division=0)
                val_metrics['f1_score'] = f1_score(val_targets_list, val_predictions, average='binary', zero_division=0)

                # Try to calculate ROC-AUC if possible
                try:
                    val_metrics['roc_auc'] = roc_auc_score(val_targets_list, val_probabilities_list)
                except ValueError as e:
                    print(f"ROC-AUC calculation failed: {str(e)}")
                    val_metrics['roc_auc'] = 0.0

                # Calculate confusion matrix
                cm = confusion_matrix(val_targets_list, val_predictions)
                if cm.shape == (2, 2):
                    tn, fp, fn, tp = cm.ravel()
                    val_metrics['specificity'] = tn / (tn + fp) if (tn + fp) > 0 else 0
                    val_metrics['confusion_matrix'] = cm.tolist()

                # Logging training/val info
                current_lr = optimizer.param_groups[0]['lr']
                print(f"Epoch {epoch+1}/{NUM_EPOCHS} | "
                      f"Train Loss: {train_metrics['loss']:.4f} | "
                      f"Val Loss: {val_metrics['loss']:.4f} | "
                      f"Val F1: {val_metrics['f1_score']:.4f} | "
                      f"Val Acc: {val_metrics['accuracy']:.4f}")

                # Log GPU stats
                if torch.cuda.is_available():
                    memory_allocated = torch.cuda.memory_allocated() / 1024**2
                    memory_reserved = torch.cuda.memory_reserved() / 1024**2
                    print(f"GPU Memory: Allocated {memory_allocated:.1f}MB, Reserved {memory_reserved:.1f}MB")

                # Step the scheduler
                scheduler.step()

                # Update training history
                training_history['train_loss'].append(train_metrics['loss'])
                training_history['val_loss'].append(val_metrics['loss'])
                training_history['train_acc'].append(train_metrics['accuracy'])
                training_history['val_acc'].append(val_metrics['accuracy'])
                training_history['train_metrics'].append(train_metrics)
                training_history['val_metrics'].append(val_metrics)
                training_history['lr_history'].append(current_lr)
                training_history['epoch_times'].append(time.time() - epoch_start_time)

                # Save best model for this fold
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    best_model_state = fold_model.state_dict().copy()

                # Check for early stopping
                if es_best_loss is None:
                    es_best_loss = val_loss
                else:
                    if val_loss > es_best_loss - EARLY_STOPPING_MIN_DELTA:
                        es_counter += 1
                        if es_counter >= EARLY_STOPPING_PATIENCE:
                            print("Early stopping triggered.")
                            break
                    else:
                        es_best_loss = val_loss
                        es_counter = 0

                # Optional performance threshold for early termination
                if val_metrics['f1_score'] > 0.95:
                    print(f"Achieved target F1 score: {val_metrics['f1_score']:.4f}. Stopping early.")
                    break

            # After finishing epochs for this fold
            if best_model_state is not None:
                fold_model.load_state_dict(best_model_state)

            # Summarize fold results
            final_train_metrics = training_history['train_metrics'][-1]
            final_val_metrics = training_history['val_metrics'][-1]

            print(f"\nFold {fold} Summary - {model_name}:")
            print(f"Validation Metrics:")
            print(f"- Accuracy: {final_val_metrics['accuracy']:.4f}")
            print(f"- F1 Score: {final_val_metrics['f1_score']:.4f}")
            print(f"- Precision: {final_val_metrics['precision']:.4f}")
            print(f"- Recall: {final_val_metrics['recall']:.4f}")

            current_fold_metrics.append(training_history)

            # Check if this fold is better globally
            if best_val_loss < best_global_loss:
                best_global_loss = best_val_loss
                best_global_state = best_model_state
                best_global_fold = fold

            # Plot training curves for this fold
            plt.figure(figsize=(15, 10))
            plt.subplot(2, 2, 1)
            plt.plot(training_history['train_loss'], label='Training Loss')
            plt.plot(training_history['val_loss'], label='Validation Loss')
            plt.title('Loss Curves')
            plt.xlabel('Epoch')
            plt.ylabel('Loss')
            plt.legend()

            plt.subplot(2, 2, 2)
            plt.plot(training_history['train_acc'], label='Training Accuracy')
            plt.plot(training_history['val_acc'], label='Validation Accuracy')
            plt.title('Accuracy Curves')
            plt.xlabel('Epoch')
            plt.ylabel('Accuracy')
            plt.legend()

            plt.subplot(2, 2, 3)
            plt.plot(training_history['lr_history'], label='Learning Rate')
            plt.title('Learning Rate Schedule')
            plt.xlabel('Epoch')
            plt.ylabel('Learning Rate')
            plt.yscale('log')

            plt.subplot(2, 2, 4)
            plt.plot(training_history['epoch_times'], label='Training Time')
            plt.title('Training Time per Epoch')
            plt.xlabel('Epoch')
            plt.ylabel('Time (seconds)')

            plt.tight_layout()
            plt.savefig(os.path.join(str(output_path), 'visualizations', f"{model_name}_fold{fold}_training_curves.png"))
            plt.close()

            # Clean up
            del train_subset, val_subset, fold_train_loader, fold_val_loader, optimizer, criterion, fold_model
            torch.cuda.empty_cache()
            gc.collect()

        # After all folds, recreate the model and load best weights
        best_model = timm.create_model(
            model_name,
            pretrained=False,
            num_classes=2,
            drop_rate=DROP_RATE
        ).to(DEVICE)

        if best_global_state is not None:
            best_model.load_state_dict(best_global_state)
            print(f"Loaded best fold state (fold {best_global_fold}) for {model_name}.")

        training_time = time.time() - start_time

        training_result = {
            'model': best_model,
            'all_fold_metrics': current_fold_metrics,
            'training_time': training_time,
            'best_fold': best_global_fold
        }

        all_training_results[model_name] = training_result

        # Now evaluate this model on the test set
        print(f"Evaluating {model_name} on test set...")
        criterion_test = nn.CrossEntropyLoss()

        # Evaluation block
        best_model.eval()
        test_running_loss = 0.0
        test_predictions = []
        test_targets = []
        test_probabilities = []
        test_inference_times = []
        test_batch_sizes = []

        with torch.no_grad():
            for images, labels in test_loader:
                test_batch_sizes.append(images.shape[0])
                images = images.to(DEVICE, non_blocking=True)
                labels = labels.to(DEVICE, non_blocking=True)

                # Warm-up run to eliminate first-batch overhead
                if len(test_inference_times) == 0:
                    _ = best_model(images)
                    if torch.cuda.is_available():
                        torch.cuda.synchronize()

                # Measure inference time
                if torch.cuda.is_available():
                    torch.cuda.synchronize()
                start_time = time.time()
                outputs = best_model(images)
                if torch.cuda.is_available():
                    torch.cuda.synchronize()
                inference_time = time.time() - start_time
                test_inference_times.append(inference_time)

                loss = criterion_test(outputs, labels)
                probs = torch.softmax(outputs, dim=1)[:, 1]

                test_running_loss += loss.item() * images.size(0)
                _, preds = torch.max(outputs, 1)

                test_predictions.extend(preds.cpu().numpy())
                test_targets.extend(labels.cpu().numpy())
                test_probabilities.extend(probs.cpu().numpy())

        # Compute test metrics
        test_loss = test_running_loss / len(test_loader.dataset)

        # Calculate metrics
        test_metrics = {}
        test_metrics['loss'] = test_loss
        test_metrics['accuracy'] = accuracy_score(test_targets, test_predictions)
        test_metrics['precision'] = precision_score(test_targets, test_predictions, average='binary', zero_division=0)
        test_metrics['recall'] = recall_score(test_targets, test_predictions, average='binary', zero_division=0)
        test_metrics['f1_score'] = f1_score(test_targets, test_predictions, average='binary', zero_division=0)

        # Try to calculate ROC-AUC if possible
        try:
            test_metrics['roc_auc'] = roc_auc_score(test_targets, test_probabilities)
        except ValueError as e:
            print(f"ROC-AUC calculation failed: {str(e)}")
            test_metrics['roc_auc'] = 0.0

        # Calculate confusion matrix
        cm = confusion_matrix(test_targets, test_predictions)
        if cm.shape == (2, 2):
            tn, fp, fn, tp = cm.ravel()
            test_metrics['specificity'] = tn / (tn + fp) if (tn + fp) > 0 else 0
            test_metrics['confusion_matrix'] = cm.tolist()

        # Print classification report
        report_text = classification_report(
            test_targets,
            test_predictions,
            target_names=['Normal', 'Defect'],
            digits=4
        )
        print(f"\nDetailed Classification Report:\n{report_text}")

        # Inference time calculations
        total_samples = sum(test_batch_sizes)
        total_inference_time = sum(test_inference_times)

        # Skip the first batch timing (warm-up)
        if len(test_inference_times) > 1:
            avg_inference_time = sum(test_inference_times[1:]) / (len(test_inference_times) - 1)
            inference_time_per_image = sum(test_inference_times[1:]) / sum(test_batch_sizes[1:])
        else:
            avg_inference_time = total_inference_time
            inference_time_per_image = total_inference_time / total_samples

        test_metrics['avg_inference_time'] = avg_inference_time  # Average time per batch
        test_metrics['inference_time_per_image'] = inference_time_per_image  # Average time per image
        test_metrics['total_inference_time'] = total_inference_time
        test_metrics['images_per_second'] = 1.0 / inference_time_per_image if inference_time_per_image > 0 else 0

        # Save test results
        test_results_dir = os.path.join(str(output_path), 'results')
        os.makedirs(test_results_dir, exist_ok=True)

        # Store test results
        all_test_results[model_name] = {
            'metrics': test_metrics,
            'targets': test_targets,
            'predictions': test_predictions,
            'probabilities': test_probabilities
        }

        # Save results to file
        test_results_path = os.path.join(test_results_dir, f"{model_name}_test_results.json")
        test_data_dict = {
            "targets": [int(t) for t in test_targets],
            "predictions": [int(p) for p in test_predictions],
            "probabilities": [float(prob) for prob in test_probabilities],
            "metrics": {k: float(v) if isinstance(v, (int, float)) else v
                      for k, v in test_metrics.items() if k != 'confusion_matrix'}
        }

        with open(test_results_path, 'w') as f_out:
            json.dump(test_data_dict, f_out, indent=4)

        # Confusion matrix plot
        if 'confusion_matrix' in test_metrics:
            cm_array = np.array(test_metrics['confusion_matrix'])
            plt.figure(figsize=(8, 6))
            sns.heatmap(cm_array, annot=True, fmt='d', cmap='Blues',
                      xticklabels=['Normal', 'Defect'],
                      yticklabels=['Normal', 'Defect'])
            plt.title(f'Confusion Matrix - {model_name}')
            plt.ylabel('True Label')
            plt.xlabel('Predicted Label')
            plt.savefig(os.path.join(str(output_path), 'visualizations', f"{model_name}_test_confusion_matrix.png"))
            plt.close()

        # Save the final model
        model_save_path = os.path.join(str(output_path), 'models', f"{model_name}_final_model.pth")
        torch.save(best_model.state_dict(), model_save_path)
        print(f"Saved final model to {model_save_path}")

        # Print detailed metrics
        print(f"Test Results for {model_name}:")
        print(f"  Accuracy: {test_metrics['accuracy']:.4f}")
        print(f"  F1 Score: {test_metrics['f1_score']:.4f}")
        print(f"  Precision: {test_metrics['precision']:.4f}")
        print(f"  Recall: {test_metrics['recall']:.4f}")
        print(f"  ROC-AUC: {test_metrics.get('roc_auc', 'N/A')}")
        print(f"  Inference Speed: {test_metrics['images_per_second']:.2f} images/second")

        # Clean up
        del best_model, criterion_test
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()

    except Exception as e:
        print(f"Error processing model {model_name}: {str(e)}")
        continue

# Select best model based on test results
print("\nSelecting best final model...")

best_final_model = None
best_final_score = -float('inf')

for model_name, result in all_test_results.items():
    metrics = result['metrics']

    # Calculate balanced score for final selection
    # Balance accuracy, F1 score, and efficiency
    accuracy = metrics['accuracy']
    f1_score_val = metrics['f1_score']
    inference_speed = metrics['images_per_second']
    model_size = models_info[model_name]['size_mb']

    # Score formula: 40% F1, 30% accuracy, 20% speed, 10% size efficiency
    # This balances model performance with practical deployment considerations
    final_score = (0.4 * f1_score_val +                         # Performance: F1 score
                  0.3 * accuracy +                              # Performance: Accuracy
                  0.2 * min(1.0, inference_speed / 100) +      # Efficiency: Speed
                  0.1 * (1.0 - min(1.0, model_size / 100)))    # Efficiency: Size

    print(f"{model_name} final score: {final_score:.4f}")

    if final_score > best_final_score:
        best_final_score = final_score
        best_final_model = model_name

# Save comprehensive results
final_comparison = {
    'models': {},
    'best_model': best_final_model,
    'training_time': {model: result['training_time'] for model, result in all_training_results.items()}
}

for model_name in models_info:
    if model_name in all_test_results:
        test_metrics = all_test_results[model_name]['metrics']
        model_size = models_info[model_name]['size_mb']
        param_count = models_info[model_name]['param_count']

        final_comparison['models'][model_name] = {
            'accuracy': test_metrics['accuracy'],
            'f1_score': test_metrics['f1_score'],
            'precision': test_metrics['precision'],
            'recall': test_metrics['recall'],
            'roc_auc': test_metrics.get('roc_auc', 0),
            'size_mb': model_size,
            'param_count': param_count,
            'inference_time_ms': test_metrics['inference_time_per_image'] * 1000,
            'images_per_second': test_metrics['images_per_second']
        }

# Save final comparison
comparison_path = os.path.join(str(output_path), 'results', 'final_model_comparison.json')
with open(comparison_path, 'w') as f:
    json.dump(final_comparison, f, indent=4)

# Final summary
print("\n=== FINAL RESULTS ===")
if best_final_model:
    print(f"Best model: {best_final_model}")
    print(f"  - Size: {models_info[best_final_model]['size_mb']:.2f} MB")
    print(f"  - Parameters: {models_info[best_final_model]['param_count']:,}")
    print(f"  - Accuracy: {all_test_results[best_final_model]['metrics']['accuracy']:.4f}")
    print(f"  - F1 Score: {all_test_results[best_final_model]['metrics']['f1_score']:.4f}")
    print(f"  - Inference: {all_test_results[best_final_model]['metrics']['images_per_second']:.2f} img/sec")

    # Create summary report file
    with open(os.path.join(str(output_path), 'best_model_summary.txt'), 'w') as f:
        f.write("=== COFFEE BEAN DEFECT DETECTION - BEST MODEL SUMMARY ===\n\n")
        f.write(f"Best Model: {best_final_model}\n")
        f.write(f"Model Size: {models_info[best_final_model]['size_mb']:.2f} MB\n")
        f.write(f"Parameters: {models_info[best_final_model]['param_count']:,}\n")
        f.write(f"Accuracy: {all_test_results[best_final_model]['metrics']['accuracy']:.4f}\n")
        f.write(f"F1 Score: {all_test_results[best_final_model]['metrics']['f1_score']:.4f}\n")
        f.write(f"Precision: {all_test_results[best_final_model]['metrics']['precision']:.4f}\n")
        f.write(f"Recall: {all_test_results[best_final_model]['metrics']['recall']:.4f}\n")
        f.write(f"Inference Time: {all_test_results[best_final_model]['metrics']['inference_time_per_image']*1000:.2f} ms/image\n")
        f.write(f"Processing Speed: {all_test_results[best_final_model]['metrics']['images_per_second']:.2f} images/second\n\n")
        f.write("=== Model Comparison ===\n")

        for model_name in models_info:
            if model_name in all_test_results:
                metrics = all_test_results[model_name]['metrics']
                f.write(f"{model_name}:\n")
                f.write(f"  - Size: {models_info[model_name]['size_mb']:.2f} MB\n")
                f.write(f"  - F1 Score: {metrics['f1_score']:.4f}\n")
                f.write(f"  - Speed: {metrics['images_per_second']:.2f} img/sec\n\n")
else:
    print("No models were successfully trained and evaluated.")

print("Training pipeline completed successfully!")

# Final cleanup
if torch.cuda.is_available():
    torch.cuda.empty_cache()
gc.collect()
print("Cleanup done.")

CUDA is available! Using NVIDIA GeForce RTX 4060 Laptop GPU
CUDA Version: 12.6
Creating train/val/test datasets...
Loading samples from Normal: D:\iate_project\original_dataset\Normal
Loading samples from Defect: D:\iate_project\original_dataset\Defect
Loaded 15120 samples for split: train

train set class distribution:
Class 0: 10121 samples (66.94%)
Class 1: 4999 samples (33.06%)
Loading samples from Normal: D:\iate_project\original_dataset\Normal
Loading samples from Defect: D:\iate_project\original_dataset\Defect
Loaded 3240 samples for split: val

val set class distribution:
Class 0: 2126 samples (65.62%)
Class 1: 1114 samples (34.38%)
Loading samples from Normal: D:\iate_project\original_dataset\Normal
Loading samples from Defect: D:\iate_project\original_dataset\Defect
Loaded 3240 samples for split: test

test set class distribution:
Class 0: 2149 samples (66.33%)
Class 1: 1091 samples (33.67%)
Dataset sizes:
Train: 15120 samples
Validation: 3240 samples
Test: 3240 samples
Initi