In [4]:
import torch, torch.nn as nn, torch.optim as optim
import torch.nn.functional as F 
from torchvision import models, datasets, transforms
from torch.utils.data import DataLoader, Subset, ConcatDataset
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc
from sklearn.model_selection import KFold
import numpy as np, os
from pathlib import Path
import matplotlib.pyplot as plt
import mlflow
import mlflow.pytorch
import itertools # For plotting CM

In [5]:
DATASET_TYPE = "pressure" 
IMG_SIZE = 224
BATCH_SIZE = 32 # 500 images is small, a 32 batch size is fine
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-4
EPOCHS = 15      # 15-20 is plenty for fine-tuning just the head
PATIENCE = 4     # Stop after 4 epochs with no improvement
K_FOLDS = 5

In [6]:
# --- Data Augmentation (Slightly lighter) ---
tfm_train = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(IMG_SIZE), # Good choice for this data
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
    transforms.RandomAffine(degrees=5), # Small rotation
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

tfm_val = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])



In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

Device: cuda


In [8]:
# %%
# --- Load Datasets & Combine for K-Fold ---
PROJECT_ROOT = Path.cwd().parent
DATA_ROOT = PROJECT_ROOT / "data" / "raw" / DATASET_TYPE
TRAIN_DIR = DATA_ROOT / "train"
VAL_DIR   = DATA_ROOT / "val" # <-- We'll load this too
TEST_DIR  = DATA_ROOT / "test"

print(f"Using dataset: {DATASET_TYPE}")
print(f"Train dir: {TRAIN_DIR}")
print(f"Val dir:   {VAL_DIR}")
print(f"Test dir (Final Holdout): {TEST_DIR}")

# 1. Load both train and val datasets initially
#    Use the TRAINING transform for BOTH initially, as they will be used for training in K-Fold
train_ds_part = datasets.ImageFolder(TRAIN_DIR, transform=tfm_train)
val_ds_part   = datasets.ImageFolder(VAL_DIR,   transform=tfm_train) # Use train transform here too

# 2. Combine them into one dataset for K-Fold
combined_train_ds = ConcatDataset([train_ds_part, val_ds_part])
num_classes = len(train_ds_part.classes) # Get classes from one part

# Keep the test set separate with its own transform
test_ds = datasets.ImageFolder(TEST_DIR, transform=tfm_val)

print(f"Classes: {train_ds_part.classes}")
print(f"Total images for K-Fold (Train+Val): {len(combined_train_ds)}")
print(f"Total test images (hold-out): {len(test_ds)}")

# --- Check balance (on the combined set) ---
# Need to get targets from both parts
combined_targets = np.concatenate([train_ds_part.targets, val_ds_part.targets])
counts = np.bincount(combined_targets)
print(f"Combined training counts: {dict(zip(train_ds_part.classes, counts))}")
print("Dataset appears balanced. No class weights needed.")

Using dataset: pressure
Train dir: c:\Users\Sai\Desktop\Tyre health POC\tyre-health-poc\data\raw\pressure\train
Val dir:   c:\Users\Sai\Desktop\Tyre health POC\tyre-health-poc\data\raw\pressure\val
Test dir (Final Holdout): c:\Users\Sai\Desktop\Tyre health POC\tyre-health-poc\data\raw\pressure\test
Classes: ['flat', 'full']
Total images for K-Fold (Train+Val): 540
Total test images (hold-out): 60
Combined training counts: {'flat': np.int64(270), 'full': np.int64(270)}
Dataset appears balanced. No class weights needed.


In [9]:
def create_model():
    """Helper function to create a fresh, frozen ResNet-18 model."""
    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
    
    # 1. Freeze all layers
    for param in model.parameters():
        param.requires_grad = False
        
    # 2. Replace the head (this new layer will be trainable by default)
    in_feats = model.fc.in_features
    model.fc = nn.Linear(in_feats, num_classes)
    
    model = model.to(device)
    return model


In [10]:
def evaluate(model, loader, crit, device):
    model.eval()
    y_true, y_pred = [], []
    total, correct, loss_sum = 0, 0, 0.0
    with torch.no_grad():
        for x,y in loader:
            x,y = x.to(device), y.to(device)
            logits = model(x)
            loss = crit(logits, y)
            loss_sum += loss.item() * y.size(0)
            pred = logits.argmax(1)
            correct += (pred==y).sum().item()
            total += y.size(0)
            y_true.extend(y.cpu().numpy())
            y_pred.extend(pred.cpu().numpy())
    acc = correct/total
    return acc, loss_sum/total, np.array(y_true), np.array(y_pred)

In [12]:
# --- K-FOLD CROSS-VALIDATION LOOP ---

experiment_name = "Tyre_Pressure_ResNet18"
mlflow.set_experiment(experiment_name)
print(f"MLflow Experiment: {experiment_name}")

# Use ONLY train_ds for cross-validation
kfold = KFold(n_splits=K_FOLDS, shuffle=True, random_state=42)

cv_scores = []
best_models_list = [] # Store the state dict of the best model from each fold

print(f"--- Starting {K_FOLDS}-Fold Cross Validation ---")

for fold, (train_idx, val_idx) in enumerate(kfold.split(combined_train_ds)):
    print(f"\n{'='*50}")
    print(f"FOLD {fold + 1}/{K_FOLDS}")
    print(f"{'='*50}")
    
    # Create fold subsets
    train_subset = Subset(combined_train_ds, train_idx)
    val_subset   = Subset(combined_train_ds, val_idx)
    
    # Apply the correct transforms (Subset doesn't copy them)
    # We must clone the dataset and set the transform
    train_subset.dataset = combined_train_ds
    temp_val_ds_ref = ConcatDataset([
        datasets.ImageFolder(TRAIN_DIR, transform=tfm_val),
        datasets.ImageFolder(VAL_DIR, transform=tfm_val)
    ])
    val_subset.dataset = temp_val_ds_ref
    
    # Create data loaders
    train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
    val_loader   = DataLoader(val_subset, batch_size=BATCH_SIZE*2, shuffle=False, num_workers=2, pin_memory=True)
    
    # --- Reset model and optimizer for each fold ---
    model = create_model()
    
    # We are ONLY training the head
    opt = optim.AdamW(model.fc.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    
    # Loss function (no weights needed for balanced data)
    crit = nn.CrossEntropyLoss()
    
    # Learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(opt, mode='max', factor=0.5, patience=3, verbose=True)
    
    best_acc = 0.0
    patience_counter = 0
    
    run_name = f"ResNet18_Fold{fold+1}"
    with mlflow.start_run(run_name=run_name):
        mlflow.log_param("model", "resnet18_frozen_head")
        mlflow.log_param("fold", fold + 1)
        mlflow.log_param("epochs", EPOCHS)
        mlflow.log_param("batch_size", BATCH_SIZE)
        mlflow.log_param("learning_rate", LEARNING_RATE)
        mlflow.log_param("weight_decay", WEIGHT_DECAY)
        
        for epoch in range(1, EPOCHS + 1):
            # Training
            model.train()
            train_loss_sum = 0.0
            for x, y in train_loader:
                x, y = x.to(device), y.to(device)
                opt.zero_grad()
                loss = crit(model(x), y)
                loss.backward()
                opt.step()
                train_loss_sum += loss.item() * y.size(0)
            
            # Validation
            acc, vloss, y_true, y_pred = evaluate(model, val_loader, crit, device)
            train_loss = train_loss_sum / len(train_subset)
            
            print(f"Epoch {epoch:2d}/{EPOCHS}: train_loss={train_loss:.4f} val_loss={vloss:.4f} val_acc={acc:.4f}")
            
            mlflow.log_metric("train_loss", train_loss, step=epoch)
            mlflow.log_metric("val_loss", vloss, step=epoch)
            mlflow.log_metric("val_acc", acc, step=epoch)
            
            scheduler.step(acc)
            
            if acc > best_acc:
                best_acc = acc
                patience_counter = 0
                print(f"  🎉 New best! (acc={acc:.4f})")
                best_models_list.append({
                    'fold': fold + 1,
                    'model_state': model.state_dict().copy(),
                    'acc': acc
                })
            else:
                patience_counter += 1
                if patience_counter >= PATIENCE:
                    print(f"  ⏹️  Early stopping at epoch {epoch}")
                    break
        
        cv_scores.append(best_acc)
        mlflow.log_metric("best_val_acc", best_acc)
        print(f"Fold {fold + 1} completed. Best accuracy: {best_acc:.4f}")

print(f"\n{'='*50}")
print("CROSS-VALIDATION RESULTS")
print(f"{'='*50}")
print(f"Individual fold scores: {[f'{score:.4f}' for score in cv_scores]}")
print(f"Mean CV Score: {np.mean(cv_scores):.4f} ± {np.std(cv_scores):.4f}")
print(f"Best single fold: {max(cv_scores):.4f}")

# Find best model across all folds
best_overall_model_info = max(best_models_list, key=lambda x: x['acc'])
print(f"Best model: Fold {best_overall_model_info['fold']} with accuracy {best_overall_model_info['acc']:.4f}")

# Load the best model for final evaluation
model = create_model()
model.load_state_dict(best_overall_model_info['model_state'])
print(f"\n✅ Cross-validation complete!")



MLflow Experiment: Tyre_Pressure_ResNet18
--- Starting 5-Fold Cross Validation ---

FOLD 1/5




Epoch  1/15: train_loss=0.7326 val_loss=1.0906 val_acc=0.4352
  🎉 New best! (acc=0.4352)
Epoch  2/15: train_loss=0.6492 val_loss=1.2827 val_acc=0.4352
Epoch  3/15: train_loss=0.5753 val_loss=0.6945 val_acc=0.5370
  🎉 New best! (acc=0.5370)
Epoch  4/15: train_loss=0.5072 val_loss=0.5233 val_acc=0.7963
  🎉 New best! (acc=0.7963)
Epoch  5/15: train_loss=0.4799 val_loss=0.5239 val_acc=0.7130
Epoch  6/15: train_loss=0.4658 val_loss=0.4393 val_acc=0.8333
  🎉 New best! (acc=0.8333)
Epoch  7/15: train_loss=0.4207 val_loss=0.4221 val_acc=0.8611
  🎉 New best! (acc=0.8611)
Epoch  8/15: train_loss=0.4345 val_loss=0.4004 val_acc=0.8519
Epoch  9/15: train_loss=0.3905 val_loss=0.4075 val_acc=0.8241
Epoch 10/15: train_loss=0.3750 val_loss=0.3856 val_acc=0.8333
Epoch 11/15: train_loss=0.3514 val_loss=0.3659 val_acc=0.8704
  🎉 New best! (acc=0.8704)
Epoch 12/15: train_loss=0.3287 val_loss=0.3479 val_acc=0.8426
Epoch 13/15: train_loss=0.3489 val_loss=0.3350 val_acc=0.8704
Epoch 14/15: train_loss=0.3113 v

KeyboardInterrupt: 

In [None]:
# --- FINAL EVALUATION ON HELD-OUT TEST SET ---
print("--- Final Evaluation on Test Set ---")

test_dl = DataLoader(test_ds, batch_size=BATCH_SIZE*2, shuffle=False, num_workers=2, pin_memory=True)

# Evaluate best model on test set
test_acc, test_loss, y_true_test, y_pred_test = evaluate(model, test_dl, crit, device)

print(f"Test Accuracy: {test_acc:.4f}")
print(f"Test Loss: {test_loss:.4f}")
print("Test Confusion Matrix:\n", confusion_matrix(y_true_test, y_pred_test))
print("\nTest Classification Report:\n", classification_report(y_true_test, y_pred_test, target_names=test_ds.classes))


In [None]:
save_path = f"best_{DATASET_TYPE}_resnet18.pt"
torch.save({
    "model": model.state_dict(), 
    "classes": test_ds.classes,
    "cv_scores": cv_scores,
    "test_acc": test_acc
}, save_path)
print(f"\nModel saved as: {save_path}")


In [None]:
def evaluate_for_roc(model, loader, device):
    """Generates probabilities for the ROC curve."""
    model.eval()
    y_true_all, y_proba_all = [], []
    with torch.no_grad():
        for x,y in loader:
            x,y = x.to(device), y.to(device)
            logits = model(x)
            probas = F.softmax(logits, dim=1)
            y_true_all.extend(y.cpu().numpy())
            y_proba_all.extend(probas[:, 1].cpu().numpy()) # Proba for class '1'
    return np.array(y_true_all), np.array(y_proba_all)

print("\nGenerating ROC-AUC plot on test set...")
y_true_roc, y_proba_roc = evaluate_for_roc(model, test_dl, device)

fpr, tpr, thresholds = roc_curve(y_true_roc, y_proba_roc)
roc_auc = auc(fpr, tpr)
print(f"Test Set AUC Score: {roc_auc:.4f}")


In [None]:
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='darkorange', lw=2, 
         label=f'ROC curve (area = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate (FPR)')
plt.ylabel('True Positive Rate (TPR)')
plt.title(f'ROC Curve - {DATASET_TYPE} (Test Set)')
plt.legend(loc="lower right")
plt.grid(True)

plot_filename = f"roc_auc_plot_{DATASET_TYPE}_resnet18.png"
plt.savefig(plot_filename)
print(f"Saved ROC plot to {plot_filename}")
plt.show()


In [None]:
print(f"\n{'='*60}")
print("FINAL SUMMARY")
print(f"{'='*60}")
print(f"Cross-Validation Results: {np.mean(cv_scores):.4f} ± {np.std(cv_scores):.4f}")
print(f"Test Set Accuracy: {test_acc:.4f}")
print(f"Test Set AUC: {roc_auc:.4f}")
print(f"Best Model: Fold {best_overall_model_info['fold']} (CV acc: {best_overall_model_info['acc']:.4f})")
print(f"{'='*60}")