# Knee Osteoarthritis Classifier â€” Refactored (Min 50 Epochs)
*Generated on 2025-09-23 04:26:44*

This notebook ensures training runs for **at least 50 epochs** before early stopping can trigger.


In [None]:
import os, random, time, math
import numpy as np
import torch

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

# ===== Config =====
SEED = 42
BATCH_SIZE = 32
NUM_WORKERS = 0
NUM_EPOCHS = 100       # maximum
MIN_EPOCHS = 50        # force at least this many
PATIENCE = 5           # early stopping patience (after 50 epochs)
LR = 3e-4
IMG_SIZE = 224
WEIGHT_DECAY = 1e-4
MODEL_NAME = 'resnet18'

def set_seed(seed=SEED):
    import random, numpy as np, torch
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(SEED)


In [None]:
# ===== Training Loop with Minimum Epochs =====
best_val_acc = -1.0
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': [], 'lr': []}
ckpt_path = 'best_model.pt'

for epoch in range(1, NUM_EPOCHS+1):
    t0 = time.time()
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, scaler)
    val_loss, val_acc, _, _ = evaluate(model, val_loader)
    scheduler.step(val_acc)

    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    history['lr'].append(optimizer.param_groups[0]['lr'])

    print(f"Epoch {epoch:02d} | "
          f"train_loss={train_loss:.4f} train_acc={train_acc:.4f} | "
          f"val_loss={val_loss:.4f} val_acc={val_acc:.4f} | "
          f"lr={optimizer.param_groups[0]['lr']:.2e} | "
          f"{time.strftime('%Mm %Ss', time.gmtime(time.time()-t0))}")

    # Early stopping with min epochs
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({'model_state_dict': model.state_dict(),
                    'config': {'classes': CLASSES, 'img_size': IMG_SIZE}}, ckpt_path)
        patience_used = 0
    else:
        patience_used = history.get('patience_used', 0) + 1
    history['patience_used'] = patience_used

    if epoch >= MIN_EPOCHS and patience_used >= PATIENCE:
        print(f"Early stopping triggered at epoch {epoch}.")
        break

print('Best val acc:', best_val_acc)
