# Deep Learning CNN Training for Coffee Bean Defect Detection
Trains multiple CNN architectures with transfer learning and advanced augmentation


In [1]:
import numpy as np
import json
import pickle
import time
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.models import (
    EfficientNet_B0_Weights,
    MobileNet_V3_Small_Weights,
    ResNet50_Weights,
)
from PIL import Image
from PIL import ImageFile
from tqdm.auto import tqdm
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, roc_auc_score, f1_score

ImageFile.LOAD_TRUNCATED_IMAGES = True

# Set random seeds for reproducibility
RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

In [2]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [3]:
# Set project paths
PROJECT_ROOT = Path('/home/tony/research_project/iate_project')
SPLITS_DIR = PROJECT_ROOT / 'data' / 'splits'
RESULTS_DIR = PROJECT_ROOT / 'results'
MODELS_DIR = RESULTS_DIR / 'models'
METRICS_DIR = RESULTS_DIR / 'metrics'

In [4]:
# Create directories
MODELS_DIR.mkdir(parents=True, exist_ok=True)
METRICS_DIR.mkdir(parents=True, exist_ok=True)

In [5]:
# Hyperparameters
BATCH_SIZE = 32
NUM_EPOCHS = 20
LEARNING_RATE = 1e-4
EARLY_STOPPING_PATIENCE = 3
NUM_WORKERS = 0
IMAGE_SIZE = 224

# 1. LOADING DATA SPLITS

In [6]:
with open(SPLITS_DIR / 'splits.pkl', 'rb') as f:
    splits = pickle.load(f)

train_paths = splits['train_paths']
train_labels = splits['train_labels']
val_paths = splits['val_paths']
val_labels = splits['val_labels']
test_paths = splits['test_paths']
test_labels = splits['test_labels']

print(f"Train: {len(train_paths)} images")
print(f"Validation: {len(val_paths)} images")
print(f"Test: {len(test_paths)} images")

Train: 3780 images
Validation: 810 images
Test: 810 images


# 2. DATASET & AUGMENTATION

In [7]:
class CoffeeBeanDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]

        # Load image
        image = Image.open(image_path).convert('RGB')

        # Apply transforms
        if self.transform:
            image = self.transform(image)

        return image, label

In [8]:
# Training augmentation - extensive augmentation
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.9, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.2),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Validation/Test augmentation - only essential preprocessing
val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 3. CREATING DATALOADERS

In [9]:
train_dataset = CoffeeBeanDataset(train_paths, train_labels, transform=train_transform)
val_dataset = CoffeeBeanDataset(val_paths, val_labels, transform=val_transform)
test_dataset = CoffeeBeanDataset(test_paths, test_labels, transform=val_transform)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=torch.cuda.is_available()
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=torch.cuda.is_available()
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=torch.cuda.is_available()
)

print(f"Batch size: {BATCH_SIZE}")
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

Batch size: 32
Train batches: 119
Val batches: 26
Test batches: 26


# 4. SETTING UP CNN MODELS

In [10]:
def create_model(model_name='efficientnet_b0', use_pretrained=True):
    """Create a model with the modern 'weights' API (no deprecation warnings)."""

    if model_name == 'efficientnet_b0':
        weights = EfficientNet_B0_Weights.DEFAULT if use_pretrained else None
        model = models.efficientnet_b0(weights=weights)
        num_features = model.classifier[1].in_features
        model.classifier = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(num_features, 256),
            nn.ReLU(),
            nn.Dropout(p=0.2),
            nn.Linear(256, 2)
        )

    elif model_name == 'mobilenet_v3':
        weights = MobileNet_V3_Small_Weights.DEFAULT if use_pretrained else None
        model = models.mobilenet_v3_small(weights=weights)
        num_features = model.classifier[0].in_features
        model.classifier = nn.Sequential(
            nn.Linear(num_features, 256),
            nn.Hardswish(),
            nn.Dropout(p=0.2),
            nn.Linear(256, 2)
        )

    elif model_name == 'resnet50':
        weights = ResNet50_Weights.DEFAULT if use_pretrained else None
        model = models.resnet50(weights=weights)
        num_features = model.fc.in_features
        model.fc = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(num_features, 256),
            nn.ReLU(),
            nn.Dropout(p=0.2),
            nn.Linear(256, 2)
        )
    else:
        raise ValueError(f"Model {model_name} not supported")

    return model

# Training function
def train_model(model, train_loader, val_loader, model_name, num_epochs=NUM_EPOCHS):
    """Train a model with early stopping"""

    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                      factor=0.5, patience=3)

    best_val_loss = float('inf')
    best_val_acc = 0
    patience_counter = 0
    train_history = {'loss': [], 'acc': []}
    val_history = {'loss': [], 'acc': []}

    print(f"\nTraining {model_name}...")
    print("-" * 40)

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        train_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]')
        for images, labels in train_bar:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()

            train_bar.set_postfix({'loss': loss.item(),
                                  'acc': train_correct/train_total})

        avg_train_loss = train_loss / len(train_loader)
        train_acc = train_correct / train_total
        train_history['loss'].append(avg_train_loss)
        train_history['acc'].append(train_acc)

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        all_val_labels = []
        all_val_preds = []
        all_val_probs = []

        with torch.no_grad():
            val_bar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Val]')
            for images, labels in val_bar:
                images, labels = images.to(device), labels.to(device)

                outputs = model(images)
                loss = criterion(outputs, labels)

                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

                # Store predictions for metrics
                probs = torch.softmax(outputs, dim=1)
                all_val_labels.extend(labels.cpu().numpy())
                all_val_preds.extend(predicted.cpu().numpy())
                all_val_probs.extend(probs[:, 1].cpu().numpy())

                val_bar.set_postfix({'loss': loss.item(),
                                    'acc': val_correct/val_total})

        avg_val_loss = val_loss / len(val_loader)
        val_acc = val_correct / val_total
        val_history['loss'].append(avg_val_loss)
        val_history['acc'].append(val_acc)

        # Calculate metrics
        precision, recall, f1, _ = precision_recall_fscore_support(
            all_val_labels, all_val_preds, average='binary', zero_division=0
        )

        print(f"Epoch {epoch+1}/{num_epochs}:")
        print(f"  Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.4f}")
        print(f"  Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.4f}")
        print(f"  Val Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")

        # Learning rate scheduling
        scheduler.step(avg_val_loss)

        # Early stopping check
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_val_acc = val_acc
            patience_counter = 0

            # Save best model
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': avg_val_loss,
                'val_acc': val_acc
            }, MODELS_DIR / f'cnn_{model_name}_best.pth')
        else:
            patience_counter += 1
            if patience_counter >= EARLY_STOPPING_PATIENCE:
                print(f"Early stopping triggered at epoch {epoch+1}")
                break

    return model, train_history, val_history

# 5. TRAINING CNN MODELS

In [11]:
models_to_train = ['efficientnet_b0', 'mobilenet_v3']  # Can add 'resnet50' if needed
results = {}

for model_name in models_to_train:
    print(f"\n{'='*60}")
    print(f"Model: {model_name.upper()}")
    print('='*60)

    # Create model
    model = create_model(model_name)

    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")

    # Train model
    start_time = time.time()
    model, train_history, val_history = train_model(model, train_loader, val_loader, model_name)
    train_time = time.time() - start_time

    # Load best model
    checkpoint = torch.load(MODELS_DIR / f'cnn_{model_name}_best.pth')
    model.load_state_dict(checkpoint['model_state_dict'])

    # Evaluate on validation set
    model.eval()
    val_labels_all = []
    val_preds_all = []
    val_probs_all = []

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            probs = torch.softmax(outputs, dim=1)
            _, predicted = torch.max(outputs.data, 1)

            val_labels_all.extend(labels.cpu().numpy())
            val_preds_all.extend(predicted.cpu().numpy())
            val_probs_all.extend(probs[:, 1].cpu().numpy())

    # Calculate metrics
    val_acc = accuracy_score(val_labels_all, val_preds_all)
    val_precision, val_recall, val_f1, _ = precision_recall_fscore_support(
        val_labels_all, val_preds_all, average='binary', zero_division=0
    )
    val_auc = roc_auc_score(val_labels_all, val_probs_all)
    val_cm = confusion_matrix(val_labels_all, val_preds_all, labels=[0, 1])

    # Store results
    results[model_name] = {
        'val_accuracy': val_acc,
        'val_precision': val_precision,
        'val_recall': val_recall,
        'val_f1_score': val_f1,
        'val_auc': val_auc,
        'confusion_matrix': val_cm.tolist(),
        'train_time': train_time,
        'best_epoch': checkpoint['epoch'],
        'train_history': train_history,
        'val_history': val_history,
        'total_params': total_params,
        'trainable_params': trainable_params
    }

    print(f"\nValidation Results:")
    print(f"  Accuracy: {val_acc:.4f}")
    print(f"  Precision: {val_precision:.4f}")
    print(f"  Recall: {val_recall:.4f}")
    print(f"  F1-Score: {val_f1:.4f}")
    print(f"  AUC: {val_auc:.4f}")
    print(f"  Training time: {train_time/60:.2f} minutes")

# Select best model
best_model_name = max(results, key=lambda x: results[x]['val_f1_score'])
print(f"\nBEST MODEL SELECTION")
print("-"*40)
print(f"Best model: {best_model_name}")
print(f"Validation F1-Score: {results[best_model_name]['val_f1_score']:.4f}")


Model: EFFICIENTNET_B0
Total parameters: 4,335,998
Trainable parameters: 4,335,998

Training efficientnet_b0...
----------------------------------------


Epoch 1/20 [Train]:   0%|          | 0/119 [00:00<?, ?it/s]

Epoch 1/20 [Val]:   0%|          | 0/26 [00:00<?, ?it/s]

Epoch 1/20:
  Train Loss: 0.4239, Train Acc: 0.8071
  Val Loss: 0.2593, Val Acc: 0.8926
  Val Precision: 0.9645, Recall: 0.7037, F1: 0.8137


Epoch 2/20 [Train]:   0%|          | 0/119 [00:00<?, ?it/s]

Epoch 2/20 [Val]:   0%|          | 0/26 [00:00<?, ?it/s]

Epoch 2/20:
  Train Loss: 0.2215, Train Acc: 0.9135
  Val Loss: 0.1913, Val Acc: 0.9247
  Val Precision: 0.9485, Recall: 0.8185, F1: 0.8787


Epoch 3/20 [Train]:   0%|          | 0/119 [00:00<?, ?it/s]

Epoch 3/20 [Val]:   0%|          | 0/26 [00:00<?, ?it/s]

Epoch 3/20:
  Train Loss: 0.1779, Train Acc: 0.9310
  Val Loss: 0.1767, Val Acc: 0.9333
  Val Precision: 0.9186, Recall: 0.8778, F1: 0.8977


Epoch 4/20 [Train]:   0%|          | 0/119 [00:00<?, ?it/s]

Epoch 4/20 [Val]:   0%|          | 0/26 [00:00<?, ?it/s]

Epoch 4/20:
  Train Loss: 0.1648, Train Acc: 0.9370
  Val Loss: 0.1879, Val Acc: 0.9309
  Val Precision: 0.9820, Recall: 0.8074, F1: 0.8862


Epoch 5/20 [Train]:   0%|          | 0/119 [00:00<?, ?it/s]

Epoch 5/20 [Val]:   0%|          | 0/26 [00:00<?, ?it/s]

Epoch 5/20:
  Train Loss: 0.1423, Train Acc: 0.9421
  Val Loss: 0.1876, Val Acc: 0.9346
  Val Precision: 0.9465, Recall: 0.8519, F1: 0.8967


Epoch 6/20 [Train]:   0%|          | 0/119 [00:00<?, ?it/s]

Epoch 6/20 [Val]:   0%|          | 0/26 [00:00<?, ?it/s]

Epoch 6/20:
  Train Loss: 0.1244, Train Acc: 0.9534
  Val Loss: 0.1879, Val Acc: 0.9395
  Val Precision: 0.9585, Recall: 0.8556, F1: 0.9041
Early stopping triggered at epoch 6

Validation Results:
  Accuracy: 0.9333
  Precision: 0.9186
  Recall: 0.8778
  F1-Score: 0.8977
  AUC: 0.9776
  Training time: 1.83 minutes

Model: MOBILENET_V3
Total parameters: 1,075,234
Trainable parameters: 1,075,234

Training mobilenet_v3...
----------------------------------------


Epoch 1/20 [Train]:   0%|          | 0/119 [00:00<?, ?it/s]

Epoch 1/20 [Val]:   0%|          | 0/26 [00:00<?, ?it/s]

Epoch 1/20:
  Train Loss: 0.4496, Train Acc: 0.7889
  Val Loss: 1.1963, Val Acc: 0.6926
  Val Precision: 1.0000, Recall: 0.0778, F1: 0.1443


Epoch 2/20 [Train]:   0%|          | 0/119 [00:00<?, ?it/s]

Epoch 2/20 [Val]:   0%|          | 0/26 [00:00<?, ?it/s]

Epoch 2/20:
  Train Loss: 0.2548, Train Acc: 0.8981
  Val Loss: 0.7474, Val Acc: 0.7494
  Val Precision: 1.0000, Recall: 0.2481, F1: 0.3976


Epoch 3/20 [Train]:   0%|          | 0/119 [00:00<?, ?it/s]

Epoch 3/20 [Val]:   0%|          | 0/26 [00:00<?, ?it/s]

Epoch 3/20:
  Train Loss: 0.2085, Train Acc: 0.9148
  Val Loss: 0.3340, Val Acc: 0.8765
  Val Precision: 0.9885, Recall: 0.6370, F1: 0.7748


Epoch 4/20 [Train]:   0%|          | 0/119 [00:00<?, ?it/s]

Epoch 4/20 [Val]:   0%|          | 0/26 [00:00<?, ?it/s]

Epoch 4/20:
  Train Loss: 0.1873, Train Acc: 0.9249
  Val Loss: 0.3474, Val Acc: 0.8617
  Val Precision: 0.9938, Recall: 0.5889, F1: 0.7395


Epoch 5/20 [Train]:   0%|          | 0/119 [00:00<?, ?it/s]

Epoch 5/20 [Val]:   0%|          | 0/26 [00:00<?, ?it/s]

Epoch 5/20:
  Train Loss: 0.1800, Train Acc: 0.9310
  Val Loss: 0.2055, Val Acc: 0.9173
  Val Precision: 0.9319, Recall: 0.8111, F1: 0.8673


Epoch 6/20 [Train]:   0%|          | 0/119 [00:00<?, ?it/s]

Epoch 6/20 [Val]:   0%|          | 0/26 [00:00<?, ?it/s]

Epoch 6/20:
  Train Loss: 0.1535, Train Acc: 0.9431
  Val Loss: 0.2293, Val Acc: 0.9049
  Val Precision: 0.9573, Recall: 0.7481, F1: 0.8399


Epoch 7/20 [Train]:   0%|          | 0/119 [00:00<?, ?it/s]

Epoch 7/20 [Val]:   0%|          | 0/26 [00:00<?, ?it/s]

Epoch 7/20:
  Train Loss: 0.1472, Train Acc: 0.9439
  Val Loss: 0.2859, Val Acc: 0.8951
  Val Precision: 0.9894, Recall: 0.6926, F1: 0.8148


Epoch 8/20 [Train]:   0%|          | 0/119 [00:00<?, ?it/s]

Epoch 8/20 [Val]:   0%|          | 0/26 [00:00<?, ?it/s]

Epoch 8/20:
  Train Loss: 0.1490, Train Acc: 0.9386
  Val Loss: 0.1956, Val Acc: 0.9346
  Val Precision: 0.9617, Recall: 0.8370, F1: 0.8950


Epoch 9/20 [Train]:   0%|          | 0/119 [00:00<?, ?it/s]

Epoch 9/20 [Val]:   0%|          | 0/26 [00:00<?, ?it/s]

Epoch 9/20:
  Train Loss: 0.1288, Train Acc: 0.9524
  Val Loss: 0.2145, Val Acc: 0.9210
  Val Precision: 0.9813, Recall: 0.7778, F1: 0.8678


Epoch 10/20 [Train]:   0%|          | 0/119 [00:00<?, ?it/s]

Epoch 10/20 [Val]:   0%|          | 0/26 [00:00<?, ?it/s]

Epoch 10/20:
  Train Loss: 0.1364, Train Acc: 0.9497
  Val Loss: 0.1703, Val Acc: 0.9444
  Val Precision: 0.9412, Recall: 0.8889, F1: 0.9143


Epoch 11/20 [Train]:   0%|          | 0/119 [00:00<?, ?it/s]

Epoch 11/20 [Val]:   0%|          | 0/26 [00:00<?, ?it/s]

Epoch 11/20:
  Train Loss: 0.1167, Train Acc: 0.9558
  Val Loss: 0.2437, Val Acc: 0.9259
  Val Precision: 0.9817, Recall: 0.7926, F1: 0.8770


Epoch 12/20 [Train]:   0%|          | 0/119 [00:00<?, ?it/s]

Epoch 12/20 [Val]:   0%|          | 0/26 [00:00<?, ?it/s]

Epoch 12/20:
  Train Loss: 0.1198, Train Acc: 0.9556
  Val Loss: 0.2319, Val Acc: 0.9222
  Val Precision: 0.9770, Recall: 0.7852, F1: 0.8706


Epoch 13/20 [Train]:   0%|          | 0/119 [00:00<?, ?it/s]

Epoch 13/20 [Val]:   0%|          | 0/26 [00:00<?, ?it/s]

Epoch 13/20:
  Train Loss: 0.1183, Train Acc: 0.9513
  Val Loss: 0.1576, Val Acc: 0.9457
  Val Precision: 0.9520, Recall: 0.8815, F1: 0.9154


Epoch 14/20 [Train]:   0%|          | 0/119 [00:00<?, ?it/s]

Epoch 14/20 [Val]:   0%|          | 0/26 [00:00<?, ?it/s]

Epoch 14/20:
  Train Loss: 0.1148, Train Acc: 0.9569
  Val Loss: 0.1651, Val Acc: 0.9395
  Val Precision: 0.9139, Recall: 0.9037, F1: 0.9088


Epoch 15/20 [Train]:   0%|          | 0/119 [00:00<?, ?it/s]

Epoch 15/20 [Val]:   0%|          | 0/26 [00:00<?, ?it/s]

Epoch 15/20:
  Train Loss: 0.1046, Train Acc: 0.9587
  Val Loss: 0.2672, Val Acc: 0.9198
  Val Precision: 0.9952, Recall: 0.7630, F1: 0.8637


Epoch 16/20 [Train]:   0%|          | 0/119 [00:00<?, ?it/s]

Epoch 16/20 [Val]:   0%|          | 0/26 [00:00<?, ?it/s]

Epoch 16/20:
  Train Loss: 0.1273, Train Acc: 0.9624
  Val Loss: 0.1572, Val Acc: 0.9457
  Val Precision: 0.9280, Recall: 0.9074, F1: 0.9176


Epoch 17/20 [Train]:   0%|          | 0/119 [00:00<?, ?it/s]

Epoch 17/20 [Val]:   0%|          | 0/26 [00:00<?, ?it/s]

Epoch 17/20:
  Train Loss: 0.1018, Train Acc: 0.9606
  Val Loss: 0.2627, Val Acc: 0.9086
  Val Precision: 1.0000, Recall: 0.7259, F1: 0.8412


Epoch 18/20 [Train]:   0%|          | 0/119 [00:00<?, ?it/s]

Epoch 18/20 [Val]:   0%|          | 0/26 [00:00<?, ?it/s]

Epoch 18/20:
  Train Loss: 0.0887, Train Acc: 0.9680
  Val Loss: 0.1445, Val Acc: 0.9481
  Val Precision: 0.9161, Recall: 0.9296, F1: 0.9228


Epoch 19/20 [Train]:   0%|          | 0/119 [00:00<?, ?it/s]

Epoch 19/20 [Val]:   0%|          | 0/26 [00:00<?, ?it/s]

Epoch 19/20:
  Train Loss: 0.0800, Train Acc: 0.9680
  Val Loss: 0.1688, Val Acc: 0.9432
  Val Precision: 0.9628, Recall: 0.8630, F1: 0.9102


Epoch 20/20 [Train]:   0%|          | 0/119 [00:00<?, ?it/s]

Epoch 20/20 [Val]:   0%|          | 0/26 [00:00<?, ?it/s]

Epoch 20/20:
  Train Loss: 0.0829, Train Acc: 0.9696
  Val Loss: 0.1328, Val Acc: 0.9519
  Val Precision: 0.9639, Recall: 0.8889, F1: 0.9249

Validation Results:
  Accuracy: 0.9519
  Precision: 0.9639
  Recall: 0.8889
  F1-Score: 0.9249
  AUC: 0.9865
  Training time: 6.13 minutes

BEST MODEL SELECTION
----------------------------------------
Best model: mobilenet_v3
Validation F1-Score: 0.9249


# 6. EVALUATING ON TEST SET

In [12]:
# Load best model
best_model = create_model(best_model_name)
checkpoint = torch.load(MODELS_DIR / f'cnn_{best_model_name}_best.pth')
best_model.load_state_dict(checkpoint['model_state_dict'])
best_model = best_model.to(device)
best_model.eval()

test_labels_all = []
test_preds_all = []
test_probs_all = []

with torch.no_grad():
    test_bar = tqdm(test_loader, desc='Testing')
    for images, labels in test_bar:
        images, labels = images.to(device), labels.to(device)
        outputs = best_model(images)
        probs = torch.softmax(outputs, dim=1)
        _, predicted = torch.max(outputs.data, 1)

        test_labels_all.extend(labels.cpu().numpy())
        test_preds_all.extend(predicted.cpu().numpy())
        test_probs_all.extend(probs[:, 1].cpu().numpy())

# Calculate test metrics
test_acc = accuracy_score(test_labels_all, test_preds_all)
test_precision, test_recall, test_f1, _ = precision_recall_fscore_support(
    test_labels_all, test_preds_all, average='binary', zero_division=0
)
test_auc = roc_auc_score(test_labels_all, test_probs_all)
test_cm = confusion_matrix(test_labels_all, test_preds_all, labels=[0, 1])

print(f"\nTest Set Results ({best_model_name}):")
print(f"  Accuracy: {test_acc:.4f}")
print(f"  Precision: {test_precision:.4f}")
print(f"  Recall: {test_recall:.4f}")
print(f"  F1-Score: {test_f1:.4f}")
print(f"  AUC: {test_auc:.4f}")

print(f"\nConfusion Matrix:")
print(f"  TN: {test_cm[0,0]}, FP: {test_cm[0,1]}")
print(f"  FN: {test_cm[1,0]}, TP: {test_cm[1,1]}")

Testing:   0%|          | 0/26 [00:00<?, ?it/s]


Test Set Results (mobilenet_v3):
  Accuracy: 0.9395
  Precision: 0.9702
  Recall: 0.8444
  F1-Score: 0.9030
  AUC: 0.9822

Confusion Matrix:
  TN: 533, FP: 7
  FN: 42, TP: 228


In [13]:
# Save all results
all_results = {
    'models': results,
    'best_model': best_model_name,
    'test_results': {
        'model': best_model_name,
        'accuracy': test_acc,
        'precision': test_precision,
        'recall': test_recall,
        'f1_score': test_f1,
        'auc': test_auc,
        'confusion_matrix': test_cm.tolist()
    },
    'hyperparameters': {
        'batch_size': BATCH_SIZE,
        'learning_rate': LEARNING_RATE,
        'num_epochs': NUM_EPOCHS,
        'early_stopping_patience': EARLY_STOPPING_PATIENCE,
        'image_size': IMAGE_SIZE
    },
    'augmentation': {
        'train': 'RandomResizedCrop, Flips, Rotation, ColorJitter, Perspective, Grayscale, Blur',
        'val_test': 'Resize, CenterCrop'
    }
}

with open(METRICS_DIR / 'cnn_results.json', 'w') as f:
    json.dump(all_results, f, indent=2)

print(f"\nResults saved to: {METRICS_DIR / 'cnn_results.json'}")


Results saved to: /home/tony/research_project/iate_project/results/metrics/cnn_results.json


In [14]:
# Save final model for deployment
torch.save({
    'model_name': best_model_name,
    'model_state_dict': best_model.state_dict(),
    'test_accuracy': test_acc,
    'test_f1_score': test_f1
}, MODELS_DIR / 'final_model.pth')

print(f"Final model saved to: {MODELS_DIR / 'final_model.pth'}")

Final model saved to: /home/tony/research_project/iate_project/results/models/final_model.pth


In [15]:
print(f"\nBest Model: {best_model_name}")
print(f"Test F1-Score: {test_f1:.4f}")
print(f"Test Accuracy: {test_acc:.4f}")


Best Model: mobilenet_v3
Test F1-Score: 0.9030
Test Accuracy: 0.9395
