In [1]:
# =========================
# [SEGMENT 0] Install dependencies
# =========================
!pip install --upgrade pip
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
!pip install timm==1.0.9
!pip install albumentations==1.4.18 opencv-python-headless==4.10.0.84 pydicom==2.4.4
!pip install pandas==2.2.3 scikit-learn==1.5.2 matplotlib==3.9.2 tqdm==4.66.5


Collecting pip
  Downloading pip-25.2-py3-none-any.whl.metadata (4.7 kB)
Downloading pip-25.2-py3-none-any.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m23.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 24.1.2
    Uninstalling pip-24.1.2:
      Successfully uninstalled pip-24.1.2
Successfully installed pip-25.2
Looking in indexes: https://download.pytorch.org/whl/cu121
Collecting timm==1.0.9
  Downloading timm-1.0.9-py3-none-any.whl.metadata (42 kB)
Downloading timm-1.0.9-py3-none-any.whl (2.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m54.3 MB/s[0m  [33m0:00:00[0m
[?25hInstalling collected packages: timm
  Attempting uninstall: timm
    Found existing installation: timm 1.0.19
    Uninstalling timm-1.0.19:
      Successfully uninstalled timm-1.0.19
Successfully installed timm-1.0.9
Collecting albumenta

In [None]:
# =========================
# [SEGMENT 1] Imports & Config
# =========================
import os, random, warnings, json
from pathlib import Path
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import cv2
import pydicom

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast

from sklearn.metrics import roc_auc_score, classification_report, confusion_matrix

import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm

# -------------------------
# Reproducibility
# -------------------------
def seed_all(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = True

seed_all(42)

# -------------------------
# Device
# -------------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# -------------------------
# Paths
# -------------------------
BASE_DIR     = Path("/content/drive/MyDrive/osteovision")  # Change if needed
OUT_DIR      = BASE_DIR / "outputs"; OUT_DIR.mkdir(parents=True, exist_ok=True)

# Dataset folders
TRAIN_DIR    = BASE_DIR / "train"
VALID_DIR    = BASE_DIR / "validation"
TEST_DIR     = BASE_DIR / "test"

# -------------------------
# Model / Training parameters
# -------------------------
NUM_CLASSES   = 2  # Change if your dataset has 3 classes
CLASS_NAMES   = ["normal", "osteoarthritis"]
IMG_SIZE      = 512
BATCH_SIZE    = 16
NUM_WORKERS   = 4

MODEL_NAME    = "tf_efficientnet_b3_ns"
FP16          = True

# Learning rates & schedule
INIT_LR_HEADS = 1e-3
INIT_LR_FULL  = 3e-4
EPOCHS_HEADS  = 5
EPOCHS_FULL   = 30
PATIENCE_LR   = 3
PATIENCE_ES   = 7

print("Device:", device)


In [None]:
# =========================
# [SEGMENT 2] Generate DataFrames from folder structure
# =========================
def folder_to_df(base_dir: Path, class_names: list[str]) -> pd.DataFrame:
    """
    Scans folders like train/normal, train/osteoarthritis and returns a DataFrame
    with columns: file_path, label
    """
    rows = []
    for cls_idx, cls_name in enumerate(class_names):
        d = base_dir / cls_name
        if not d.exists():
            continue
        for p in d.rglob("*"):
            if p.is_file() and p.suffix.lower() in [".png",".jpg",".jpeg",".dcm"]:
                rows.append({"file_path": str(p), "label": cls_idx})
    df = pd.DataFrame(rows)
    return df

# Generate DataFrames
df_train = folder_to_df(TRAIN_DIR, CLASS_NAMES)
df_valid = folder_to_df(VALID_DIR, CLASS_NAMES)
df_test  = folder_to_df(TEST_DIR,  CLASS_NAMES)

print(f"Train images: {len(df_train)} | Validation images: {len(df_valid)} | Test images: {len(df_test)}")


In [None]:
# =========================
# [SEGMENT 3] Dataset class, augmentations, and DataLoaders
# =========================
import albumentations as A
from albumentations.pytorch import ToTensorV2

# --- 1️⃣ Read images (PNG/JPG/DICOM) ---
def read_image_any(path: str) -> np.ndarray:
    path = str(path)
    ext = Path(path).suffix.lower()
    if ext == ".dcm":
        dcm = pydicom.dcmread(path)
        arr = dcm.pixel_array.astype(np.float32)
        arr = (arr - arr.min()) / (arr.max() - arr.min() + 1e-6)
        arr = (arr * 255).clip(0,255).astype(np.uint8)
        if arr.ndim == 2:
            arr = cv2.cvtColor(arr, cv2.COLOR_GRAY2RGB)
        return arr
    img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
    if img is None:
        raise FileNotFoundError(path)
    if img.ndim == 2:
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
    else:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img

# --- 2️⃣ Augmentations ---
train_tfms = A.Compose([
    A.LongestMaxSize(IMG_SIZE),
    A.PadIfNeeded(IMG_SIZE, IMG_SIZE, border_mode=cv2.BORDER_REFLECT_101),
    A.RandomResizedCrop(IMG_SIZE, IMG_SIZE, scale=(0.85, 1.0), ratio=(0.9, 1.1), p=0.6),
    A.HorizontalFlip(p=0.5),
    A.ShiftScaleRotate(shift_limit=0.02, scale_limit=0.1, rotate_limit=10, p=0.5),
    A.RandomBrightnessContrast(p=0.3),
    A.CoarseDropout(max_holes=1, max_height=IMG_SIZE//10, max_width=IMG_SIZE//10, p=0.25),
    A.Normalize(),
    ToTensorV2()
])
valid_tfms = A.Compose([
    A.LongestMaxSize(IMG_SIZE),
    A.PadIfNeeded(IMG_SIZE, IMG_SIZE, border_mode=cv2.BORDER_REFLECT_101),
    A.Normalize(),
    ToTensorV2()
])

# --- 3️⃣ Dataset class ---
class OAImageDataset(Dataset):
    def __init__(self, df: pd.DataFrame, transforms=None):
        self.df = df.reset_index(drop=True)
        self.transforms = transforms
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        r = self.df.iloc[idx]
        img = read_image_any(r["file_path"])
        if self.transforms:
            img = self.transforms(image=img)["image"]
        label = int(r["label"])
        return img, label

# --- 4️⃣ Create Dataset objects & DataLoaders ---
train_ds = OAImageDataset(df_train, transforms=train_tfms)
valid_ds = OAImageDataset(df_valid, transforms=valid_tfms)
test_ds  = OAImageDataset(df_test,  transforms=valid_tfms)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS, pin_memory=True)
valid_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)


In [None]:
# =========================
# [SEGMENT 4] Model, Loss, Optimizer
# =========================
import timm
import torch.nn as nn
import torch.optim as optim

# --- 1️⃣ Build model ---
def build_model(num_classes: int):
    """
    Creates a timm model with pretrained weights.
    """
    model = timm.create_model(MODEL_NAME, pretrained=True, in_chans=3, num_classes=num_classes)
    return model

model = build_model(NUM_CLASSES).to(device)

# --- 2️⃣ Loss function ---
criterion = nn.CrossEntropyLoss()

# --- 3️⃣ Optimizer ---
def make_optimizer(model, lr):
    """
    AdamW optimizer only for parameters that require grad.
    """
    return optim.AdamW([p for p in model.parameters() if p.requires_grad],
                       lr=lr, weight_decay=1e-4)

# --- Optional: print model summary (requires torchinfo) ---
# from torchinfo import summary
# summary(model, input_size=(BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE))


In [None]:
# =========================
# [SEGMENT 5] Utilities: metrics, early stopping, checkpointing
# =========================
import torch

# --- 1️⃣ Accuracy metric ---
def accuracy(logits, y):
    """
    Returns simple accuracy between logits and true labels.
    """
    preds = logits.argmax(1)
    return (preds == y).float().mean().item()

# --- 2️⃣ EarlyStopping ---
class EarlyStopping:
    def __init__(self, patience=PATIENCE_ES, mode="min", delta=0.0):
        self.patience = patience
        self.mode = mode
        self.delta = delta
        self.best = None
        self.num_bad = 0
        self.stop = False

    def __call__(self, value):
        if self.best is None:
            self.best = value
            return
        improve = (value < self.best - self.delta) if self.mode == "min" else (value > self.best + self.delta)
        if improve:
            self.best = value
            self.num_bad = 0
        else:
            self.num_bad += 1
            if self.num_bad >= self.patience:
                self.stop = True

# --- 3️⃣ Checkpoint saving/loading ---
def save_ckpt(path, model, optimizer, epoch, best_metric):
    torch.save({
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "epoch": epoch,
        "best_metric": best_metric
    }, path)

def load_ckpt(path, model, optimizer=None):
    ckpt = torch.load(path, map_location="cpu")
    model.load_state_dict(ckpt["model"])
    if optimizer is not None and "optimizer" in ckpt:
        optimizer.load_state_dict(ckpt["optimizer"])
    return ckpt.get("epoch", 0), ckpt.get("best_metric", None)


In [None]:
# =========================
# [SEGMENT 6] Train & Eval Epoch Loops (AMP + FP16)
# =========================
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm
from sklearn.metrics import roc_auc_score

def run_epoch(model, loader, optimizer=None, scaler=None):
    """
    Runs one epoch of training or evaluation.
    - If optimizer is provided → training
    - If optimizer is None → evaluation only
    """
    is_train = optimizer is not None
    model.train(is_train)

    total_loss, total_acc, n = 0.0, 0.0, 0
    all_logits, all_y = [], []

    for imgs, labels in tqdm(loader, disable=False):
        imgs = imgs.to(device)
        labels = labels.to(device)

        with autocast(enabled=FP16):
            logits = model(imgs)
            loss = criterion(logits, labels)

        if is_train:
            optimizer.zero_grad(set_to_none=True)
            if FP16:
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                optimizer.step()

        total_loss += loss.item() * imgs.size(0)
        total_acc  += accuracy(logits, labels) * imgs.size(0)
        all_logits.append(logits.detach().cpu())
        all_y.append(labels.detach().cpu())
        n += imgs.size(0)

    avg_loss = total_loss / max(n,1)
    avg_acc  = total_acc  / max(n,1)
    all_logits = torch.cat(all_logits)
    all_y      = torch.cat(all_y)

    auc = None
    try:
        proba = all_logits.softmax(1).numpy()
        auc = roc_auc_score(all_y.numpy(), proba, multi_class="ovr")
    except Exception:
        pass

    return {"loss": avg_loss, "acc": avg_acc, "auc": auc}

# --- Example usage ---
# scaler = GradScaler(enabled=FP16)
# optimizer = make_optimizer(model, INIT_LR_HEADS)
# train_metrics = run_epoch(model, train_loader, optimizer, scaler)
# val_metrics   = run_epoch(model, valid_loader)


In [None]:
# =========================
# [SEGMENT 7] Stage 1: Train Heads Only
# =========================
# Freeze all except head/fc/classifier layers
for n, p in model.named_parameters():
    p.requires_grad = ("fc" in n or "classifier" in n or "head" in n)

optimizer = make_optimizer(model, INIT_LR_HEADS)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=PATIENCE_LR, factor=0.5, verbose=True)
early = EarlyStopping(patience=PATIENCE_ES, mode="min")
scaler = GradScaler(enabled=FP16)

best_val = 1e9
CKPT_HEADS = OUT_DIR / "model_heads_best.pt"

for epoch in range(1, EPOCHS_HEADS+1):
    tr = run_epoch(model, train_loader, optimizer, scaler)
    va = run_epoch(model, valid_loader)
    scheduler.step(va["loss"])
    early(va["loss"])
    if va["loss"] < best_val:
        best_val = va["loss"]
        save_ckpt(CKPT_HEADS, model, optimizer, epoch, best_val)
    print(f"[Heads][{epoch}] train {tr['loss']:.4f}/{tr['acc']:.4f} | valid {va['loss']:.4f}/{va['acc']:.4f} auc {va['auc']}")
    if early.stop:
        print("Early stopping (heads)")
        break

# Load best heads checkpoint
_ = load_ckpt(CKPT_HEADS, model)


In [None]:
# =========================
# [SEGMENT 8] Stage 2: Full Fine-Tune
# =========================
# Unfreeze all layers
for p in model.parameters():
    p.requires_grad = True

optimizer = make_optimizer(model, INIT_LR_FULL)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=PATIENCE_LR, factor=0.5, verbose=True)
early = EarlyStopping(patience=PATIENCE_ES, mode="min")
scaler = GradScaler(enabled=FP16)

best_val = 1e9
CKPT_FULL = OUT_DIR / "model_full_best.pt"

for epoch in range(1, EPOCHS_FULL+1):
    tr = run_epoch(model, train_loader, optimizer, scaler)
    va = run_epoch(model, valid_loader)
    scheduler.step(va["loss"])
    early(va["loss"])
    if va["loss"] < best_val:
        best_val = va["loss"]
        save_ckpt(CKPT_FULL, model, optimizer, epoch, best_val)
    print(f"[Full][{epoch}] train {tr['loss']:.4f}/{tr['acc']:.4f} | valid {va['loss']:.4f}/{va['acc']:.4f} auc {va['auc']}")
    if early.stop:
        print("Early stopping (full)")
        break

# Load best full checkpoint
_ = load_ckpt(CKPT_FULL, model)


In [None]:
# =========================
# [SEGMENT 9] Optional Cosine Cooldown
# =========================
EPOCHS_COOLDOWN = 5
optimizer = make_optimizer(model, INIT_LR_FULL * 0.3)
cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS_COOLDOWN)
scaler = GradScaler(enabled=FP16)

for epoch in range(1, EPOCHS_COOLDOWN+1):
    tr = run_epoch(model, train_loader, optimizer, scaler)
    va = run_epoch(model, valid_loader)
    cosine.step()
    print(f"[Cooldown][{epoch}] train {tr['loss']:.4f}/{tr['acc']:.4f} | valid {va['loss']:.4f}/{va['acc']:.4f} auc {va['auc']}")


In [None]:
# =========================
# [SEGMENT 10] Evaluation on Validation Set
# =========================
model.eval()
ys, ps = [], []

with torch.no_grad():
    for imgs, labels in DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS):
        imgs = imgs.to(device)
        logits = model(imgs).softmax(1).cpu().numpy()
        ps.append(logits)
        ys.append(labels.numpy())

ps = np.concatenate(ps)
ys = np.concatenate(ys)
preds = ps.argmax(1)

# Classification report
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score

print("Classification Report:")
print(classification_report(ys, preds, target_names=CLASS_NAMES))

print("Confusion Matrix:")
print(confusion_matrix(ys, preds))

# Macro AUC
try:
    auc = roc_auc_score(ys, ps, multi_class="ovr")
    print("Macro AUC:", auc)
except Exception:
    print("AUC calculation failed (check labels/probabilities).")


In [None]:
# =========================
# [SEGMENT 11] Export Weights + Metadata
# =========================
EXPORT_DIR = OUT_DIR / "export"
EXPORT_DIR.mkdir(parents=True, exist_ok=True)

# Save model state dict
torch.save(model.state_dict(), EXPORT_DIR / "model_state_dict.pt")

# Save metadata for inference
meta = {
    "model_name": MODEL_NAME,
    "img_size": IMG_SIZE,
    "num_classes": NUM_CLASSES,
    "class_names": CLASS_NAMES,
    "normalization": "albumentations.Normalize() default (ImageNet)",
}
with open(EXPORT_DIR / "meta.json", "w") as f:
    json.dump(meta, f, indent=4)

print("Exported model and metadata to:", EXPORT_DIR)


In [None]:
# =========================
# [SEGMENT 12] Simple Inference Helper
# =========================
def infer_one(path: str):
    """
    Predicts class probabilities for a single image.
    Returns a dict {class_name: probability}.
    """
    model.eval()
    img = read_image_any(path)
    tfm = valid_tfms(image=img)["image"].unsqueeze(0).to(device)
    with torch.no_grad():
        logits = model(tfm).softmax(1).cpu().numpy()[0]
    return {CLASS_NAMES[i]: float(logits[i]) for i in range(NUM_CLASSES)}

# --- Example usage ---
# sample_path = df_valid.iloc[0]["file_path"]
# print(infer_one(sample_path))
