In [3]:
"""
Contrastive training (CLIP‑style) with a UNI2 image encoder **and** the
**CoCa text encoder** from *OpenCLIP* (matching OmiCLIP’s causal masking
transformer).

Key points
==========
* **Image branch**  – UNI2 model that pools the 4 centre spatial tokens
  per cell.
* **Text branch**   – CoCa causal text encoder (pre‑trained on LAION‑2B)
  loaded via *open_clip*.
* **Projection heads** map both modalities → shared dim (`proj_dim`).
* **InfoNCE loss**   – symmetric (image→text & text→image).

Dependencies
------------
```bash
pip install open_clip_torch transformers timm opencv-python pillow
```
"""

from __future__ import annotations

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from typing import List
import timm
import pandas as pd 
import argparse 
import open_clip           # ⇦ CoCa lives here
from transformers import AutoTokenizer
from torch.utils.data import Dataset, DataLoader
from typing import List, Optional
from PIL import Image
import openslide
import scanpy as sc
import glob
import math
import csv, time, math
from contextlib import nullcontext
from tqdm import tqdm

In [2]:
# -----------------------------------------------------------------------------
# Argparse → CFG
# -----------------------------------------------------------------------------
# parser = argparse.ArgumentParser(description="CLIP-GO contrastive training (UNI2 + CoCa)")
# parser.add_argument("--cancer", type=str, default="lung",
#                     choices=["lung","breast","lymph_node","prostate","skin","ovarian","cervical"]) 
# parser.add_argument("--epochs", type=int, default=10)
# parser.add_argument("--run_name", type=str, default="clipgo_run")
# parser.add_argument("--freeze_text", action="store_true")
# parser.add_argument("--batch_size", type=int, default=72)
# parser.add_argument("--num_workers", type=int, default=8)
# parser.add_argument("--lr", type=float, default=1e-4)
# parser.add_argument("--weight_decay", type=float, default=1e-3)
# parser.add_argument("--proj_dim", type=int, default=256)
# parser.add_argument("--dropout", type=float, default=0.1)
# parser.add_argument("--context_len", type=int, default=76)
# parser.add_argument("--target_mpp", type=float, default=0.5, help="target µm/px (≈20x)")
# parser.add_argument("--go_sentence_col", type=str, default="go_sentence")
# args = parser.parse_args() if __name__ != "__main__" else parser.parse_args([])


class CFG:
    # data
    cancer = "lung"
    ground_truth = "refined"            # dataset variant
    level = 0                            # UNI2 spatial‑token level
    batch_size = 72
    num_workers = 8

    # optimisation
    temperature = 1.0
    patience = 2.0
    projection_dim = 256
    lr = 1e-4
    weight_decay = 1e-3
    dropout = 0.1
    epochs = 50

    # Embeddings
    morph_emb_dims = 1536
    patch_size = 224

    # Text / CoCa
    coca_model = "coca_ViT-L-14"
    coca_pretrain = "laion2B-s13b-b90k"
    context_len =76
    freeze_text = True
    # go_sentence_col = args.go_sentence_col

    # paths
    root = "/rsrch9/home/plm/idso_fa1_pathology/TIER1/paul-xenium/public_data/10x_genomics"
    xenium_sample_dict = {
        "lung":"Xenium_Prime_Human_Lung_Cancer_FFPE_outs",
        "breast": "Xenium_Prime_Breast_Cancer_FFPE_outs",
        "lymph_node": "Xenium_Prime_Human_Lymph_Node_Reactive_FFPE_outs",
        "prostate": "Xenium_Prime_Human_Prostate_FFPE_outs",
        "skin": "Xenium_Prime_Human_Skin_FFPE_outs",
        "ovarian": "Xenium_Prime_Ovarian_Cancer_FFPE_outs",
        "cervical": "Xenium_Prime_Cervical_Cancer_FFPE_outs",
    }
    go_dir = "/rsrch9/home/plm/idso_fa1_pathology/TIER2/paul-xenium/gene_ontology"
    model_dir = "/rsrch9/home/plm/idso_fa1_pathology/TIER2/paul-xenium/models/public/UNI2-h"  # pretrained UNI2 weights
    ckpt_dir = f"/rsrch9/home/plm/idso_fa1_pathology/TIER2/paul-xenium/models/fine_tuned/GoCLIP/{cancer}"  # outputs
    target_mpp = 0.5



In [12]:
# -----------------------------------------------------------------------------
# Projection head
# -----------------------------------------------------------------------------
class ProjectionHead(nn.Module):
    def __init__(self, in_dim: int, proj_dim: int = 256):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, proj_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(proj_dim, proj_dim)
        self.ln  = nn.LayerNorm(proj_dim)

    def forward(self, x):
        h = self.fc1(x)
        x = self.act(h)
        x = self.fc2(x)
        return self.ln(x + h)

# -----------------------------------------------------------------------------
# UNI2 wrapper (4-centre token pooling)
# -----------------------------------------------------------------------------
class UNI2Wrapper(nn.Module):
    def __init__(self, uni2: nn.Module, centre_idx: List[int]):
        super().__init__()
        self.uni2 = uni2
        self.centre_idx = torch.tensor(centre_idx, dtype=torch.long)
        self.prefix_tokens = 9

    def forward(self, x: torch.Tensor):
        tok = self.uni2.forward_features(x)               # (B, 265, 1536)
        spatial = tok[:, self.prefix_tokens :, :]
        centre  = spatial.index_select(1, self.centre_idx.to(x.device)).mean(1)
        return centre                                     # (B,1536)

# -----------------------------------------------------------------------------
# Dataset: cell-centred patch + GO sentence
# -----------------------------------------------------------------------------
class CellPatchTextDataset(Dataset):
    def __init__(
        self,
        slide,                        # OpenSlide object with read_region
        cell_df: pd.DataFrame,        # index=cell_id, has x_centroid, y_centroid
        sentences: pd.Series,         # index=cell_id, value=string sentence
        transform,                    # torchvision transforms
        scale: float,                 # px-per-px scaling for target magnification
        patch_size: int = 224,
    ):
        self.slide = slide
        self.cells = cell_df.reset_index(drop=False)   # keep cell_id in column 'index'
        self.sentences = sentences
        self.tfm = transform
        self.scale = scale
        self.patch_size = patch_size

    def __len__(self):
        return len(self.cells)

    def _read_patch(self, x, y):
        big = int(self.patch_size * self.scale)
        tlx, tly = int(x - big/2), int(y - big/2)
        patch = self.slide.read_region((tlx, tly), 0, (big, big)).convert("RGB")
        return patch.resize((self.patch_size, self.patch_size), Image.LANCZOS)

    def __getitem__(self, idx):
        row   = self.cells.iloc[idx]
        patch = self._read_patch(row.x_centroid, row.y_centroid)
        img_t = self.tfm(patch)
        cell_id = row["index"]
        sent = self.sentences.loc[cell_id]
        sent = "" if pd.isna(sent) else str(sent)
        return {"image": img_t, "text": sent}

# -----------------------------------------------------------------------------
# CLIP-GO: UNI2 + CoCa text encoder
# -----------------------------------------------------------------------------
class CLIPGO(nn.Module):
    def __init__(
        self,
        vision_backbone: nn.Module,
        coca_model: str = CFG.coca_model,
        coca_pretrain: str = CFG.coca_pretrain,
        proj_dim: int = CFG.projection_dim,
        context_len: int = CFG.context_len,
        freeze_text: bool = CFG.freeze_text,
    ):
        super().__init__()
        self.context_len = context_len
        self.freeze_text = freeze_text

        # ---- Vision branch ----
        self.vision_encoder = vision_backbone
        vision_dim = vision_backbone.uni2.embed_dim  # 1536 for ViT-Giant
        self.vision_proj  = ProjectionHead(vision_dim, proj_dim)

        # ---- Text branch (CoCa causal encoder) ----
        self.text_encoder, _, _ = open_clip.create_model_and_transforms(
            coca_model, pretrained=coca_pretrain,
            cache_dir=os.path.expanduser("~/.cache/open_clip")
        )
        self.tokenizer = open_clip.get_tokenizer(coca_model)
        if freeze_text:
            for p in self.text_encoder.parameters():
                p.requires_grad_(False)

        # Determine text width dynamically
        with torch.no_grad():
            dummy = self.tokenizer(["dummy"], context_length=context_len)
            txt_feat = self.text_encoder.encode_text(dummy)
        text_dim = txt_feat.shape[-1]
        self.text_proj = ProjectionHead(text_dim, proj_dim)

        # Learnable temperature
        self.logit_scale = nn.Parameter(torch.log(torch.tensor(1/0.07)))

    def encode_image(self, imgs):
        feats = self.vision_encoder(imgs)
        return self.vision_proj(feats)

    def encode_text(self, sentences: List[str]):
        tokens = self.tokenizer(sentences, context_length=self.context_len).to(next(self.parameters()).device)
        if self.freeze_text:
            with torch.no_grad():
                feats = self.text_encoder.encode_text(tokens)
        else:
            feats = self.text_encoder.encode_text(tokens)
        return self.text_proj(feats)

    def forward(self, imgs, sentences):
        img_emb = F.normalize(self.encode_image(imgs), dim=-1)
        txt_emb = F.normalize(self.encode_text(sentences), dim=-1)
        scale   = self.logit_scale.exp()
        logits  = scale * img_emb @ txt_emb.T

        targets = torch.arange(img_emb.size(0), device=img_emb.device)
        loss_i = F.cross_entropy(logits, targets)
        loss_t = F.cross_entropy(logits.T, targets)
        return 0.5 * (loss_i + loss_t), logits, logits.T

In [13]:
# -----------------------------------------------------------------------------
# Helpers: slide/MPP, transforms, config dump
# -----------------------------------------------------------------------------

def get_slide_and_mpp(slide_dir: str):
    # try common names: *_he_image_registered.ome.tif inside sample folder
    tifs = sorted(glob.glob(os.path.join(slide_dir, "**", "*he_image_registered*.ome.tif"), recursive=True))
    if not tifs:
        # fallback: any .tif in root
        tifs = sorted(glob.glob(os.path.join(slide_dir, "**", "*.tif"), recursive=True))
        assert tifs, f"No slide tif found under {slide_dir}"
    slide_path = tifs[0]
    slide = openslide.open_slide(slide_path)

    # mpp search (robust):
    props = slide.properties
    mpp = None
    for key in ("openslide.mpp-x", "aperio.MPP", "tiff.XResolution" ):
        if key in props:
            try:
                mpp = float(props[key])
                break
            except Exception:
                pass
    if mpp is None:
        # 10x fallback ≈ 0.5 µm/px
        mpp = CFG.target_mpp
    return slide, mpp, slide_path


def build_transforms():
    import torchvision.transforms as T
    return T.Compose([
        T.ToTensor(),
        T.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
    ])


def dump_cfg(path: str):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w") as f:
        yaml.safe_dump({k: v for k, v in vars(CFG).items() if not k.startswith("__")}, f, sort_keys=False)

In [14]:
# -----------------------------------------------------------------------------
# Training
# -----------------------------------------------------------------------------
@torch.inference_mode()
def batch_acc(logits):
    target = torch.arange(logits.size(0), device=logits.device)
    pred_i = logits.max(dim=1).indices
    pred_t = logits.max(dim=0).indices
    acc_i = (pred_i == target).float().mean()
    acc_t = (pred_t == target).float().mean()
    return acc_i.item(), acc_t.item()


def train_clipgo(model: CLIPGO, loader: DataLoader, val_loader: DataLoader = None,
                 grad_clip: float = 1.0, amp: bool = True):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    opt = torch.optim.AdamW(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)
    scaler = torch.cuda.amp.GradScaler(enabled=amp)

    os.makedirs(CFG.ckpt_dir, exist_ok=True)
    log_csv = os.path.join(CFG.ckpt_dir, "train_log.csv")
    if not os.path.isfile(log_csv):
        with open(log_csv, "w", newline="") as f:
            csv.writer(f).writerow(["epoch","step","loss","acc_i","acc_t","lr","logit_scale","imgs_per_sec"])

    best = math.inf

    for epoch in range(1, CFG.epochs + 1):
        model.train()
        running, ema_loss = 0.0, None
        epoch_start = time.time()

        # Throughput helpers
        seen = 0
        step_start = time.time()

        pbar = tqdm(loader, desc=f"Epoch {epoch}/{CFG.epochs}", dynamic_ncols=True)
        for step, batch in enumerate(pbar, 1):
            imgs  = batch["image"].to(device, non_blocking=True)
            texts = batch["text"]
            B = imgs.size(0)

            opt.zero_grad(set_to_none=True)
            autocast_ctx = torch.cuda.amp.autocast(enabled=amp) if device.type == "cuda" else nullcontext()

            try:
                with autocast_ctx:
                    loss, logits_it, logits_ti = model(imgs, texts)
                scaler.scale(loss).backward()
                if grad_clip is not None:
                    nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
                scaler.step(opt)
                scaler.update()
            except torch.cuda.OutOfMemoryError:
                if device.type == "cuda":
                    torch.cuda.empty_cache()
                print("[WARN] CUDA OOM: skipping batch")
                continue

            running += loss.item()
            ema_loss = loss.item() if ema_loss is None else 0.9*ema_loss + 0.1*loss.item()
            acc_i, acc_t = batch_acc(logits_it)

            # throughput
            seen += B
            dt = max(time.time() - step_start, 1e-6)
            ips = B / dt
            step_start = time.time()

            # current lr (first group)
            lr = opt.param_groups[0]["lr"]
            logit_scale = float(model.logit_scale.exp().detach().cpu())

            pbar.set_postfix(loss=f"{ema_loss:.4f}", acc_i=f"{acc_i:.3f}",
                             acc_t=f"{acc_t:.3f}", lr=f"{lr:.1e}",
                             τ=f"{1.0/logit_scale:.3f}", ips=f"{ips:.0f}")

            # CSV log (every ~50 steps)
            if step % 50 == 0 or step == len(loader):
                with open(log_csv, "a", newline="") as f:
                    csv.writer(f).writerow([epoch, step, loss.item(), acc_i, acc_t, lr, logit_scale, ips])

        # ----- epoch end -----
        avg = running / max(len(loader), 1)
        epoch_time = time.time() - epoch_start
        print(f"Epoch {epoch:03d} | loss {avg:.4f} | time {epoch_time/60:.1f} min | "
              f"logit_scale {logit_scale:.3f} (τ≈{1.0/logit_scale:.3f})")

        # checkpoints
        torch.save({"epoch": epoch, "model": model.state_dict()},
                   os.path.join(CFG.ckpt_dir, f"epoch_{epoch:03d}.pth"))
        torch.save({"epoch": epoch, "model": model.state_dict()},
                   os.path.join(CFG.ckpt_dir, "last.pth"))
        if avg < best:
            best = avg
            torch.save({"epoch": epoch, "model": model.state_dict()},
                       os.path.join(CFG.ckpt_dir, "best.pth"))
            print("✓ new best")

    dump_cfg(os.path.join(CFG.ckpt_dir, "config.yaml"))
    print("Training complete. Best loss:", best)


In [None]:
# Resolve dataset-specific paths
sample  = CFG.xenium_sample_dict[CFG.cancer]
sample_dir = os.path.join(CFG.root, sample)

# Load AnnData (expects x_centroid / y_centroid in .obs)
adata_path = os.path.join(
    sample_dir,
    "preprocessed",
    f"fine_tune_{CFG.ground_truth}_v2",
    f"processed_xenium_data_fine_tune_{CFG.ground_truth}_v2_annotated.h5ad",
)
adata = sc.read_h5ad(adata_path)
cell_df = adata.obs.copy()  # index = cell_id

# Sentences: external file (index=cell_id, column=go_sentences)
sentences = pd.read_csv(f"{CFG.go_dir}/{sample.replace('outs', 'GO.csv')}", index_col="cell_id")
sentences = sentences["go_sentences"].astype(str)       # ensure it's a Series
sentences = sentences.reindex(cell_df.index)            # align to adata
missing = sentences.isna().sum()
if missing > 0:
    print(f"[WARN] {missing} cells missing sentences; filling with empty strings.")
    sentences = sentences.fillna("")
    
# Slide + scale
slide, mpp, slide_path = get_slide_and_mpp(sample_dir)
scale_factor = max(CFG.target_mpp / float(mpp), 1e-6)

# Build vision backbone (UNI2)
uni2_cfg = {
    'model_name':'vit_giant_patch14_224','img_size':224,'patch_size':14,'depth':24,
    'num_heads':24,'init_values':1e-5,'embed_dim':1536,'mlp_ratio':2.66667*2,
    'num_classes':0,'no_embed_class':True,'mlp_layer':timm.layers.SwiGLUPacked,
    'act_layer':torch.nn.SiLU,'reg_tokens':8,'dynamic_img_size':True,
}

uni2 = timm.create_model(pretrained=False, **uni2_cfg)
# Load weights if you have them
uni2_weights = os.path.join(CFG.model_dir, "pytorch_model.bin")
uni2.load_state_dict(torch.load(os.path.join(uni2_weights), map_location="cpu"), strict=True)


centre_idx = [119, 120, 135, 136] if CFG.level == 0 else [
    102,103,104,105,118,119,120,121,134,135,136,137,150,151,152,153
]
vision_backbone = UNI2Wrapper(uni2, centre_idx)

# Dataset / DataLoader
transform = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
])
# DataLoader – consider drop_last=True for strict CLIP pairing
dataset = CellPatchTextDataset(slide, cell_df, sentences, transform,
                               scale=scale_factor, patch_size=CFG.patch_size)
loader  = DataLoader(dataset,
                     batch_size=CFG.batch_size,
                     shuffle=True,
                     num_workers=CFG.num_workers,
                     pin_memory=True,
                     persistent_workers=True,
                     prefetch_factor=4,
                     drop_last=True)   # <- recommended for contrastive matching

# Model + train
model = CLIPGO(vision_backbone)
os.makedirs(CFG.ckpt_dir, exist_ok=True)
train_clipgo(model, loader)

Epoch 1/50: 100%|██████████| 3398/3398 [22:24<00:00,  2.53it/s, acc_i=0.153, acc_t=0.097, ips=220, loss=3.4355, lr=1.0e-04, τ=0.068]


Epoch 001 | loss 3.5743 | time 22.4 min | logit_scale 14.633 (τ≈0.068)
✓ new best


Epoch 2/50: 100%|██████████| 3398/3398 [21:31<00:00,  2.63it/s, acc_i=0.194, acc_t=0.222, ips=224, loss=3.3200, lr=1.0e-04, τ=0.066]


Epoch 002 | loss 3.3916 | time 21.5 min | logit_scale 15.090 (τ≈0.066)
✓ new best


Epoch 3/50: 100%|██████████| 3398/3398 [22:31<00:00,  2.51it/s, acc_i=0.125, acc_t=0.153, ips=226, loss=3.3351, lr=1.0e-04, τ=0.064] 


Epoch 003 | loss 3.2885 | time 22.5 min | logit_scale 15.626 (τ≈0.064)
✓ new best


Epoch 4/50: 100%|██████████| 3398/3398 [21:37<00:00,  2.62it/s, acc_i=0.153, acc_t=0.208, ips=226, loss=3.1627, lr=1.0e-04, τ=0.061] 


Epoch 004 | loss 3.1783 | time 21.6 min | logit_scale 16.336 (τ≈0.061)
✓ new best


Epoch 5/50: 100%|██████████| 3398/3398 [22:02<00:00,  2.57it/s, acc_i=0.222, acc_t=0.236, ips=225, loss=3.1382, lr=1.0e-04, τ=0.057]


Epoch 005 | loss 3.0321 | time 22.0 min | logit_scale 17.460 (τ≈0.057)
✓ new best


Epoch 6/50: 100%|██████████| 3398/3398 [21:35<00:00,  2.62it/s, acc_i=0.236, acc_t=0.292, ips=227, loss=2.8000, lr=1.0e-04, τ=0.052]


Epoch 006 | loss 2.7805 | time 21.6 min | logit_scale 19.354 (τ≈0.052)
✓ new best


Epoch 7/50: 100%|██████████| 3398/3398 [22:11<00:00,  2.55it/s, acc_i=0.292, acc_t=0.306, ips=229, loss=2.3808, lr=1.0e-04, τ=0.043]


Epoch 007 | loss 2.3386 | time 22.2 min | logit_scale 23.008 (τ≈0.043)
✓ new best


Epoch 8/50: 100%|██████████| 3398/3398 [21:25<00:00,  2.64it/s, acc_i=0.403, acc_t=0.389, ips=227, loss=2.0970, lr=1.0e-04, τ=0.036]


Epoch 008 | loss 1.8095 | time 21.4 min | logit_scale 27.603 (τ≈0.036)
✓ new best


Epoch 9/50: 100%|██████████| 3398/3398 [22:01<00:00,  2.57it/s, acc_i=0.569, acc_t=0.611, ips=228, loss=1.2340, lr=1.0e-04, τ=0.029]


Epoch 009 | loss 0.9870 | time 22.0 min | logit_scale 34.260 (τ≈0.029)
✓ new best


Epoch 10/50: 100%|██████████| 3398/3398 [21:29<00:00,  2.64it/s, acc_i=0.792, acc_t=0.778, ips=228, loss=0.6872, lr=1.0e-04, τ=0.024]


Epoch 010 | loss 0.5523 | time 21.5 min | logit_scale 41.662 (τ≈0.024)
✓ new best


Epoch 11/50: 100%|██████████| 3398/3398 [21:46<00:00,  2.60it/s, acc_i=0.889, acc_t=0.931, ips=228, loss=0.2894, lr=1.0e-04, τ=0.020]


Epoch 011 | loss 0.3049 | time 21.8 min | logit_scale 49.302 (τ≈0.020)
✓ new best


Epoch 12/50: 100%|██████████| 3398/3398 [21:24<00:00,  2.65it/s, acc_i=0.931, acc_t=0.931, ips=224, loss=0.2590, lr=1.0e-04, τ=0.019]


Epoch 012 | loss 0.2946 | time 21.4 min | logit_scale 52.688 (τ≈0.019)
✓ new best


Epoch 13/50: 100%|██████████| 3398/3398 [22:16<00:00,  2.54it/s, acc_i=0.861, acc_t=0.889, ips=228, loss=0.2686, lr=1.0e-04, τ=0.017]


Epoch 013 | loss 0.1750 | time 22.3 min | logit_scale 57.727 (τ≈0.017)
✓ new best


Epoch 14/50: 100%|██████████| 3398/3398 [22:12<00:00,  2.55it/s, acc_i=0.931, acc_t=0.931, ips=228, loss=0.1612, lr=1.0e-04, τ=0.017] 


Epoch 014 | loss 0.2281 | time 22.2 min | logit_scale 59.263 (τ≈0.017)


Epoch 15/50: 100%|██████████| 3398/3398 [21:58<00:00,  2.58it/s, acc_i=0.958, acc_t=0.931, ips=228, loss=0.1314, lr=1.0e-04, τ=0.016]


Epoch 015 | loss 0.1266 | time 22.0 min | logit_scale 63.052 (τ≈0.016)
✓ new best


Epoch 16/50: 100%|██████████| 3398/3398 [21:37<00:00,  2.62it/s, acc_i=0.917, acc_t=0.931, ips=228, loss=0.1242, lr=1.0e-04, τ=0.015]


Epoch 016 | loss 0.1137 | time 21.6 min | logit_scale 65.670 (τ≈0.015)
✓ new best


Epoch 17/50: 100%|██████████| 3398/3398 [21:39<00:00,  2.62it/s, acc_i=0.986, acc_t=0.944, ips=226, loss=0.1138, lr=1.0e-04, τ=0.015]


Epoch 017 | loss 0.1029 | time 21.7 min | logit_scale 67.658 (τ≈0.015)
✓ new best


Epoch 18/50: 100%|██████████| 3398/3398 [22:07<00:00,  2.56it/s, acc_i=0.958, acc_t=0.958, ips=227, loss=0.1070, lr=1.0e-04, τ=0.014]


Epoch 018 | loss 0.0947 | time 22.1 min | logit_scale 69.055 (τ≈0.014)
✓ new best


Epoch 19/50: 100%|██████████| 3398/3398 [21:26<00:00,  2.64it/s, acc_i=0.958, acc_t=0.944, ips=227, loss=0.1086, lr=1.0e-04, τ=0.014]


Epoch 019 | loss 0.0899 | time 21.4 min | logit_scale 70.037 (τ≈0.014)
✓ new best


Epoch 20/50: 100%|██████████| 3398/3398 [21:26<00:00,  2.64it/s, acc_i=1.000, acc_t=0.986, ips=227, loss=0.0870, lr=1.0e-04, τ=0.014]


Epoch 020 | loss 0.0748 | time 21.4 min | logit_scale 71.829 (τ≈0.014)
✓ new best


Epoch 21/50: 100%|██████████| 3398/3398 [21:18<00:00,  2.66it/s, acc_i=0.972, acc_t=0.958, ips=229, loss=0.0897, lr=1.0e-04, τ=0.014]


Epoch 021 | loss 0.0797 | time 21.3 min | logit_scale 72.083 (τ≈0.014)


Epoch 22/50: 100%|██████████| 3398/3398 [21:07<00:00,  2.68it/s, acc_i=0.944, acc_t=0.931, ips=226, loss=0.0926, lr=1.0e-04, τ=0.014]


Epoch 022 | loss 0.0742 | time 21.1 min | logit_scale 72.651 (τ≈0.014)
✓ new best


Epoch 23/50: 100%|██████████| 3398/3398 [18:33<00:00,  3.05it/s, acc_i=0.972, acc_t=1.000, ips=228, loss=0.0672, lr=1.0e-04, τ=0.014]


Epoch 023 | loss 0.0708 | time 18.6 min | logit_scale 73.367 (τ≈0.014)
✓ new best


Epoch 24/50: 100%|██████████| 3398/3398 [18:26<00:00,  3.07it/s, acc_i=0.986, acc_t=0.986, ips=228, loss=0.0539, lr=1.0e-04, τ=0.014]


Epoch 024 | loss 0.0681 | time 18.4 min | logit_scale 73.637 (τ≈0.014)
✓ new best


Epoch 25/50: 100%|██████████| 3398/3398 [18:18<00:00,  3.09it/s, acc_i=1.000, acc_t=0.972, ips=230, loss=0.0579, lr=1.0e-04, τ=0.013]


Epoch 025 | loss 0.0628 | time 18.3 min | logit_scale 74.123 (τ≈0.013)
✓ new best


Epoch 26/50: 100%|██████████| 3398/3398 [18:23<00:00,  3.08it/s, acc_i=0.972, acc_t=0.903, ips=227, loss=0.0822, lr=1.0e-04, τ=0.013]


Epoch 026 | loss 0.0641 | time 18.4 min | logit_scale 74.279 (τ≈0.013)


Epoch 27/50: 100%|██████████| 3398/3398 [18:31<00:00,  3.06it/s, acc_i=0.972, acc_t=0.958, ips=229, loss=0.0550, lr=1.0e-04, τ=0.013]


Epoch 027 | loss 0.0615 | time 18.5 min | logit_scale 75.060 (τ≈0.013)
✓ new best


Epoch 28/50: 100%|██████████| 3398/3398 [18:44<00:00,  3.02it/s, acc_i=0.986, acc_t=0.972, ips=226, loss=0.0739, lr=1.0e-04, τ=0.013]


Epoch 028 | loss 0.0598 | time 18.7 min | logit_scale 75.427 (τ≈0.013)
✓ new best


Epoch 29/50: 100%|██████████| 3398/3398 [18:51<00:00,  3.00it/s, acc_i=0.986, acc_t=0.986, ips=227, loss=0.0531, lr=1.0e-04, τ=0.013]


Epoch 029 | loss 0.0569 | time 18.9 min | logit_scale 75.823 (τ≈0.013)
✓ new best


Epoch 30/50: 100%|██████████| 3398/3398 [18:38<00:00,  3.04it/s, acc_i=0.986, acc_t=0.986, ips=228, loss=0.0558, lr=1.0e-04, τ=0.013]


Epoch 030 | loss 0.0534 | time 18.6 min | logit_scale 77.017 (τ≈0.013)
✓ new best


Epoch 31/50: 100%|██████████| 3398/3398 [19:07<00:00,  2.96it/s, acc_i=0.986, acc_t=0.986, ips=227, loss=0.0513, lr=1.0e-04, τ=0.013]


Epoch 031 | loss 0.0507 | time 19.1 min | logit_scale 77.350 (τ≈0.013)
✓ new best


Epoch 32/50: 100%|██████████| 3398/3398 [19:42<00:00,  2.87it/s, acc_i=0.972, acc_t=0.958, ips=229, loss=0.0530, lr=1.0e-04, τ=0.013]


Epoch 032 | loss 0.0486 | time 19.7 min | logit_scale 77.923 (τ≈0.013)
✓ new best


Epoch 33/50:  11%|█         | 380/3398 [02:13<16:45,  3.00it/s, acc_i=1.000, acc_t=0.986, ips=227, loss=0.0346, lr=1.0e-04, τ=0.013]