# Humpback Whale Identification — Plan

Goal: Win a medal (MAP@5). Build a strong, fast baseline and iterate.

Plan v0:
- Environment & GPU: Verify GPU access; install torch cu121 stack once.
- Data audit: Inspect train.csv, image counts, class distribution (long-tail expected).
- CV: Stratified KFold on labels (Id). Fix folds to reuse across runs.
- Baseline model: timm pretrained classifier (e.g., tf_efficientnet_b3/b4 or convnext_tiny), CE with label smoothing, mixup/cutmix off initially.
- Augmentations: Resize ~384, horizontal flip, light color/geo. Keep simple for baseline.
- Training: 5 folds, early stopping by val MAP@5 proxy (Top-5 accuracy).
- Inference: TTA light (hflip), average logits across TTA and folds.
- Submission: Top-5 labels per image.

Next:
1) Setup GPU and PyTorch
2) Explore data and class distribution
3) Implement CV split and minimal training loop
4) Baseline train (1 seed), evaluate OOF top-5, generate submission
5) Iterate: better backbone, arcface head if time, TTA, ensembling

Checkpoints: request expert review after plan, after EDA, after baseline OOF.

In [None]:
# Environment and GPU check
import os, sys, subprocess, shutil, json, time, platform
from pathlib import Path

def run(cmd):
    return subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True).stdout

print('Python:', sys.version)
print('Platform:', platform.platform())
print('CWD:', os.getcwd())
print('nvidia-smi:')
print(run(['bash','-lc','nvidia-smi || true']))

print('List top-level files:')
for p in Path('.').iterdir():
    try:
        mtime = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(p.stat().st_mtime))
        print(mtime, p)
    except Exception as e:
        print(p, e)

train_dir = Path('train')
test_dir = Path('test')
print('Train images:', len(list(train_dir.glob('*'))))
print('Test images:', len(list(test_dir.glob('*'))))

print('Preview CSVs:')
for f in ['train.csv','sample_submission.csv']:
    if Path(f).exists():
        print(f'--- {f} head ---')
        print('\n'.join(Path(f).read_text().splitlines()[:5]))

In [None]:
# Install CUDA 12.1 torch stack and deps
import os, sys, subprocess, shutil
from pathlib import Path

def pip(*args):
    print('>', *args, flush=True)
    subprocess.run([sys.executable, '-m', 'pip', *args], check=True)

# Uninstall any preexisting torch stack
for pkg in ('torch','torchvision','torchaudio'):
    subprocess.run([sys.executable, '-m', 'pip', 'uninstall', '-y', pkg], check=False)

# Clean stray site dirs (idempotent)
for d in (
    '/app/.pip-target/torch',
    '/app/.pip-target/torchvision',
    '/app/.pip-target/torchaudio',
    '/app/.pip-target/torch-2.8.0.dist-info',
    '/app/.pip-target/torch-2.4.1.dist-info',
    '/app/.pip-target/torchvision-0.23.0.dist-info',
    '/app/.pip-target/torchvision-0.19.1.dist-info',
    '/app/.pip-target/torchaudio-2.8.0.dist-info',
    '/app/.pip-target/torchaudio-2.4.1.dist-info',
):
    if os.path.exists(d):
        print('Removing', d)
        shutil.rmtree(d, ignore_errors=True)

# Install exact cu121 torch stack
pip('install',
    '--index-url', 'https://download.pytorch.org/whl/cu121',
    '--extra-index-url', 'https://pypi.org/simple',
    'torch==2.4.1', 'torchvision==0.19.1', 'torchaudio==2.4.1')

# Freeze constraints
Path('constraints.txt').write_text('torch==2.4.1\ntorchvision==0.19.1\ntorchaudio==2.4.1\n')

# Non-torch deps (avoid upgrading torch)
pip('install', '-c', 'constraints.txt',
    'timm==1.0.9',
    'albumentations==1.4.8',
    'scikit-learn==1.5.2',
    'pandas', 'numpy',
    'opencv-python-headless==4.10.0.84',
    'faiss-cpu==1.8.0.post1',
    'matplotlib',
    '--upgrade-strategy', 'only-if-needed')

# Sanity check
import torch
print('torch:', torch.__version__, 'built CUDA:', getattr(torch.version, 'cuda', None))
print('CUDA available:', torch.cuda.is_available())
if torch.cuda.is_available():
    print('GPU:', torch.cuda.get_device_name(0))
assert str(getattr(torch.version,'cuda','')).startswith('12.1'), f'Wrong CUDA build: {torch.version.cuda}'

In [None]:
# EDA and CV fold creation
import pandas as pd, numpy as np
from pathlib import Path
from sklearn.model_selection import KFold

df = pd.read_csv('train.csv')
df['image_path'] = df['Image'].apply(lambda x: str(Path('train')/x))

# Basic stats
n_images = len(df)
n_ids = df['Id'].nunique()
vc = df['Id'].value_counts()
singletons = (vc==1).sum()
print(f'Train rows: {n_images}, unique Ids: {n_ids}, singletons: {singletons} ({singletons/n_ids:.1%} of classes)')
print('Top 5 classes by count:\n', vc.head().to_string())
print('Bottom 5 classes by count:\n', vc.tail().to_string())

# Create 5-fold KFold (stratified by exact Id is impossible due to many singletons).
# Retrieval CV tolerates plain KFold well; ensure shuffle and fixed seed.
n_splits = 5
kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
folds = np.full(n_images, -1, dtype=int)
for i, (_, val_idx) in enumerate(kf.split(df)):
    folds[val_idx] = i
df['fold'] = folds
assert (df['fold']>=0).all()
df[['Image','Id','fold']].to_csv('folds.csv', index=False)
print('Saved folds.csv with shape:', df[['Image','Id','fold']].shape)

# Quick sanity: distribution by fold
by_fold = df.groupby('fold')['Id'].nunique().rename('unique_ids')
rows_by_fold = df['fold'].value_counts().sort_index().rename('rows')
print('Rows by fold:\n', rows_by_fold.to_string())
print('Unique Ids by fold:\n', by_fold.to_string())

# Preview
print(df.head().to_string(index=False))

In [None]:
# Zero-train retrieval baseline with timm convnext_tiny and FAISS (CPU) + tau tuning
import os, time, math, gc, faiss, numpy as np, pandas as pd, torch, timm
from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
IM_SIZE = 384
BATCH_SIZE = 64
NUM_WORKERS = min(8, os.cpu_count() or 4)
K_RETR = 50  # neighbors to retrieve
ALPHA = 15.0  # vote sharpness

class ImageDS(Dataset):
    def __init__(self, df, root='.', tta_hflip=False):
        self.paths = df['image_path'].tolist() if 'image_path' in df.columns else [str(Path(root)/p) for p in df]
        self.tta_hflip = tta_hflip
        self.transform = T.Compose([
            T.Resize((IM_SIZE, IM_SIZE), interpolation=T.InterpolationMode.BICUBIC),
            T.ToTensor(),
            T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
        ])
        self.hflip = T.RandomHorizontalFlip(p=1.0)
    def __len__(self): return len(self.paths)
    def __getitem__(self, i):
        p = self.paths[i]
        img = Image.open(p).convert('RGB')
        x = self.transform(img)
        if self.tta_hflip:
            x2 = self.transform(self.hflip(img))
            return x, x2, p
        return x, p

def get_backbone():
    # num_classes=0 returns feature extractor with global pooling
    model = timm.create_model('convnext_tiny', pretrained=True, num_classes=0, global_pool='avg')
    model.eval().to(device)
    return model

@torch.no_grad()
def extract_embeddings(df_or_paths, tta_hflip=True):
    if isinstance(df_or_paths, pd.DataFrame):
        ds = ImageDS(df_or_paths, tta_hflip=tta_hflip)
    else:
        ds = ImageDS(pd.DataFrame({'image_path': df_or_paths}), tta_hflip=tta_hflip)
    dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    model = get_backbone()
    embs, paths = [], []
    t0 = time.time()
    for bi, batch in enumerate(dl):
        if tta_hflip:
            x, x2, p = batch
            x, x2 = x.to(device, non_blocking=True), x2.to(device, non_blocking=True)
            e1 = model(x)
            e2 = model(x2)
            e = (e1 + e2) / 2.0
        else:
            x, p = batch
            x = x.to(device, non_blocking=True)
            e = model(x)
        e = torch.nn.functional.normalize(e, dim=1).cpu().numpy()
        embs.append(e)
        paths += list(p)
        if (bi+1)%20==0:
            print(f'Emb batches {bi+1}, elapsed {time.time()-t0:.1f}s', flush=True)
    embs = np.concatenate(embs, axis=0) if len(embs)>0 else np.zeros((0, model.num_features), dtype=np.float32)
    return embs.astype('float32'), paths

def build_index(embs):
    # Cosine similarity via inner product on normalized vectors
    index = faiss.IndexFlatIP(embs.shape[1])
    index.add(embs)
    return index

def knn_search(index, query_embs, k):
    sims, idx = index.search(query_embs, k)
    return sims, idx

def map5_score(preds, truths):
    # preds: list of list of 5 labels; truths: list of true label
    s = 0.0
    for p, t in zip(preds, truths):
        try:
            r = p.index(t) + 1
            s += 1.0 / r
        except ValueError:
            s += 0.0
    return s / len(truths)

def rank_labels(nei_ids, nei_sims, tau=None):
    # nei_ids: [K] labels, nei_sims: [K] similarities
    scores = {}
    for lab, sim in zip(nei_ids, nei_sims):
        scores[lab] = scores.get(lab, 0.0) + math.exp(ALPHA * float(sim))
    ranked = sorted(scores.items(), key=lambda x: -x[1])
    ordered = [lab for lab,_ in ranked]
    # ensure unique labels in top-5
    ordered = list(dict.fromkeys(ordered))
    # new_whale logic handled outside using tau and max sim
    return ordered

def oof_tau_tune(train_df, folds_df, all_train_embs, all_train_paths, taus):
    # Build path->row mapping
    path2id = dict(zip((Path('train')/train_df['Image']).astype(str), train_df['Id']))
    path2idx = {p:i for i,p in enumerate(all_train_paths)}
    best_tau, best_map5 = None, -1.0
    for tau in taus:
        preds_all, truths_all = [], []
        t0 = time.time()
        for f in sorted(folds_df['fold'].unique()):
            tr_mask = folds_df['fold'] != f
            va_mask = folds_df['fold'] == f
            va_imgs = (Path('train')/folds_df.loc[va_mask, 'Image']).astype(str).tolist()
            tr_imgs = (Path('train')/folds_df.loc[tr_mask, 'Image']).astype(str).tolist()
            tr_idx = np.array([path2idx[p] for p in tr_imgs], dtype=np.int64)
            va_idx = np.array([path2idx[p] for p in va_imgs], dtype=np.int64)
            gallery = all_train_embs[tr_idx]
            queries = all_train_embs[va_idx]
            index = build_index(gallery)
            sims, idxs = knn_search(index, queries, min(K_RETR, gallery.shape[0]))
            for i in range(len(va_imgs)):
                nei_idx = idxs[i]
                nei_sims = sims[i]
                labs = [path2id[tr_imgs[j]] for j in nei_idx]
                ordered = rank_labels(labs, nei_sims)
                top5 = []
                if len(nei_sims)>0 and float(nei_sims[0]) < tau:
                    top5.append('new_whale')
                for lab in ordered:
                    if lab not in top5:
                        top5.append(lab)
                    if len(top5)==5: break
                if len(top5)<5:
                    # pad with most common label excluding duplicates (rarely needed)
                    top5 += ['new_whale']*(5-len(top5))
                preds_all.append(top5)
            truths_all += folds_df.loc[va_mask, 'Id'].tolist()
        m = map5_score(preds_all, truths_all)
        print(f'tau {tau:.3f} OOF MAP@5={m:.5f} in {time.time()-t0:.1f}s', flush=True)
        if m > best_map5:
            best_map5, best_tau = m, tau
    print(f'Best tau {best_tau:.3f} OOF MAP@5={best_map5:.5f}')
    return best_tau, best_map5

# Pipeline: 1) extract all train embeddings once; 2) tau tune via 5-fold OOF; 3) extract test embeddings; 4) build full gallery and predict; 5) write submission
t_start = time.time()
train_df = pd.read_csv('train.csv')
train_df['image_path'] = (Path('train')/train_df['Image']).astype(str)
folds = pd.read_csv('folds.csv')

print('Extracting train embeddings...')
train_embs, train_paths = extract_embeddings(train_df, tta_hflip=True)
print('Train embeddings shape:', train_embs.shape)
gc.collect();

taus = np.linspace(0.35, 0.65, 13)
best_tau, best_map5 = oof_tau_tune(train_df, folds, train_embs, train_paths, taus)

print('Extracting test embeddings...')
ss = pd.read_csv('sample_submission.csv')
test_df = pd.DataFrame({'Image': ss['Image']})
test_df['image_path'] = (Path('test')/test_df['Image']).astype(str)
test_embs, test_paths = extract_embeddings(test_df, tta_hflip=True)
print('Test embeddings shape:', test_embs.shape)

# Build full gallery on all train
index_full = build_index(train_embs)
path2id_full = dict(zip(train_paths, train_df['Id']))

print('Retrieving for test...')
sims, idxs = knn_search(index_full, test_embs, min(K_RETR, train_embs.shape[0]))
pred_rows = []
for i in range(len(test_paths)):
    nei_idx = idxs[i]
    nei_sims = sims[i]
    labs = [path2id_full[train_paths[j]] for j in nei_idx]
    ordered = rank_labels(labs, nei_sims)
    top5 = []
    if len(nei_sims)>0 and float(nei_sims[0]) < best_tau:
        top5.append('new_whale')
    for lab in ordered:
        if lab not in top5:
            top5.append(lab)
        if len(top5)==5: break
    if len(top5)<5:
        top5 += ['new_whale']*(5-len(top5))
    pred_rows.append(' '.join(top5[:5]))

sub = pd.DataFrame({'Image': ss['Image'], 'Id': pred_rows})
sub.to_csv('submission.csv', index=False)
print('Wrote submission.csv with shape', sub.shape, 'Elapsed', f'{time.time()-t_start:.1f}s')

# Show head
print(sub.head().to_string(index=False))

In [None]:
# Rebuild folds with perceptual duplicate grouping (aHash) + GroupKFold
import numpy as np, pandas as pd, os, time
from pathlib import Path
from PIL import Image
from sklearn.model_selection import GroupKFold

def ahash_image(path, size=8):
    try:
        img = Image.open(path).convert('L').resize((size, size), Image.BILINEAR)
        arr = np.asarray(img, dtype=np.float32)
        m = arr.mean()
        bits = (arr > m).astype(np.uint8)
        # pack to 64-bit integer
        val = 0
        for b in bits.flatten():
            val = (val << 1) | int(b)
        return val
    except Exception as e:
        return None

t0 = time.time()
df = pd.read_csv('train.csv')
df['image_path'] = (Path('train')/df['Image']).astype(str)
hashes = []
for i, p in enumerate(df['image_path'].tolist()):
    h = ahash_image(p)
    hashes.append(h)
    if (i+1)%1000==0:
        print(f'Hashed {i+1}/{len(df)} images...')
df['ahash'] = hashes

# Group by exact hash (fast). This catches exact/near-duplicates under aHash;
# we avoid O(N^2) hamming search for now due to time. Can refine later if needed.
df['dup_group'] = pd.factorize(df['ahash'].fillna(-1))[0]
print('Unique dup groups:', df['dup_group'].nunique())

# Build GroupKFold on dup_group; ensure balanced rows across folds
gkf = GroupKFold(n_splits=5)
folds = np.full(len(df), -1, dtype=int)
for k, (_, va_idx) in enumerate(gkf.split(df, groups=df['dup_group'])):
    folds[va_idx] = k
df['fold'] = folds
assert (df['fold']>=0).all()
df[['Image','Id','fold','dup_group','ahash']].to_csv('folds_grouped.csv', index=False)
print('Saved folds_grouped.csv:', df.shape, 'elapsed', f'{time.time()-t0:.1f}s')
print(df.head().to_string(index=False))

In [None]:
# ArcFace training: ConvNeXt-Tiny -> GeM -> BNNeck(512) -> ArcFace; PK sampler; 5-fold full training with EMA
import os, math, time, gc, random, numpy as np, pandas as pd, faiss, torch, timm
import torchvision.transforms as T
from pathlib import Path
from PIL import Image, ImageOps
from sklearn.model_selection import GroupKFold
from torch import nn
from torch.utils.data import Dataset, DataLoader, Sampler
from torch.cuda.amp import autocast, GradScaler
from timm.utils import ModelEmaV2

# Reduce fragmentation risk
os.environ.setdefault('PYTORCH_CUDA_ALLOC_CONF', 'expandable_segments:True')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
IMG_SIZE = 384
P_CLASSES = 24
K_IMGS = 2
# Use accumulation to keep VRAM safe while achieving effective batch 48
BATCH_SIZE = P_CLASSES * K_IMGS // 2  # physical batch 24
ACCUM_STEPS = 2  # effective batch = BATCH_SIZE * ACCUM_STEPS
EPOCHS = 13
NUM_WORKERS = min(8, os.cpu_count() or 4)
HEAD_LR = 3e-3
BB_LR = 2e-4
WD = 0.05
SCALE_S = 32.0
MARGIN_M = 0.30
WARMUP_EPOCHS = 1.0
K_RETR = 100
ALPHA = 20.0
SEED = 42
USE_BNNECK_FOR_RETR = False  # epochs 0-1 use GeM; switch to BNNeck from epoch>=2
torch.backends.cudnn.benchmark = True
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED);
torch.set_float32_matmul_precision('high')

def get_transforms():
    train_tf = T.Compose([
        T.RandomResizedCrop(IMG_SIZE, scale=(0.6, 1.0), interpolation=T.InterpolationMode.BICUBIC),
        T.RandomHorizontalFlip(p=0.5),
        T.RandomGrayscale(p=0.1),
        T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1, hue=0.05),
        T.RandomAffine(degrees=10, translate=(0.08,0.08), scale=(0.9,1.1)),
        T.ToTensor(),
        T.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
        T.RandomErasing(p=0.22, scale=(0.02, 0.1), ratio=(0.3, 3.3), value=0.0),
    ])
    val_tf = T.Compose([
        T.Resize((IMG_SIZE, IMG_SIZE), interpolation=T.InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
    ])
    return train_tf, val_tf

class WhalesDS(Dataset):
    def __init__(self, df, label2idx=None, mode='train', tta_hflip=False):
        self.df = df.reset_index(drop=True)
        self.mode = mode
        self.tta_hflip = tta_hflip
        self.train_tf, self.val_tf = get_transforms()
        self.tf = self.train_tf if mode=='train' else self.val_tf
        self.label2idx = label2idx
    def __len__(self): return len(self.df)
    def __getitem__(self, i):
        row = self.df.iloc[i]
        img = Image.open(str(row.image_path)).convert('RGB')
        x = self.tf(img)
        if self.mode != 'train' and self.tta_hflip:
            x2 = self.val_tf(ImageOps.mirror(img))
        if self.mode=='train':
            y = self.label2idx[row.Id]
            return x, y
        else:
            if self.tta_hflip:
                return x, x2, row.image_path, row.Id
            return x, row.image_path, row.Id

class PKSampler(Sampler):
    def __init__(self, df, label2idx, p=P_CLASSES, k=K_IMGS):
        self.df = df.reset_index(drop=True)
        self.p, self.k = p, k
        self.label2idx = label2idx
        self.cls2idxs = {}
        for i, lab in enumerate(self.df['Id']):
            self.cls2idxs.setdefault(lab, []).append(i)
        self.multi_classes = [c for c, idxs in self.cls2idxs.items() if len(idxs) >= 2]
        self.single_classes = [c for c, idxs in self.cls2idxs.items() if len(idxs) == 1]
        self.all_classes = list(self.cls2idxs.keys())
        self.n_batches = math.ceil(len(self.df) / (p*k))
        self.rng = random.Random(SEED)
    def __len__(self): return self.n_batches * self.p * self.k
    def __iter__(self):
        rng = self.rng
        for _ in range(self.n_batches):
            # 60-70% multi-instance, rest fill (may include singles) with no duplicates
            chosen = set()
            p_multi = min(len(self.multi_classes), max(0, int(self.p * 0.65)))
            if p_multi > 0:
                chosen.update(rng.sample(self.multi_classes, p_multi))
            pool = [c for c in self.all_classes if c not in chosen]
            need = self.p - len(chosen)
            if need > 0:
                take = min(len(pool), need)
                if take > 0:
                    chosen.update(rng.sample(pool, take))
            while len(chosen) < self.p:
                pool = [c for c in self.all_classes if c not in chosen]
                if not pool: break
                chosen.add(rng.choice(pool))
            chosen = list(chosen)
            batch_idxs = []
            for c in chosen:
                idxs = self.cls2idxs[c]
                if len(idxs) >= self.k:
                    sel = rng.sample(idxs, self.k)
                else:
                    sel = [rng.choice(idxs) for _ in range(self.k)]
                batch_idxs.extend(sel)
            yield from batch_idxs

class GeM(nn.Module):
    def __init__(self, p=3.0, eps=1e-6):
        super().__init__(); self.p = nn.Parameter(torch.ones(1)*p); self.eps = eps
    def forward(self, x):
        x = x.clamp(min=self.eps).pow(self.p)
        x = torch.mean(x, dim=(-1,-2)).pow(1.0/self.p)
        return x

class ArcMarginProduct(nn.Module):
    def __init__(self, in_features, out_features, s=SCALE_S, m=MARGIN_M):
        super().__init__(); self.in_features=in_features; self.out_features=out_features
        self.s = s; self.m = m
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)
    def forward(self, x, labels, margin_override=None):
        m = self.m if margin_override is None else margin_override
        x = nn.functional.normalize(x, dim=1)
        W = nn.functional.normalize(self.weight, dim=1)
        cosine = nn.functional.linear(x, W)
        cosine = cosine.clamp(-1+1e-7, 1-1e-7)  # clamp for numerical safety
        sine = torch.sqrt((1.0 - cosine**2).clamp(0,1))
        cos_m = math.cos(m); sin_m = math.sin(m); th = math.cos(math.pi - m); mm = math.sin(math.pi - m) * m
        phi = cosine * cos_m - sine * sin_m
        phi = torch.where(cosine > th, phi, cosine - mm)
        one_hot = torch.zeros_like(cosine); one_hot.scatter_(1, labels.view(-1,1), 1.0)
        logits = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        logits *= self.s
        return logits

class Net(nn.Module):
    def __init__(self, n_classes):
        super().__init__()
        self.backbone = timm.create_model('convnext_tiny', pretrained=True, num_classes=0, global_pool='')
        in_ch = self.backbone.num_features
        self.gem = GeM(p=3.0)
        self.emb = nn.Linear(in_ch, 512, bias=False)
        self.bnneck = nn.BatchNorm1d(512, eps=1e-5, momentum=0.1)
        self.arc = ArcMarginProduct(512, n_classes, s=SCALE_S, m=MARGIN_M)
    def forward(self, x, labels=None, margin_override=None):
        feat = self.backbone.forward_features(x)
        feat = self.gem(feat)
        feat_512 = self.emb(feat)
        logits = self.arc(feat_512, labels, margin_override=margin_override) if labels is not None else None
        feat_bn = self.bnneck(feat_512)
        return logits, feat_512, feat_bn

def model_feats(model, x, use_bnneck: bool):
    with torch.no_grad():
        feat = model.backbone.forward_features(x)
        feat = model.gem(feat)
        if use_bnneck:
            f512 = model.emb(feat)
            fbn = model.bnneck(f512)
            return nn.functional.normalize(fbn, dim=1)
        else:
            return nn.functional.normalize(feat, dim=1)

def build_label_mapping(df):
    labs = sorted([x for x in df['Id'].unique().tolist() if x != 'new_whale'])
    return {l:i for i,l in enumerate(labs)}

@torch.no_grad()
def extract_feats(model, df, tta_hflip=True):
    ds = WhalesDS(df, mode='val', tta_hflip=tta_hflip)
    dl = DataLoader(ds, batch_size=64, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    all_feats, paths, labels = [], [], []
    t0=time.time()
    for bi, batch in enumerate(dl):
        if tta_hflip:
            x, x2, p, y = batch
            x = x.to(device, non_blocking=True).to(memory_format=torch.channels_last)
            x2 = x2.to(device, non_blocking=True).to(memory_format=torch.channels_last)
            f1 = model_feats(model, x, USE_BNNECK_FOR_RETR)
            f2 = model_feats(model, x2, USE_BNNECK_FOR_RETR)
            f = (f1 + f2) / 2.0
        else:
            x, p, y = batch
            x = x.to(device, non_blocking=True).to(memory_format=torch.channels_last)
            f = model_feats(model, x, USE_BNNECK_FOR_RETR)
        all_feats.append(f.cpu().numpy()); paths += list(p); labels += list(y)
        if (bi+1)%20==0: print(f'FE bi {bi+1}, {time.time()-t0:.1f}s', flush=True)
    if len(all_feats)>0:
        feats = np.concatenate(all_feats, axis=0)
    else:
        dim = int(model.emb.out_features) if USE_BNNECK_FOR_RETR else int(model.backbone.num_features)
        feats = np.zeros((0, dim), dtype=np.float32)
    return feats.astype('float32'), paths, labels

def build_index(embs):
    index = faiss.IndexFlatIP(embs.shape[1])
    index.add(embs)
    return index

def vote_rank(nei_ids, nei_sims):
    scores = {}; 
    for lab, sim in zip(nei_ids, nei_sims):
        scores[lab] = scores.get(lab, 0.0) + math.exp(ALPHA * float(sim))
    ordered = [k for k,_ in sorted(scores.items(), key=lambda x:-x[1])]
    return ordered

def map5(preds, truths):
    s=0.0
    for p,t in zip(preds, truths):
        try: s+=1.0/(p.index(t)+1)
        except ValueError: s+=0.0
    return s/len(truths)

def oof_eval(model, df_gallery, df_val_fold, taus, ambiguity_gate=True):
    tr_feats, tr_paths, tr_ids = extract_feats(model, df_gallery, tta_hflip=True)
    va_feats, va_paths, va_ids = extract_feats(model, df_val_fold, tta_hflip=True)
    tr_ids_map = dict(zip(tr_paths, tr_ids))
    index = build_index(tr_feats)
    sims, idxs = index.search(va_feats, min(K_RETR, tr_feats.shape[0]))
    try:
        max_sims = sims[:,0] if sims.size>0 else np.array([], dtype=np.float32)
        q25, q50, q75 = (np.quantile(max_sims, 0.25), np.quantile(max_sims, 0.50), np.quantile(max_sims, 0.75)) if len(max_sims)>0 else (0,0,0)
        gal_labels = set(tr_ids)
        va_in_gal = sum(1 for v in va_ids if v in gal_labels)
        print(f'[OOF] top1 sim q25/q50/q75: {q25:.3f}/{q50:.3f}/{q75:.3f} | val covered in gallery: {va_in_gal}/{len(va_ids)} ({(va_in_gal/len(va_ids))*100:.1f}%)', flush=True)
    except Exception as e:
        print('[OOF] diag error:', e, flush=True)
    gal_labels_set = set(tr_ids)
    va_ids_eval = [vid if vid in gal_labels_set else 'new_whale' for vid in va_ids]
    best_tau, best_score = None, -1.0
    for tau in taus:
        preds=[]
        for i in range(len(va_paths)):
            nei_idx = idxs[i]; nei_sims = sims[i]
            labs = [tr_ids_map[tr_paths[j]] for j in nei_idx]
            ordered = vote_rank(labs, nei_sims)
            top5=[]
            s1 = float(nei_sims[0]) if len(nei_sims)>0 else -1.0
            s2 = float(nei_sims[1]) if len(nei_sims)>1 else s1
            cond_tau = (len(nei_sims)>0 and s1 < tau)
            cond_margin = (len(nei_sims)>1 and (s1 - s2) < 0.03)
            cond_ratio = (len(nei_sims)>1 and (s1 / max(s2, 1e-6)) < 1.06)
            if cond_tau or (ambiguity_gate and (cond_margin or cond_ratio)):
                top5.append('new_whale')
            for lab in ordered:
                if lab not in top5:
                    top5.append(lab)
                if len(top5)==5: break
            if len(top5)<5: top5 += ['new_whale']*(5-len(top5))
            preds.append(top5)
        sc = map5(preds, va_ids_eval)
        if sc>best_score: best_score, best_tau = sc, tau
    return best_score, best_tau

def train_fold(df_tr_train, df_gallery, df_va, fold_idx):
    global USE_BNNECK_FOR_RETR
    label2idx = build_label_mapping(df_tr_train)
    n_classes = len(label2idx)
    print(f'[Fold {fold_idx}] classes (excl new_whale):', n_classes, 'rows:', len(df_tr_train))
    ds_tr = WhalesDS(df_tr_train, label2idx=label2idx, mode='train')
    sampler = PKSampler(df_tr_train, label2idx, p=P_CLASSES, k=K_IMGS)
    dl_tr = DataLoader(ds_tr, batch_size=BATCH_SIZE, sampler=sampler, num_workers=NUM_WORKERS, pin_memory=True)
    model = Net(n_classes).to(device)
    model = model.to(memory_format=torch.channels_last)
    # EMA
    ema_model = ModelEmaV2(model, decay=0.999)
    try:
        if hasattr(model.backbone, 'set_grad_checkpointing'):
            model.backbone.set_grad_checkpointing(True)
            print('[Fold', fold_idx, '] Grad checkpointing: ON', flush=True)
    except Exception as e:
        print('[Fold', fold_idx, '] Grad checkpointing not set:', e, flush=True)
    bb_params = []; head_params = []
    for n,p in model.named_parameters():
        if any(k in n for k in ['\u200bemb','emb','bnneck','arc']): head_params.append(p)
        else: bb_params.append(p)
    optim = torch.optim.AdamW([{'params': bb_params, 'lr': BB_LR}, {'params': head_params, 'lr': HEAD_LR}], weight_decay=WD)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=max(1,EPOCHS-1))
    scaler = GradScaler(enabled=True)
    best_oof, best_tau = -1.0, 0.7
    best_path = f'model_fold{fold_idx}_best.pth'
    for epoch in range(EPOCHS):
        USE_BNNECK_FOR_RETR = (epoch >= 2)
        taus = np.arange(0.93, 0.99, 0.005) if not USE_BNNECK_FOR_RETR else np.arange(0.68, 0.84, 0.01)
        torch.cuda.empty_cache()
        if epoch == 0:
            for i, g in enumerate(optim.param_groups):
                if i == 0: g['lr'] = 0.0
                else: g['lr'] = HEAD_LR
        else:
            for i, g in enumerate(optim.param_groups):
                if i == 0: g['lr'] = BB_LR
                else: g['lr'] = g['lr']
        model.train()
        t0=time.time(); run_loss=0.0; n_batches=0
        steps_per_epoch = max(1, len(sampler)//BATCH_SIZE)
        optim.zero_grad(set_to_none=True)
        for bi, (x,y) in enumerate(dl_tr):
            x = x.to(device, non_blocking=True).to(memory_format=torch.channels_last); y = torch.as_tensor(y, dtype=torch.long, device=device)
            progress = min(1.0, (bi+1)/max(1, steps_per_epoch)/max(1e-6, WARMUP_EPOCHS))
            m_cur = MARGIN_M * progress
            with autocast(enabled=(device.type=='cuda')):
                logits, _, _ = model(x, labels=y, margin_override=m_cur)
                loss = nn.functional.cross_entropy(logits, y) / ACCUM_STEPS
            scaler.scale(loss).backward()
            if ((bi+1) % ACCUM_STEPS) == 0:
                scaler.unscale_(optim)
                nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                scaler.step(optim)
                scaler.update()
                optim.zero_grad(set_to_none=True)
                # EMA update
                ema_model.update(model)
            run_loss += loss.item()*ACCUM_STEPS; n_batches += 1
            if (bi+1) % 100 == 0:
                print(f'[Fold {fold_idx}] Ep{epoch+1} B{bi+1} loss {run_loss/n_batches:.4f} elapsed {time.time()-t0:.1f}s', flush=True)
        if ((bi+1) % ACCUM_STEPS) != 0:
            scaler.unscale_(optim)
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optim)
            scaler.update()
            optim.zero_grad(set_to_none=True)
            ema_model.update(model)
        print(f'[Fold {fold_idx}] Ep{epoch+1} train_loss {run_loss/max(1,n_batches):.4f} epoch_time {time.time()-t0:.1f}s | retr_feats={"BNNeck" if USE_BNNECK_FOR_RETR else "GeM"} tau_range=({taus[0]:.2f}-{taus[-1]:.2f})')
        if epoch >= 1:
            scheduler.step()
        # Evaluate with EMA weights
        model.eval()
        with torch.no_grad():
            oof_sc, oof_tau = oof_eval(ema_model.module, df_gallery[['Image','Id','image_path']], df_va[['Image','Id','image_path']], taus, ambiguity_gate=True)
        print(f'[Fold {fold_idx}] Ep{epoch+1} OOF MAP@5={oof_sc:.4f} tau={oof_tau:.3f} (feats={"BNNeck" if USE_BNNECK_FOR_RETR else "GeM"})')
        if oof_sc > best_oof:
            best_oof, best_tau = oof_sc, oof_tau
            torch.save({'model': ema_model.module.state_dict(), 'best_oof': best_oof, 'best_tau': best_tau, 'epoch': epoch+1}, best_path)
            print(f'[Fold {fold_idx}] Saved new best to {best_path}')
    return best_oof, best_tau, best_path

# Prepare grouped folds and dataframes
folds = pd.read_csv('folds_grouped.csv')
folds['image_path'] = (Path('train')/folds['Image']).astype(str)

# 5-fold full training on all non-new_whale classes; gallery = full train fold excl new_whale
all_folds = [0,1,2,3,4]
oof_scores = []; taus_best = []; ckpts = []
t_all = time.time()
for f in all_folds:
    tr_df = folds[folds['fold'] != f].copy()
    va_df = folds[folds['fold'] == f].copy()
    gallery_df = tr_df[tr_df['Id']!='new_whale'].copy()
    # Train on ALL non-new_whale (include singletons) per expert advice
    tr_df_train = gallery_df.copy()
    print(f'[Fold {f}] training rows (all non-new_whale):', len(tr_df_train), 'classes:', tr_df_train['Id'].nunique())
    sc, tau, ckpt = train_fold(tr_df_train, gallery_df, va_df, f)
    oof_scores.append(sc); taus_best.append(tau); ckpts.append(ckpt)
    gc.collect()
print('5-fold OOF MAP@5 mean:', float(np.mean(oof_scores)))
print('per-fold taus:', taus_best)
print('checkpoints:', ckpts)
print('Elapsed total', time.time()-t_all)

print('Training complete. Next: extract train/test BNNeck embeddings with EMA weights, build ID prototypes, tune global tau (median of per-fold), and generate submission.')

In [None]:
# Fix albumentations/albucore version mismatch
import sys, subprocess
def pip(*args):
    print('>', *args, flush=True)
    subprocess.run([sys.executable, '-m', 'pip', *args], check=True)

# Align versions to resolve ImportError: preserve_channel_dim
pip('install', '-c', 'constraints.txt', '--upgrade', 'albumentations==1.4.11', 'albucore==0.0.13', '--upgrade-strategy', 'only-if-needed')

import albumentations as A
import albucore
print('albumentations:', A.__version__)
import inspect
from albucore import utils as ac_utils
print('albucore:', getattr(albucore, '__version__', 'unknown'))
print('has preserve_channel_dim:', hasattr(ac_utils, 'preserve_channel_dim'))

In [None]:
# Micro-overfit sanity: train on ~20 multi-instance classes to validate loop
import pandas as pd, numpy as np, torch, time, gc
from pathlib import Path
from torch import nn
from torch.cuda.amp import autocast, GradScaler

def build_subset(df_all, min_imgs=5, n_classes=20):
    dfmi = df_all[df_all['Id']!='new_whale'].copy()
    vc = dfmi['Id'].value_counts()
    keep_ids = vc[vc>=min_imgs].index.tolist()[:n_classes]
    sub = dfmi[dfmi['Id'].isin(keep_ids)].copy().reset_index(drop=True)
    return sub

def split_train_val(df_sub, val_frac=0.2, seed=42):
    rng = np.random.default_rng(seed)
    train_idx = []
    val_idx = []
    for gid, g in df_sub.groupby('Id'):
        idx = g.index.to_numpy()
        rng.shuffle(idx)
        n_val = max(1, int(len(idx)*val_frac))
        val_idx.extend(idx[:n_val].tolist())
        train_idx.extend(idx[n_val:].tolist())
    tr = df_sub.loc[sorted(train_idx)].reset_index(drop=True)
    va = df_sub.loc[sorted(val_idx)].reset_index(drop=True)
    return tr, va

def build_label_mapping_local(df):
    labs = sorted(df['Id'].unique().tolist())
    return {l:i for i,l in enumerate(labs)}

def micro_overfit_run(img_size=384, p=10, k=4, epochs=5, head_lr=3e-3, bb_lr=0.0, wd=0.05, m_max=0.30):
    global IMG_SIZE
    IMG_SIZE = img_size
    folds = pd.read_csv('folds_grouped.csv')
    folds['image_path'] = (Path('train')/folds['Image']).astype(str)
    sub = build_subset(folds, min_imgs=5, n_classes=20)
    tr, va = split_train_val(sub, val_frac=0.2, seed=42)
    print('Subset shapes:', sub.shape, 'train:', tr.shape, 'val:', va.shape)
    label2idx = build_label_mapping_local(tr)
    ds_tr = WhalesDS(tr, label2idx=label2idx, mode='train')
    sampler = PKSampler(tr, label2idx, p=p, k=k)
    dl_tr = DataLoader(ds_tr, batch_size=p*k, sampler=sampler, num_workers=NUM_WORKERS, pin_memory=True)
    model = Net(n_classes=len(label2idx)).to(device)
    # Freeze backbone if bb_lr==0
    bb_params = []; head_params = []
    for n, param in model.named_parameters():
        if any(t in n for t in ['emb','bnneck','arc']):
            head_params.append(param)
        else:
            if bb_lr==0.0:
                param.requires_grad = False
            bb_params.append(param)
    optim = torch.optim.AdamW([{'params': [p for p in bb_params if p.requires_grad], 'lr': bb_lr}, {'params': head_params, 'lr': head_lr}], weight_decay=wd)
    scaler = GradScaler(enabled=True)
    taus = np.arange(0.60, 0.86, 0.02)
    best_sc=-1.0; best_tau=0.7
    for ep in range(epochs):
        model.train(); t0=time.time(); run_loss=0.0; nb=0
        steps_per_epoch = max(1, len(sampler)//(p*k))
        for bi, (x,y) in enumerate(dl_tr):
            x = x.to(device, non_blocking=True); y = torch.as_tensor(y, dtype=torch.long, device=device)
            progress = min(1.0, (bi+1)/max(1, steps_per_epoch)/max(1e-6, WARMUP_EPOCHS))
            m_cur = m_max * progress
            optim.zero_grad(set_to_none=True)
            with autocast(enabled=(device.type=='cuda')):
                logits, _, _ = model(x, labels=y, margin_override=m_cur)
                loss = nn.functional.cross_entropy(logits, y)
            scaler.scale(loss).backward()
            scaler.unscale_(optim)
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optim); scaler.update()
            run_loss += loss.item(); nb+=1
            if (bi+1)%50==0:
                print(f'[Micro] Ep{ep+1} B{bi+1} loss {run_loss/nb:.4f} elapsed {time.time()-t0:.1f}s', flush=True)
        print(f'[Micro] Ep{ep+1} train_loss {run_loss/max(1,nb):.4f} time {time.time()-t0:.1f}s')
        model.eval()
        with torch.no_grad():
            sc, tau = oof_eval(model, tr[['Image','Id','image_path']], va[['Image','Id','image_path']], taus)
        print(f'[Micro] Ep{ep+1} OOF MAP@5={sc:.4f} tau={tau:.3f}')
        if sc>best_sc: best_sc, best_tau = sc, tau
    print('[Micro] Best OOF:', best_sc, 'tau:', best_tau)
    return best_sc, best_tau

# Run micro-overfit: expect clear learning (>0.5 MAP@5 on this tiny subset within a few epochs) if pipeline is healthy
best_sc, best_tau = micro_overfit_run(img_size=384, p=10, k=4, epochs=5, head_lr=3e-3, bb_lr=0.0)
print('Done micro-overfit. Score:', best_sc, 'tau:', best_tau)

In [1]:
# Minimal model + feature extractor definitions for inference (no training)
import os, math, time, torch, timm
import torchvision.transforms as T
import numpy as np
from torch import nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image, ImageOps
from pathlib import Path

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
IMG_SIZE = 384
NUM_WORKERS = min(8, os.cpu_count() or 4)
USE_BNNECK_FOR_RETR = False  # will be overridden by extract_feats_bnneck wrapper

def get_infer_transform():
    return T.Compose([
        T.Resize((IMG_SIZE, IMG_SIZE), interpolation=T.InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
    ])

class WhalesDS(Dataset):
    def __init__(self, df, tta_hflip=False):
        self.df = df.reset_index(drop=True)
        self.tta_hflip = tta_hflip
        self.tf = get_infer_transform()
    def __len__(self): return len(self.df)
    def __getitem__(self, i):
        row = self.df.iloc[i]
        img = Image.open(str(row.image_path)).convert('RGB')
        x = self.tf(img)
        if self.tta_hflip:
            x2 = self.tf(ImageOps.mirror(img))
            return x, x2, row.image_path, row.Id
        return x, row.image_path, row.Id

class GeM(nn.Module):
    def __init__(self, p=3.0, eps=1e-6):
        super().__init__(); self.p = nn.Parameter(torch.ones(1)*p); self.eps = eps
    def forward(self, x):
        x = x.clamp(min=self.eps).pow(self.p)
        x = torch.mean(x, dim=(-1,-2)).pow(1.0/self.p)
        return x

class Net(nn.Module):
    def __init__(self, n_classes):
        super().__init__()
        self.backbone = timm.create_model('convnext_tiny', pretrained=True, num_classes=0, global_pool='')
        in_ch = self.backbone.num_features
        self.gem = GeM(p=3.0)
        self.emb = nn.Linear(in_ch, 512, bias=False)
        self.bnneck = nn.BatchNorm1d(512, eps=1e-5, momentum=0.1)
    def forward(self, x):
        feat = self.backbone.forward_features(x)
        feat = self.gem(feat)
        f512 = self.emb(feat)
        fbn = self.bnneck(f512)
        return f512, fbn

@torch.no_grad()
def model_feats(model, x, use_bnneck: bool):
    feat = model.backbone.forward_features(x)
    feat = model.gem(feat)
    if use_bnneck:
        f512 = model.emb(feat)
        fbn = model.bnneck(f512)
        return nn.functional.normalize(fbn, dim=1)
    else:
        return nn.functional.normalize(feat, dim=1)

@torch.no_grad()
def extract_feats(model, df, tta_hflip=True):
    ds = WhalesDS(df, tta_hflip=tta_hflip)
    dl = DataLoader(ds, batch_size=64, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    feats, paths, labels = [], [], []
    t0=time.time()
    for bi, batch in enumerate(dl):
        if tta_hflip:
            x, x2, p, y = batch
            x = x.to(device, non_blocking=True).to(memory_format=torch.channels_last)
            x2 = x2.to(device, non_blocking=True).to(memory_format=torch.channels_last)
            f1 = model_feats(model, x, USE_BNNECK_FOR_RETR)
            f2 = model_feats(model, x2, USE_BNNECK_FOR_RETR)
            f = (f1 + f2) / 2.0
        else:
            x, p, y = batch
            x = x.to(device, non_blocking=True).to(memory_format=torch.channels_last)
            f = model_feats(model, x, USE_BNNECK_FOR_RETR)
        feats.append(f.cpu().numpy()); paths += list(p); labels += list(y)
        if (bi+1)%20==0:
            print(f'FE bi {bi+1}, {time.time()-t0:.1f}s', flush=True)
    if len(feats)>0:
        feats = np.concatenate(feats, axis=0)
    else:
        dim = int(model.emb.out_features) if USE_BNNECK_FOR_RETR else int(model.backbone.num_features)
        feats = np.zeros((0, dim), dtype=np.float32)
    return feats.astype('float32'), paths, labels

In [2]:
# Inference: load EMA checkpoints, extract BNNeck embeddings, build ID prototypes or image-gallery, predict test, write submission.csv
import os, math, time, gc, numpy as np, pandas as pd, torch, faiss
from pathlib import Path
from collections import OrderedDict

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ALPHA = 12.0  # expert: use 12.0

def l2_normalize(x, axis=1, eps=1e-9):
    x = np.asarray(x)
    if x.ndim == 1:
        n = np.linalg.norm(x) + eps
        return x / n
    n = np.linalg.norm(x, axis=axis, keepdims=True) + eps
    return x / n

def load_ckpt(ckpt_path):
    ckpt = torch.load(ckpt_path, map_location='cpu')
    state = ckpt.get('model', ckpt)  # tolerate raw state_dict
    # Drop ArcFace head weights to avoid class-dimension mismatch at inference
    if isinstance(state, (dict, OrderedDict)):
        new_state = OrderedDict()
        for k, v in state.items():
            if ('arc.' in k) or k.endswith('arc.weight') or (k == 'arc.weight') or ('arc_weight' in k):
                continue
            new_state[k] = v
        state = new_state
    best_tau = ckpt.get('best_tau', None)
    return state, best_tau

@torch.no_grad()
def extract_feats_bnneck(model, df, batch_size=64, tta_hflip=True):
    # Force BNNeck features for inference
    global USE_BNNECK_FOR_RETR
    prev_flag = USE_BNNECK_FOR_RETR
    USE_BNNECK_FOR_RETR = True
    feats, paths, labels = extract_feats(model, df[['Image','Id','image_path']], tta_hflip=tta_hflip)
    USE_BNNECK_FOR_RETR = prev_flag
    return feats, paths, labels

def build_id_prototypes(embs: np.ndarray, ids: list):
    # Average embeddings per ID (no normalization here; we'll normalize after stacking)
    df = pd.DataFrame({'Id': ids})
    df['idx'] = np.arange(len(ids))
    protos = []
    labels = []
    for gid, g in df.groupby('Id'):
        idxs = g['idx'].values
        m = embs[idxs].mean(axis=0, dtype=np.float32)
        protos.append(m)
        labels.append(gid)
    return np.stack(protos).astype('float32'), labels

def build_index_ip(embs):
    index = faiss.IndexFlatIP(embs.shape[1])
    index.add(embs)
    return index

def predict_with_gate(index, gallery_labels, query_embs, k=100, tau=0.74, alpha=12.0, margin=0.0, ratio=0.0):
    sims, idxs = index.search(query_embs, min(k, index.ntotal))
    preds = []
    for i in range(len(query_embs)):
        nei_idx = idxs[i]; nei_sims = sims[i]
        labs = [gallery_labels[j] for j in nei_idx]
        # exponential voting
        scores = {}
        for lab, sim in zip(labs, nei_sims):
            scores[lab] = scores.get(lab, 0.0) + math.exp(alpha * float(sim))
        ordered = [k for k,_ in sorted(scores.items(), key=lambda x:-x[1])]
        # ambiguity/new_whale gate
        s1 = float(nei_sims[0]) if len(nei_sims)>0 else -1.0
        s2 = float(nei_sims[1]) if len(nei_sims)>1 else s1
        top5 = []
        if (len(nei_sims)>0 and s1 < tau) or (margin>0.0 and len(nei_sims)>1 and (s1 - s2) < margin) or (ratio>0.0 and len(nei_sims)>1 and (s1/max(s2,1e-6) < ratio)):
            top5.append('new_whale')
        for lab in ordered:
            if lab not in top5:
                top5.append(lab)
            if len(top5)==5: break
        if len(top5)<5:
            top5 += ['new_whale']*(5-len(top5))
        preds.append(top5[:5])
    return preds

def map5_score(preds, truths):
    s = 0.0
    for p, t in zip(preds, truths):
        try:
            r = p.index(t) + 1
            s += 1.0 / r
        except ValueError:
            s += 0.0
    return s / len(truths) if len(truths)>0 else 0.0

def _build_dynamic_tau_grid_from_s1(s1_vals, lo_floor=0.20, hi_cap=0.90, pad=0.01, step=0.001):
    v = s1_vals[np.isfinite(s1_vals)]
    v = v[v > 0]
    if v.size > 0:
        p10, p90 = np.quantile(v, [0.10, 0.90])
        lo = max(lo_floor, float(p10) - pad)
        hi = min(hi_cap, float(p90) + pad)
        if hi > lo:
            return np.arange(lo, hi + 1e-9, step)
    return np.arange(lo_floor, hi_cap, 0.002)

def calibrate_tau_fold(model, folds_df, fold_idx, tta_hflip=True, tau_grid=None, margin=0.0, ratio=0.0, enable_dba=False, dba_M=8, dba_lambda=0.3):
    tr = folds_df[(folds_df['fold'] != fold_idx) & (folds_df['Id'] != 'new_whale')].copy()
    va = folds_df[folds_df['fold'] == fold_idx].copy()
    # Extract train/val features
    feats_tr, paths_tr, labs_tr = extract_feats_bnneck(model, tr, tta_hflip=tta_hflip)
    feats_va, paths_va, labs_va = extract_feats_bnneck(model, va, tta_hflip=tta_hflip)
    protos, proto_labels = build_id_prototypes(feats_tr, labs_tr)
    protos = l2_normalize(protos, axis=1).astype('float32')
    if enable_dba:
        protos = dba_smooth(protos, M=dba_M, lam=dba_lambda)
    index = build_index_ip(protos)
    # Determine eval ground truth (IDs not in gallery become new_whale)
    gal_id_set = set(proto_labels)
    truths = [lab if lab in gal_id_set else 'new_whale' for lab in labs_va]
    sims_all, idxs_all = index.search(l2_normalize(feats_va, axis=1).astype('float32'), min(50, index.ntotal))
    # Dynamic tau grid from s1 percentiles
    s1 = sims_all[:, 0] if sims_all.size > 0 else np.array([], dtype=np.float32)
    dyn_grid = _build_dynamic_tau_grid_from_s1(s1)
    if tau_grid is None or len(tau_grid) == 0:
        tau_grid = dyn_grid
    best_tau, best_sc = None, -1.0
    for tau in tau_grid:
        preds = []
        for i in range(len(paths_va)):
            nei_idx = idxs_all[i]; nei_sims = sims_all[i]
            ordered = [proto_labels[j] for j in nei_idx]
            # Apply gate using predict_with_gate logic but faster inline
            s1i = float(nei_sims[0]) if len(nei_sims)>0 else -1.0
            s2i = float(nei_sims[1]) if len(nei_sims)>1 else s1i
            top5 = []
            if (len(nei_sims)>0 and s1i < tau) or (margin>0.0 and len(nei_sims)>1 and (s1i - s2i) < margin) or (ratio>0.0 and len(nei_sims)>1 and (s1i/max(s2i,1e-6) < ratio)):
                top5.append('new_whale')
            # simple unique ordering based on sims (already by similarity)
            for lab in ordered:
                if lab not in top5:
                    top5.append(lab)
                if len(top5)==5: break
            if len(top5)<5:
                top5 += ['__DUMMY__']*(5-len(top5))  # do not reward the gate during calibration
            preds.append(top5[:5])
        sc = map5_score(preds, truths)
        if sc > best_sc:
            best_sc, best_tau = sc, float(tau)
    print(f'[Calib] fold {fold_idx} best_tau={best_tau:.3f} map5={best_sc:.4f}')
    return best_tau, best_sc

def dba_smooth(proto_mat, M=8, lam=0.3):
    # Database-side augmentation: smooth each vector with mean of its M nearest neighbors
    if M <= 0 or lam <= 0: return proto_mat
    index = build_index_ip(proto_mat)
    sims, idxs = index.search(proto_mat, min(M+1, index.ntotal))
    out = proto_mat.copy()
    for i in range(proto_mat.shape[0]):
        neigh = idxs[i][1:]  # exclude self at 0
        if len(neigh) == 0: continue
        mean_nei = proto_mat[neigh].mean(axis=0)
        out[i] = l2_normalize(proto_mat[i] + lam * mean_nei)
    return out.astype('float32')

def query_expansion(index, gallery_vecs, queries, L=8, lam=0.3, conditional_tau=None):
    if L <= 0 or lam <= 0: return queries
    sims, idxs = index.search(queries, min(L, index.ntotal))
    out = queries.copy()
    for i in range(queries.shape[0]):
        # Conditional QE: skip if top1 sim below (tau-0.02)
        if conditional_tau is not None:
            s1 = float(sims[i][0]) if sims.shape[1] > 0 else -1.0
            if s1 < (conditional_tau - 0.02):
                out[i] = queries[i]
                continue
        neigh = idxs[i]
        if len(neigh) == 0: continue
        mean_nei = gallery_vecs[neigh].mean(axis=0)
        out[i] = l2_normalize(queries[i] + lam * mean_nei)
    return out.astype('float32')

def calibrate_tau_global(proto_mat, proto_labels, folds_df, val_feats_per_fold, val_labels_per_fold, tau_grid=None, margin=0.0, ratio=0.0):
    # Build one index on final gallery (with DBA already applied); QE is OFF here by design
    index = build_index_ip(proto_mat)
    proto_label_set = set(proto_labels)
    # Collect s1 across all folds for dynamic grid
    s1_all = []
    for f in sorted(folds_df['fold'].unique()):
        feats = l2_normalize(val_feats_per_fold[f], axis=1).astype('float32')
        sims_all, _ = index.search(feats, min(100, index.ntotal))
        if sims_all.size > 0:
            s1_all.append(sims_all[:, 0])
    if len(s1_all) > 0:
        s1_cat = np.concatenate(s1_all, axis=0)
    else:
        s1_cat = np.array([], dtype=np.float32)
    dyn_grid = _build_dynamic_tau_grid_from_s1(s1_cat)
    if tau_grid is None or len(tau_grid) == 0:
        tau_grid = dyn_grid
    try:
        if s1_cat.size > 0:
            q25, q50, q75 = np.quantile(s1_cat, [0.25, 0.50, 0.75])
            print(f'[Calib-Global] s1 q25/q50/q75: {q25:.3f}/{q50:.3f}/{q75:.3f} | tau_grid: {tau_grid[0]:.3f}-{tau_grid[-1]:.3f} (n={len(tau_grid)})')
    except Exception:
        pass
    best_tau, best_sc = None, -1.0
    for tau in tau_grid:
        scores = []
        for f in sorted(folds_df['fold'].unique()):
            feats = l2_normalize(val_feats_per_fold[f], axis=1).astype('float32')
            labs = val_labels_per_fold[f]
            sims_all, idxs_all = index.search(feats, min(100, index.ntotal))
            preds = []
            truths = []
            for i in range(len(labs)):
                lab = labs[i]
                truths.append(lab if lab in proto_label_set else 'new_whale')
                nei_idx = idxs_all[i]; nei_sims = sims_all[i]
                # exponential voting
                scores_map = {}
                for j, sim in zip(nei_idx, nei_sims):
                    scores_map[proto_labels[j]] = scores_map.get(proto_labels[j], 0.0) + math.exp(ALPHA * float(sim))
                ordered = [k for k,_ in sorted(scores_map.items(), key=lambda x:-x[1])]
                s1i = float(nei_sims[0]) if len(nei_sims)>0 else -1.0
                s2i = float(nei_sims[1]) if len(nei_sims)>1 else s1i
                top5 = []
                if (len(nei_sims)>0 and s1i < tau) or (margin>0.0 and len(nei_sims)>1 and (s1i - s2i) < margin) or (ratio>0.0 and len(nei_sims)>1 and (s1i/max(s2i,1e-6) < ratio)):
                    top5.append('new_whale')
                for lab2 in ordered:
                    if lab2 not in top5:
                        top5.append(lab2)
                    if len(top5)==5: break
                if len(top5)<5:
                    top5 += ['__DUMMY__']*(5-len(top5))  # do not reward gate during calibration
                preds.append(top5[:5])
            scores.append(map5_score(preds, truths))
        sc = float(np.mean(scores)) if len(scores)>0 else -1.0
        if sc > best_sc:
            best_sc, best_tau = sc, float(tau)
    print(f'[Calib-Global] best_tau={best_tau:.3f} mean_oof_map5={best_sc:.4f}')
    return best_tau, best_sc

def run_full_inference(ckpt_paths, out_csv='submission.csv', tta_hflip=True, enable_dba=False, dba_M=8, dba_lambda=0.3, enable_qe=False, qe_L=8, qe_lambda=0.3, tau_offset=0.0, margin=0.02, ratio=1.04):
    t0_all = time.time()
    folds_df = pd.read_csv('folds_grouped.csv')
    folds_df['image_path'] = (Path('train')/folds_df['Image']).astype(str)
    train_all = folds_df[folds_df['Id']!='new_whale'].copy()
    ids_all = sorted(train_all['Id'].unique().tolist())
    n_classes = len(ids_all)
    # Build model skeleton once
    model = Net(n_classes=n_classes).to(device).eval()
    model = model.to(memory_format=torch.channels_last)
    # Prepare test df
    ss = pd.read_csv('sample_submission.csv')
    test_df = pd.DataFrame({'Image': ss['Image']})
    test_df['image_path'] = (Path('test')/test_df['Image']).astype(str)
    # Accumulators
    tau_list = []
    proto_accum = {}  # Id -> sum vector
    proto_counts = {} # Id -> count of folds
    test_emb_accum = None
    # Store per-fold val feats for global tau calibration later
    val_feats_per_fold = {}
    val_labels_per_fold = {}
    for fold_i, ck in enumerate(ckpt_paths):
        if not os.path.exists(ck):
            print(f'[Infer] Skip missing ckpt: {ck}')
            continue
        state, _ = load_ckpt(ck)
        # Additionally drop any state keys not present/shaped like current model to be extra safe
        model_state = model.state_dict()
        filtered = OrderedDict()
        for k, v in state.items():
            if k in model_state and tuple(model_state[k].shape) == tuple(v.shape):
                filtered[k] = v
        miss = model.load_state_dict(filtered, strict=False)
        print(f'[Infer] Loaded {ck}; used keys: {len(filtered)}/{len(state)}; missing keys: {len(miss.missing_keys)}; unexpected: {len(miss.unexpected_keys)}')
        # Calibrate tau on this fold using prototype retrieval with BNNeck (dynamic tau grid, tau-only gate, dummy padding)
        best_tau_fold, _ = calibrate_tau_fold(model, folds_df, fold_i, tta_hflip=tta_hflip, tau_grid=None, margin=0.0, ratio=0.0, enable_dba=enable_dba, dba_M=dba_M, dba_lambda=dba_lambda)
        tau_list.append(best_tau_fold)
        # Train embeddings -> prototypes on ALL train (for final gallery)
        feats_tr, paths_tr, labs_tr = extract_feats_bnneck(model, train_all, tta_hflip=tta_hflip)
        protos, proto_labels = build_id_prototypes(feats_tr, labs_tr)
        # accumulate prototypes
        for v, lab in zip(protos, proto_labels):
            if lab not in proto_accum:
                proto_accum[lab] = v.astype('float32').copy()
                proto_counts[lab] = 1
            else:
                proto_accum[lab] += v.astype('float32')
                proto_counts[lab] += 1
        # Test embeddings
        feats_te, paths_te, _ = extract_feats_bnneck(model, test_df.assign(Id='dummy'), tta_hflip=tta_hflip)
        if test_emb_accum is None:
            test_emb_accum = feats_te.astype('float32')
        else:
            test_emb_accum += feats_te.astype('float32')
        # Store fold val feats for diagnostics/global calib
        va = folds_df[folds_df['fold'] == fold_i].copy()
        feats_va, _, labs_va = extract_feats_bnneck(model, va, tta_hflip=tta_hflip)
        val_feats_per_fold[fold_i] = feats_va.astype('float32')
        val_labels_per_fold[fold_i] = list(labs_va)
        gc.collect()
    # Average accumulators
    if len(proto_accum)==0 or test_emb_accum is None:
        raise RuntimeError('No checkpoints processed; cannot run inference.')
    proto_keys = sorted(proto_accum.keys())
    # stack first, then normalize row-wise to avoid 1D normalization pitfalls
    proto_raw = np.stack([proto_accum[k] / max(1, proto_counts[k]) for k in proto_keys]).astype('float32')
    proto_mat = l2_normalize(proto_raw, axis=1).astype('float32')
    test_mat = l2_normalize(test_emb_accum / max(1, len(ckpt_paths))).astype('float32')
    # Optional DBA smoothing on prototypes
    if enable_dba:
        proto_mat = dba_smooth(proto_mat, M=dba_M, lam=dba_lambda)
    # Global calibration on FINAL gallery (DBA ON, QE OFF), with safety gates (margin/ratio)
    tau_global, oof_mean = calibrate_tau_global(proto_mat, proto_keys, folds_df, val_feats_per_fold, val_labels_per_fold, tau_grid=None, margin=margin, ratio=ratio)
    tau_global = float(np.clip(tau_global + float(tau_offset), 0.0, 1.0))
    print(f'[Infer] Global calib tau={tau_global:.4f} (offset={float(tau_offset):.4f}) | per-fold taus median={float(np.median(tau_list)):.4f}')
    # Build index on prototypes
    index = build_index_ip(proto_mat)
    # Optional Query Expansion (conditional)
    if enable_qe and qe_L>0 and qe_lambda>0:
        test_mat_qe = query_expansion(index, proto_mat, test_mat, L=qe_L, lam=qe_lambda, conditional_tau=tau_global)
    else:
        test_mat_qe = test_mat
    preds = predict_with_gate(index, proto_keys, test_mat_qe, k=50, tau=tau_global, alpha=ALPHA, margin=margin, ratio=ratio)
    # Write submission
    pred_rows = [' '.join(p[:5]) for p in preds]
    sub = pd.DataFrame({'Image': ss['Image'], 'Id': pred_rows})
    sub.to_csv(out_csv, index=False)
    print('Wrote', out_csv, 'shape', sub.shape, 'elapsed', f'{time.time()-t0_all:.1f}s')
    print(sub.head().to_string(index=False))
    print('Per-fold calibrated taus:', tau_list)
    return out_csv

def run_full_inference_image_gallery(ckpt_paths, out_csv='submission.csv', tta_hflip=True, enable_dba=True, dba_M=8, dba_lambda=0.3, enable_qe=True, qe_L=8, qe_lambda=0.3, tau_offset=0.0, margin=0.0, ratio=0.0):
    # Build an image-level gallery (vote-by-ID) by averaging per-image embeddings across folds, then retrieve
    t0_all = time.time()
    folds_df = pd.read_csv('folds_grouped.csv')
    folds_df['image_path'] = (Path('train')/folds_df['Image']).astype(str)
    train_all = folds_df[folds_df['Id']!='new_whale'].copy()
    ids_all = sorted(train_all['Id'].unique().tolist())
    n_classes = len(ids_all)
    model = Net(n_classes=n_classes).to(device).eval()
    model = model.to(memory_format=torch.channels_last)
    ss = pd.read_csv('sample_submission.csv')
    test_df = pd.DataFrame({'Image': ss['Image']})
    test_df['image_path'] = (Path('test')/test_df['Image']).astype(str)
    # Accumulators (per-image embeddings across folds)
    gal_feat_sum = None
    gal_labels = train_all['Id'].tolist()
    gal_paths = train_all['image_path'].tolist()
    test_feat_sum = None
    # For calibration, collect per-fold val features
    val_feat_sums = {}
    val_labels_per_fold = {}
    n_loaded = 0
    for fold_i, ck in enumerate(ckpt_paths):
        if not os.path.exists(ck):
            print(f'[Infer-IMG] Skip missing ckpt: {ck}')
            continue
        state, _ = load_ckpt(ck)
        model_state = model.state_dict()
        filtered = OrderedDict((k,v) for k,v in state.items() if k in model_state and tuple(model_state[k].shape)==tuple(v.shape))
        _ = model.load_state_dict(filtered, strict=False)
        # Extract gallery (all train non-new_whale, per-image)
        feats_gal, _, _ = extract_feats_bnneck(model, train_all, tta_hflip=tta_hflip)
        gal_feat_sum = feats_gal.astype('float32') if gal_feat_sum is None else gal_feat_sum + feats_gal.astype('float32')
        # Extract test
        feats_te, _, _ = extract_feats_bnneck(model, test_df.assign(Id='dummy'), tta_hflip=tta_hflip)
        test_feat_sum = feats_te.astype('float32') if test_feat_sum is None else test_feat_sum + feats_te.astype('float32')
        # Validation features for this fold for calibration averaging
        va_df = folds_df[folds_df['fold']==fold_i].copy()
        feats_va, _, labs_va = extract_feats_bnneck(model, va_df, tta_hflip=tta_hflip)
        if fold_i not in val_feat_sums:
            val_feat_sums[fold_i] = feats_va.astype('float32')
            val_labels_per_fold[fold_i] = list(labs_va)
        else:
            val_feat_sums[fold_i] += feats_va.astype('float32')
        n_loaded += 1
        gc.collect()
    if n_loaded == 0:
        raise RuntimeError('No checkpoints processed; cannot run inference.')
    gal_mat_raw = l2_normalize(gal_feat_sum / n_loaded, axis=1).astype('float32')
    test_mat = l2_normalize(test_feat_sum / n_loaded, axis=1).astype('float32')
    for f in val_feat_sums.keys():
        val_feat_sums[f] = l2_normalize(val_feat_sums[f] / n_loaded, axis=1).astype('float32')
    assert gal_mat_raw.shape[0] == len(gal_labels)
    # Leak-free per-fold tau calibration (QE OFF); recompute DBA on each sub-gallery
    print('[Infer-IMG] Starting leak-free per-fold tau calibration...')
    tau_grid = np.arange(0.95, 1.0001, 0.001)
    train_all_idx = train_all.reset_index(drop=True)  # align row order
    gal_labels_arr = np.array(gal_labels, dtype=object)
    gallery_fold_indices = train_all_idx['fold'].values
    taus_best = []
    fold_scores = []
    for f in sorted(val_feat_sums.keys()):
        sub_mask = (gallery_fold_indices != f)
        gal_sub = gal_mat_raw[sub_mask]
        labels_sub = gal_labels_arr[sub_mask].tolist()
        assert gal_sub.shape[0] == len(labels_sub)
        # DBA on sub-gallery to avoid leak via neighbors
        gal_sub_dba = dba_smooth(gal_sub, M=dba_M, lam=dba_lambda) if enable_dba else gal_sub
        index_sub = build_index_ip(gal_sub_dba)
        q = val_feat_sums[f]
        labs = val_labels_per_fold[f]
        assert q.shape[0] == len(labs)
        sims_all, idxs_all = index_sub.search(q, min(50, index_sub.ntotal))
        truths = [lab if lab in set(labels_sub) else 'new_whale' for lab in labs]
        best_tau_f, best_sc_f = None, -1.0
        for tau in tau_grid:
            preds = []
            for i in range(q.shape[0]):
                nei_idx = idxs_all[i]; nei_sims = sims_all[i]
                nei_ids = [labels_sub[j] for j in nei_idx]
                # vote-by-ID
                scores_map = {}
                for lab2, sim in zip(nei_ids, nei_sims):
                    scores_map[lab2] = scores_map.get(lab2, 0.0) + math.exp(ALPHA * float(sim))
                ordered = [k for k,_ in sorted(scores_map.items(), key=lambda x: -x[1])]
                s1 = float(nei_sims[0]) if len(nei_sims)>0 else -1.0
                s2 = float(nei_sims[1]) if len(nei_sims)>1 else s1
                top5 = []
                if (len(nei_sims)>0 and s1 < tau) or (len(nei_sims)>1 and (s1 - s2) < margin) or (len(nei_sims)>1 and (s1/max(s2,1e-6)) < ratio):
                    top5.append('new_whale')
                for lab2 in ordered:
                    if lab2 not in top5: top5.append(lab2)
                    if len(top5)==5: break
                if len(top5)<5: top5 += ['new_whale']*(5-len(top5))
                preds.append(top5[:5])
            sc = map5_score(preds, truths)
            if sc > best_sc_f:
                best_sc_f, best_tau_f = sc, float(tau)
        taus_best.append(best_tau_f if best_tau_f is not None else float(np.median(tau_grid)))
        fold_scores.append(best_sc_f if best_sc_f is not None else 0.0)
        print(f'[OOF-fold-img] f={f} tau*={taus_best[-1]:.3f} score={fold_scores[-1]:.4f}')
    tau_global = float(np.median(taus_best))
    mean_oof = float(np.mean(fold_scores))
    # Apply user-provided tau_offset to adjust gating for test distribution
    tau_global = float(np.clip(tau_global + float(tau_offset), 0.65, 0.999))
    print(f'[OOF-img] mean MAP@5={mean_oof:.4f} | median tau={tau_global:.3f} (after offset={float(tau_offset):.4f})')
    # Build final gallery (DBA exactly as configured) and index
    gal_mat = dba_smooth(gal_mat_raw, M=dba_M, lam=dba_lambda) if enable_dba else gal_mat_raw
    index = build_index_ip(gal_mat)
    # QE only for final inference (conditional)
    if enable_qe:
        test_q = query_expansion(index, gal_mat, test_mat, L=qe_L, lam=qe_lambda, conditional_tau=tau_global)
    else:
        test_q = test_mat
    sims_all, idxs_all = index.search(test_q, min(50, index.ntotal))
    pred_rows = []
    for i in range(test_q.shape[0]):
        nei_idx = idxs_all[i]; nei_sims = sims_all[i]
        nei_ids = [gal_labels[j] for j in nei_idx]
        scores_map = {}
        for lab, sim in zip(nei_ids, nei_sims):
            scores_map[lab] = scores_map.get(lab, 0.0) + math.exp(ALPHA * float(sim))
        ordered = [k for k,_ in sorted(scores_map.items(), key=lambda x:-x[1])]
        s1 = float(nei_sims[0]) if len(nei_sims)>0 else -1.0
        s2 = float(nei_sims[1]) if len(nei_sims)>1 else s1
        top5 = []
        if (len(nei_sims)>0 and s1 < tau_global) or (len(nei_sims)>1 and (s1 - s2) < margin) or (len(nei_sims)>1 and (s1/max(s2,1e-6)) < ratio):
            top5.append('new_whale')
        for lab2 in ordered:
            if lab2 not in top5: top5.append(lab2)
            if len(top5)==5: break
        if len(top5)<5: top5 += ['new_whale']*(5-len(top5))
        pred_rows.append(' '.join(top5[:5]))
    sub = pd.DataFrame({'Image': ss['Image'], 'Id': pred_rows})
    sub.to_csv(out_csv, index=False)
    print('Wrote (IMG) ', out_csv, 'shape', sub.shape, 'elapsed', f'{time.time()-t0_all:.1f}s')
    print(sub.head().to_string(index=False))
    return out_csv

# Example usage after training completes:
# ckpts = [f'model_fold{k}_best.pth' for k in range(5)]
# run_full_inference(ckpts, out_csv='submission.csv', tta_hflip=True)
# run_full_inference_image_gallery(ckpts, out_csv='submission_img.csv', tta_hflip=True)

In [None]:
# Run inference: prototype gallery (Path B) with global open-set calibration, DBA ON, QE OFF
from pathlib import Path
ckpts = [f'model_fold{k}_best.pth' for k in range(5)]
out_csv = run_full_inference(
    ckpt_paths=ckpts,
    out_csv='submission.csv',
    tta_hflip=True,
    enable_dba=True,
    dba_M=5,
    dba_lambda=0.15,
    enable_qe=False,
    qe_L=0,
    qe_lambda=0.0,
    tau_offset=-0.45,  # Force tau into the prototype similarity regime to target ~35-55% new_whale@1
    margin=0.02,
    ratio=1.04
)
print('submission.csv exists?', Path(out_csv).exists())
print('submission.csv size:', Path(out_csv).stat().st_size if Path(out_csv).exists() else -1)

In [None]:
# OOF tuner using image-level galleries (vote-by-ID), per-fold models, to match training OOF conditions
import math, time, gc, numpy as np, pandas as pd, torch, os
from pathlib import Path
from collections import OrderedDict

@torch.no_grad()
def build_fold_gallery_and_val(ckpt_paths, tta_hflip=True):
    folds_df = pd.read_csv('folds_grouped.csv')
    folds_df['image_path'] = (Path('train')/folds_df['Image']).astype(str)
    ids_all = sorted(folds_df.loc[folds_df['Id']!='new_whale', 'Id'].unique().tolist())
    n_classes = len(ids_all)
    gallery_by_fold = {}   # f -> (gallery_feats[L2], gallery_ids[list of Id])
    val_feats_by_fold = {}
    val_labels_by_fold = {}
    for fold_i, ck in enumerate(ckpt_paths):
        if not os.path.exists(ck):
            print(f'[Prep] Skip missing ckpt: {ck}')
            continue
        model = Net(n_classes=n_classes).to(device).eval().to(memory_format=torch.channels_last)
        state, _ = load_ckpt(ck)
        model_state = model.state_dict()
        filtered = OrderedDict((k,v) for k,v in state.items() if k in model_state and tuple(model_state[k].shape)==tuple(v.shape))
        _ = model.load_state_dict(filtered, strict=False)
        # Gallery = train excluding this fold, exclude new_whale (mirrors training OOF)
        gal_df = folds_df[(folds_df['fold']!=fold_i) & (folds_df['Id']!='new_whale')].copy()
        gal_feats, _, gal_ids = extract_feats_bnneck(model, gal_df, tta_hflip=tta_hflip)
        gal_feats = l2_normalize(gal_feats, axis=1).astype('float32')
        gallery_by_fold[fold_i] = (gal_feats, list(gal_ids))
        # Validation features for THIS fold with THIS model
        va_df = folds_df[folds_df['fold']==fold_i].copy()
        feats_va, _, labs_va = extract_feats_bnneck(model, va_df, tta_hflip=tta_hflip)
        val_feats_by_fold[fold_i] = l2_normalize(feats_va, axis=1).astype('float32')
        val_labels_by_fold[fold_i] = list(labs_va)
        gc.collect()
    return gallery_by_fold, val_feats_by_fold, val_labels_by_fold

def vote_rank_ids(nei_ids, nei_sims, alpha=ALPHA):
    scores = {}
    for lab, sim in zip(nei_ids, nei_sims):
        scores[lab] = scores.get(lab, 0.0) + math.exp(alpha * float(sim))
    ordered = [k for k,_ in sorted(scores.items(), key=lambda x:-x[1])]
    return ordered

def oof_score_image_gallery(gallery_by_fold, val_feats_by_fold, val_labels_by_fold, dba_M=8, dba_lambda=0.3, qe_L=8, qe_lambda=0.3, conditional_qe=True, tau_grid=np.arange(0.65,0.81,0.01)):
    taus_best = []
    fold_scores = []
    for f in sorted(gallery_by_fold.keys()):
        gal_feats, gal_ids = gallery_by_fold[f]
        # Apply DBA smoothing on gallery embeddings (safe on vectors)
        gal_mat = dba_smooth(gal_feats, M=dba_M, lam=dba_lambda) if (dba_M>0 and dba_lambda>0) else gal_feats
        index = build_index_ip(gal_mat)
        feats = val_feats_by_fold[f]
        labs = val_labels_by_fold[f]
        # Calibrate tau per-fold (QE OFF for calibration)
        best_tau, best_sc = None, -1.0
        sims_all, idxs_all = index.search(feats, min(100, index.ntotal))
        truths = [lab if lab in set(gal_ids) else 'new_whale' for lab in labs]
        for tau in tau_grid:
            preds = []
            for i in range(len(labs)):
                nei_idx = idxs_all[i]; nei_sims = sims_all[i]
                nei_lab = [gal_ids[j] for j in nei_idx]
                ordered = vote_rank_ids(nei_lab, nei_sims, alpha=ALPHA)
                s1 = float(nei_sims[0]) if len(nei_sims)>0 else -1.0
                s2 = float(nei_sims[1]) if len(nei_sims)>1 else s1
                top5 = []
                if (len(nei_sims)>0 and s1 < tau) or (len(nei_sims)>1 and (s1 - s2) < 0.03) or (len(nei_sims)>1 and (s1/max(s2,1e-6) < 1.06)):
                    top5.append('new_whale')
                for lab2 in ordered:
                    if lab2 not in top5: top5.append(lab2)
                    if len(top5)==5: break
                if len(top5)<5: top5 += ['new_whale']*(5-len(top5))
                preds.append(top5[:5])
            sc = map5_score(preds, truths)
            if sc > best_sc:
                best_sc, best_tau = sc, float(tau)
        taus_best.append(best_tau if best_tau is not None else float(np.median(tau_grid)))
        # QE for final scoring (optional)
        if qe_L>0 and qe_lambda>0:
            feats_q = query_expansion(index, gal_mat, feats, L=qe_L, lam=qe_lambda, conditional_tau=(best_tau if conditional_qe else None))
            sims_all, idxs_all = index.search(feats_q, min(100, index.ntotal))
        # Final scoring with best tau
        preds = []
        for i in range(len(labs)):
            nei_idx = idxs_all[i]; nei_sims = sims_all[i]
            nei_lab = [gal_ids[j] for j in nei_idx]
            ordered = vote_rank_ids(nei_lab, nei_sims, alpha=ALPHA)
            s1 = float(nei_sims[0]) if len(nei_sims)>0 else -1.0
            s2 = float(nei_sims[1]) if len(nei_sims)>1 else s1
            top5 = []
            if (len(nei_sims)>0 and s1 < best_tau) or (len(nei_sims)>1 and (s1 - s2) < 0.03) or (len(nei_sims)>1 and (s1/max(s2,1e-6) < 1.06)):
                top5.append('new_whale')
            for lab2 in ordered:
                if lab2 not in top5: top5.append(lab2)
                if len(top5)==5: break
            if len(top5)<5: top5 += ['new_whale']*(5-len(top5))
            preds.append(top5[:5])
        sc_final = map5_score(preds, truths)
        fold_scores.append(sc_final)
        print(f'[OOF-fold-img] f={f} tau*={best_tau:.3f} score={sc_final:.4f}')
    mean_sc = float(np.mean(fold_scores)) if len(fold_scores)>0 else -1.0
    tau_global = float(np.median(taus_best)) if len(taus_best)>0 else float(np.median(tau_grid))
    print(f'[OOF-img] mean MAP@5={mean_sc:.4f} | median tau={tau_global:.3f}')
    return mean_sc, tau_global

def grid_tune_and_submit(ckpts, dba_M_grid=(5,8,12), dba_lambda_grid=(0.2,0.3,0.4), qe_L_grid=(5,8,12), qe_lambda_grid=(0.2,0.3), conditional_qe=True):
    t0=time.time()
    gallery_by_fold, val_feats_by_fold, val_labels_by_fold = build_fold_gallery_and_val(ckpts, tta_hflip=True)
    print('[Tune] Prepared per-fold image-level galleries and val feats in', f'{time.time()-t0:.1f}s')
    best = (-1.0, None)
    for M in dba_M_grid:
        for lam in dba_lambda_grid:
            for L in qe_L_grid:
                for qel in qe_lambda_grid:
                    sc, tau = oof_score_image_gallery(gallery_by_fold, val_feats_by_fold, val_labels_by_fold, dba_M=M, dba_lambda=lam, qe_L=L, qe_lambda=qel, conditional_qe=conditional_qe)
                    print(f'[Tune] DBA(M={M},lam={lam}) QE(L={L},lam={qel},cond={conditional_qe}) -> OOF {sc:.4f} tau {tau:.3f}')
                    if sc > best[0]:
                        best = (sc, (M, lam, L, qel, tau))
    best_sc, (M, lam, L, qel, tau) = best
    print(f'[Tune] Best OOF {best_sc:.4f} with DBA(M={M},lam={lam}) QE(L={L},lam={qel}) tau={tau:.3f}')
    # Final inference still uses blended prototype pipeline for speed; params carried over
    out_csv = run_full_inference(ckpts, out_csv='submission.csv', tta_hflip=True, enable_dba=True, dba_M=M, dba_lambda=lam, enable_qe=True, qe_L=L, qe_lambda=qel)
    print('[Tune] Submission written to', out_csv)
    return out_csv, {'OOF': best_sc, 'DBA_M': M, 'DBA_lambda': lam, 'QE_L': L, 'QE_lambda': qel, 'tau': tau}

In [None]:
# Run DBA/QE grid tuning (reduced grid for sanity) and regenerate submission with best params
ckpts = [f'model_fold{k}_best.pth' for k in range(5)]
print('[Run-Tune] Starting DBA/QE grid (reduced) using per-fold OOF...')
t0=time.time()
out_csv, best_info = grid_tune_and_submit(
    ckpts,
    dba_M_grid=(8,),
    dba_lambda_grid=(0.3,),
    qe_L_grid=(8,),
    qe_lambda_grid=(0.3,),
    conditional_qe=True
)
print('[Run-Tune] Done in', f'{time.time()-t0:.1f}s', 'Best:', best_info)
from pathlib import Path
print('submission.csv exists?', Path(out_csv).exists(), 'size:', Path(out_csv).stat().st_size if Path(out_csv).exists() else -1)

In [None]:
# Run per-fold score-level fusion inference (DBA + conditional QE + k-reciprocal rerank) with leak-free calibration
ckpts = [f'model_fold{k}_best.pth' for k in range(5)]
print('[Infer-Fusion] Starting per-fold score-level fusion (DBA+QE+rerank, calibrated tau)...')
from pathlib import Path
import time
t0=time.time()
out_csv = run_infer_img_gallery_score_fusion(
    ckpt_paths=ckpts,
    out_csv='submission.csv',
    cache_dir='cache_feats',
    tta_hflip=True,
    enable_dba=True, dba_M=8, dba_lambda=0.3,
    K=200, alpha=12.0,
    margin=0.0, ratio=0.0,   # gates OFF for stable calibration
    tau_offset=-0.0020,      # S2: target NH@1 ~55–58%
    enable_qe=True, qe_L=10, qe_lambda=0.35,
    enable_rerank=True, rerank_k1=20, rerank_k2=6, rerank_lam=0.25
)
print('[Infer-Fusion] Done in', f'{time.time()-t0:.1f}s', '->', out_csv)
print('submission.csv exists?', Path(out_csv).exists(), 'size:', Path(out_csv).stat().st_size if Path(out_csv).exists() else -1)

[Infer-Fusion] Starting per-fold score-level fusion (DBA+QE+rerank, calibrated tau)...


  ckpt = torch.load(ckpt_path, map_location='cpu')


[DEBUG fold 0] sims range: [0.827, 0.997] mean=0.952 K=200 gal=5201 val=1448


[DEBUG fold 0] sims range: [0.964, 1.000] mean=0.989 K=200 gal=5201 val=1448


In [10]:
# Per-fold score-level fusion with leak-free calibration + caching + k-reciprocal re-ranking
from collections import OrderedDict, defaultdict
from pathlib import Path
import numpy as np, pandas as pd, math, time, os, gc

def _load_cached(npy_path, n_rows=None):
    if not Path(npy_path).exists(): return None
    arr = np.load(npy_path)
    if (n_rows is not None) and (arr.shape[0] != n_rows): return None
    return arr.astype('float32')

def _save_cached(npy_path, arr):
    np.save(npy_path, arr.astype('float32'))

def _get_feats_cached(model, df, cache_path, tta_hflip=True):
    arr = _load_cached(cache_path, len(df))
    if arr is not None: return arr
    feats, _, _ = extract_feats_bnneck(model, df, tta_hflip=tta_hflip)
    _save_cached(cache_path, feats)
    return feats

def _accumulate_scores(nei_ids, nei_sims, alpha, store: dict):
    for lab, sim in zip(nei_ids, nei_sims):
        store[lab] = store.get(lab, 0.0) + math.exp(alpha * float(sim))

def _accumulate_max_sims(nei_ids, nei_sims, store: dict):
    # Track per-ID maximum similarity across folds
    for lab, sim in zip(nei_ids, nei_sims):
        cur = store.get(lab, -1.0)
        if float(sim) > cur:
            store[lab] = float(sim)

# k-reciprocal re-ranking helpers
def _precompute_gal_nn(index, gal_feats, k2):
    _, idxs_g = index.search(gal_feats, min(k2+1, index.ntotal))
    return idxs_g

def _rerank_k_reciprocal(sims, idxs, gal_nn_idx, k1=20, k2=6, lam=0.3):
    if sims.size == 0: return sims, idxs
    sims_new = sims.copy()
    nq, K = sims.shape
    for i in range(nq):
        topk = idxs[i][:k1]
        topk_set = set(topk.tolist())
        for j in range(K):
            gid = idxs[i][j]
            neigh = gal_nn_idx[gid][1:1+k2]
            overlap = len(topk_set.intersection(neigh.tolist()))
            bonus = overlap / float(max(1, k2))
            sims_new[i][j] = (1.0 - lam) * sims[i][j] + lam * bonus
        order = np.argsort(-sims_new[i])
        sims_new[i] = sims_new[i][order]
        idxs[i] = idxs[i][order]
    return sims_new, idxs

def _gate_from_fused_max(maxsim_dict: dict, tau: float, margin: float, ratio: float) -> bool:
    # Fixed gating: proper s2 handling and clamping; avoid ambiguity when only one candidate
    if not maxsim_dict:
        return True
    vals = [max(0.0, min(1.0, float(v))) for v in maxsim_dict.values()]
    if not vals:
        return True
    vals.sort(reverse=True)
    s1 = vals[0]
    if s1 < tau:
        return True
    if len(vals) > 1:
        s2 = vals[1]
        if (s1 - s2) < margin:
            return True
        if (s1 / max(s2, 1e-6)) < ratio:
            return True
    return False

def calibrate_tau_score_fusion(ckpt_paths, folds_df, train_all, cache_dir, alpha=18.0, dba_M=12, dba_lambda=0.4, margin=0.02, ratio=1.04, tau_grid=None, tta_hflip=True, K=200, enable_rerank=True, rerank_k1=20, rerank_k2=6, rerank_lam=0.3):
    os.makedirs(cache_dir, exist_ok=True)
    n_classes = train_all['Id'].nunique()
    model = Net(n_classes=n_classes).to(device).eval().to(memory_format=torch.channels_last)

    gal_labels_all = train_all['Id'].tolist()
    gal_folds_all = train_all['fold'].to_numpy()
    taus_best = []
    for f in sorted(folds_df['fold'].unique()):
        va_df = folds_df[folds_df['fold']==f].copy().reset_index(drop=True)
        n_q = len(va_df)
        fused_scores = [defaultdict(float) for _ in range(n_q)]
        fused_maxsim = [dict() for _ in range(n_q)]  # per-query per-ID max similarity

        sub_mask = (gal_folds_all != f)
        sub_labels = np.array(gal_labels_all, dtype=object)[sub_mask].tolist()
        truths = [lab if lab in set(sub_labels) else 'new_whale' for lab in va_df['Id']]

        for k, ck in enumerate(ckpt_paths):
            if not os.path.exists(ck): continue
            state, _ = load_ckpt(ck)
            filtered = OrderedDict((kk,vv) for kk,vv in state.items() if kk in model.state_dict() and tuple(model.state_dict()[kk].shape)==tuple(vv.shape))
            _ = model.load_state_dict(filtered, strict=False)

            gal_all = _get_feats_cached(model, train_all, f'{cache_dir}/gal_feats_k{k}.npy', tta_hflip=tta_hflip)
            gal_all = l2_normalize(gal_all, axis=1); gal = gal_all[sub_mask]
            assert gal.shape[0] == len(sub_labels)
            if dba_M>0 and dba_lambda>0: gal = dba_smooth(gal, M=dba_M, lam=dba_lambda)
            index = build_index_ip(gal)
            gal_nn_idx = _precompute_gal_nn(index, gal, k2=rerank_k2)

            val_feats = _get_feats_cached(model, va_df, f'{cache_dir}/val_feats_k{k}_f{f}.npy', tta_hflip=tta_hflip)
            val_feats = l2_normalize(val_feats, axis=1)
            sims, idxs = index.search(val_feats, min(K, index.ntotal))
            # Debug stats for similarity distribution
            try:
                print(f"[DEBUG fold {f}] sims range: [{sims.min():.3f}, {sims.max():.3f}] mean={sims.mean():.3f} K={sims.shape[1]} gal={gal.shape[0]} val={val_feats.shape[0]}")
            except Exception:
                pass
            if enable_rerank:
                sims, idxs = _rerank_k_reciprocal(sims, idxs, gal_nn_idx, k1=rerank_k1, k2=rerank_k2, lam=rerank_lam)
            for qi in range(n_q):
                ns = sims[qi]; ni = idxs[qi]
                if ns.size == 0: continue
                nei_ids = [sub_labels[j] for j in ni]
                _accumulate_scores(nei_ids, ns, alpha, fused_scores[qi])
                _accumulate_max_sims(nei_ids, ns, fused_maxsim[qi])

        # Build dynamic tau_grid from fused s1 distribution for this fold
        s1_vals = np.array([max(v.values()) if len(v)>0 else -1.0 for v in fused_maxsim], dtype=np.float32)
        v = s1_vals[s1_vals > 0]
        if v.size > 0:
            p10, p90 = np.quantile(v, [0.10, 0.90])
            lo = max(0.90, float(p10) - 0.01)
            hi = min(0.999, float(p90) + 0.01)
            tau_grid_fold = np.arange(lo, hi + 1e-9, 0.001)
        else:
            tau_grid_fold = np.arange(0.95, 0.999, 0.001)

        best_tau, best_sc = None, -1.0
        for tau in tau_grid_fold:
            preds = []
            for qi in range(n_q):
                ordered = [k for k,_ in sorted(fused_scores[qi].items(), key=lambda kv: -kv[1])]
                # ensure uniqueness for calibration predictions as well
                seen = set()
                ordered_unique = []
                for lab in ordered:
                    if lab not in seen:
                        ordered_unique.append(lab); seen.add(lab)
                top5 = []
                # Use the SAME gate as inference (tau + margin/ratio)
                if _gate_from_fused_max(fused_maxsim[qi], tau, margin=margin, ratio=ratio):
                    top5.append('new_whale')
                for lab in ordered_unique:
                    if lab not in top5: top5.append(lab)
                    if len(top5)==5: break
                if len(top5) < 5:
                    top5 += ['__DUMMY__'] * (5 - len(top5))  # avoid rewarding gate with 'new_whale'
                preds.append(top5)
            sc = map5_score(preds, truths)
            if sc > best_sc:
                best_sc, best_tau = sc, float(tau)
        taus_best.append(float(best_tau))
        print(f'[Fusion-Calib] fold {f}: tau*={best_tau:.3f} oof_map5={best_sc:.4f}')
    tau_global = float(np.median(taus_best))
    print(f'[Fusion-Calib] median tau={tau_global:.3f}')
    return tau_global

def run_infer_img_gallery_score_fusion(ckpt_paths, out_csv='submission.csv', cache_dir='cache_feats', tta_hflip=True, enable_dba=True, dba_M=12, dba_lambda=0.4, K=200, alpha=18.0, margin=0.02, ratio=1.04, tau_offset=0.0, enable_qe=True, qe_L=8, qe_lambda=0.3, enable_rerank=True, rerank_k1=20, rerank_k2=6, rerank_lam=0.3):
    t0=time.time()
    os.makedirs(cache_dir, exist_ok=True)
    folds_df = pd.read_csv('folds_grouped.csv')
    folds_df['image_path'] = (Path('train')/folds_df['Image']).astype(str)
    train_all = folds_df[folds_df['Id']!='new_whale'].copy().reset_index(drop=True)

    ss = pd.read_csv('sample_submission.csv')
    test_df = pd.DataFrame({'Image': ss['Image']})
    test_df['image_path'] = (Path('test')/test_df['Image']).astype(str)

    n_classes = train_all['Id'].nunique()
    model = Net(n_classes=n_classes).to(device).eval().to(memory_format=torch.channels_last)

    # Leak-free tau calibration on fused setup (DBA ON, QE OFF)
    tau = calibrate_tau_score_fusion(ckpt_paths, folds_df, train_all, cache_dir, alpha=alpha, dba_M=dba_M, dba_lambda=dba_lambda, margin=margin, ratio=ratio, tta_hflip=tta_hflip, K=K, enable_rerank=enable_rerank, rerank_k1=rerank_k1, rerank_k2=rerank_k2, rerank_lam=rerank_lam)
    tau += float(tau_offset)

    # Final inference: per-fold retrieval + score fusion
    gal_labels_full = train_all['Id'].tolist()
    # Precompute fallback IDs (most frequent) to ensure 5 unique labels
    fallback_ids = train_all['Id'].value_counts().index.tolist()
    fused_scores = [defaultdict(float) for _ in range(len(test_df))]
    fused_maxsim = [dict() for _ in range(len(test_df))]  # per-query per-ID max similarity

    for k, ck in enumerate(ckpt_paths):
        if not os.path.exists(ck): continue
        state, _ = load_ckpt(ck)
        filtered = OrderedDict((kk,vv) for kk,vv in state.items() if kk in model.state_dict() and tuple(model.state_dict()[kk].shape)==tuple(vv.shape))
        _ = model.load_state_dict(filtered, strict=False)

        gal_feats = _get_feats_cached(model, train_all, f'{cache_dir}/gal_feats_k{k}.npy', tta_hflip=tta_hflip)
        gal_feats = l2_normalize(gal_feats, axis=1)
        if enable_dba and dba_M>0 and dba_lambda>0:
            gal_feats = dba_smooth(gal_feats, M=dba_M, lam=dba_lambda)
        index = build_index_ip(gal_feats)
        gal_nn_idx = _precompute_gal_nn(index, gal_feats, k2=rerank_k2)

        te_feats = _get_feats_cached(model, test_df.assign(Id='dummy'), f'{cache_dir}/test_feats_k{k}.npy', tta_hflip=tta_hflip)
        te_feats = l2_normalize(te_feats, axis=1)
        # Conditional QE only for final inference (not during calibration)
        if enable_qe and qe_L>0 and qe_lambda>0:
            te_feats = query_expansion(index, gal_feats, te_feats, L=qe_L, lam=qe_lambda, conditional_tau=tau)

        sims, idxs = index.search(te_feats, min(K, index.ntotal))
        if enable_rerank:
            sims, idxs = _rerank_k_reciprocal(sims, idxs, gal_nn_idx, k1=rerank_k1, k2=rerank_k2, lam=rerank_lam)
        for qi in range(len(test_df)):
            ns = sims[qi]; ni = idxs[qi]
            if ns.size == 0: continue
            nei_ids = [gal_labels_full[j] for j in ni]
            _accumulate_scores(nei_ids, ns, alpha, fused_scores[qi])
            _accumulate_max_sims(nei_ids, ns, fused_maxsim[qi])
        gc.collect()

    pred_rows = []
    new_whale_first = 0
    s1_list = []
    for qi in range(len(test_df)):
        # Build unique ordered labels by fused score
        ordered = [lab for lab,_ in sorted(fused_scores[qi].items(), key=lambda kv: -kv[1])]
        seen = set()
        ordered_unique = []
        for lab in ordered:
            if lab not in seen:
                ordered_unique.append(lab); seen.add(lab)
        top5 = []
        if _gate_from_fused_max(fused_maxsim[qi], tau, margin, ratio) and 'new_whale' not in top5:
            top5.append('new_whale')
        for lab in ordered_unique:
            if lab not in top5:
                top5.append(lab)
            if len(top5)==5: break
        # Pad with most frequent training IDs, ensuring uniqueness; only one 'new_whale'
        if len(top5) < 5:
            for fid in fallback_ids:
                if fid not in top5:
                    top5.append(fid)
                    if len(top5) == 5: break
        pred_rows.append(' '.join(top5[:5]))
        if len(top5)>0 and top5[0]=='new_whale': new_whale_first += 1
        # diagnostics: fused top1 similarity
        if len(fused_maxsim[qi])>0:
            s1_list.append(max(fused_maxsim[qi].values()))
        else:
            s1_list.append(-1.0)

    # Diagnostics
    nh_rate = new_whale_first / max(1, len(test_df))
    try:
        s1_arr = np.array(s1_list, dtype=np.float32)
        q25, q50, q75 = np.quantile(s1_arr, 0.25), np.quantile(s1_arr, 0.50), np.quantile(s1_arr, 0.75)
        print(f'[Fusion] new_whale@1 rate={nh_rate*100:.1f}% | fused s1 q25/q50/q75: {q25:.3f}/{q50:.3f}/{q75:.3f}')
    except Exception as e:
        print(f'[Fusion] new_whale@1 rate={nh_rate*100:.1f}% | diag error:', e)

    sub = pd.DataFrame({'Image': ss['Image'], 'Id': pred_rows})
    sub.to_csv(out_csv, index=False)
    print(f'[Fusion] Wrote {out_csv} shape {sub.shape} elapsed {time.time()-t0:.1f}s')
    return out_csv

In [49]:
# Quick diagnostics on submission: new_whale@1 rate and basic checks
import pandas as pd, numpy as np
from pathlib import Path
sub_path = Path('submission.csv')
assert sub_path.exists(), 'submission.csv not found'
sub = pd.read_csv(sub_path)
def first_label(s):
    try:
        return str(s).split()[0]
    except Exception:
        return ''
firsts = sub['Id'].map(first_label)
nh_rate = (firsts == 'new_whale').mean() * 100.0
print(f'new_whale@1 rate: {nh_rate:.2f}%')
print('Unique first labels:', firsts.nunique())
print('Head:')
print(sub.head().to_string(index=False))

new_whale@1 rate: 45.67%
Unique first labels: 736
Head:
       Image                                                Id
00087b01.jpg w_d9aab0a w_da2efe0 w_fea7fe6 w_73b705e w_17ee910
0014cfdf.jpg w_0e4ef50 w_fe8233d w_ea6651e w_823fcbb w_a74742c
0035632e.jpg w_95874a5 w_da2efe0 w_3c304db w_6e8486d w_8c1e2e4
004c5fb9.jpg new_whale w_17ee910 w_bb2d34d w_95874a5 w_8c1e2e4
00863b8c.jpg w_a646643 w_d19a884 w_1eafe46 w_b0e05b1 w_d36f58c


In [None]:
# Image-gallery pipeline inference (DBA, QE OFF) with corrected tau grid and tau_offset=0.0 (per expert backup plan)
from pathlib import Path
ckpts = [f'model_fold{k}_best.pth' for k in range(5)]
out_csv = run_full_inference_image_gallery(
    ckpt_paths=ckpts,
    out_csv='submission.csv',
    tta_hflip=True,
    enable_dba=True, dba_M=5, dba_lambda=0.2,
    enable_qe=False,
    tau_offset=0.0,
    margin=0.0, ratio=0.0
)
print('submission.csv exists?', Path(out_csv).exists())
print('submission.csv size:', Path(out_csv).stat().st_size if Path(out_csv).exists() else -1)

In [22]:
# Image-gallery inference v2 with explicit tau_grid override and diagnostics
import numpy as np, math, time, gc, os, pandas as pd, torch
from collections import OrderedDict
from pathlib import Path

def run_full_inference_image_gallery_v2(
    ckpt_paths,
    out_csv='submission.csv',
    tta_hflip=True,
    enable_dba=True, dba_M=5, dba_lambda=0.2,
    enable_qe=False, qe_L=0, qe_lambda=0.0,
    tau_offset=0.0,
    margin=0.0, ratio=0.0,
    tau_grid_override=None,
):
    t0_all = time.time()
    folds_df = pd.read_csv('folds_grouped.csv')
    folds_df['image_path'] = (Path('train')/folds_df['Image']).astype(str)
    train_all = folds_df[folds_df['Id']!='new_whale'].copy()
    ids_all = sorted(train_all['Id'].unique().tolist())
    n_classes = len(ids_all)
    model = Net(n_classes=n_classes).to(device).eval().to(memory_format=torch.channels_last)
    ss = pd.read_csv('sample_submission.csv')
    test_df = pd.DataFrame({'Image': ss['Image']})
    test_df['image_path'] = (Path('test')/test_df['Image']).astype(str)

    gal_feat_sum = None
    gal_labels = train_all['Id'].tolist()
    gal_paths = train_all['image_path'].tolist()
    test_feat_sum = None
    val_feat_sums = {}
    val_labels_per_fold = {}
    n_loaded = 0
    for fold_i, ck in enumerate(ckpt_paths):
        if not os.path.exists(ck):
            print(f'[Infer-IMG-v2] Skip missing ckpt: {ck}')
            continue
        state, _ = load_ckpt(ck)
        model_state = model.state_dict()
        filtered = OrderedDict((k,v) for k,v in state.items() if k in model_state and tuple(model_state[k].shape)==tuple(v.shape))
        _ = model.load_state_dict(filtered, strict=False)
        feats_gal, _, _ = extract_feats_bnneck(model, train_all, tta_hflip=tta_hflip)
        gal_feat_sum = feats_gal.astype('float32') if gal_feat_sum is None else gal_feat_sum + feats_gal.astype('float32')
        feats_te, _, _ = extract_feats_bnneck(model, test_df.assign(Id='dummy'), tta_hflip=tta_hflip)
        test_feat_sum = feats_te.astype('float32') if test_feat_sum is None else test_feat_sum + feats_te.astype('float32')
        va_df = folds_df[folds_df['fold']==fold_i].copy()
        feats_va, _, labs_va = extract_feats_bnneck(model, va_df, tta_hflip=tta_hflip)
        if fold_i not in val_feat_sums:
            val_feat_sums[fold_i] = feats_va.astype('float32')
            val_labels_per_fold[fold_i] = list(labs_va)
        else:
            val_feat_sums[fold_i] += feats_va.astype('float32')
        n_loaded += 1
        gc.collect()
    if n_loaded == 0:
        raise RuntimeError('No checkpoints processed; cannot run inference.')
    gal_mat_raw = l2_normalize(gal_feat_sum / n_loaded, axis=1).astype('float32')
    test_mat = l2_normalize(test_feat_sum / n_loaded, axis=1).astype('float32')
    for f in val_feat_sums.keys():
        val_feat_sums[f] = l2_normalize(val_feat_sums[f] / n_loaded, axis=1).astype('float32')
    assert gal_mat_raw.shape[0] == len(gal_labels)
    # Sanity: L2 norms
    gal_norms = np.linalg.norm(gal_mat_raw, axis=1)
    te_norms = np.linalg.norm(test_mat, axis=1)
    assert np.max(np.abs(gal_norms - 1.0)) < 1e-3, f'Gallery not L2-normalized: max|norm-1|={np.max(np.abs(gal_norms-1.0)):.3e}'
    assert np.max(np.abs(te_norms - 1.0)) < 1e-3, f'Test not L2-normalized: max|norm-1|={np.max(np.abs(te_norms-1.0)):.3e}'

    print('[Infer-IMG-v2] Starting leak-free per-fold tau calibration...')
    train_all_idx = train_all.reset_index(drop=True)
    gal_labels_arr = np.array(gal_labels, dtype=object)
    gallery_fold_indices = train_all_idx['fold'].values
    taus_best = []
    fold_scores = []
    for f in sorted(val_feat_sums.keys()):
        sub_mask = (gallery_fold_indices != f)
        gal_sub = gal_mat_raw[sub_mask]
        labels_sub = gal_labels_arr[sub_mask].tolist()
        gal_sub_dba = dba_smooth(gal_sub, M=dba_M, lam=dba_lambda) if enable_dba else gal_sub
        index_sub = build_index_ip(gal_sub_dba)
        q = val_feat_sums[f]
        labs = val_labels_per_fold[f]
        sims_all, idxs_all = index_sub.search(q, min(50, index_sub.ntotal))
        truths = [lab if lab in set(labels_sub) else 'new_whale' for lab in labs]
        # Diagnostics for s1 scale + grid selection
        s1 = sims_all[:,0] if sims_all.size>0 else np.array([], dtype=np.float32)
        if s1.size > 0:
            q25, q50, q75 = np.quantile(s1, [0.25, 0.50, 0.75])
        else:
            q25=q50=q75=0.0
        if tau_grid_override is not None:
            tau_grid_fold = np.array(list(tau_grid_override), dtype=float)
        else:
            if s1.size > 0:
                p10, p90 = np.quantile(s1, [0.10, 0.90])
                lo = float(max(0.0, min(1.0, p10 - 0.01)))
                hi = float(max(0.0, min(1.0, p90 + 0.01)))
                if hi <= lo:
                    lo, hi = 0.55, 0.75
                tau_grid_fold = np.arange(lo, hi + 1e-9, 0.001)
            else:
                tau_grid_fold = np.arange(0.55, 0.75, 0.005)
        print(f"[OOF-fold-img-v2] f={f} s1 q25/q50/q75={q25:.3f}/{q50:.3f}/{q75:.3f} | grid={tau_grid_fold[0]:.3f}-{tau_grid_fold[-1]:.3f} (n={len(tau_grid_fold)})", flush=True)
        best_tau_f, best_sc_f = None, -1.0
        for tau in tau_grid_fold:
            preds = []
            for i in range(q.shape[0]):
                nei_idx = idxs_all[i]; nei_sims = sims_all[i]
                nei_ids = [labels_sub[j] for j in nei_idx]
                # vote-by-ID
                scores_map = {}
                for lab2, sim in zip(nei_ids, nei_sims):
                    scores_map[lab2] = scores_map.get(lab2, 0.0) + math.exp(ALPHA * float(sim))
                ordered = [k for k,_ in sorted(scores_map.items(), key=lambda x: -x[1])]
                s1i = float(nei_sims[0]) if len(nei_sims)>0 else -1.0
                s2i = float(nei_sims[1]) if len(nei_sims)>1 else s1i
                top5 = []
                if (len(nei_sims)>0 and s1i < tau) or (len(nei_sims)>1 and (s1i - s2i) < margin) or (len(nei_sims)>1 and (s1i/max(s2i,1e-6)) < ratio):
                    top5.append('new_whale')
                for lab2 in ordered:
                    if lab2 not in top5: top5.append(lab2)
                    if len(top5)==5: break
                if len(top5)<5: top5 += ['__DUMMY__']*(5-len(top5))  # dummy padding during calib
                preds.append(top5[:5])
            sc = map5_score(preds, truths)
            if sc > best_sc_f:
                best_sc_f, best_tau_f = sc, float(tau)
        taus_best.append(best_tau_f if best_tau_f is not None else float(np.median(tau_grid_fold)))
        fold_scores.append(best_sc_f if best_sc_f is not None else 0.0)
        print(f'[OOF-fold-img-v2] f={f} tau*={taus_best[-1]:.3f} score={fold_scores[-1]:.4f}')

    tau_global = float(np.median(taus_best))
    mean_oof = float(np.mean(fold_scores))
    tau_global = float(np.clip(tau_global + float(tau_offset), 0.0, 1.0))
    print(f'[OOF-img-v2] mean MAP@5={mean_oof:.4f} | median tau={tau_global:.3f} (after offset={float(tau_offset):.4f})')

    # Build final gallery and perform final inference WITHOUT global override and WITHOUT test quantile gate
    gal_mat = dba_smooth(gal_mat_raw, M=dba_M, lam=dba_lambda) if enable_dba else gal_mat_raw
    index = build_index_ip(gal_mat)

    # QE for final inference (conditional if enabled)
    if enable_qe and qe_L>0 and qe_lambda>0:
        test_q = query_expansion(index, gal_mat, test_mat, L=qe_L, lam=qe_lambda, conditional_tau=tau_global)
    else:
        test_q = test_mat

    sims_all, idxs_all = index.search(test_q, min(50, index.ntotal))
    pred_rows = []
    new_first = 0
    for i in range(test_q.shape[0]):
        nei_idx = idxs_all[i]; nei_sims = sims_all[i]
        nei_ids = [gal_labels[j] for j in nei_idx]
        scores_map = {}
        for lab, sim in zip(nei_ids, nei_sims):
            scores_map[lab] = scores_map.get(lab, 0.0) + math.exp(ALPHA * float(sim))
        ordered = [k for k,_ in sorted(scores_map.items(), key=lambda x:-x[1])]
        s1 = float(nei_sims[0]) if len(nei_sims)>0 else -1.0
        s2 = float(nei_sims[1]) if len(nei_sims)>1 else s1
        top5 = []
        if (len(nei_sims)>0 and s1 < tau_global) or (len(nei_sims)>1 and (s1 - s2) < margin) or (len(nei_sims)>1 and (s1/max(s2,1e-6)) < ratio):
            top5.append('new_whale')
        for lab2 in ordered:
            if lab2 not in top5: top5.append(lab2)
            if len(top5)==5: break
        if len(top5)<5: top5 += ['new_whale']*(5-len(top5))
        if len(top5)>0 and top5[0]=='new_whale': new_first += 1
        pred_rows.append(' '.join(top5[:5]))
    nh_rate = 100.0 * new_first / max(1, len(test_df))
    print(f'[Infer-IMG-v2] Test new_whale@1={nh_rate:.2f}% | tau={tau_global:.3f}')

    sub = pd.DataFrame({'Image': ss['Image'], 'Id': pred_rows})
    sub.to_csv(out_csv, index=False)
    print('Wrote (IMG-v2) ', out_csv, 'shape', sub.shape, 'elapsed', f'{time.time()-t0_all:.1f}s')
    print(sub.head().to_string(index=False))
    return out_csv

# Example run (execute in a separate cell):
# out_csv = run_full_inference_image_gallery_v2(
#     ckpt_paths=[f'model_fold{k}_best.pth' for k in range(5)],
#     out_csv='submission.csv',
#     tta_hflip=True,
#     enable_dba=True, dba_M=5, dba_lambda=0.2,
#     enable_qe=False,
#     tau_offset=0.0,
#     margin=0.0, ratio=0.0,
#     tau_grid_override=np.arange(0.55, 0.75, 0.005),
# )

In [35]:
# Execute image-gallery v2 with expert backup params (stabilized gating)
ckpts = [f'model_fold{k}_best.pth' for k in range(5)]
out_csv = run_full_inference_image_gallery_v2(
    ckpt_paths=ckpts,
    out_csv='submission.csv',
    tta_hflip=True,
    enable_dba=True, dba_M=8, dba_lambda=0.3,
    enable_qe=False, qe_L=0, qe_lambda=0.0,
    tau_offset=-0.05,
    margin=0.02, ratio=1.04,
    tau_grid_override=np.arange(0.35, 0.55, 0.005),
)
from pathlib import Path
print('submission.csv exists?', Path(out_csv).exists())
print('submission.csv size:', Path(out_csv).stat().st_size if Path(out_csv).exists() else -1)

  ckpt = torch.load(ckpt_path, map_location='cpu')


FE bi 20, 7.7s


FE bi 40, 13.9s


FE bi 60, 20.0s


FE bi 80, 26.2s


FE bi 100, 32.3s


FE bi 20, 7.3s


FE bi 40, 13.5s


FE bi 20, 7.4s


  ckpt = torch.load(ckpt_path, map_location='cpu')


FE bi 20, 7.8s


FE bi 40, 14.0s


FE bi 60, 20.2s


FE bi 80, 26.4s


FE bi 100, 32.6s


FE bi 20, 7.4s


FE bi 40, 13.6s


FE bi 20, 7.4s


  ckpt = torch.load(ckpt_path, map_location='cpu')


FE bi 20, 7.9s


FE bi 40, 14.1s


FE bi 60, 20.3s


FE bi 80, 26.6s


FE bi 100, 32.8s


FE bi 20, 7.8s


FE bi 40, 14.1s


FE bi 20, 7.7s


  ckpt = torch.load(ckpt_path, map_location='cpu')


FE bi 20, 7.6s


FE bi 40, 13.9s


FE bi 60, 20.1s


FE bi 80, 26.4s


FE bi 100, 32.7s


FE bi 20, 7.5s


FE bi 40, 13.8s


FE bi 20, 7.9s


  ckpt = torch.load(ckpt_path, map_location='cpu')


FE bi 20, 7.7s


FE bi 40, 14.0s


FE bi 60, 20.3s


FE bi 80, 26.5s


FE bi 100, 32.8s


FE bi 20, 7.5s


FE bi 40, 13.8s


FE bi 20, 7.6s


[Infer-IMG-v2] Starting leak-free per-fold tau calibration...


[OOF-fold-img-v2] f=0 s1 q25/q50/q75=0.411/0.422/0.439 | grid=0.350-0.550 (n=41)


[OOF-fold-img-v2] f=0 tau*=0.350 score=0.5828


[OOF-fold-img-v2] f=1 s1 q25/q50/q75=0.495/0.502/0.510 | grid=0.350-0.550 (n=41)


[OOF-fold-img-v2] f=1 tau*=0.350 score=0.5838


[OOF-fold-img-v2] f=2 s1 q25/q50/q75=0.469/0.475/0.480 | grid=0.350-0.550 (n=41)


[OOF-fold-img-v2] f=2 tau*=0.350 score=0.5670


[OOF-fold-img-v2] f=3 s1 q25/q50/q75=0.387/0.415/0.437 | grid=0.350-0.550 (n=41)


[OOF-fold-img-v2] f=3 tau*=0.425 score=0.5795


[OOF-fold-img-v2] f=4 s1 q25/q50/q75=0.439/0.443/0.448 | grid=0.350-0.550 (n=41)


[OOF-fold-img-v2] f=4 tau*=0.350 score=0.5625
[OOF-img-v2] mean MAP@5=0.5751 | median tau=0.300 (after offset=-0.0500)


[Infer-IMG-v2] Test new_whale@1=99.50% | tau=0.300
Wrote (IMG-v2)  submission.csv shape (2610, 2) elapsed 292.1s
       Image                                                Id
00087b01.jpg new_whale w_d9aab0a w_3b0894d w_b942708 w_7c7a78c
0014cfdf.jpg new_whale w_3e0f25d w_a74742c w_ea6651e w_0e4ef50
0035632e.jpg new_whale w_b942708 w_da2efe0 w_511c464 w_3a7d86d
004c5fb9.jpg new_whale w_17ee910 w_da2efe0 w_b942708 w_cf00b01
00863b8c.jpg new_whale w_1eafe46 w_d36f58c w_64f3545 w_3f2a05c
submission.csv exists? True
submission.csv size: 164439


In [None]:
# Override: use high-range tau grid for prototype similarity scale (BNNeck+DBA ~0.95-1.00)
import numpy as np

def _build_dynamic_tau_grid_from_s1(s1_vals, lo_floor=0.90, hi_cap=0.999, pad=0.01, step=0.001):
    v = s1_vals[np.isfinite(s1_vals)]
    v = v[v > 0]
    if v.size > 0:
        p10, p90 = np.quantile(v, [0.10, 0.90])
        lo = max(lo_floor, float(p10) - pad)
        hi = min(hi_cap, float(p90) + pad)
        if hi > lo:
            return np.arange(lo, hi + 1e-9, step)
    return np.arange(lo_floor, hi_cap + 1e-9, step)

print('[Patch] _build_dynamic_tau_grid_from_s1 set to high-range [0.90, 0.999] with step=0.001')