In [None]:
import torch, torch.nn as nn, torch.optim as optim
import torch.nn.functional as F 
from torchvision import models, datasets, transforms
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc
import numpy as np, os
from pathlib import Path
import matplotlib.pyplot as plt  
import mlflow                   
import mlflow.pytorch    

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


In [None]:
IMG_SIZE = 224
BATCH_SIZE = 32
tfm_train = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(IMG_SIZE),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
tfm_val = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])


In [None]:
DATASET_TYPE = "condition" 

PROJECT_ROOT = Path.cwd().parent
DATA_ROOT = PROJECT_ROOT / "data" / "raw" / DATASET_TYPE
TRAIN_DIR = DATA_ROOT / "train"
VAL_DIR   = DATA_ROOT / "val"

print(f"Using dataset: {DATASET_TYPE}")
print(f"Train dir: {TRAIN_DIR}")
print(f"Val dir:   {VAL_DIR}")

train_ds = datasets.ImageFolder(TRAIN_DIR, transform=tfm_train)
val_ds   = datasets.ImageFolder(VAL_DIR,   transform=tfm_val)
num_classes = len(train_ds.classes)

print("Classes:", train_ds.classes)


In [None]:

# Get the count for each class in the *training* set
counts = np.bincount(train_ds.targets)
print(f"Training counts: {dict(zip(train_ds.classes, counts))}")

# Calculate weights (inversely proportional to class frequency)
weight_for_0 = (1 / counts[0]) * (sum(counts) / num_classes)
weight_for_1 = (1 / counts[1]) * (sum(counts) / num_classes)
class_weights = torch.tensor([weight_for_0, weight_for_1], dtype=torch.float32).to(device)
print(f"Using weights (for 'good', 'worn'): {class_weights.cpu().numpy()}")
# --- END NEW ---

In [None]:
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_dl   = DataLoader(val_ds, batch_size=BATCH_SIZE*2, shuffle=False, num_workers=2, pin_memory=True)


In [None]:

model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)

# 1. Freeze all parameters first
for p in model.parameters():
    p.requires_grad = False

# 2. Unfreeze the final block (layer4)
for p in model.layer4.parameters():
    p.requires_grad = True

# 3. Replace the head 
in_feats = model.fc.in_features
model.fc = nn.Sequential(
    nn.Dropout(p=0.5), # a 50% dropout layer
    nn.Linear(in_feats, num_classes)
)
model = model.to(device)

print(f"Model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters.")



In [None]:
LEARNING_RATE_HEAD = 5e-4
LEARNING_RATE_BODY = 1e-6

opt = optim.AdamW([
    {"params": model.layer4.parameters(), "lr": LEARNING_RATE_BODY},
    {"params": model.fc.parameters(), "lr": LEARNING_RATE_HEAD}
], weight_decay=1e-4)

# 2. Add weights to the loss function
crit = nn.CrossEntropyLoss(weight=class_weights)

In [None]:

def evaluate():
    model.eval()
    y_true, y_pred = [], []
    total, correct, loss_sum = 0, 0, 0.0
    with torch.no_grad():
        for x,y in val_dl:
            x,y = x.to(device), y.to(device)
            logits = model(x)
            loss = crit(logits, y)
            loss_sum += loss.item() * y.size(0)
            pred = logits.argmax(1)
            correct += (pred==y).sum().item()
            total += y.size(0)
            y_true.extend(y.cpu().numpy())
            y_pred.extend(pred.cpu().numpy())
    acc = correct/total
    return acc, loss_sum/total, np.array(y_true), np.array(y_pred)



In [None]:
EPOCHS = 15  
best_acc = 0.0


mlflow.set_experiment(f"Tyre_Model_{DATASET_TYPE}")

with mlflow.start_run():

    mlflow.log_param("model", "resnet18_layer4_unfrozen")
    mlflow.log_param("epochs", EPOCHS)
    mlflow.log_param("batch_size", BATCH_SIZE)
    mlflow.log_param("lr_head", LEARNING_RATE_HEAD)
    mlflow.log_param("lr_body", LEARNING_RATE_BODY)
    mlflow.log_param("class_weights", class_weights.cpu().numpy())

    print(f"--- Starting Training for {EPOCHS} epochs ---")
    for epoch in range(1, EPOCHS+1):
        print(f"------------Commencing Epoch {epoch}/{EPOCHS}------------")
        model.train()
        train_loss_sum = 0.0 
        for x,y in train_dl:
            x,y = x.to(device), y.to(device)
            opt.zero_grad()
            loss = crit(model(x), y)
            loss.backward()
            opt.step()
            train_loss_sum += loss.item() * y.size(0) 
        
        # Get epoch stats
        epoch_train_loss = train_loss_sum / len(train_ds)
        acc, vloss, y_true, y_pred = evaluate()
        
        print(f"Epoch {epoch}/{EPOCHS}: train_loss={epoch_train_loss:.4f} val_loss={vloss:.4f} val_acc={acc:.4f}")
        

        mlflow.log_metric("train_loss", epoch_train_loss, step=epoch)
        mlflow.log_metric("val_loss", vloss, step=epoch)
        mlflow.log_metric("val_acc", acc, step=epoch)

        if acc > best_acc:
            best_acc = acc
            print(f"  🎉 New best model! Saving... (acc={acc:.4f})")
            save_path = f"best_{DATASET_TYPE}.pt"
            torch.save({"model": model.state_dict(), "classes": train_ds.classes}, save_path)
            
            # --- Log best model & score ---
            mlflow.pytorch.log_model(model, "best_model")
            mlflow.log_metric("best_val_acc", best_acc)
    
    print(f"\n✅ Training complete. Best val_acc={best_acc:.4f}")
    

    # Log classification report
    report = classification_report(y_true, y_pred, target_names=train_ds.classes)
    with open("classification_report.txt", "w") as f:
        f.write(report)
    mlflow.log_artifact("classification_report.txt")
    
    # Log confusion matrix
    cm_plot = plt.figure()
    plt.imshow(confusion_matrix(y_true, y_pred), interpolation='nearest', cmap=plt.cm.Blues)
    plt.title("Confusion Matrix (Last Epoch)")
    plt.colorbar()
    tick_marks = np.arange(len(train_ds.classes))
    plt.xticks(tick_marks, train_ds.classes, rotation=45)
    plt.yticks(tick_marks, train_ds.classes)
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    cm_path = "confusion_matrix.png"
    plt.savefig(cm_path)
    mlflow.log_artifact(cm_path)
    plt.close(cm_plot) 

In [None]:
# Note: y_true and y_pred are from the *last* epoch
print(f"\n✅ Training complete. Best val_acc={best_acc:.4f}")
print("--- Final Report (from last epoch) ---")
print("Confusion matrix:\n", confusion_matrix(y_true, y_pred))
print("\nClassification report:\n", classification_report(y_true, y_pred, target_names=train_ds.classes))


In [None]:
def evaluate_for_roc():
    model.eval()
    y_true_all, y_pred_all = [], []
    y_proba_all = [] # <-- NEW: Store probabilities
    
    with torch.no_grad():
        for x,y in val_dl:
            x,y = x.to(device), y.to(device)
            
            # 1. Get Logits
            logits = model(x)
            
            # 2. Get Probabilities
            probas = F.softmax(logits, dim=1)
            
            # 3. Get Final Predictions
            pred = logits.argmax(1)
            
            # --- Store everything ---
            y_true_all.extend(y.cpu().numpy())
            y_pred_all.extend(pred.cpu().numpy())
            
            # Store the probability of the POSITIVE class (class 1, "worn")
            y_proba_all.extend(probas[:, 1].cpu().numpy()) 

    return np.array(y_true_all), np.array(y_pred_all), np.array(y_proba_all)

In [None]:
print("Generating ROC-AUC plot...")

# 1. Get the values from your (last) model
y_true, y_pred, y_proba = evaluate_for_roc()

# 2. Calculate the ROC curve points
fpr, tpr, thresholds = roc_curve(y_true, y_proba)

# 3. Calculate the Area Under the Curve (AUC)
roc_auc = auc(fpr, tpr)
print(f"\nModel AUC Score: {roc_auc:.4f}")


In [None]:
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='darkorange', lw=2, 
         label=f'ROC curve (area = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate (FPR)')
plt.ylabel('True Positive Rate (TPR)')
plt.title(f'Receiver Operating Characteristic (ROC) Curve - {DATASET_TYPE}')
plt.legend(loc="lower right")
plt.grid(True)

# --- NEW: Save plot as artifact ---
plot_filename = f"roc_auc_plot_{DATASET_TYPE}.png"
plt.savefig(plot_filename)
print(f"Saved ROC plot to {plot_filename}")

# If your MLFlow run is still active, you would log it:
# mlflow.log_artifact(plot_filename)

plt.show()