In [1]:
!mkdir -p /content/data
%cd /content/data

# Endless Forams combined training set
!wget -O Endless_Forams_training_set.zip \
  "https://zenodo.org/records/3996436/files/Endless_Forams_training_set.zip?download=1"
!unzip -q Endless_Forams_training_set.zip

# Core MD022508 training set
!wget -O MD022508_training_set.zip \
  "https://zenodo.org/records/3996436/files/MD022508_training_set.zip?download=1"
!unzip -q MD022508_training_set.zip

# Core MD972138 training set
!wget -O MD972138_training_set.zip \
  "https://zenodo.org/records/3996436/files/MD972138_training_set.zip?download=1"
!unzip -q MD972138_training_set.zip


/content/data
--2025-12-15 06:01:07--  https://zenodo.org/records/3996436/files/Endless_Forams_training_set.zip?download=1
Resolving zenodo.org (zenodo.org)... 188.185.43.153, 188.185.48.75, 137.138.52.235, ...
Connecting to zenodo.org (zenodo.org)|188.185.43.153|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 714725290 (682M) [application/octet-stream]
Saving to: ‘Endless_Forams_training_set.zip’


2025-12-15 06:13:37 (932 KB/s) - ‘Endless_Forams_training_set.zip’ saved [714725290/714725290]

--2025-12-15 06:13:47--  https://zenodo.org/records/3996436/files/MD022508_training_set.zip?download=1
Resolving zenodo.org (zenodo.org)... 137.138.52.235, 188.185.43.153, 188.185.48.75, ...
Connecting to zenodo.org (zenodo.org)|137.138.52.235|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 824795913 (787M) [application/octet-stream]
Saving to: ‘MD022508_training_set.zip’


2025-12-15 06:18:14 (2.96 MB/s) - ‘MD022508_training_set.zip’ saved [824

In [None]:
# -*- coding: utf-8 -*-


!pip install -q timm torchmetrics albumentations

import os
import math
import random
from typing import Optional, Dict, Any, List

import numpy as np
from PIL import Image, ImageDraw

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

import timm
from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score


class CFG:
    # labels (will be overridden after scanning Endless Forams)
    NUM_FINE   = 35   # placeholder; updated dynamically
    NUM_COARSE = 2    # simple hierarchy: first half vs second half

    # training
    BATCH_SIZE = 64
    NUM_WORKERS = 2   # set 0 if debugging DataLoader errors
    IMAGE_SIZE = 224
    EPOCHS_PRETRAIN = 2        # small for demo
    EPOCHS_FINETUNE = 2
    LR_PRETRAIN = 3e-4
    LR_FINETUNE = 1e-4
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

    # losses / weights
    LAMBDA_FINE_CE = 1.0
    LAMBDA_METRIC  = 0.1
    TEMP_CONTRAST  = 0.1

    # pretraining multi-task weights
    ALPHA_SUP = 1.0
    BETA_MAE  = 0.5
    GAMMA_AUX = 0.2

    # active learning
    MC_DROPOUT_SAMPLES = 8
    ACTIVE_TOP_K = 200

cfg = CFG()
print("Using device:", cfg.DEVICE)


def supervised_contrastive_loss(embeddings: torch.Tensor,
                                labels: torch.Tensor,
                                temperature: float = 0.1):
    """
    embeddings: (B, D)
    labels:     (B,)
    """
    device = embeddings.device
    z = F.normalize(embeddings, dim=-1)
    logits = torch.matmul(z, z.T) / temperature  # (B, B)
    labels = labels.view(-1, 1)
    mask = torch.eq(labels, labels.T).float().to(device)

    # remove self-comparisons
    logits_mask = torch.ones_like(mask) - torch.eye(mask.size(0), device=device)
    mask = mask * logits_mask

    # log-softmax over rows
    logits_max, _ = torch.max(logits, dim=1, keepdim=True)
    logits = logits - logits_max.detach()
    exp_logits = torch.exp(logits) * logits_mask
    log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-12)

    pos_count = mask.sum(1)
    pos_count = torch.clamp(pos_count, min=1.0)
    mean_log_prob_pos = (mask * log_prob).sum(1) / pos_count

    loss = -mean_log_prob_pos.mean()
    return loss


class DFHViT(nn.Module):
    def __init__(self,
                 backbone_name: str = "vit_base_patch16_224",
                 num_fine: int = 10,
                 num_coarse: int = 2,
                 metadata_dim: Optional[int] = None):
        super().__init__()
        self.backbone = timm.create_model(
            backbone_name,
            pretrained=True,
            num_classes=0,     # feature extractor only
            global_pool="avg",
        )
        self.embed_dim = self.backbone.num_features

        self.use_metadata = metadata_dim is not None
        if self.use_metadata:
            self.meta_mlp = nn.Sequential(
                nn.Linear(metadata_dim, self.embed_dim),
                nn.ReLU(),
                nn.Linear(self.embed_dim, self.embed_dim)
            )

        # hierarchical heads
        self.coarse_head = nn.Linear(self.embed_dim, num_coarse)
        self.fine_head = nn.Sequential(
            nn.Linear(self.embed_dim + num_coarse, self.embed_dim),
            nn.ReLU(),
            nn.Linear(self.embed_dim, num_fine)
        )

        # reconstruction head (32x32 RGB)
        self.recon_head = nn.Sequential(
            nn.Linear(self.embed_dim, 3 * 32 * 32),
        )
        # aux head (e.g., rotation prediction, 4 classes)
        self.aux_head = nn.Linear(self.embed_dim, 4)

    def encode(self, images: torch.Tensor, metadata: Optional[torch.Tensor] = None):
        z = self.backbone(images)   # (B, D)
        if self.use_metadata and metadata is not None:
            m = self.meta_mlp(metadata)
            z = z + m
        return z

    def forward(self, images: torch.Tensor, metadata: Optional[torch.Tensor] = None):
        z = self.encode(images, metadata)
        coarse_logits = self.coarse_head(z)
        fine_input = torch.cat([z, coarse_logits], dim=-1)
        fine_logits = self.fine_head(fine_input)

        recon_logits = self.recon_head(z)
        aux_logits   = self.aux_head(z)
        return coarse_logits, fine_logits, z, recon_logits, aux_logits


class SyntheticFossilFractalDataset(Dataset):
    def __init__(self, num_samples: int = 20000, image_size: int = 224, transform=None):
        self.num_samples = num_samples
        self.image_size = image_size
        self.transform = transform

        self.num_coarse = cfg.NUM_COARSE
        self.num_fine   = cfg.NUM_FINE

    def __len__(self):
        return self.num_samples

    def _draw_spumellarian_like(self, draw: ImageDraw.Draw, size: int):
        cx, cy = size // 2, size // 2
        r = random.randint(size//6, size//3)
        draw.ellipse([cx-r, cy-r, cx+r, cy+r], outline="white", width=2)
        for _ in range(random.randint(6, 20)):
            angle = random.uniform(0, 2*math.pi)
            r2 = random.randint(r, size//2)
            x2 = cx + int(r2 * math.cos(angle))
            y2 = cy + int(r2 * math.sin(angle))
            draw.line([cx, cy, x2, y2], fill="white", width=1)

    def _draw_nassellarian_like(self, draw: ImageDraw.Draw, size: int):
        cx, cy = size // 2, size // 2
        top = cy - size//4
        bottom = cy + size//4
        w = size//6
        draw.polygon([(cx, top), (cx-w, bottom), (cx+w, bottom)],
                     outline="white", width=2)
        for i in range(4):
            y = top + (bottom - top) * (i+1) / 5
            draw.line([cx-w+2, y, cx+w-2, y], fill="white", width=1)

    def _draw_diatom_like(self, draw: ImageDraw.Draw, size: int):
        cx, cy = size // 2, size // 2
        w = size//4
        h = size//8
        draw.ellipse([cx-w, cy-h, cx+w, cy+h], outline="white", width=2)
        for _ in range(50):
            x = random.randint(cx-w, cx+w)
            y = random.randint(cy-h, cy+h)
            draw.point((x, y), fill="white")

    def _generate_image_and_labels(self):
        img = Image.new("L", (self.image_size, self.image_size), color=0)
        draw = ImageDraw.Draw(img)

        coarse = random.randint(0, self.num_coarse - 1)
        if coarse == 0:
            self._draw_spumellarian_like(draw, self.image_size)
        elif coarse == 1:
            self._draw_nassellarian_like(draw, self.image_size)
        else:
            self._draw_diatom_like(draw, self.image_size)

        fine = random.randint(0, self.num_fine - 1)
        img = img.convert("RGB")
        return img, coarse, fine

    def __getitem__(self, idx):
        img, coarse, fine = self._generate_image_and_labels()
        if self.transform:
            img = self.transform(img)
        return {
            "image": img,
            "coarse_label": torch.tensor(coarse, dtype=torch.long),
            "fine_label":   torch.tensor(fine,   dtype=torch.long),
        }


class EndlessForamsHierDataset(Dataset):
    """
    Hierarchical dataset for Endless Forams:
    - fine_label: species index (0..num_species-1)
    - coarse_label: simple binary split over species indices
    """
    def __init__(
        self,
        samples,
        class_to_idx,
        transform=None,
    ):
        """
        samples: list of (image_path, fine_label)
        class_to_idx: dict {species_name: index}
        """
        self.samples = samples
        self.class_to_idx = class_to_idx
        self.transform = transform

        self.classes = sorted(class_to_idx, key=lambda k: class_to_idx[k])
        self.num_classes = len(self.classes)

        # simple 2-way hierarchy: first half vs second half of classes
        midpoint = self.num_classes // 2
        self.coarse_map = {
            i: 0 if i < midpoint else 1
            for i in range(self.num_classes)
        }

    @classmethod
    def build_from_root(
        cls,
        root_dir: str,
        transform_train=None,
        transform_val=None,
        val_ratio: float = 0.2,
        seed: int = 42,
    ):
        """
        Scans root_dir/<species_name>/*.jpg and creates
        train/val splits and a shared class_to_idx mapping.
        """
        # discover species (subfolders)
        classes = sorted(
            d for d in os.listdir(root_dir)
            if os.path.isdir(os.path.join(root_dir, d))
        )
        class_to_idx = {name: i for i, name in enumerate(classes)}

        # collect (path, fine_label)
        all_samples = []
        for name in classes:
            class_dir = os.path.join(root_dir, name)
            for fname in os.listdir(class_dir):
                if fname.lower().endswith((".jpg", ".jpeg", ".png")):
                    full_path = os.path.join(class_dir, fname)
                    all_samples.append((full_path, class_to_idx[name]))

        indices = np.arange(len(all_samples))
        rng = np.random.RandomState(seed)
        rng.shuffle(indices)

        split = int((1.0 - val_ratio) * len(indices))
        train_idx = indices[:split]
        val_idx   = indices[split:]

        train_samples = [all_samples[i] for i in train_idx]
        val_samples   = [all_samples[i] for i in val_idx]

        train_ds = cls(train_samples, class_to_idx, transform_train)
        val_ds   = cls(val_samples,   class_to_idx, transform_val)
        return train_ds, val_ds, class_to_idx

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, fine = self.samples[idx]
        img = Image.open(path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        coarse = self.coarse_map[fine]
        return {
            "image": img,
            "coarse_label": torch.tensor(coarse, dtype=torch.long),
            "fine_label":   torch.tensor(fine,   dtype=torch.long),
        }


class EndlessForamsUnlabeledPool(Dataset):
    """
    Unlabeled pool built from MD022508 and MD972138 training sets.
    Returns only image + index (no labels), for active learning / pseudo-labeling.
    """
    def __init__(self, root_dirs, transform=None):
        """
        root_dirs: list of roots like
           ["/content/data/MD022508_training_set/MD022508_training_set",
            "/content/data/MD972138_training_set/MD972138_training_set"]
        """
        self.transform = transform
        self.paths = []

        for root_dir in root_dirs:
            if not os.path.isdir(root_dir):
                continue
            species_dirs = [
                d for d in os.listdir(root_dir)
                if os.path.isdir(os.path.join(root_dir, d))
            ]
            for s in species_dirs:
                class_dir = os.path.join(root_dir, s)
                for fname in os.listdir(class_dir):
                    if fname.lower().endswith((".jpg", ".jpeg", ".png")):
                        self.paths.append(os.path.join(class_dir, fname))

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx]
        img = Image.open(path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return {
            "image": img,
            "index": idx,
        }


# ===================== Transforms ============================
train_transform = transforms.Compose([
    transforms.Resize((cfg.IMAGE_SIZE, cfg.IMAGE_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(cfg.IMAGE_SIZE, padding=4),
    transforms.ToTensor(),
])

val_transform = transforms.Compose([
    transforms.Resize((cfg.IMAGE_SIZE, cfg.IMAGE_SIZE)),
    transforms.ToTensor(),
])


# ===================== Dataset paths =========================
ENDLESS_ROOT   = "/content/data/Endless_Forams_training_set"
MD022508_ROOT  = "/content/data/MD022508_training_set"
MD972138_ROOT  = "/content/data/MD972138_training_set"

assert os.path.isdir(ENDLESS_ROOT), f"Endless Forams root not found: {ENDLESS_ROOT}"


# Build train/val splits for Endless Forams training set
train_real_dataset, val_real_dataset, class_to_idx = EndlessForamsHierDataset.build_from_root(
    root_dir=ENDLESS_ROOT,
    transform_train=train_transform,
    transform_val=val_transform,
    val_ratio=0.2,
    seed=42,
)

# Update config based on discovered classes
cfg.NUM_FINE   = len(class_to_idx)
cfg.NUM_COARSE = 2  # we defined a 2-way hierarchy above

print(f"NUM_FINE (species classes) = {cfg.NUM_FINE}")
print(f"NUM_COARSE (coarse groups) = {cfg.NUM_COARSE}")

train_real_loader = DataLoader(
    train_real_dataset,
    batch_size=cfg.BATCH_SIZE,
    shuffle=True,
    num_workers=cfg.NUM_WORKERS,
    pin_memory=True,
)

val_real_loader = DataLoader(
    val_real_dataset,
    batch_size=cfg.BATCH_SIZE,
    shuffle=False,
    num_workers=cfg.NUM_WORKERS,
    pin_memory=True,
)

# Synthetic pretraining data (after cfg.NUM_FINE is updated)
synthetic_dataset = SyntheticFossilFractalDataset(
    num_samples=10000,
    image_size=cfg.IMAGE_SIZE,
    transform=train_transform,
)
synthetic_loader = DataLoader(
    synthetic_dataset,
    batch_size=cfg.BATCH_SIZE,
    shuffle=True,
    num_workers=cfg.NUM_WORKERS,
    pin_memory=True,
)

# Unlabeled pool for active learning demo
unlabeled_pool_dataset = EndlessForamsUnlabeledPool(
    root_dirs=[MD022508_ROOT, MD972138_ROOT],
    transform=val_transform,
)

print(f"Train images (Endless Forams): {len(train_real_dataset)}")
print(f"Val images   (Endless Forams): {len(val_real_dataset)}")
print(f"Unlabeled pool images (MD022508 + MD972138): {len(unlabeled_pool_dataset)}")


def pretrain_epoch(model: DFHViT,
                   dataloader: DataLoader,
                   optimizer: torch.optim.Optimizer,
                   device: str = "cuda"):
    model.train()
    total_loss = 0.0
    for batch in dataloader:
        images = batch["image"].to(device)
        coarse_labels = batch["coarse_label"].to(device)
        fine_labels   = batch["fine_label"].to(device)

        optimizer.zero_grad()
        coarse_logits, fine_logits, z, recon_logits, aux_logits = model(images, metadata=None)

        loss_coarse = F.cross_entropy(coarse_logits, coarse_labels)
        loss_fine   = F.cross_entropy(fine_logits, fine_labels)
        L_sup = loss_coarse + loss_fine

        with torch.no_grad():
            img_small = F.interpolate(images, size=(32, 32),
                                      mode="bilinear", align_corners=False)
        img_target = img_small.view(images.size(0), -1)
        L_mae = F.mse_loss(recon_logits, img_target)

        # dummy aux loss: predict rotation class using fine_labels % 4
        L_aux = F.cross_entropy(aux_logits, fine_labels % 4)

        L_pre = cfg.ALPHA_SUP * L_sup + cfg.BETA_MAE * L_mae + cfg.GAMMA_AUX * L_aux
        L_pre.backward()
        optimizer.step()
        total_loss += L_pre.item() * images.size(0)

    return total_loss / len(dataloader.dataset)


def finetune_epoch(model: DFHViT,
                   dataloader: DataLoader,
                   optimizer: torch.optim.Optimizer,
                   device: str = "cuda"):
    model.train()
    total_loss = 0.0
    for batch in dataloader:
        images = batch["image"].to(device)
        coarse_labels = batch["coarse_label"].to(device)
        fine_labels   = batch["fine_label"].to(device)

        optimizer.zero_grad()
        coarse_logits, fine_logits, z, recon_logits, aux_logits = model(images, metadata=None)

        loss_coarse = F.cross_entropy(coarse_logits, coarse_labels)
        loss_fine   = F.cross_entropy(fine_logits, fine_labels)
        L_ce = loss_coarse + cfg.LAMBDA_FINE_CE * loss_fine

        L_metric = supervised_contrastive_loss(z, fine_labels, temperature=cfg.TEMP_CONTRAST)
        loss = L_ce + cfg.LAMBDA_METRIC * L_metric

        loss.backward()
        optimizer.step()
        total_loss += loss.item() * images.size(0)

    return total_loss / len(dataloader.dataset)


def evaluate(model: DFHViT,
             dataloader: DataLoader,
             device: str = "cuda"):
    model.eval()
    coarse_acc = MulticlassAccuracy(num_classes=cfg.NUM_COARSE).to(device)
    fine_acc   = MulticlassAccuracy(num_classes=cfg.NUM_FINE).to(device)
    fine_f1    = MulticlassF1Score(num_classes=cfg.NUM_FINE, average="macro").to(device)

    total_loss = 0.0
    with torch.no_grad():
        for batch in dataloader:
            images = batch["image"].to(device)
            coarse_labels = batch["coarse_label"].to(device)
            fine_labels   = batch["fine_label"].to(device)

            coarse_logits, fine_logits, z, recon_logits, aux_logits = model(images, metadata=None)

            loss_coarse = F.cross_entropy(coarse_logits, coarse_labels)
            loss_fine   = F.cross_entropy(fine_logits, fine_labels)
            L_ce = loss_coarse + cfg.LAMBDA_FINE_CE * loss_fine
            total_loss += L_ce.item() * images.size(0)

            coarse_acc.update(coarse_logits, coarse_labels)
            fine_acc.update(fine_logits, fine_labels)
            fine_f1.update(fine_logits, fine_labels)

    n = len(dataloader.dataset)
    metrics = {
        "loss": total_loss / n,
        "coarse_acc": coarse_acc.compute().item(),
        "fine_acc":   fine_acc.compute().item(),
        "fine_f1":    fine_f1.compute().item(),
    }
    return metrics


def predictive_entropy_mc_dropout(model: DFHViT,
                                  images: torch.Tensor,
                                  metadata: Optional[torch.Tensor],
                                  mc_samples: int = 8,
                                  device: str = "cuda"):
    model.train()  # keep dropout active
    images = images.to(device)

    probs_list = []
    with torch.no_grad():
        for _ in range(mc_samples):
            _, fine_logits, _, _, _ = model(images, metadata)
            probs = F.softmax(fine_logits, dim=-1)
            probs_list.append(probs)

    probs_stack = torch.stack(probs_list, dim=0)   # (T, B, K)
    probs_mean = probs_stack.mean(dim=0)           # (B, K)
    entropy = -(probs_mean * (probs_mean.clamp_min(1e-12).log())).sum(dim=-1)
    return entropy.cpu()


def select_uncertain_samples(model: DFHViT,
                             pool_dataset: Dataset,
                             batch_size: int = 128,
                             top_k: int = 200,
                             device: str = "cuda"):
    model.to(device)
    model.eval()  # we'll flip to train inside entropy fn

    pool_loader = DataLoader(pool_dataset, batch_size=batch_size,
                             shuffle=False, num_workers=cfg.NUM_WORKERS)

    all_indices = []
    all_entropy = []
    for batch in pool_loader:
        images = batch["image"]
        idxs = batch["index"]
        entropy = predictive_entropy_mc_dropout(
            model, images, metadata=None,
            mc_samples=cfg.MC_DROPOUT_SAMPLES,
            device=device,
        )
        all_indices.extend(idxs.numpy().tolist())
        all_entropy.extend(entropy.numpy().tolist())

    all_indices = np.array(all_indices)
    all_entropy = np.array(all_entropy)

    top_k = min(top_k, len(all_indices))
    selected_idx = np.argpartition(-all_entropy, top_k-1)[:top_k]
    pool_indices_selected = all_indices[selected_idx]
    ent_selected = all_entropy[selected_idx]
    return pool_indices_selected, ent_selected


def run_training():
    device = cfg.DEVICE
    model = DFHViT(
        backbone_name="vit_base_patch16_224",
        num_fine=cfg.NUM_FINE,
        num_coarse=cfg.NUM_COARSE,
        metadata_dim=None,
    ).to(device)

    print("----- Pretraining on synthetic fractal data -----")
    optimizer_pre = torch.optim.Adam(model.parameters(), lr=cfg.LR_PRETRAIN)
    for epoch in range(cfg.EPOCHS_PRETRAIN):
        loss_pre = pretrain_epoch(model, synthetic_loader, optimizer_pre, device=device)
        print(f"[Pretrain] Epoch {epoch+1}/{cfg.EPOCHS_PRETRAIN} - loss: {loss_pre:.4f}")

    print("\n----- Fine-tuning on Endless Forams -----")
    optimizer_ft = torch.optim.Adam(model.parameters(), lr=cfg.LR_FINETUNE)
    best_f1 = 0.0
    for epoch in range(cfg.EPOCHS_FINETUNE):
        train_loss = finetune_epoch(model, train_real_loader, optimizer_ft, device=device)
        val_metrics = evaluate(model, val_real_loader, device=device)
        print(f"[Finetune] Epoch {epoch+1}/{cfg.EPOCHS_FINETUNE} "
              f"train_loss={train_loss:.4f} "
              f"val_loss={val_metrics['loss']:.4f} "
              f"coarse_acc={val_metrics['coarse_acc']:.4f} "
              f"fine_acc={val_metrics['fine_acc']:.4f} "
              f"fine_f1={val_metrics['fine_f1']:.4f}")

        if val_metrics["fine_f1"] > best_f1:
            best_f1 = val_metrics["fine_f1"]
            torch.save(model.state_dict(), "/content/dfh_vit_forams_best.pth")
            print("  -> saved best model (F1 = {:.4f})".format(best_f1))
    return model


# ------------------- Run training ----------------------------
model = run_training()

# After training, you can run one active learning round:
selected_indices, entropies = select_uncertain_samples(
     model, unlabeled_pool_dataset,
     top_k=cfg.ACTIVE_TOP_K,
     device=cfg.DEVICE,
)
print("Selected pool indices (most uncertain):", selected_indices[:20])


# === One-shot evaluation & visualization cell ===
# This assumes CFG, DFHViT, val_real_loader, train_real_dataset, and evaluate()
# are already defined and that you saved the best model to:
# "/content/dfh_vit_forams_best.pth"

import matplotlib.pyplot as plt

# Install sklearn if not available (for confusion_matrix)
try:
    from sklearn.metrics import confusion_matrix
except ImportError:
    !pip install -q scikit-learn
    from sklearn.metrics import confusion_matrix

device = cfg.DEVICE

# 1) Load best model
model = DFHViT(
    backbone_name="vit_base_patch16_224",
    num_fine=cfg.NUM_FINE,
    num_coarse=cfg.NUM_COARSE,
    metadata_dim=None,
).to(device)

state_dict = torch.load("/content/dfh_vit_forams_best.pth", map_location=device)
model.load_state_dict(state_dict)
model.eval()

# 2) Overall metrics on validation set
val_metrics = evaluate(model, val_real_loader, device=device)
print("=== Overall Validation Metrics ===")
for k, v in val_metrics.items():
    print(f"{k}: {v:.4f}")

# 3) Collect predictions and labels for confusion matrix & per-class accuracy
all_preds = []
all_labels = []

with torch.no_grad():
    for batch in val_real_loader:
        images = batch["image"].to(device)
        labels = batch["fine_label"].to(device)
        _, fine_logits, _, _, _ = model(images, metadata=None)
        preds = fine_logits.argmax(dim=1)
        all_preds.append(preds.cpu().numpy())
        all_labels.append(labels.cpu().numpy())

all_preds = np.concatenate(all_preds)
all_labels = np.concatenate(all_labels)

# 4) Confusion matrix
cm = confusion_matrix(all_labels, all_preds, labels=list(range(cfg.NUM_FINE)))

plt.figure(figsize=(8, 8))
plt.imshow(cm, interpolation="nearest")
plt.title("Confusion Matrix (Fine Labels)")
plt.colorbar()
plt.xlabel("Predicted")
plt.ylabel("True")
plt.tight_layout()
plt.show()

# 5) Per-class accuracy bar plot
per_class_acc = cm.diagonal() / cm.sum(axis=1).clip(min=1)

# Get class names from EndlessForamsHierDataset
if hasattr(train_real_dataset, "classes"):
    class_names = train_real_dataset.classes
else:
    class_names = [f"cls_{i}" for i in range(cfg.NUM_FINE)]

plt.figure(figsize=(max(10, cfg.NUM_FINE * 0.4), 4))
plt.bar(range(cfg.NUM_FINE), per_class_acc)
plt.xticks(range(cfg.NUM_FINE), class_names, rotation=90, ha="right")
plt.ylim(0, 1.0)
plt.ylabel("Accuracy")
plt.title("Per-Class Accuracy (Fine Labels)")
plt.tight_layout()
plt.show()

# 6) Show a grid of example predictions
def show_example_predictions(model, dataloader, n=16):
    model.eval()
    images_shown = 0
    rows = 4
    cols = 4
    fig, axes = plt.subplots(rows, cols, figsize=(8, 8))
    axes = axes.flatten()

    with torch.no_grad():
        for batch in dataloader:
            images = batch["image"].to(device)
            labels = batch["fine_label"].to(device)
            _, fine_logits, _, _, _ = model(images, metadata=None)
            preds = fine_logits.argmax(dim=1)

            for i in range(images.size(0)):
                if images_shown >= n:
                    plt.tight_layout()
                    plt.show()
                    return
                ax = axes[images_shown]
                img = images[i].cpu().permute(1, 2, 0).numpy()
                true_label = labels[i].item()
                pred_label = preds[i].item()
                ax.imshow(img)
                ax.axis("off")
                t_name = class_names[true_label] if true_label < len(class_names) else str(true_label)
                p_name = class_names[pred_label] if pred_label < len(class_names) else str(pred_label)
                ax.set_title(f"T:{t_name}\nP:{p_name}", fontsize=8)
                images_shown += 1

    plt.tight_layout()
    plt.show()

print("\n=== Example Predictions (first 16 val images) ===")
show_example_predictions(model, val_real_loader, n=16)


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/983.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m983.0/983.2 kB[0m [31m41.0 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.2/983.2 kB[0m [31m27.4 MB/s[0m eta [36m0:00:00[0m
[?25hUsing device: cuda
NUM_FINE (species classes) = 35
NUM_COARSE (coarse groups) = 2
Train images (Endless Forams): 22183
Val images   (Endless Forams): 5546
Unlabeled pool images (MD022508 + MD972138): 28275


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]



----- Pretraining on synthetic fractal data -----
[Pretrain] Epoch 1/2 - loss: 3.9457
[Pretrain] Epoch 2/2 - loss: 3.8564

----- Fine-tuning on Endless Forams -----
[Finetune] Epoch 1/2 train_loss=3.8364 val_loss=3.3174 coarse_acc=0.4999 fine_acc=0.0369 fine_f1=0.0175
  -> saved best model (F1 = 0.0175)
[Finetune] Epoch 2/2 train_loss=3.6429 val_loss=3.2632 coarse_acc=0.5367 fine_acc=0.0428 fine_f1=0.0269
  -> saved best model (F1 = 0.0269)
