# Plan to Medal: Humpback Whale Identification (MAP@5)

## 0) Environment & GPU
- Verify GPU with nvidia-smi.
- Install PyTorch cu121 stack and key libs (timm, albumentations, torchvision).
- Add constraints.txt to lock torch versions.

## 1) Data Audit
- Inspect train.csv, sample_submission.csv.
- Check image counts in train/ and test/; verify filename matching.
- Class distribution, #classes, imbalance, images per class.
- Create stratified KFold splits (by class) and persist folds.

## 2) Baseline Model
- Image classification with strong pretrained backbone (timm):
  - Start: tf_efficientnet_b0_ns or convnext_tiny, 224–256 px.
  - Loss: CrossEntropy with label smoothing (0.05).
  - Augs: flips, random resized crop, color jitter, CutMix/Mixup.
  - Optim: AdamW, OneCycle/Cosine, AMP, weight decay 1e-4.
  - 5-fold CV, early stopping; log per-fold metrics.
- Produce OOF logits and test logits.

## 3) Iterate
- Upgrade backbone (convnext_base, eva02, NFNet), increase res (384).
- Fine-tune head-last-only warmup then full.
- TTA at inference (hflip, multi-scale).
- Calibrate/Blend multiple seeds/backbones (weighted avg).

## 4) Validation Discipline
- Single saved folds; transforms fit per-fold only.
- Monitor MAP@5 on OOF; ensure no leakage.
- Cache features/logits to npy for fast iteration.

## 5) Submission
- Generate top-5 labels per image in required format.
- Verify submission.csv shape/columns and spot-check.

## Checkpoints for Expert Review
- After this plan.
- After data audit + folds.
- After first baseline OOF.
- After each major model/ensemble change.

In [2]:
# Environment check: GPU + files + quick data audit
import os, sys, subprocess, time, json, pandas as pd, numpy as np, glob
from pathlib import Path

t0 = time.time()
print('Python:', sys.version)
print('CWD:', os.getcwd())
print('Listing top-level:')
for p in Path('.').iterdir():
    mtime = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(p.stat().st_mtime))
    print('-', p, ('[DIR]' if p.is_dir() else ''), mtime)

print('\n== nvidia-smi ==')
try:
    out = subprocess.check_output(['bash','-lc','nvidia-smi || true'], text=True)
    print(out)
except Exception as e:
    print('nvidia-smi error:', e)

# Quick file counts
train_dir = Path('train')
test_dir = Path('test')
train_imgs = sorted(str(p.name) for p in train_dir.glob('*') if p.is_file()) if train_dir.exists() else []
test_imgs = sorted(str(p.name) for p in test_dir.glob('*') if p.is_file()) if test_dir.exists() else []
print(f'Train images: {len(train_imgs)} | Test images: {len(test_imgs)}')
print('Sample train imgs:', train_imgs[:5])
print('Sample test imgs:', test_imgs[:5])

# Load CSVs
train_csv = pd.read_csv('train.csv')
sample_sub = pd.read_csv('sample_submission.csv')
print('\ntrain.csv shape:', train_csv.shape)
print(train_csv.head())
print('\nsample_submission.csv shape:', sample_sub.shape)
print(sample_sub.head())

# Basic integrity checks
img_col = 'Image' if 'Image' in train_csv.columns else train_csv.columns[0]
train_fns = set(train_csv[img_col].astype(str))
if len(train_imgs) and len(train_fns):
    inter = len(set(train_imgs) & train_fns)
    print(f'Filename match train.csv vs train/ : {inter}/{len(train_fns)} present')
else:
    print('Skip filename match check (missing files or column)')

print(f'Elapsed: {time.time()-t0:.2f}s')

Python: 3.11.0rc1 (main, Aug 12 2022, 10:02:14) [GCC 11.2.0]
CWD: /var/lib/simon/agent_run_states/whale-categorization-playground-20250929-051000
Listing top-level:
- sample_submission.csv  2025-09-29 05:13:46
- .00_eda_and_planning_kernel_state.json  2025-09-29 05:18:16
- docker_run.log  2025-09-29 05:18:16
- requirements.txt  2025-09-29 05:17:55
- test [DIR] 2025-09-29 05:13:47
- submission.csv  2025-09-29 05:13:51
- train [DIR] 2025-09-29 05:13:47
- train.csv  2025-09-29 05:13:46
- agent_metadata [DIR] 2025-09-29 05:14:31
- task.txt  2025-09-29 05:13:19
- 00_eda_and_planning.ipynb  2025-09-29 05:18:12
- description.md  2025-09-29 05:13:47

== nvidia-smi ==
Mon Sep 29 05:18:16 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.144.06             Driver Version: 550.144.06     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name             

In [3]:
# Install PyTorch cu121 stack and key libs; sanity check GPU
import os, sys, subprocess, shutil, time
from pathlib import Path

def pip(*args):
    print('> pip', ' '.join(args), flush=True)
    subprocess.run([sys.executable, '-m', 'pip', *args], check=True)

t0=time.time()
print('Starting install...')
# Uninstall any pre-existing torch stack to avoid conflicts
for pkg in ('torch','torchvision','torchaudio'):
    subprocess.run([sys.executable, '-m', 'pip', 'uninstall', '-y', pkg], check=False)

# Clean stray site dirs that can shadow correct wheels (idempotent)
for d in (
    '/app/.pip-target/torch',
    '/app/.pip-target/torchvision',
    '/app/.pip-target/torchaudio',
    '/app/.pip-target/torch-2.8.0.dist-info',
    '/app/.pip-target/torch-2.4.1.dist-info',
    '/app/.pip-target/torchvision-0.23.0.dist-info',
    '/app/.pip-target/torchvision-0.19.1.dist-info',
    '/app/.pip-target/torchaudio-2.8.0.dist-info',
    '/app/.pip-target/torchaudio-2.4.1.dist-info',
    '/app/.pip-target/torchgen',
    '/app/.pip-target/functorch',
):
    if os.path.exists(d):
        print('Removing', d)
        shutil.rmtree(d, ignore_errors=True)

# Install exact cu121 wheels
pip('install',
    '--index-url','https://download.pytorch.org/whl/cu121',
    '--extra-index-url','https://pypi.org/simple',
    'torch==2.4.1','torchvision==0.19.1','torchaudio==2.4.1')

# Freeze torch versions
Path('constraints.txt').write_text('torch==2.4.1\ntorchvision==0.19.1\ntorchaudio==2.4.1\n')

# Install non-torch deps honoring constraints
pip('install','-c','constraints.txt',
    'timm==1.0.9','albumentations==1.4.14','opencv-python-headless==4.10.0.84',
    'faiss-cpu==1.8.0','scikit-learn','pandas','numpy','matplotlib',
    '--upgrade-strategy','only-if-needed')

# Sanity check torch CUDA
import torch
print('torch:', torch.__version__, 'built CUDA:', getattr(torch.version,'cuda',None))
print('CUDA available:', torch.cuda.is_available())
assert str(getattr(torch.version,'cuda','')).startswith('12.1'), f'Wrong CUDA build: {torch.version.cuda}'
assert torch.cuda.is_available(), 'CUDA not available'
print('GPU:', torch.cuda.get_device_name(0))
print(f'Install done in {time.time()-t0:.1f}s')

Starting install...






> pip install --index-url https://download.pytorch.org/whl/cu121 --extra-index-url https://pypi.org/simple torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1




Looking in indexes: https://download.pytorch.org/whl/cu121, https://pypi.org/simple


Collecting torch==2.4.1
  Downloading https://download.pytorch.org/whl/cu121/torch-2.4.1%2Bcu121-cp311-cp311-linux_x86_64.whl (799.0 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 799.0/799.0 MB 438.9 MB/s eta 0:00:00


Collecting torchvision==0.19.1
  Downloading https://download.pytorch.org/whl/cu121/torchvision-0.19.1%2Bcu121-cp311-cp311-linux_x86_64.whl (7.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.1/7.1 MB 466.6 MB/s eta 0:00:00


Collecting torchaudio==2.4.1
  Downloading https://download.pytorch.org/whl/cu121/torchaudio-2.4.1%2Bcu121-cp311-cp311-linux_x86_64.whl (3.4 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.4/3.4 MB 491.7 MB/s eta 0:00:00
Collecting nvidia-nccl-cu12==2.20.5
  Downloading nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl (176.2 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 176.2/176.2 MB 219.7 MB/s eta 0:00:00


Collecting nvidia-curand-cu12==10.3.2.106
  Downloading nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 56.5/56.5 MB 243.7 MB/s eta 0:00:00
Collecting nvidia-cusparse-cu12==12.1.0.106
  Downloading nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 196.0/196.0 MB 236.2 MB/s eta 0:00:00


Collecting nvidia-nvtx-cu12==12.1.105
  Downloading nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 99.1/99.1 KB 453.6 MB/s eta 0:00:00
Collecting triton==3.0.0
  Downloading triton-3.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (209.4 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 209.4/209.4 MB 366.9 MB/s eta 0:00:00


Collecting filelock
  Downloading filelock-3.19.1-py3-none-any.whl (15 kB)


Collecting nvidia-cublas-cu12==12.1.3.1
  Downloading nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 410.6/410.6 MB 185.4 MB/s eta 0:00:00


Collecting nvidia-cusolver-cu12==11.4.5.107
  Downloading nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 124.2/124.2 MB 208.6 MB/s eta 0:00:00


Collecting nvidia-cuda-cupti-cu12==12.1.105
  Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 14.1/14.1 MB 191.2 MB/s eta 0:00:00
Collecting nvidia-cuda-runtime-cu12==12.1.105
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 823.6/823.6 KB 312.1 MB/s eta 0:00:00


Collecting nvidia-cufft-cu12==11.0.2.54
  Downloading nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 121.6/121.6 MB 261.6 MB/s eta 0:00:00


Collecting nvidia-cuda-nvrtc-cu12==12.1.105
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 23.7/23.7 MB 227.2 MB/s eta 0:00:00


Collecting networkx
  Downloading networkx-3.5-py3-none-any.whl (2.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.0/2.0 MB 481.5 MB/s eta 0:00:00
Collecting sympy
  Downloading sympy-1.14.0-py3-none-any.whl (6.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.3/6.3 MB 535.3 MB/s eta 0:00:00


Collecting typing-extensions>=4.8.0
  Downloading typing_extensions-4.15.0-py3-none-any.whl (44 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 44.6/44.6 KB 364.7 MB/s eta 0:00:00
Collecting fsspec
  Downloading fsspec-2025.9.0-py3-none-any.whl (199 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 199.3/199.3 KB 468.5 MB/s eta 0:00:00


Collecting jinja2
  Downloading jinja2-3.1.6-py3-none-any.whl (134 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 134.9/134.9 KB 426.4 MB/s eta 0:00:00
Collecting nvidia-cudnn-cu12==9.1.0.70
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 664.8/664.8 MB 200.3 MB/s eta 0:00:00


Collecting pillow!=8.3.*,>=5.3.0
  Downloading pillow-11.3.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (6.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.6/6.6 MB 235.6 MB/s eta 0:00:00


Collecting numpy
  Downloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 18.3/18.3 MB 185.9 MB/s eta 0:00:00


Collecting nvidia-nvjitlink-cu12
  Downloading nvidia_nvjitlink_cu12-12.9.86-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl (39.7 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 39.7/39.7 MB 247.6 MB/s eta 0:00:00


Collecting MarkupSafe>=2.0
  Downloading markupsafe-3.0.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (22 kB)
Collecting mpmath<1.4,>=1.1.0
  Downloading mpmath-1.3.0-py3-none-any.whl (536 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 536.2/536.2 KB 500.1 MB/s eta 0:00:00


Installing collected packages: mpmath, typing-extensions, sympy, pillow, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, numpy, networkx, MarkupSafe, fsspec, filelock, triton, nvidia-cusparse-cu12, nvidia-cudnn-cu12, jinja2, nvidia-cusolver-cu12, torch, torchvision, torchaudio


Successfully installed MarkupSafe-3.0.3 filelock-3.19.1 fsspec-2025.9.0 jinja2-3.1.6 mpmath-1.3.0 networkx-3.5 numpy-1.26.4 nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.20.5 nvidia-nvjitlink-cu12-12.9.86 nvidia-nvtx-cu12-12.1.105 pillow-11.3.0 sympy-1.14.0 torch-2.4.1+cu121 torchaudio-2.4.1+cu121 torchvision-0.19.1+cu121 triton-3.0.0 typing-extensions-4.15.0


> pip install -c constraints.txt timm==1.0.9 albumentations==1.4.14 opencv-python-headless==4.10.0.84 faiss-cpu==1.8.0 scikit-learn pandas numpy matplotlib --upgrade-strategy only-if-needed


Collecting timm==1.0.9
  Downloading timm-1.0.9-py3-none-any.whl (2.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.3/2.3 MB 55.3 MB/s eta 0:00:00
Collecting albumentations==1.4.14
  Downloading albumentations-1.4.14-py3-none-any.whl (177 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 178.0/178.0 KB 422.2 MB/s eta 0:00:00


Collecting opencv-python-headless==4.10.0.84
  Downloading opencv_python_headless-4.10.0.84-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (49.9 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 49.9/49.9 MB 212.8 MB/s eta 0:00:00
Collecting faiss-cpu==1.8.0
  Downloading faiss_cpu-1.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (27.0 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 27.0/27.0 MB 131.2 MB/s eta 0:00:00
Collecting scikit-learn
  Downloading scikit_learn-1.7.2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (9.7 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 9.7/9.7 MB 211.1 MB/s eta 0:00:00
Collecting pandas
  Downloading pandas-2.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (12.4 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 12.4/12.4 MB 572.6 MB/s eta 0:00:00


Collecting numpy
  Downloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 18.3/18.3 MB 288.0 MB/s eta 0:00:00


Collecting matplotlib
  Downloading matplotlib-3.10.6-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (8.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 8.7/8.7 MB 407.5 MB/s eta 0:00:00
Collecting torch
  Downloading torch-2.4.1-cp311-cp311-manylinux1_x86_64.whl (797.1 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 797.1/797.1 MB 85.1 MB/s eta 0:00:00


Collecting pyyaml
  Downloading pyyaml-6.0.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (806 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 806.6/806.6 KB 526.7 MB/s eta 0:00:00
Collecting huggingface_hub
  Downloading huggingface_hub-0.35.1-py3-none-any.whl (563 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 563.3/563.3 KB 529.6 MB/s eta 0:00:00
Collecting torchvision
  Downloading torchvision-0.19.1-cp311-cp311-manylinux1_x86_64.whl (7.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.0/7.0 MB 232.2 MB/s eta 0:00:00


Collecting safetensors
  Downloading safetensors-0.6.2-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (485 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 485.8/485.8 KB 508.6 MB/s eta 0:00:00
Collecting typing-extensions>=4.9.0
  Downloading typing_extensions-4.15.0-py3-none-any.whl (44 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 44.6/44.6 KB 362.6 MB/s eta 0:00:00
Collecting eval-type-backport
  Downloading eval_type_backport-0.2.2-py3-none-any.whl (5.8 kB)
Collecting albucore>=0.0.13
  Downloading albucore-0.0.33-py3-none-any.whl (18 kB)


Collecting scipy>=1.10.0
  Downloading scipy-1.16.2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (35.9 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 35.9/35.9 MB 549.3 MB/s eta 0:00:00


Collecting pydantic>=2.7.0
  Downloading pydantic-2.11.9-py3-none-any.whl (444 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 444.9/444.9 KB 516.9 MB/s eta 0:00:00
Collecting scikit-image>=0.21.0
  Downloading scikit_image-0.25.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (14.8 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 14.8/14.8 MB 199.2 MB/s eta 0:00:00


Collecting threadpoolctl>=3.1.0
  Downloading threadpoolctl-3.6.0-py3-none-any.whl (18 kB)
Collecting joblib>=1.2.0
  Downloading joblib-1.5.2-py3-none-any.whl (308 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 308.4/308.4 KB 475.7 MB/s eta 0:00:00
Collecting python-dateutil>=2.8.2
  Downloading python_dateutil-2.9.0.post0-py2.py3-none-any.whl (229 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 229.9/229.9 KB 475.4 MB/s eta 0:00:00
Collecting tzdata>=2022.7
  Downloading tzdata-2025.2-py2.py3-none-any.whl (347 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 347.8/347.8 KB 507.5 MB/s eta 0:00:00
Collecting pytz>=2020.1
  Downloading pytz-2025.2-py2.py3-none-any.whl (509 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 509.2/509.2 KB 503.1 MB/s eta 0:00:00


Collecting pyparsing>=2.3.1
  Downloading pyparsing-3.2.5-py3-none-any.whl (113 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 113.9/113.9 KB 413.4 MB/s eta 0:00:00
Collecting cycler>=0.10
  Downloading cycler-0.12.1-py3-none-any.whl (8.3 kB)
Collecting kiwisolver>=1.3.1
  Downloading kiwisolver-1.4.9-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (1.4 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.4/1.4 MB 532.3 MB/s eta 0:00:00


Collecting fonttools>=4.22.0
  Downloading fonttools-4.60.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (5.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5.0/5.0 MB 553.4 MB/s eta 0:00:00
Collecting packaging>=20.0
  Downloading packaging-25.0-py3-none-any.whl (66 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 66.5/66.5 KB 391.8 MB/s eta 0:00:00
Collecting contourpy>=1.0.1
  Downloading contourpy-1.3.3-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (355 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 355.2/355.2 KB 490.4 MB/s eta 0:00:00


Collecting pillow>=8
  Downloading pillow-11.3.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (6.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.6/6.6 MB 399.1 MB/s eta 0:00:00


Collecting stringzilla>=3.10.4
  Downloading stringzilla-4.0.14-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl (496 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 496.5/496.5 KB 359.8 MB/s eta 0:00:00


Collecting simsimd>=5.9.2
  Downloading simsimd-6.5.3-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (1.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.1/1.1 MB 167.1 MB/s eta 0:00:00
Collecting typing-inspection>=0.4.0
  Downloading typing_inspection-0.4.1-py3-none-any.whl (14 kB)


Collecting pydantic-core==2.33.2
  Downloading pydantic_core-2.33.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.0/2.0 MB 519.8 MB/s eta 0:00:00
Collecting annotated-types>=0.6.0
  Downloading annotated_types-0.7.0-py3-none-any.whl (13 kB)
Collecting six>=1.5
  Downloading six-1.17.0-py2.py3-none-any.whl (11 kB)
Collecting tifffile>=2022.8.12
  Downloading tifffile-2025.9.20-py3-none-any.whl (230 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 230.1/230.1 KB 452.9 MB/s eta 0:00:00
Collecting networkx>=3.0
  Downloading networkx-3.5-py3-none-any.whl (2.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.0/2.0 MB 546.0 MB/s eta 0:00:00
Collecting lazy-loader>=0.4
  Downloading lazy_loader-0.4-py3-none-any.whl (12 kB)


Collecting imageio!=2.35.0,>=2.33
  Downloading imageio-2.37.0-py3-none-any.whl (315 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 315.8/315.8 KB 500.2 MB/s eta 0:00:00
Collecting hf-xet<2.0.0,>=1.1.3
  Downloading hf_xet-1.1.10-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.2 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.2/3.2 MB 289.3 MB/s eta 0:00:00


Collecting fsspec>=2023.5.0
  Downloading fsspec-2025.9.0-py3-none-any.whl (199 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 199.3/199.3 KB 417.1 MB/s eta 0:00:00
Collecting tqdm>=4.42.1
  Downloading tqdm-4.67.1-py3-none-any.whl (78 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 78.5/78.5 KB 450.6 MB/s eta 0:00:00
Collecting filelock
  Downloading filelock-3.19.1-py3-none-any.whl (15 kB)
Collecting requests
  Downloading requests-2.32.5-py3-none-any.whl (64 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 64.7/64.7 KB 404.6 MB/s eta 0:00:00


Collecting sympy
  Downloading sympy-1.14.0-py3-none-any.whl (6.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.3/6.3 MB 348.5 MB/s eta 0:00:00
Collecting nvidia-cuda-cupti-cu12==12.1.105
  Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 14.1/14.1 MB 238.3 MB/s eta 0:00:00
Collecting nvidia-cusolver-cu12==11.4.5.107
  Downloading nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 124.2/124.2 MB 227.3 MB/s eta 0:00:00
Collecting jinja2
  Downloading jinja2-3.1.6-py3-none-any.whl (134 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 134.9/134.9 KB 487.1 MB/s eta 0:00:00
Collecting nvidia-nvtx-cu12==12.1.105
  Downloading nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 99.1/99.1 KB 468.4 MB/s eta 0:00:00
Collecting nvidia-cuda-nvrtc-cu12==12.1.105
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 23.7/23.7 MB 226.7 MB/s eta 0:00:00
Collecting nvidia-cudnn-cu12==9.1.0.70
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 664.8/664.8 MB 251.5 MB/s eta 0:00:00


Collecting nvidia-curand-cu12==10.3.2.106
  Downloading nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 56.5/56.5 MB 285.9 MB/s eta 0:00:00
Collecting nvidia-cusparse-cu12==12.1.0.106
  Downloading nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 196.0/196.0 MB 244.5 MB/s eta 0:00:00
Collecting nvidia-nccl-cu12==2.20.5
  Downloading nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl (176.2 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 176.2/176.2 MB 192.8 MB/s eta 0:00:00
Collecting triton==3.0.0
  Downloading triton-3.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (209.4 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 209.4/209.4 MB 122.2 MB/s eta 0:00:00
Collecting nvidia-cuda-runtime-cu12==12.1.105
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 823.6/823.6 KB 501.9 MB/s eta 0:00:00
Collecting nvidia-cublas-cu12==12.1.3.1
  Downloading nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 410.6/410.6 MB 228.5 MB/s eta 0:00:00


Collecting nvidia-cufft-cu12==11.0.2.54
  Downloading nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 121.6/121.6 MB 249.7 MB/s eta 0:00:00
Collecting nvidia-nvjitlink-cu12
  Downloading nvidia_nvjitlink_cu12-12.9.86-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl (39.7 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 39.7/39.7 MB 206.7 MB/s eta 0:00:00


Collecting MarkupSafe>=2.0
  Downloading markupsafe-3.0.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (22 kB)
Collecting idna<4,>=2.5
  Downloading idna-3.10-py3-none-any.whl (70 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 70.4/70.4 KB 390.8 MB/s eta 0:00:00
Collecting urllib3<3,>=1.21.1
  Downloading urllib3-2.5.0-py3-none-any.whl (129 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 129.8/129.8 KB 411.6 MB/s eta 0:00:00
Collecting charset_normalizer<4,>=2
  Downloading charset_normalizer-3.4.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (150 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 150.3/150.3 KB 488.0 MB/s eta 0:00:00


Collecting certifi>=2017.4.17
  Downloading certifi-2025.8.3-py3-none-any.whl (161 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 161.2/161.2 KB 477.8 MB/s eta 0:00:00
Collecting mpmath<1.4,>=1.1.0
  Downloading mpmath-1.3.0-py3-none-any.whl (536 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 536.2/536.2 KB 294.1 MB/s eta 0:00:00


Installing collected packages: simsimd, pytz, mpmath, urllib3, tzdata, typing-extensions, tqdm, threadpoolctl, sympy, stringzilla, six, safetensors, pyyaml, pyparsing, pillow, packaging, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, numpy, networkx, MarkupSafe, kiwisolver, joblib, idna, hf-xet, fsspec, fonttools, filelock, eval-type-backport, cycler, charset_normalizer, certifi, annotated-types, typing-inspection, triton, tifffile, scipy, requests, python-dateutil, pydantic-core, opencv-python-headless, nvidia-cusparse-cu12, nvidia-cudnn-cu12, lazy-loader, jinja2, imageio, faiss-cpu, contourpy, scikit-learn, scikit-image, pydantic, pandas, nvidia-cusolver-cu12, matplotlib, huggingface_hub, albucore, torch, albumentations, torchvision, timm


Successfully installed MarkupSafe-3.0.3 albucore-0.0.33 albumentations-1.4.14 annotated-types-0.7.0 certifi-2025.8.3 charset_normalizer-3.4.3 contourpy-1.3.3 cycler-0.12.1 eval-type-backport-0.2.2 faiss-cpu-1.8.0 filelock-3.19.1 fonttools-4.60.0 fsspec-2025.9.0 hf-xet-1.1.10 huggingface_hub-0.35.1 idna-3.10 imageio-2.37.0 jinja2-3.1.6 joblib-1.5.2 kiwisolver-1.4.9 lazy-loader-0.4 matplotlib-3.10.6 mpmath-1.3.0 networkx-3.5 numpy-1.26.4 nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.20.5 nvidia-nvjitlink-cu12-12.9.86 nvidia-nvtx-cu12-12.1.105 opencv-python-headless-4.10.0.84 packaging-25.0 pandas-2.3.2 pillow-11.3.0 pydantic-2.11.9 pydantic-core-2.33.2 pyparsing-3.2.5 python-dateutil-2.9.0.post0 pytz-2025.2 pyyaml-6.0.3 requests-2.32.5 safetens









torch: 2.4.1+cu121 built CUDA: 12.1
CUDA available: True
GPU: NVIDIA A10-24Q
Install done in 133.9s


In [4]:
# Duplicate-aware folds: pHash clustering + StratifiedGroupKFold
import os, time, math, gc
from pathlib import Path
import numpy as np
import pandas as pd
import cv2
from sklearn.model_selection import StratifiedGroupKFold

t0=time.time()
print('Starting duplicate audit and fold creation...')
train_df = pd.read_csv('train.csv')
img_col = 'Image'
id_col = 'Id'
train_dir = Path('train')

# Fast perceptual hash via DCT (OpenCV), returns 64-bit int as hex string
def phash_cv2(img_path, hash_size=8, highfreq_factor=4):
    try:
        img = cv2.imdecode(np.fromfile(str(img_path), dtype=np.uint8), cv2.IMREAD_GRAYSCALE)
        if img is None:
            return None
        img = cv2.resize(img, (hash_size*highfreq_factor, hash_size*highfreq_factor), interpolation=cv2.INTER_AREA)
        img = np.float32(img)
        dct = cv2.dct(img)
        dctlow = dct[:hash_size, :hash_size]
        med = np.median(dctlow)
        diff = dctlow > med
        bits = ''.join('1' if x else '0' for x in diff.flatten())
        return hex(int(bits, 2))
    except Exception:
        return None

def hamming_hex(h1, h2):
    if h1 is None or h2 is None:
        return 64
    n1 = int(h1, 16); n2 = int(h2, 16)
    return (n1 ^ n2).bit_count()

# Compute pHash for all train images
paths = [train_dir / fn for fn in train_df[img_col].tolist()]
hashes = []
log_every = 500
for i, p in enumerate(paths):
    if (i % log_every)==0:
        print(f'phash {i}/{len(paths)} elapsed {time.time()-t0:.1f}s', flush=True)
    hashes.append(phash_cv2(p))
train_df['phash'] = hashes

# Cluster near-duplicates by simple binning + union-find (Hamming <= 4)
parent = list(range(len(train_df)))
def find(x):
    while parent[x]!=x:
        parent[x]=parent[parent[x]]
        x=parent[x]
    return x
def union(a,b):
    ra, rb = find(a), find(b)
    if ra!=rb:
        parent[rb]=ra

# Bucket by first N hex chars to reduce comparisons
prefix = 4  # 16 bits bucket
buckets = {}
for idx,h in enumerate(train_df['phash']):
    if h is None:
        key = 'none'
    else:
        key = h[:2+prefix]  # '0x' + prefix chars
    buckets.setdefault(key, []).append(idx)

thr = 4
checked_pairs = 0
for key, idxs in buckets.items():
    n = len(idxs)
    if n<=1: continue
    # compare all pairs within bucket (typically small)
    for i in range(n):
        hi = train_df.at[idxs[i], 'phash']
        for j in range(i+1, n):
            hj = train_df.at[idxs[j], 'phash']
            d = hamming_hex(hi, hj)
            checked_pairs += 1
            if d <= thr:
                union(idxs[i], idxs[j])
print('Buckets:', len(buckets), 'pairs checked:', checked_pairs)

# Assign cluster ids
cluster_id = [find(i) for i in range(len(train_df))]
root_map = {}
next_cluster = 0
clusters = []
for r in cluster_id:
    if r not in root_map:
        root_map[r] = next_cluster; next_cluster += 1
    clusters.append(root_map[r])
train_df['dup_cluster'] = clusters

print('Unique dup clusters:', train_df['dup_cluster'].nunique())
print('Preparing StratifiedGroupKFold with groups = dup_cluster (fallback to Id if needed)')

y = train_df[id_col].values
groups = train_df['dup_cluster'].values

n_splits = 5
cv = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=42)
folds = np.full(len(train_df), -1, dtype=int)
for fold, (trn_idx, val_idx) in enumerate(cv.split(np.zeros(len(train_df)), y, groups)):
    folds[val_idx] = fold
    print(f'Fold {fold}: trn {len(trn_idx)} | val {len(val_idx)} | elapsed {time.time()-t0:.1f}s', flush=True)

assert (folds>=0).all(), 'Some rows not assigned to folds'
train_df['fold'] = folds
train_df[['Image','Id','phash','dup_cluster','fold']].to_csv('folds.csv', index=False)
print('Saved folds.csv')

# Sanity: per-fold label coverage
lab_counts = train_df.groupby(['fold'])[id_col].nunique()
print('Unique IDs per fold:\n', lab_counts)
print(f'Done in {time.time()-t0:.1f}s')

Starting duplicate audit and fold creation...
phash 0/7240 elapsed 0.0s


phash 500/7240 elapsed 0.9s


phash 1000/7240 elapsed 1.7s


phash 1500/7240 elapsed 2.6s


phash 2000/7240 elapsed 3.5s


phash 2500/7240 elapsed 4.4s


phash 3000/7240 elapsed 5.3s


phash 3500/7240 elapsed 6.2s


phash 4000/7240 elapsed 7.0s


phash 4500/7240 elapsed 7.9s


phash 5000/7240 elapsed 8.7s


phash 5500/7240 elapsed 9.6s


phash 6000/7240 elapsed 10.5s


phash 6500/7240 elapsed 11.3s


phash 7000/7240 elapsed 12.2s


Buckets: 3909 pairs checked: 14350
Unique dup clusters: 6861
Preparing StratifiedGroupKFold with groups = dup_cluster (fallback to Id if needed)


  type_of_target_y = type_of_target(y)


Fold 0: trn 5780 | val 1460 | elapsed 16.3s


Fold 1: trn 5791 | val 1449 | elapsed 16.3s


Fold 2: trn 5789 | val 1451 | elapsed 16.3s


Fold 3: trn 5809 | val 1431 | elapsed 16.3s


Fold 4: trn 5791 | val 1449 | elapsed 16.3s


Saved folds.csv
Unique IDs per fold:
 fold
0    1053
1    1055
2    1038
3    1100
4    1066
Name: Id, dtype: int64
Done in 16.3s


In [6]:
# Pretrained embedding retrieval baseline (ConvNeXt-Base @384) + CV MAP@5 + submission
import os, time, math, gc, sys
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import cv2
import timm

torch.backends.cudnn.benchmark = True
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

train_df = pd.read_csv('train.csv')
folds_df = pd.read_csv('folds.csv')  # Image, Id, phash, dup_cluster, fold
train_df = train_df.merge(folds_df[['Image','fold']], on='Image', how='left')
test_files = sorted([p.name for p in Path('test').glob('*') if p.is_file()])

# Quick label info
id_counts = train_df['Id'].value_counts()
print('Num train images:', len(train_df), 'Num IDs:', id_counts.shape[0], 'new_whale in train:', ('new_whale' in id_counts.index))

IMG_SIZE = 384
MODEL_NAME = 'convnext_base.fb_in22k_ft_in1k'
BATCH_SIZE = 64
NUM_WORKERS = 4

IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
IMAGENET_STD  = np.array([0.229, 0.224, 0.225], dtype=np.float32)

def read_image(path):
    data = np.fromfile(path, dtype=np.uint8)
    img = cv2.imdecode(data, cv2.IMREAD_COLOR)
    if img is None:
        raise RuntimeError(f'Failed to read {path}')
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img

def preprocess(img, size, hflip=False):
    h, w = img.shape[:2]
    scale = min(size / h, size / w)
    nh, nw = int(round(h * scale)), int(round(w * scale))
    img_resized = cv2.resize(img, (nw, nh), interpolation=cv2.INTER_AREA)
    top = (size - nh) // 2
    bottom = size - nh - top
    left = (size - nw) // 2
    right = size - nw - left
    img_padded = cv2.copyMakeBorder(img_resized, top, bottom, left, right, borderType=cv2.BORDER_CONSTANT, value=(0,0,0))
    if hflip:
        img_padded = np.ascontiguousarray(img_padded[:, ::-1, :])
    x = img_padded.astype(np.float32) / 255.0
    x = (x - IMAGENET_MEAN) / IMAGENET_STD
    x = np.transpose(x, (2, 0, 1))  # CHW
    return torch.from_numpy(x)

class ImgDs(Dataset):
    def __init__(self, root_dir, img_names, size, hflip=False):
        self.root = Path(root_dir)
        self.names = img_names
        self.size = size
        self.hflip = hflip
    def __len__(self): return len(self.names)
    def __getitem__(self, i):
        fn = self.names[i]
        img = read_image(self.root / fn)
        tensor = preprocess(img, self.size, hflip=self.hflip)
        return fn, tensor

class FeatExtractor(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        self.model = timm.create_model(model_name, pretrained=True, num_classes=0, global_pool='avg')
    def forward(self, x):
        with torch.cuda.amp.autocast(enabled=True):
            feats = self.model(x)
        return feats

@torch.no_grad()
def extract_embeddings(model, root_dir, img_names, size, hflip=False):
    ds = ImgDs(root_dir, img_names, size=size, hflip=hflip)
    dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    dummy = torch.zeros(1,3,size,size, device=device)
    emb_dim = model(dummy).shape[1]
    embs = np.zeros((len(img_names), emb_dim), dtype=np.float32)
    order = []
    t0=time.time()
    seen=0
    for i,(fns, imgs) in enumerate(dl):
        imgs = imgs.to(device, non_blocking=True)
        feats = model(imgs)
        feats = nn.functional.normalize(feats, dim=1).float().cpu().numpy()
        embs[seen:seen+feats.shape[0]] = feats
        order.extend(fns)
        seen += feats.shape[0]
        if i%20==0:
            print(f'embed batch {i}, total {seen}/{len(img_names)}', flush=True)
    return order, embs

def mapk(truths, preds, k=5):
    score=0.0
    for t, ps in zip(truths, preds):
        for i,p in enumerate(ps[:k]):
            if p==t:
                score += 1.0/(i+1)
                break
    return score/len(truths)

def centroid_by_id(names, embs, ids):
    df = pd.DataFrame({'Image': names, 'Id': ids})
    id_to_idx = {}
    for idv, grp in df.groupby('Id').groups.items():
        idxs = np.array(list(grp), dtype=int)
        id_to_idx[idv] = idxs
    centroids = {}
    for idv, idxs in id_to_idx.items():
        centroids[idv] = nn.functional.normalize(torch.from_numpy(embs[idxs]).mean(0, keepdims=True), dim=1).numpy()[0].astype(np.float32)
    labels = list(centroids.keys())
    mat = np.stack([centroids[l] for l in labels], axis=0)
    return labels, mat

def topk_labels(query_embs, gallery_labels, gallery_centroids, k=5):
    sims = query_embs @ gallery_centroids.T
    k_eff = min(k, sims.shape[1])
    topk_idx = np.argpartition(-sims, kth=k_eff-1, axis=1)[:, :k_eff]
    rows = np.arange(sims.shape[0])[:,None]
    sorted_order = np.argsort(-sims[rows, topk_idx], axis=1)
    topk_idx_sorted = topk_idx[rows, sorted_order]
    pred_labels = [[gallery_labels[j] for j in row] for row in topk_idx_sorted]
    top_scores = np.take_along_axis(sims, topk_idx_sorted, axis=1)
    return pred_labels, top_scores

# Build model
model = FeatExtractor(MODEL_NAME).to(device).eval()

# Extract train embeddings (orig + hflip TTA)
train_names = train_df['Image'].tolist()
print('Extracting train embeddings ...')
order1, emb1 = extract_embeddings(model, 'train', train_names, IMG_SIZE, hflip=False)
order_flip, emb_flip = extract_embeddings(model, 'train', train_names, IMG_SIZE, hflip=True)
assert order1==order_flip==train_names, 'Embedding order mismatch'
train_emb = nn.functional.normalize(torch.from_numpy((emb1 + emb_flip)/2.0), dim=1).numpy().astype(np.float32)
del emb1, emb_flip; gc.collect()

# CV MAP@5 using centroids from train folds
n_splits = int(train_df['fold'].max())+1
all_scores = []
for fold in range(n_splits):
    t_start=time.time()
    trn_mask = train_df['fold'] != fold
    val_mask = train_df['fold'] == fold
    trn_names = train_df.loc[trn_mask, 'Image'].tolist()
    val_names = train_df.loc[val_mask, 'Image'].tolist()
    trn_ids = train_df.loc[trn_mask, 'Id'].tolist()
    val_ids = train_df.loc[val_mask, 'Id'].tolist()
    idx_map = {name:i for i,name in enumerate(train_names)}
    trn_idx = np.array([idx_map[n] for n in trn_names], dtype=int)
    val_idx = np.array([idx_map[n] for n in val_names], dtype=int)
    labels, centroids = centroid_by_id(trn_names, train_emb[trn_idx], trn_ids)
    preds, scores = topk_labels(train_emb[val_idx], labels, centroids, k=5)
    score = mapk(val_ids, preds, k=5)
    all_scores.append(score)
    print(f'Fold {fold} MAP@5: {score:.4f} elapsed {time.time()-t_start:.1f}s with {len(labels)} gallery IDs')
print('OOF MAP@5 (mean):', np.mean(all_scores).round(5))

# Extract test embeddings with TTA
print('Extracting test embeddings ...')
test_names = test_files
order_t1, test_emb1 = extract_embeddings(model, 'test', test_names, IMG_SIZE, hflip=False)
order_tflip, test_emb_flip = extract_embeddings(model, 'test', test_names, IMG_SIZE, hflip=True)
assert order_t1==order_tflip==test_names, 'Test order mismatch'
test_emb = nn.functional.normalize(torch.from_numpy((test_emb1 + test_emb_flip)/2.0), dim=1).numpy().astype(np.float32)

# Build final gallery centroids from full train
labels_all, centroids_all = centroid_by_id(train_names, train_emb, train_df['Id'].tolist())
preds_test, scores_test = topk_labels(test_emb, labels_all, centroids_all, k=5)

# Optional new_whale thresholding (simple v1): if max score < thr, place 'new_whale' at rank-1
thr = 0.70
has_new = 'new_whale' in labels_all
fallback_label = labels_all[0]
final_strs = []
for i, labs in enumerate(preds_test):
    smax = scores_test[i,0]
    out = labs[:5]
    if (smax < thr) and has_new and ('new_whale' not in out):
        out = ['new_whale'] + out[:4]
    # ensure 5 unique labels
    seen=set()
    uniq=[]
    for l in out:
        if l not in seen:
            uniq.append(l); seen.add(l)
        if len(uniq)==5: break
    while len(uniq)<5:
        if has_new and 'new_whale' not in seen:
            uniq.append('new_whale'); seen.add('new_whale')
        else:
            uniq.append(fallback_label)
    final_strs.append(' '.join(uniq))

sub = pd.DataFrame({'Image': test_names, 'Id': final_strs})
sub.to_csv('submission.csv', index=False)
print('Saved submission.csv; head:\n', sub.head())

Device: cuda
Num train images: 7240 Num IDs: 4029 new_whale in train: True


Extracting train embeddings ...


  with torch.cuda.amp.autocast(enabled=True):


embed batch 0, total 64/7240


embed batch 20, total 1344/7240


embed batch 40, total 2624/7240


embed batch 60, total 3904/7240


embed batch 80, total 5184/7240


embed batch 100, total 6464/7240


embed batch 0, total 64/7240


embed batch 20, total 1344/7240


embed batch 40, total 2624/7240


embed batch 60, total 3904/7240


embed batch 80, total 5184/7240


embed batch 100, total 6464/7240


Fold 0 MAP@5: 0.0815 elapsed 0.2s with 3371 gallery IDs


Fold 1 MAP@5: 0.0794 elapsed 0.2s with 3385 gallery IDs


Fold 2 MAP@5: 0.0777 elapsed 0.3s with 3399 gallery IDs


Fold 3 MAP@5: 0.0822 elapsed 0.3s with 3363 gallery IDs


Fold 4 MAP@5: 0.0950 elapsed 0.2s with 3395 gallery IDs
OOF MAP@5 (mean): 0.08314
Extracting test embeddings ...


  with torch.cuda.amp.autocast(enabled=True):


embed batch 0, total 64/2610


embed batch 20, total 1344/2610


embed batch 40, total 2610/2610


embed batch 0, total 64/2610


embed batch 20, total 1344/2610


embed batch 40, total 2610/2610


Saved submission.csv; head:
           Image                                                 Id
0  00087b01.jpg  w_2957331 w_6e47e0e w_0f54cdf w_9df0865 w_da2efe0
1  0014cfdf.jpg  w_5982299 w_f3b63ba w_790c2aa w_adffd9c w_9874f0d
2  0035632e.jpg  w_3d0bc7a w_b729b1f w_4c9d3df w_4659acf w_38e4aae
3  004c5fb9.jpg  w_1a5e7a2 w_17ee910 w_bb2d34d w_d7de1ee w_f4e0748
4  00863b8c.jpg  w_e09e886 w_7311fe4 w_2c68b75 w_2db01d5 new_whale


In [7]:
# FAISS per-image retrieval + temperature voting + new_whale calibration; save embs
import numpy as np, pandas as pd, time, faiss, torch, math, gc, random

# Sanity: expect train_emb and train_names from previous cell
assert 'train_emb' in globals() and 'train_names' in globals(), 'Run cell 4 first to build train embeddings'
train_df_full = pd.read_csv('train.csv')
folds_df = pd.read_csv('folds.csv')
train_df_full = train_df_full.merge(folds_df[['Image','fold']], on='Image', how='left')

# Exclude new_whale from gallery always
is_new = (train_df_full['Id'] == 'new_whale').values

# Cache embeddings for reuse
np.save('emb_train_avg_tta.npy', train_emb)
pd.Series(train_names).to_csv('emb_train_names.txt', index=False, header=False)
print('Saved train embeddings to disk')

def softmax_temp(x, T=0.07):
    x = x / max(T, 1e-6)
    x = x - x.max(axis=1, keepdims=True)
    ex = np.exp(x)
    return ex / (ex.sum(axis=1, keepdims=True) + 1e-9)

def vote_labels(nei_idx, nei_sim, labels, T=0.07, topk=5):
    # nei_idx: [N, K], nei_sim: [N, K], labels: list length G
    w = softmax_temp(nei_sim, T=T)  # [N,K]
    N, K = nei_idx.shape
    out = []
    out_scores = []
    for i in range(N):
        agg = {}
        for j in range(K):
            lab = labels[nei_idx[i, j]]
            agg[lab] = agg.get(lab, 0.0) + w[i, j]
        # sort by agg desc
        items = sorted(agg.items(), key=lambda t: -t[1])[:topk]
        out.append([k for k,_ in items])
        out_scores.append([v for _,v in items])
    return out, out_scores

def mapk(truths, preds, k=5):
    score=0.0
    for t, ps in zip(truths, preds):
        for i,p in enumerate(ps[:k]):
            if p==t:
                score += 1.0/(i+1); break
    return score/len(truths)

def build_index(emb):
    d = emb.shape[1]
    index = faiss.IndexFlatIP(d)
    return index

def l2norm(a):
    na = np.linalg.norm(a, axis=1, keepdims=True) + 1e-9
    return a / na

def oof_eval_faiss(train_emb, train_names, train_df, T=0.07, K=50, tune_thr=True, unknown_frac=0.3, seed=42):
    rng = random.Random(seed)
    n_splits = int(train_df['fold'].max())+1
    all_scores = []
    best_thrs = []
    for fold in range(n_splits):
        t0=time.time()
        trn_mask = (train_df['fold'] != fold) & (train_df['Id'] != 'new_whale')
        val_mask = (train_df['fold'] == fold)
        trn_df = train_df[trn_mask].reset_index(drop=True)
        val_df = train_df[val_mask].reset_index(drop=True)
        idx_map = {name:i for i,name in enumerate(train_names)}
        trn_idx = np.array([idx_map[n] for n in trn_df['Image']], dtype=int)
        val_idx = np.array([idx_map[n] for n in val_df['Image']], dtype=int)
        G = train_emb[trn_idx].astype(np.float32)
        Q = train_emb[val_idx].astype(np.float32)
        # build labels list and faiss index
        gallery_labels = trn_df['Id'].tolist()
        G = l2norm(G); Q = l2norm(Q)
        index = build_index(G)
        index.add(G)
        D, I = index.search(Q, min(K, len(G)))  # cosine via IP
        # temperature voting
        pred_top, pred_scores = vote_labels(I, D, gallery_labels, T=T, topk=5)
        # Unknowns calibration: hold out some whole IDs from gallery to simulate new_whale
        thr_best = 0.7
        if tune_thr:
            ids_train = list(set(trn_df['Id'].tolist()))
            rng.shuffle(ids_train)
            hold_n = max(1, int(len(ids_train)*unknown_frac))
            hold_ids = set(ids_train[:hold_n])
            # mask held-out IDs in voting by zeroing their votes
            held_mask = np.array([1.0 if lab not in hold_ids else 0.0 for lab in gallery_labels], dtype=np.float32)
            # recompute per-query label scores with held-out masked
            w = softmax_temp(D, T=T)  # [Nv, K]
            Nv, Kk = I.shape
            tuned_preds = []
            tuned_max = []
            for i in range(Nv):
                agg = {}
                mx = 0.0
                for j in range(Kk):
                    lab = gallery_labels[I[i, j]]
                    weight = w[i, j] * (1.0 if lab not in hold_ids else 0.0)
                    if weight<=0: continue
                    agg[lab] = agg.get(lab, 0.0) + weight
                    if agg[lab] > mx: mx = agg[lab]
                items = sorted(agg.items(), key=lambda t: -t[1])[:5]
                tuned_preds.append([k for k,_ in items])
                tuned_max.append(mx)
            tuned_max = np.array(tuned_max, dtype=np.float32)
            # grid search threshold
            truths = val_df['Id'].tolist()
            grid = [round(x,2) for x in np.arange(0.60, 0.82, 0.02)]
            best_score = -1.0
            for thr in grid:
                preds_thr = []
                for i, labs in enumerate(tuned_preds):
                    out = labs[:5]
                    if tuned_max[i] < thr:
                        out = ['new_whale'] + out[:4]
                    # ensure unique and fill with new_whale
                    seen=set(); uniq=[]
                    for l in out:
                        if l not in seen: uniq.append(l); seen.add(l)
                        if len(uniq)==5: break
                    while len(uniq)<5:
                        if 'new_whale' not in seen: uniq.append('new_whale'); seen.add('new_whale')
                        else: uniq.append(gallery_labels[0])
                    preds_thr.append(uniq)
                score = mapk(truths, preds_thr, k=5)
                if score > best_score:
                    best_score = score; thr_best = thr
        # Apply best threshold to original predictions
        truths = val_df['Id'].tolist()
        preds_final = []
        # compute max aggregated weight per query for thresholding
        w_all = softmax_temp(D, T=T)
        Nv, Kk = I.shape
        q_max = np.zeros(Nv, dtype=np.float32)
        for i in range(Nv):
            agg = {}
            mx = 0.0
            for j in range(Kk):
                lab = gallery_labels[I[i, j]]
                agg[lab] = agg.get(lab, 0.0) + w_all[i, j]
                if agg[lab] > mx: mx = agg[lab]
            out = sorted(agg.items(), key=lambda t: -t[1])[:5]
            labs = [k for k,_ in out]
            if (mx < thr_best):
                labs = ['new_whale'] + labs[:4]
            # ensure 5 unique labels
            seen=set(); uniq=[]
            for l in labs:
                if l not in seen: uniq.append(l); seen.add(l)
                if len(uniq)==5: break
            while len(uniq)<5:
                if 'new_whale' not in seen: uniq.append('new_whale'); seen.add('new_whale')
                else: uniq.append(gallery_labels[0])
            preds_final.append(uniq)
        score = mapk(truths, preds_final, k=5)
        print(f'[FAISS] Fold {fold} MAP@5: {score:.4f} | best_thr {thr_best:.2f} | T {T} | K {K} | elapsed {time.time()-t0:.1f}s')
        all_scores.append(score); best_thrs.append(thr_best)
    return float(np.mean(all_scores)), float(np.mean(best_thrs))

# Run OOF FAISS eval with voting and threshold tuning
mean_oof, avg_thr = oof_eval_faiss(train_emb, train_names, train_df_full, T=0.07, K=50, tune_thr=True, unknown_frac=0.3, seed=42)
print('OOF (FAISS) MAP@5 mean:', round(mean_oof,5), 'avg tuned thr:', round(avg_thr,2))

# Build final gallery on full train excluding new_whale and generate submission via FAISS voting
gallery_mask = (train_df_full['Id'] != 'new_whale')
gallery_idx = np.array([i for i,n in enumerate(train_names) if gallery_mask.iloc[i]], dtype=int)
G_all = l2norm(train_emb[gallery_idx].astype(np.float32))
labels_all = train_df_full.loc[gallery_mask, 'Id'].tolist()
index_all = build_index(G_all); index_all.add(G_all)

# If test embeddings exist from previous cell, reuse; else, skip to only evaluate OOF
if 'test_emb' not in globals():
    print('Warning: test_emb not found; run extraction cell 4 for test to write submission.')
else:
    Q = l2norm(test_emb.astype(np.float32))
    K = min(50, len(G_all))
    D, I = index_all.search(Q, K)
    # aggregate votes
    w = softmax_temp(D, T=0.07)
    final_preds = []
    for i in range(Q.shape[0]):
        agg = {}
        for j in range(K):
            lab = labels_all[I[i, j]]
            agg[lab] = agg.get(lab, 0.0) + w[i, j]
        items = sorted(agg.items(), key=lambda t: -t[1])[:5]
        labs = [k for k,_ in items]
        mx = items[0][1] if items else 0.0
        if mx < avg_thr:
            labs = ['new_whale'] + labs[:4]
        seen=set(); uniq=[]
        for l in labs:
            if l not in seen: uniq.append(l); seen.add(l)
            if len(uniq)==5: break
        while len(uniq)<5:
            if 'new_whale' not in seen: uniq.append('new_whale'); seen.add('new_whale')
            else: uniq.append(labels_all[0])
        final_preds.append(' '.join(uniq))
    sub = pd.DataFrame({'Image': pd.read_csv('sample_submission.csv')['Image'], 'Id': final_preds})
    sub.to_csv('submission.csv', index=False)
    print('Saved submission.csv (FAISS voting)')

gc.collect()

Saved train embeddings to disk


[FAISS] Fold 0 MAP@5: 0.1498 | best_thr 0.72 | T 0.07 | K 50 | elapsed 0.7s


[FAISS] Fold 1 MAP@5: 0.1478 | best_thr 0.60 | T 0.07 | K 50 | elapsed 0.7s


[FAISS] Fold 2 MAP@5: 0.1445 | best_thr 0.60 | T 0.07 | K 50 | elapsed 0.8s


[FAISS] Fold 3 MAP@5: 0.1347 | best_thr 0.60 | T 0.07 | K 50 | elapsed 0.7s


[FAISS] Fold 4 MAP@5: 0.1431 | best_thr 0.60 | T 0.07 | K 50 | elapsed 0.7s
OOF (FAISS) MAP@5 mean: 0.14397 avg tuned thr: 0.62


Saved submission.csv (FAISS voting)


0

In [None]:
# ArcFace metric-learning smoke run (fold 0) with ConvNeXt-Base @384; extract embs + FAISS submission
import os, time, math, gc, random, sys
from pathlib import Path
import numpy as np
import pandas as pd
from PIL import Image, ImageOps
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import timm
import faiss

torch.backends.cudnn.benchmark = True
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

# Config
IMG_SIZE = 384
MODEL_NAME = 'convnext_base.fb_in22k_ft_in1k'
BATCH_SIZE_TR = 32
BATCH_SIZE_INF = 96
NUM_WORKERS = 6
EMB_DIM = 512
EPOCHS_WARM = 1   # smoke
EPOCHS_FT = 2     # smoke
LR_BACKBONE = 2e-4
LR_HEAD = 1e-3
WD = 5e-2
MARGIN = 0.35
SCALE = 30.0
FOLD_TO_RUN = 0

IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)
IMAGENET_STD  = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)

def read_image_pil(path):
    img = Image.open(path)
    try:
        img = ImageOps.exif_transpose(img)
    except Exception:
        pass
    return img.convert('RGB')

class TrainDs(Dataset):
    def __init__(self, df, root, size, aug=True):
        self.df = df.reset_index(drop=True)
        self.root = Path(root)
        self.size = size
        self.aug = aug
    def __len__(self): return len(self.df)
    def __getitem__(self, i):
        row = self.df.iloc[i]
        img = read_image_pil(self.root / row.Image)
        # RandomResizedCrop-like
        if self.aug:
            scale = random.uniform(0.9, 1.0)
        else:
            scale = 1.0
        w,h = img.size
        nw, nh = int(w*scale), int(h*scale)
        img = img.resize((nw, nh), Image.BICUBIC)
        # pad to square
        pad_w = max(0, self.size - nw); pad_h = max(0, self.size - nh)
        pad_left = pad_w//2; pad_right = pad_w - pad_left
        pad_top = pad_h//2; pad_bottom = pad_h - pad_top
        if pad_w>0 or pad_h>0:
            img = ImageOps.expand(img, border=(pad_left,pad_top,pad_right,pad_bottom), fill=(0,0,0))
        img = img.resize((self.size, self.size), Image.BICUBIC)
        if self.aug and random.random()<0.5:
            img = img.transpose(Image.FLIP_LEFT_RIGHT)
        if self.aug:
            # small rotation
            angle = random.uniform(-10,10)
            img = img.rotate(angle, resample=Image.BICUBIC, fillcolor=(0,0,0))
        x = torch.from_numpy(np.array(img)).permute(2,0,1).float()/255.0
        x = (x - IMAGENET_MEAN) / IMAGENET_STD
        y = int(row['label'])
        return x, y

class InferDs(Dataset):
    def __init__(self, names, root, size, hflip=False):
        self.names = names
        self.root = Path(root)
        self.size = size
        self.hflip = hflip
    def __len__(self): return len(self.names)
    def __getitem__(self, i):
        fn = self.names[i]
        img = read_image_pil(self.root / fn)
        img = img.resize((self.size,self.size), Image.BICUBIC)
        if self.hflip:
            img = img.transpose(Image.FLIP_LEFT_RIGHT)
        x = torch.from_numpy(np.array(img)).permute(2,0,1).float()/255.0
        x = (x - IMAGENET_MEAN) / IMAGENET_STD
        return fn, x

class GeM(nn.Module):
    def __init__(self, p=3.0, eps=1e-6):
        super().__init__(); self.p = nn.Parameter(torch.ones(1)*p); self.eps=eps
    def forward(self, x):
        x = torch.clamp(x, min=self.eps).pow(self.p)
        x = F.avg_pool2d(x, (x.size(-2), x.size(-1))).pow(1.0/self.p)
        return x

class Backbone(nn.Module):
    def __init__(self, name, emb_dim):
        super().__init__()
        self.backbone = timm.create_model(name, pretrained=True, num_classes=0, global_pool='')
        self.pool = GeM()
        in_ch = self.backbone.num_features
        self.head = nn.Sequential(
            nn.Linear(in_ch, emb_dim, bias=False),
            nn.BatchNorm1d(emb_dim),
        )
    def forward(self, x):
        feat = self.backbone.forward_features(x)
        feat = self.pool(feat).view(feat.size(0), -1)
        emb = self.head(feat)
        emb = F.normalize(emb, p=2, dim=1)
        return emb

class ArcMarginProduct(nn.Module):
    def __init__(self, in_features, out_features, s=30.0, m=0.35, easy_margin=False):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s; self.m = m; self.easy_margin = easy_margin
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)
        self.cos_m = math.cos(m); self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m); self.mm = math.sin(math.pi - m) * m
    def forward(self, input, label):
        # input: [B, in_features] L2-normalized; weight normalized
        W = F.normalize(self.weight, p=2, dim=1)
        cosine = F.linear(input, W)  # [B, C]
        sine = torch.sqrt(torch.clamp(1.0 - cosine**2, min=1e-9))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        one_hot = torch.zeros_like(cosine)
        one_hot.scatter_(1, label.view(-1,1), 1.0)
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output = output * self.s
        return output

def build_loaders(train_df, fold, label_map):
    trn = train_df[(train_df.fold!=fold) & (train_df.Id!='new_whale')].copy()
    val = train_df[(train_df.fold==fold) & (train_df.Id!='new_whale')].copy()
    trn['label'] = trn['Id'].map(label_map).astype(int)
    val['label'] = val['Id'].map(label_map).astype(int)
    # class-balanced sampler
    cls_counts = trn['label'].value_counts().to_dict()
    weights = trn['label'].map(lambda x: 1.0/cls_counts[x]).values
    sampler = WeightedRandomSampler(weights=torch.DoubleTensor(weights), num_samples=len(weights), replacement=True)
    ds_tr = TrainDs(trn[['Image','label']], 'train', IMG_SIZE, aug=True)
    ds_va = TrainDs(val[['Image','label']], 'train', IMG_SIZE, aug=False)
    dl_tr = DataLoader(ds_tr, batch_size=BATCH_SIZE_TR, sampler=sampler, num_workers=NUM_WORKERS, pin_memory=True)
    dl_va = DataLoader(ds_va, batch_size=BATCH_SIZE_INF, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    return dl_tr, dl_va, trn, val

@torch.no_grad()
def extract_features(model, names, root='train', hflip=False):
    ds = InferDs(names, root, IMG_SIZE, hflip=hflip)
    dl = DataLoader(ds, batch_size=BATCH_SIZE_INF, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    embs = np.zeros((len(names), EMB_DIM), dtype=np.float32)
    order = []
    seen=0; t0=time.time()
    for it,(fns, imgs) in enumerate(dl):
        imgs = imgs.to(device, non_blocking=True)
        with torch.amp.autocast(device_type='cuda', enabled=True):
            e = model(imgs)
        e = F.normalize(e, dim=1).float().cpu().numpy()
        embs[seen:seen+e.shape[0]] = e
        order.extend(fns); seen += e.shape[0]
        if it%20==0: print(f'feat {it} {seen}/{len(names)}', flush=True)
    return order, embs

def l2norm(a):
    na = np.linalg.norm(a, axis=1, keepdims=True) + 1e-9
    return a / na

def softmax_temp(x, T=0.07):
    x = x / max(T, 1e-6); x = x - x.max(axis=1, keepdims=True); ex = np.exp(x); return ex/(ex.sum(axis=1, keepdims=True)+1e-9)

def vote_labels(nei_idx, nei_sim, labels, T=0.07, topk=5):
    w = softmax_temp(nei_sim, T=T); out=[]
    for i in range(nei_idx.shape[0]):
        agg={}
        for j in range(nei_idx.shape[1]):
            lab = labels[nei_idx[i,j]]; agg[lab]=agg.get(lab,0.0)+w[i,j]
        items = sorted(agg.items(), key=lambda t: -t[1])[:topk]
        out.append([k for k,_ in items])
    return out

# Load data and folds
train_df = pd.read_csv('train.csv')
folds_df = pd.read_csv('folds.csv')
df = train_df.merge(folds_df[['Image','fold']], on='Image', how='left')
ids = sorted(df.loc[df.Id!='new_whale','Id'].unique().tolist())
label_map = {idv:i for i,idv in enumerate(ids)}
n_classes = len(label_map)
print('Classes (excluding new_whale):', n_classes)

# Build model
model = Backbone(MODEL_NAME, EMB_DIM).to(device)
margin_head = ArcMarginProduct(EMB_DIM, n_classes, s=SCALE, m=MARGIN).to(device)

# Optimizers
def param_groups(model):
    decay, no_decay = [], []
    for n,p in model.named_parameters():
        if not p.requires_grad: continue
        if p.ndim==1 or n.endswith('bias') or 'bn' in n.lower(): no_decay.append(p)
        else: decay.append(p)
    return [ {'params': decay, 'weight_decay': WD}, {'params': no_decay, 'weight_decay': 0.0} ]

dl_tr, dl_va, trn_sub, val_sub = build_loaders(df, FOLD_TO_RUN, label_map)

scaler = torch.cuda.amp.GradScaler(enabled=True)

def train_one_epoch(epoch, warmup=False):
    model.train(); margin_head.train()
    if warmup:
        for p in model.backbone.parameters(): p.requires_grad=False
        opt = torch.optim.AdamW(list(model.head.parameters())+list(margin_head.parameters()), lr=LR_HEAD, weight_decay=WD)
    else:
        for p in model.backbone.parameters(): p.requires_grad=True
        opt = torch.optim.AdamW([{'params': model.parameters(), 'lr': LR_BACKBONE, 'weight_decay': WD},
                                {'params': margin_head.parameters(), 'lr': LR_HEAD, 'weight_decay': WD}],
                                lr=LR_BACKBONE, weight_decay=WD)
    running=0.0; n=0; t0=time.time()
    for it,(imgs, labels) in enumerate(dl_tr):
        imgs = imgs.to(device, non_blocking=True); labels = labels.to(device, non_blocking=True)
        opt.zero_grad(set_to_none=True)
        with torch.amp.autocast(device_type='cuda', enabled=True):
            emb = model(imgs)
            logits = margin_head(emb, labels)
            loss = F.cross_entropy(logits, labels)
        scaler.scale(loss).backward()
        scaler.step(opt); scaler.update()
        running += loss.item()*imgs.size(0); n += imgs.size(0)
        if it%50==0: print(f'ep{epoch} it{it} loss {running/max(1,n):.4f}', flush=True)
    print(f'ep{epoch} done loss {running/max(1,n):.4f} elapsed {time.time()-t0:.1f}s')

# Train (smoke)
for e in range(EPOCHS_WARM):
    train_one_epoch(e, warmup=True)
for e in range(EPOCHS_WARM, EPOCHS_WARM+EPOCHS_FT):
    train_one_epoch(e, warmup=False)

# Save checkpoint
ckpt_path = f'ckpt_convnext_base_fold{FOLD_TO_RUN}.pt'
torch.save({'model': model.state_dict(), 'head': margin_head.state_dict(), 'label_map': label_map}, ckpt_path)
print('Saved', ckpt_path)

# Extract train/test embeddings using trained model (orig + hflip)
train_names = df['Image'].tolist()
order_tr1, tr_e1 = extract_features(model, train_names, root='train', hflip=False)
order_tr2, tr_e2 = extract_features(model, train_names, root='train', hflip=True)
assert order_tr1==order_tr2==train_names, 'order mismatch'
train_emb_ml = l2norm(((tr_e1 + tr_e2)/2.0).astype(np.float32))
np.save(f'emb_train_ml_fold{FOLD_TO_RUN}.npy', train_emb_ml)

test_names = pd.read_csv('sample_submission.csv')['Image'].tolist()
order_te1, te_e1 = extract_features(model, test_names, root='test', hflip=False)
order_te2, te_e2 = extract_features(model, test_names, root='test', hflip=True)
assert order_te1==order_te2==test_names, 'test order mismatch'
test_emb_ml = l2norm(((te_e1 + te_e2)/2.0).astype(np.float32))
np.save(f'emb_test_ml_fold{FOLD_TO_RUN}.npy', test_emb_ml)

# Build FAISS gallery (exclude new_whale) and predict test with temperature voting
gallery_mask = (df['Id']!='new_whale')
gallery_labels = df.loc[gallery_mask, 'Id'].tolist()
gallery_idx = np.where(gallery_mask.values)[0]
G = train_emb_ml[gallery_idx].astype(np.float32)
index = faiss.IndexFlatIP(G.shape[1]); index.add(G)
K = min(50, len(G))
D, I = index.search(test_emb_ml, K)
def finalize_preds(D, I, labels_all, thr=0.62, T=0.07):
    w = softmax_temp(D, T=T)
    out=[]
    for i in range(I.shape[0]):
        agg={}
        for j in range(I.shape[1]):
            lab = labels_all[I[i,j]]; agg[lab]=agg.get(lab,0.0)+w[i,j]
        items = sorted(agg.items(), key=lambda t: -t[1])[:5]
        labs = [k for k,_ in items]
        mx = items[0][1] if items else 0.0
        if mx < thr: labs = ['new_whale'] + labs[:4]
        seen=set(); uniq=[]
        for l in labs:
            if l not in seen: uniq.append(l); seen.add(l)
            if len(uniq)==5: break
        while len(uniq)<5:
            if 'new_whale' not in seen: uniq.append('new_whale'); seen.add('new_whale')
            else: uniq.append(labels_all[0])
        out.append(' '.join(uniq))
    return out

pred_strs = finalize_preds(D, I, gallery_labels, thr=0.62, T=0.07)
sub = pd.DataFrame({'Image': test_names, 'Id': pred_strs})
sub.to_csv('submission.csv', index=False)
print('Wrote submission.csv (ArcFace smoke run)')
gc.collect()
print('Done.')

Device: cuda
Classes (excluding new_whale): 4028


  scaler = torch.cuda.amp.GradScaler(enabled=True)


ep0 it0 loss 19.0334
