In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!cp /content/drive/MyDrive/LOOCV_2models_VIT/loo_temp_Anthisnes_Chateau_de_Xhos_Camera_1_HIT.zip /content
!unzip /content/loo_temp_Anthisnes_Chateau_de_Xhos_Camera_1_HIT.zip -d /content > /dev/null

In [None]:
import os
import random
import shutil

# Base path
base_path = "/content/loo_temp_Anthisnes_Chateau_de_Xhos_Camera_1_HIT"
train_dir = os.path.join(base_path, "train")
val_dir = os.path.join(base_path, "val")

background_dir = os.path.join(train_dir, "background")
bats_dir = os.path.join(train_dir, "bats")

val_background_dir = os.path.join(val_dir, "background")
val_bats_dir = os.path.join(val_dir, "bats")

# Fixed seed for reproducibility
RANDOM_SEED = 42
random.seed(RANDOM_SEED)

# Gîte names
chosen_gites = [
    'Pont_de_Bousval_Photos_2022_PHOTO',
    'Modave_Camera_3_toiture_PHOTO',
    'Bornival_PHOTO_2023CAM04',
    'Pont_de_Bousval_Photos_2023_PHOTO_WK6HDBOUSVAL',
    'Pont_de_Bousval_Photos_2023_PHOTO_2022CAM12',
    'Pont_de_Bousval_Photos_2023_PHOTO_2023CAM06',
    'Bornival_PHOTO_2023CAM03',
    'Chaumont_Gistoux_Camera_2',
    'Chaumont_Gistoux_Camera_1',
    'Pont_de_Bousval_Photos_2023_PHOTO_2023CAM05',
    # 'Anthisnes_Chateau_de_Xhos_Camera_1_HIT',
    'Jenneret_Camera_1_PHOTO',
    'Modave_Camera_plancher_PHOTO'
]

# Randomly select a gîte using fixed seed
held_out_gite = random.choice(chosen_gites)
print(f"📦 Holding out gîte for validation: {held_out_gite}")
print(f"🧪 Reproducible with seed: {RANDOM_SEED}")

# Create val folders
os.makedirs(val_background_dir, exist_ok=True)
os.makedirs(val_bats_dir, exist_ok=True)

# Function to move matching files
def move_files_by_gite(source_dir, target_dir, gite_name):
    moved_count = 0
    for fname in sorted(os.listdir(source_dir)):  # sort to ensure order
        if gite_name in fname:
            shutil.move(os.path.join(source_dir, fname), os.path.join(target_dir, fname))
            moved_count += 1
    return moved_count

# Move files
bkg_moved = move_files_by_gite(background_dir, val_background_dir, held_out_gite)
bats_moved = move_files_by_gite(bats_dir, val_bats_dir, held_out_gite)

print(f"✅ Moved {bkg_moved} background images and {bats_moved} bat images to validation set.")


📦 Holding out gîte for validation: Jenneret_Camera_1_PHOTO
🧪 Reproducible with seed: 42
✅ Moved 6889 background images and 1330 bat images to validation set.


# ViT Bat Classifier (Google Colab Setup)

This script trains and evaluates a binary image classifier to distinguish bats from background images using a pretrained Vision Transformer (`ViT`).

## Features
- Mixed-precision training using `torch.amp.autocast` for performance
- `torch.compile()` for model acceleration (when supported)
- Balanced sampling and MixUp/CutMix augmentations for robust generalization
- Custom Focal Loss implementation to handle class imbalance
- Automatic threshold tuning with ROC/PR curve analysis
- Evaluation metrics and visualization (AUC, F1, Precision, Recall)

## Directory Structure
```
/content/loo_temp_<dataset_name>/
├── train/
│ ├── background/
│ └── bats/
├── val/
│ ├── background/
│ └── bats/
└── test/
├── background/
└── bats/
```


Each folder should contain image files with recognizable gîte prefixes (e.g. `Jenneret_Camera_1_PHOTO_img123.png`).

---

## Key Modules

### Configuration
- Defines file paths, batch size, learning rate, number of epochs, and seeds for reproducibility.

### Transforms and Augmentation
- Applies resizing, normalization, and data augmentation (rotation, jitter, perspective, blur).
- MixUp and CutMix applied during training.

### Data Handling
- Loads data using `ImageFolder` from `torchvision.datasets`.
- Uses `WeightedRandomSampler` to address class imbalance in the training set.

### Model Architecture
- Loads `google/vit-base-patch16-224` from HuggingFace Transformers.
- Replaces the classifier head with `Dropout(0.5) → Linear(embed_dim, 2)`.
- Freezes the first two-thirds of encoder layers for transfer learning.

### Loss Function
- Implements Focal Loss to reduce the contribution of easy examples and focus on difficult ones.

### Training Loop
- Incorporates mixed-precision (`autocast`) and gradient scaling (`GradScaler`).
- Uses early stopping based on macro F1-score improvements.
- Learning rate scheduling via cosine annealing.

### Threshold Optimization
- Finds the threshold where precision ≈ recall from PR curve.
- Saves this threshold and plots ROC for visual reference.

### Inference
- Loads the best model and optimal threshold.
- Runs classification on new or test images and prints predictions.

### Final Evaluation
- Reports macro-averaged precision, recall, F1-score, and AUC on the test set.
- Plots ROC curve for test predictions.

---

## Hardware Requirements
- Optimized for Google Colab with NVIDIA T4 GPU.
- Falls back to CPU if GPU is not available.

---

## Outputs
- `best_model/`: Directory containing the saved model from best validation F1.
- `val_blob.pt`: Tensor file containing validation logits and labels.
- `best_thr.txt`: Optimal decision threshold for inference.
- Inline ROC and PR plots via `matplotlib`.

---

## Notes
- Ensure your dataset is pre-organized before starting training.
- Adjust `NUM_EPOCHS`, `BATCH_SIZE`, or learning rate depending on your dataset size and hardware.
- Script is modular and can be reused for other binary classification problems with minimal changes.


In [None]:
# ==========================  bat_classifier.py  ==========================
"""
ViT-based bat/background classifier
- Mixed-precision (torch.amp.autocast)
- torch.compile(), fast DataLoader, in-file FocalLoss
- Optimized for Colab (NVIDIA T4)
"""

import os, math, random, json, numpy as np, torch
from PIL import Image
from collections import Counter
from tqdm.auto import tqdm
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision.transforms import InterpolationMode
from timm.data.mixup import Mixup
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from transformers import ViTForImageClassification
from sklearn.metrics import (
    precision_score, recall_score, f1_score,
    precision_recall_curve, roc_curve, auc
)
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F

import albumentations as A
from albumentations.pytorch import ToTensorV2

# Reproducibility
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True

# --- Config --- #
DATA_DIR      = "/content/loo_temp_Anthisnes_Chateau_de_Xhos_Camera_1_HIT"
TRAIN_DIR     = os.path.join(DATA_DIR, "train")
VAL_DIR       = os.path.join(DATA_DIR, "val")
TEST_DIR      = os.path.join(DATA_DIR, "test")
INFER_DIR     = TEST_DIR
BATCH_SIZE    = 32
NUM_EPOCHS    = 10
LEARNING_RATE = 3e-5
MODEL_NAME    = "google/vit-base-patch16-224"
OUTPUT_DIR    = "/content/vit_finetuned_bats_Anthisnes_Chateau_de_Xhos_Camera_1_HIT"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# --- Performance settings --- #
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision("high")
try:
    from torch.amp import autocast, GradScaler
    AMP_KW = dict(device_type="cuda")
except ImportError:
    from torch.cuda.amp import autocast, GradScaler
    AMP_KW = {}
scaler = GradScaler()

IMAGE_SIZE = 224

train_albu = A.Compose([
    A.RandomResizedCrop(size=(IMAGE_SIZE, IMAGE_SIZE), scale=(0.5,1.0), ratio=(0.9,1.1), p=1.0),
    A.HorizontalFlip(p=0.5),
    A.Rotate(limit=15, p=0.3),
    A.RandomBrightnessContrast(0.2,0.2,p=0.3),
    A.RandomGamma((80,120), p=0.3),
    A.GaussNoise(std_range=(0.04, 0.20), p=0.2),
    A.OneOf([A.MotionBlur(5,p=1.0),A.MedianBlur(5,p=1.0),A.Blur(5,p=1.0)], p=0.2),
    A.CLAHE(p=0.1),
    A.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
    ToTensorV2(),
])
val_albu = A.Compose([
    A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE),
    A.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
    ToTensorV2(),
])

mixup_fn = Mixup(
    num_classes=2,
    mixup_alpha=0.2,
    cutmix_alpha=1.0,
    prob=0.5,
    switch_prob=0.5,
    mode="batch",
    label_smoothing=0.0
)

from torchvision.datasets import ImageFolder

class AlbumentationsFolder(ImageFolder):
    def __init__(self, root, transform):
        super().__init__(root, transform=None)
        self.alb_transform = transform

    def __getitem__(self, index):
        path, label = self.samples[index]
        img = np.array(Image.open(path).convert("RGB"))
        augmented = self.alb_transform(image=img)
        return augmented["image"], label

# --- Dataset & sampler --- #
train_set = AlbumentationsFolder(TRAIN_DIR, transform=train_albu)
val_set   = AlbumentationsFolder(VAL_DIR,   transform=val_albu)

targets = train_set.targets
class_cnt = Counter(targets)
cnt = torch.tensor([class_cnt[i] for i in range(len(class_cnt))], dtype=torch.float)
class_weights = 1. / cnt
class_weights /= class_weights.sum()
sample_weights = class_weights[targets]

sampler = WeightedRandomSampler(sample_weights, num_samples=len(train_set)*2, replacement=True)

num_workers = os.cpu_count() or 2
train_loader = DataLoader(
    train_set, batch_size=BATCH_SIZE, sampler=sampler,
    num_workers=num_workers, pin_memory=True, persistent_workers=True, prefetch_factor=4
)
val_loader = DataLoader(
    val_set, batch_size=BATCH_SIZE*2, shuffle=False,
    num_workers=num_workers, pin_memory=True, persistent_workers=True
)
# --- Focal Loss --- #
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction="mean"):
        super().__init__()
        if isinstance(alpha, (list, tuple)):
            alpha = torch.tensor(alpha, dtype=torch.float32)
        self.alpha, self.gamma, self.reduction = alpha, gamma, reduction

    def forward(self, logits: torch.Tensor, targets: torch.Tensor):
        ce = F.cross_entropy(logits, targets, reduction="none")
        pt = torch.exp(-ce)

        if isinstance(self.alpha, torch.Tensor):
            if targets.dtype in (torch.int64, torch.int32):
                a = self.alpha.to(logits.device)[targets]
            else:
                a = self.alpha.mean().to(logits.device)
        else:
            a = self.alpha

        loss = a * (1 - pt) ** self.gamma * ce
        return loss.mean() if self.reduction == "mean" else loss.sum()

# --- Model setup --- #
base_model = ViTForImageClassification.from_pretrained(
    MODEL_NAME,
    num_labels=2,
    id2label={0: "background", 1: "bats"},
    label2id={"background": 0, "bats": 1},
    ignore_mismatched_sizes=True
).to(DEVICE)

embed_dim = base_model.classifier.in_features
base_model.classifier = nn.Sequential(
    nn.Dropout(p=0.5),
    nn.Linear(embed_dim, 2)
).to(DEVICE)

# Freeze part of the encoder
freeze_upto = int(len(base_model.vit.encoder.layer) * 2 / 3)
for i, blk in enumerate(base_model.vit.encoder.layer):
    if i < freeze_upto:
        for p in blk.parameters():
            p.requires_grad = False

try:
    model = torch.compile(base_model)
except (AttributeError, RuntimeError):
    model = base_model

# --- Training setup --- #
alpha = torch.tensor([0.05, 0.95])
loss_fn = FocalLoss(alpha=alpha, gamma=2.0).to(DEVICE)
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS * len(train_loader))

# --- Training loop --- #
patience, best_f1, epochs_no_imp = 3, 0.0, 0
for epoch in range(1, NUM_EPOCHS + 1):
    model.train()
    train_loss = 0.0
    for px, y in tqdm(train_loader, desc=f"Train {epoch}/{NUM_EPOCHS}", unit="batch"):
        px, y = px.to(DEVICE, non_blocking=True), y.to(DEVICE, non_blocking=True)
        px, y = mixup_fn(px, y)
        with autocast(**AMP_KW):
            logits = model(pixel_values=px).logits
            loss = loss_fn(logits, y)
        optimizer.zero_grad(set_to_none=True)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        train_loss += loss.item()
    avg_train_loss = train_loss / len(train_loader)

    model.eval()
    val_logits, val_labels = [], []
    val_loss = 0.0
    with torch.no_grad():
        for px, y in tqdm(val_loader, desc="Valid", unit="batch", leave=False):
            px, y = px.to(DEVICE), y.to(DEVICE)
            with autocast(**AMP_KW):
                out = model(pixel_values=px).logits
                loss = loss_fn(out, y)
            val_loss += loss.item()
            val_logits.append(out.cpu())
            val_labels.append(y.cpu())

    val_logits = torch.cat(val_logits)
    val_labels = torch.cat(val_labels)
    avg_val_loss = val_loss / len(val_loader)
    preds = val_logits.argmax(dim=1)

    # Metrics
    f1_macro = f1_score(val_labels, preds, average="macro")
    precision_macro = precision_score(val_labels, preds, average="macro")
    recall_macro = recall_score(val_labels, preds, average="macro")

    print(f"Epoch {epoch:02d} │ TrainL {avg_train_loss:.4f} │ ValL {avg_val_loss:.4f} │ F1 {f1_macro:.4f} │ P {precision_macro:.4f} │ R {recall_macro:.4f}")

    if f1_macro > best_f1:
        best_f1, epochs_no_imp = f1_macro, 0
        model.save_pretrained(os.path.join(OUTPUT_DIR, "best_model"))
        torch.save({"logits": val_logits, "labels": val_labels}, os.path.join(OUTPUT_DIR, "val_blob.pt"))
    else:
        epochs_no_imp += 1
        if epochs_no_imp >= patience:
            print("Early stopping triggered.")
            break

# --- Threshold tuning --- #
blob = torch.load(os.path.join(OUTPUT_DIR, "val_blob.pt"))
labels = blob["labels"].numpy()
probs  = blob["logits"].softmax(1)[:, 1].cpu().numpy()

prec, rec, thr = precision_recall_curve(labels, probs)
best_thr = float(thr[np.argmin(np.abs(prec - rec))])
with open(os.path.join(OUTPUT_DIR, "best_thr.txt"), "w") as f:
    f.write(str(best_thr))
print(f"Optimal threshold saved: {best_thr:.4f}")

# --- ROC curve --- #
fpr, tpr, roc_thr = roc_curve(labels, probs)
roc_auc = auc(fpr, tpr)
idx_point = np.argmin(np.abs(roc_thr - best_thr))
plt.figure(figsize=(6,6))
plt.plot(fpr, tpr, label=f"ROC (AUC = {roc_auc:.3f})")
plt.scatter([fpr[idx_point]], [tpr[idx_point]], c='red', s=50, label=f"thr={best_thr:.3f}")
plt.plot([0,1], [0,1], 'k--', alpha=0.4)
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve")
plt.legend(loc="lower right")
plt.grid(True)
plt.tight_layout()
plt.show()

# --- Inference --- #
print("\nRunning inference on crops …")
model = ViTForImageClassification.from_pretrained(os.path.join(OUTPUT_DIR, "best_model")).to(DEVICE).eval()
best_thr = float(open(os.path.join(OUTPUT_DIR, "best_thr.txt")).read())

for fn in os.listdir(INFER_DIR):
    if not fn.lower().endswith((".png", ".jpg", ".jpeg", ".bmp", ".tiff")):
        continue
    img = Image.open(os.path.join(INFER_DIR, fn)).convert("RGB")
    with torch.no_grad(), autocast(**AMP_KW):
        arr = np.array(img)
        tensor = val_albu(image=arr)["image"].unsqueeze(0).to(DEVICE)
        logit = model(pixel_values=tensor).logits
    prob = logit.softmax(1)[0, 1].item()
    pred = "bats" if prob >= best_thr else "background"
    print(f"{fn}: {pred} (Pbat={prob:.3f})")

# --- Final test evaluation --- #
print("\nEvaluating on independent TEST_DIR ...")
test_set = AlbumentationsFolder(TEST_DIR, transform=val_albu)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE*2, shuffle=False, num_workers=os.cpu_count() or 2, pin_memory=True, persistent_workers=True)

model.eval()
test_logits, test_labels = [], []
with torch.no_grad():
    for px, y in tqdm(test_loader, desc="Test", unit="batch", leave=False):
        px, y = px.to(DEVICE), y.to(DEVICE)
        with autocast(**AMP_KW):
            out = model(pixel_values=px).logits
        test_logits.append(out.cpu())
        test_labels.append(y.cpu())

test_logits = torch.cat(test_logits)
test_labels = torch.cat(test_labels)
test_probs = test_logits.softmax(1)[:, 1]
test_preds = (test_probs >= best_thr).int()

precision = precision_score(test_labels, test_preds, average="macro")
recall    = recall_score(test_labels, test_preds, average="macro")
f1        = f1_score(test_labels, test_preds, average="macro")
fpr, tpr, _ = roc_curve(test_labels, test_probs)
roc_auc = auc(fpr, tpr)

print(f"\n TEST SET PERFORMANCE:\nPrecision: {precision:.4f} │ Recall: {recall:.4f} │ F1_macro: {f1:.4f} │ AUC: {roc_auc:.4f}")

plt.figure(figsize=(6,6))
plt.plot(fpr, tpr, label=f"ROC AUC = {roc_auc:.3f}")
plt.plot([0,1], [0,1], 'k--', alpha=0.3)
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve - TEST set")
plt.legend(loc="lower right")
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:
!cp /content/drive/MyDrive/Final_train.zip /content
!unzip /content/Final_train.zip -d /content > /dev/null

In [None]:
# ===========================  bat_classifier.py (ViT-base)  ===========================
"""
ViT-base-patch16-224 bat / background classifier
• Mixed-precision via torch.amp.autocast
• torch.compile(), fast DataLoader, in-file Focal-Loss
• Trained on the entire Final_train dataset
"""

import os
import random
import json
import math

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import ImageFolder

from PIL import Image
from collections import Counter
from tqdm.auto import tqdm

from torch.utils.data import DataLoader, WeightedRandomSampler
from transformers import ViTForImageClassification
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR

from sklearn.metrics import (
    precision_score, recall_score, f1_score,
    precision_recall_curve, roc_curve, auc
)

import matplotlib.pyplot as plt

from timm.data.mixup import Mixup

import albumentations as A
from albumentations.pytorch import ToTensorV2

# ── Reproducibility ────────────────────────────────────────────────────────────────
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark     = True
torch.set_float32_matmul_precision("high")

# ── Config ─────────────────────────────────────────────────────────────────────────
DATA_DIR      = "/content/Final_train"
TRAIN_DIR     = DATA_DIR
OUTPUT_DIR    = "outputs"
os.makedirs(OUTPUT_DIR, exist_ok=True)

BATCH_SIZE    = 32
NUM_EPOCHS    = 10
LEARNING_RATE = 3e-5
IMAGE_SIZE    = 224

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ── AMP Setup ──────────────────────────────────────────────────────────────────────
try:
    from torch.amp import autocast, GradScaler
    AMP_KW = dict(device_type="cuda")
except ImportError:
    from torch.cuda.amp import autocast, GradScaler
    AMP_KW = {}
scaler = GradScaler()

# ── MixUp ───────────────────────────────────────────────────────────────────────────
mixup_fn = Mixup(
    num_classes=2,
    mixup_alpha=0.2,
    cutmix_alpha=1.0,
    prob=1.0,
    switch_prob=0.5,
    mode="batch",
    label_smoothing=0.0
)

# ── Albumentations pipelines ───────────────────────────────────────────────────────
train_albu = A.Compose([
    A.RandomResizedCrop(size=(IMAGE_SIZE, IMAGE_SIZE), scale=(0.5, 1.0), ratio=(0.9, 1.1), p=1.0),
    A.HorizontalFlip(p=0.5),
    A.Rotate(limit=15, p=0.3),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.3),
    A.RandomGamma(gamma_limit=(80, 120), p=0.3),
    A.GaussNoise(std_range=(0.04, 0.20), p=0.2),
    A.OneOf([
        A.MotionBlur(blur_limit=5, p=1.0),
        A.MedianBlur(blur_limit=5, p=1.0),
        A.Blur(blur_limit=5, p=1.0),
    ], p=0.2),
    A.CLAHE(p=0.1),
    A.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
    ToTensorV2(),
])

# ── Dataset wrapper ────────────────────────────────────────────────────────────────
class AlbumentationsDataset(torch.utils.data.Dataset):
    def __init__(self, folder, albu_transform):
        self.ds   = ImageFolder(folder, transform=None)
        self.albu = albu_transform

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        img, lbl = self.ds[idx]
        augmented = self.albu(image=np.array(img))
        return augmented["image"], lbl

# ── DataLoader ────────────────────────────────────────────────────────────────────
train_set = AlbumentationsDataset(TRAIN_DIR, train_albu)

# compute class‐balanced sampler
targets       = train_set.ds.targets
class_cnt     = Counter(targets)
cnt           = torch.tensor([class_cnt[i] for i in range(len(class_cnt))], dtype=torch.float)
class_weights = (1. / cnt)
class_weights /= class_weights.sum()
sample_weights = class_weights[targets]

sampler = WeightedRandomSampler(
    sample_weights, num_samples=len(train_set), replacement=True
)

num_workers = os.cpu_count() or 2
train_loader = DataLoader(
    train_set, batch_size=BATCH_SIZE, sampler=sampler,
    num_workers=num_workers, pin_memory=True,
    persistent_workers=True, prefetch_factor=4,
    drop_last=True,
)

# ── Focal Loss ────────────────────────────────────────────────────────────────────
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction="mean"):
        super().__init__()
        if isinstance(alpha, (list, tuple)):
            alpha = torch.tensor(alpha, dtype=torch.float32)
        self.alpha, self.gamma, self.reduction = alpha, gamma, reduction

    def forward(self, logits: torch.Tensor, targets: torch.Tensor):
        ce = F.cross_entropy(logits, targets, reduction="none")
        pt = torch.exp(-ce)

        if isinstance(self.alpha, torch.Tensor):
            if targets.dtype in (torch.int64, torch.int32):
                a = self.alpha.to(logits.device)[targets]
            else:
                a = self.alpha.mean().to(logits.device)
        else:
            a = self.alpha

        loss = a * (1 - pt) ** self.gamma * ce
        return loss.mean() if self.reduction == "mean" else loss.sum()

# ── Model, optimizer, scheduler ─────────────────────────────────────────────────
# load pretrained ViT
base_model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224",
    num_labels=2,
    ignore_mismatched_sizes=True
)

# freeze early transformer layers
total_blocks = len(base_model.vit.encoder.layer)
freeze_upto  = int(total_blocks * 2 / 3)
for i, block in enumerate(base_model.vit.encoder.layer):
    if i < freeze_upto:
        for p in block.parameters():
            p.requires_grad = False

# replace classification head
embed_dim = base_model.classifier.in_features
base_model.classifier = nn.Sequential(
    nn.Dropout(p=0.5),
    nn.Linear(embed_dim, 2)
)

# move to device & compile
try:
    model = torch.compile(base_model.to(DEVICE))
except Exception:
    model = base_model.to(DEVICE)

# loss, optimizer, scheduler
alpha     = torch.tensor([0.05, 0.95])
loss_fn   = FocalLoss(alpha=alpha, gamma=2.0).to(DEVICE)
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS * len(train_loader))

# ── Training Loop ────────────────────────────────────────────────────────────────
for epoch in range(1, NUM_EPOCHS + 1):
    model.train()
    train_loss = 0.0

    for xb, yb in tqdm(train_loader, desc=f"Train {epoch}/{NUM_EPOCHS}", unit="batch"):
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        xb, yb = mixup_fn(xb, yb)

        with autocast(**AMP_KW):
            logits = model(xb).logits
            loss   = loss_fn(logits, yb)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        train_loss += loss.item()

    avg_train = train_loss / len(train_loader)
    print(f"Epoch {epoch:02d} │ Train Loss {avg_train:.4f}")

    # save checkpoint each epoch
    torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, f"model_epoch{epoch}.pth"))

# ── Save final model ──────────────────────────────────────────────────────────────
torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, "best_model.pth"))
print("Training complete. Model saved to", OUTPUT_DIR)
