# Plant Pathology 2021 - FGVC8: Plan

Objectives:
- Verify GPU availability and environment.
- Inspect data schema (train.csv, sample_submission.csv) and image directories.
- Establish robust CV and baseline image model (torch/TF) with multilabel micro-F1.
- Produce a working submission quickly, then iterate on aug/resolution/architectures and ensembling.

Initial Plan:
1) Env check (GPU).
2) Load train.csv/test images; verify multilabel format and submission format.
3) Quick EDA: label distribution, any leaks, filename mapping.
4) Baseline: torchvision + timm (e.g., tf_efficientnet_b0/b3, resnet50), resolution 512, strong aug (albumentations), BCEWithLogitsLoss, sigmoid threshold tuning on CV.
5) CV: StratifiedKFold for multilabel via iterative stratification (skmultilearn/iterstrat) or multilabel stratifier; otherwise GroupKFold if groups exist.
6) Logging: per-fold micro-F1, OOF saving, test-time augmentation, threshold calibration.
7) Submit baseline; iterate with higher res, different backbones, seeds, and blends.

Checkpoints for Expert Review:
- After env/EDA, after baseline CV setup, after first baseline OOF, after improvements/ensembles.

In [None]:
import os, subprocess, time, glob, sys
import pandas as pd

def run(cmd):
    print("$", " ".join(cmd), flush=True)
    try:
        out = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, check=False)
        print(out.stdout)
    except Exception as e:
        print("Command failed:", e)

print("Env check: nvidia-smi")
run(['bash','-lc','nvidia-smi || true'])

base = os.getcwd()
print("CWD:", base)
print("Listing top-level files:")
for p in sorted(os.listdir(base)):
    print(" -", p)

train_csv = 'train.csv'
sample_csv = 'sample_submission.csv' if os.path.exists('sample_submission.csv') else None
print("\nLoading CSVs...")
train_df = pd.read_csv(train_csv)
print("train.csv shape:", train_df.shape)
print(train_df.head(3))
if sample_csv:
    sample_df = pd.read_csv(sample_csv)
    print("sample_submission.csv shape:", sample_df.shape)
    print(sample_df.head(3))

train_images_dir = 'train_images'
test_images_dir = 'test_images'
train_imgs = glob.glob(os.path.join(train_images_dir, '*'))
test_imgs = glob.glob(os.path.join(test_images_dir, '*'))
print(f"Train images: {len(train_imgs)} | Test images: {len(test_imgs)}")
if len(train_imgs)>0:
    print("Sample train images:", [os.path.basename(p) for p in train_imgs[:5]])
if len(test_imgs)>0:
    print("Sample test images:", [os.path.basename(p) for p in test_imgs[:5]])

print("\nColumns in train.csv:", list(train_df.columns))
print("Done env/data probe.")

In [1]:
import os, sys, subprocess, shutil
from pathlib import Path

def pip(*args):
    print('>', *args, flush=True)
    subprocess.run([sys.executable, '-m', 'pip', *args], check=True)

# 0) Uninstall any preinstalled torch stack to avoid mismatches
for pkg in ('torch','torchvision','torchaudio'):
    subprocess.run([sys.executable, '-m', 'pip', 'uninstall', '-y', pkg], check=False)

# Clean stray site dirs (idempotent)
for d in (
    '/app/.pip-target/torch',
    '/app/.pip-target/torchvision',
    '/app/.pip-target/torchaudio',
    '/app/.pip-target/torchgen',
    '/app/.pip-target/functorch',
):
    if os.path.exists(d):
        print('Removing', d)
        shutil.rmtree(d, ignore_errors=True)

# 1) Install exact CUDA 12.1 torch stack
pip('install',
    '--index-url', 'https://download.pytorch.org/whl/cu121',
    '--extra-index-url', 'https://pypi.org/simple',
    'torch==2.4.1', 'torchvision==0.19.1', 'torchaudio==2.4.1')

# 2) Freeze versions
Path('constraints.txt').write_text('torch==2.4.1\ntorchvision==0.19.1\ntorchaudio==2.4.1\n')

# 3) Install non-torch deps honoring constraints
pip('install', '-c', 'constraints.txt',
    'timm==1.0.9',
    'albumentations==1.4.14',
    'opencv-python-headless',
    'iterative-stratification',
    'scikit-learn',
    'pandas', 'numpy',
    '--upgrade-strategy', 'only-if-needed')

# 4) Sanity check GPU
import torch
print('torch:', torch.__version__, 'built CUDA:', getattr(torch.version, 'cuda', None))
print('CUDA available:', torch.cuda.is_available())
assert str(getattr(torch.version,'cuda','')).startswith('12.1'), f'Wrong CUDA build: {torch.version.cuda}'
assert torch.cuda.is_available(), 'CUDA not available'
print('GPU:', torch.cuda.get_device_name(0))

Found existing installation: torch 2.4.1+cu121


Uninstalling torch-2.4.1+cu121:
  Successfully uninstalled torch-2.4.1+cu121


Found existing installation: torchvision 0.19.1+cu121
Uninstalling torchvision-0.19.1+cu121:
  Successfully uninstalled torchvision-0.19.1+cu121


Found existing installation: torchaudio 2.4.1+cu121
Uninstalling torchaudio-2.4.1+cu121:
  Successfully uninstalled torchaudio-2.4.1+cu121
> install --index-url https://download.pytorch.org/whl/cu121 --extra-index-url https://pypi.org/simple torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1


Looking in indexes: https://download.pytorch.org/whl/cu121, https://pypi.org/simple


Collecting torch==2.4.1
  Downloading https://download.pytorch.org/whl/cu121/torch-2.4.1%2Bcu121-cp311-cp311-linux_x86_64.whl (799.0 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 799.0/799.0 MB 313.2 MB/s eta 0:00:00


Collecting torchvision==0.19.1
  Downloading https://download.pytorch.org/whl/cu121/torchvision-0.19.1%2Bcu121-cp311-cp311-linux_x86_64.whl (7.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.1/7.1 MB 455.5 MB/s eta 0:00:00


Collecting torchaudio==2.4.1
  Downloading https://download.pytorch.org/whl/cu121/torchaudio-2.4.1%2Bcu121-cp311-cp311-linux_x86_64.whl (3.4 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.4/3.4 MB 494.9 MB/s eta 0:00:00
Collecting typing-extensions>=4.8.0


  Downloading typing_extensions-4.15.0-py3-none-any.whl (44 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 44.6/44.6 KB 4.0 MB/s eta 0:00:00
Collecting networkx
  Downloading networkx-3.5-py3-none-any.whl (2.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.0/2.0 MB 83.6 MB/s eta 0:00:00


Collecting nvidia-nvtx-cu12==12.1.105
  Downloading nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 99.1/99.1 KB 234.6 MB/s eta 0:00:00


Collecting nvidia-cuda-nvrtc-cu12==12.1.105
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 23.7/23.7 MB 165.3 MB/s eta 0:00:00


Collecting nvidia-cuda-runtime-cu12==12.1.105
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 823.6/823.6 KB 243.4 MB/s eta 0:00:00
Collecting sympy
  Downloading sympy-1.14.0-py3-none-any.whl (6.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.3/6.3 MB 306.4 MB/s eta 0:00:00


Collecting nvidia-cusolver-cu12==11.4.5.107
  Downloading nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 124.2/124.2 MB 133.7 MB/s eta 0:00:00


Collecting filelock
  Downloading filelock-3.19.1-py3-none-any.whl (15 kB)


Collecting nvidia-cuda-cupti-cu12==12.1.105
  Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 14.1/14.1 MB 190.9 MB/s eta 0:00:00


Collecting nvidia-nccl-cu12==2.20.5
  Downloading nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl (176.2 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 176.2/176.2 MB 188.7 MB/s eta 0:00:00


Collecting fsspec
  Downloading fsspec-2025.9.0-py3-none-any.whl (199 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 199.3/199.3 KB 490.2 MB/s eta 0:00:00


Collecting triton==3.0.0
  Downloading triton-3.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (209.4 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 209.4/209.4 MB 193.4 MB/s eta 0:00:00


Collecting nvidia-cublas-cu12==12.1.3.1
  Downloading nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 410.6/410.6 MB 192.1 MB/s eta 0:00:00


Collecting nvidia-curand-cu12==10.3.2.106
  Downloading nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 56.5/56.5 MB 150.9 MB/s eta 0:00:00


Collecting nvidia-cudnn-cu12==9.1.0.70
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 664.8/664.8 MB 158.2 MB/s eta 0:00:00


Collecting jinja2
  Downloading jinja2-3.1.6-py3-none-any.whl (134 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 134.9/134.9 KB 384.6 MB/s eta 0:00:00
Collecting nvidia-cufft-cu12==11.0.2.54
  Downloading nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 121.6/121.6 MB 151.1 MB/s eta 0:00:00


Collecting nvidia-cusparse-cu12==12.1.0.106
  Downloading nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 196.0/196.0 MB 139.6 MB/s eta 0:00:00


Collecting numpy
  Downloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 18.3/18.3 MB 190.2 MB/s eta 0:00:00


Collecting pillow!=8.3.*,>=5.3.0
  Downloading pillow-11.3.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (6.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.6/6.6 MB 247.5 MB/s eta 0:00:00


Collecting nvidia-nvjitlink-cu12
  Downloading nvidia_nvjitlink_cu12-12.9.86-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl (39.7 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 39.7/39.7 MB 177.1 MB/s eta 0:00:00


Collecting MarkupSafe>=2.0
  Downloading MarkupSafe-3.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (23 kB)


Collecting mpmath<1.4,>=1.1.0
  Downloading mpmath-1.3.0-py3-none-any.whl (536 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 536.2/536.2 KB 501.5 MB/s eta 0:00:00


Installing collected packages: mpmath, typing-extensions, sympy, pillow, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, numpy, networkx, MarkupSafe, fsspec, filelock, triton, nvidia-cusparse-cu12, nvidia-cudnn-cu12, jinja2, nvidia-cusolver-cu12, torch, torchvision, torchaudio


Successfully installed MarkupSafe-3.0.2 filelock-3.19.1 fsspec-2025.9.0 jinja2-3.1.6 mpmath-1.3.0 networkx-3.5 numpy-1.26.4 nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.20.5 nvidia-nvjitlink-cu12-12.9.86 nvidia-nvtx-cu12-12.1.105 pillow-11.3.0 sympy-1.14.0 torch-2.4.1+cu121 torchaudio-2.4.1+cu121 torchvision-0.19.1+cu121 triton-3.0.0 typing-extensions-4.15.0




> install -c constraints.txt timm==1.0.9 albumentations==1.4.14 opencv-python-headless iterative-stratification scikit-learn pandas numpy --upgrade-strategy only-if-needed


Collecting timm==1.0.9
  Downloading timm-1.0.9-py3-none-any.whl (2.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.3/2.3 MB 55.7 MB/s eta 0:00:00
Collecting albumentations==1.4.14
  Downloading albumentations-1.4.14-py3-none-any.whl (177 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 178.0/178.0 KB 266.8 MB/s eta 0:00:00


Collecting opencv-python-headless
  Downloading opencv_python_headless-4.12.0.88-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (54.0 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 54.0/54.0 MB 193.3 MB/s eta 0:00:00
Collecting iterative-stratification
  Downloading iterative_stratification-0.1.9-py3-none-any.whl (8.5 kB)


Collecting scikit-learn
  Downloading scikit_learn-1.7.2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (9.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 9.7/9.7 MB 115.5 MB/s eta 0:00:00


Collecting pandas
  Downloading pandas-2.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (12.4 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 12.4/12.4 MB 34.2 MB/s eta 0:00:00


Collecting numpy
  Downloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 18.3/18.3 MB 228.1 MB/s eta 0:00:00
Collecting torch
  Downloading torch-2.4.1-cp311-cp311-manylinux1_x86_64.whl (797.1 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 797.1/797.1 MB 72.4 MB/s eta 0:00:00


Collecting huggingface_hub
  Downloading huggingface_hub-0.35.1-py3-none-any.whl (563 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 563.3/563.3 KB 378.2 MB/s eta 0:00:00
Collecting safetensors
  Downloading safetensors-0.6.2-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (485 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 485.8/485.8 KB 437.0 MB/s eta 0:00:00


Collecting pyyaml
  Downloading pyyaml-6.0.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (806 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 806.6/806.6 KB 527.2 MB/s eta 0:00:00
Collecting torchvision
  Downloading torchvision-0.19.1-cp311-cp311-manylinux1_x86_64.whl (7.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.0/7.0 MB 207.6 MB/s eta 0:00:00


Collecting scipy>=1.10.0
  Downloading scipy-1.16.2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (35.9 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 35.9/35.9 MB 135.0 MB/s eta 0:00:00


Collecting pydantic>=2.7.0
  Downloading pydantic-2.11.9-py3-none-any.whl (444 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 444.9/444.9 KB 510.7 MB/s eta 0:00:00
Collecting albucore>=0.0.13
  Downloading albucore-0.0.33-py3-none-any.whl (18 kB)
Collecting scikit-image>=0.21.0
  Downloading scikit_image-0.25.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (14.8 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 14.8/14.8 MB 54.5 MB/s eta 0:00:00
Collecting typing-extensions>=4.9.0
  Downloading typing_extensions-4.15.0-py3-none-any.whl (44 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 44.6/44.6 KB 415.0 MB/s eta 0:00:00
Collecting eval-type-backport
  Downloading eval_type_backport-0.2.2-py3-none-any.whl (5.8 kB)
Collecting opencv-python-headless
  Downloading opencv_python_headless-4.11.0.86-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (50.0 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 50.0/50.0 MB 222.9 MB/s eta 0:00:00
Collecting joblib>=1.2.0
  Downloading joblib-1.5.2-py3-none-any.whl (308 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 308.4/308.4 KB 507.2 MB/s eta 0:00:00


Collecting threadpoolctl>=3.1.0
  Downloading threadpoolctl-3.6.0-py3-none-any.whl (18 kB)
Collecting python-dateutil>=2.8.2
  Downloading python_dateutil-2.9.0.post0-py2.py3-none-any.whl (229 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 229.9/229.9 KB 473.8 MB/s eta 0:00:00
Collecting tzdata>=2022.7
  Downloading tzdata-2025.2-py2.py3-none-any.whl (347 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 347.8/347.8 KB 505.0 MB/s eta 0:00:00
Collecting pytz>=2020.1
  Downloading pytz-2025.2-py2.py3-none-any.whl (509 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 509.2/509.2 KB 493.3 MB/s eta 0:00:00


Collecting simsimd>=5.9.2
  Downloading simsimd-6.5.3-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (1.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.1/1.1 MB 137.3 MB/s eta 0:00:00


Collecting stringzilla>=3.10.4
  Downloading stringzilla-4.0.14-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl (496 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 496.5/496.5 KB 408.2 MB/s eta 0:00:00
Collecting annotated-types>=0.6.0
  Downloading annotated_types-0.7.0-py3-none-any.whl (13 kB)


Collecting pydantic-core==2.33.2
  Downloading pydantic_core-2.33.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.0/2.0 MB 520.0 MB/s eta 0:00:00
Collecting typing-inspection>=0.4.0
  Downloading typing_inspection-0.4.1-py3-none-any.whl (14 kB)
Collecting six>=1.5
  Downloading six-1.17.0-py2.py3-none-any.whl (11 kB)
Collecting lazy-loader>=0.4
  Downloading lazy_loader-0.4-py3-none-any.whl (12 kB)


Collecting pillow>=10.1
  Downloading pillow-11.3.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (6.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.6/6.6 MB 173.1 MB/s eta 0:00:00
Collecting tifffile>=2022.8.12
  Downloading tifffile-2025.9.20-py3-none-any.whl (230 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 230.1/230.1 KB 460.3 MB/s eta 0:00:00
Collecting packaging>=21
  Downloading packaging-25.0-py3-none-any.whl (66 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 66.5/66.5 KB 392.1 MB/s eta 0:00:00
Collecting imageio!=2.35.0,>=2.33
  Downloading imageio-2.37.0-py3-none-any.whl (315 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 315.8/315.8 KB 452.4 MB/s eta 0:00:00


Collecting networkx>=3.0
  Downloading networkx-3.5-py3-none-any.whl (2.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.0/2.0 MB 490.3 MB/s eta 0:00:00
Collecting fsspec>=2023.5.0
  Downloading fsspec-2025.9.0-py3-none-any.whl (199 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 199.3/199.3 KB 418.2 MB/s eta 0:00:00
Collecting requests


  Downloading requests-2.32.5-py3-none-any.whl (64 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 64.7/64.7 KB 320.3 MB/s eta 0:00:00
Collecting filelock
  Downloading filelock-3.19.1-py3-none-any.whl (15 kB)
Collecting hf-xet<2.0.0,>=1.1.3
  Downloading hf_xet-1.1.10-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.2 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.2/3.2 MB 424.8 MB/s eta 0:00:00
Collecting tqdm>=4.42.1
  Downloading tqdm-4.67.1-py3-none-any.whl (78 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 78.5/78.5 KB 431.5 MB/s eta 0:00:00


Collecting sympy
  Downloading sympy-1.14.0-py3-none-any.whl (6.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.3/6.3 MB 292.9 MB/s eta 0:00:00
Collecting nvidia-cudnn-cu12==9.1.0.70
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 664.8/664.8 MB 181.6 MB/s eta 0:00:00


Collecting nvidia-cusolver-cu12==11.4.5.107
  Downloading nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 124.2/124.2 MB 245.6 MB/s eta 0:00:00
Collecting nvidia-curand-cu12==10.3.2.106
  Downloading nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 56.5/56.5 MB 238.6 MB/s eta 0:00:00
Collecting nvidia-nccl-cu12==2.20.5
  Downloading nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl (176.2 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 176.2/176.2 MB 562.1 MB/s eta 0:00:00
Collecting nvidia-cusparse-cu12==12.1.0.106
  Downloading nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 196.0/196.0 MB 262.3 MB/s eta 0:00:00
Collecting nvidia-cuda-cupti-cu12==12.1.105
  Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 14.1/14.1 MB 253.2 MB/s eta 0:00:00
Collecting nvidia-cuda-nvrtc-cu12==12.1.105
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 23.7/23.7 MB 217.6 MB/s eta 0:00:00
Collecting nvidia-cublas-cu12==12.1.3.1
  Downloading nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 410.6/410.6 MB 217.1 MB/s eta 0:00:00


Collecting nvidia-nvtx-cu12==12.1.105
  Downloading nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 99.1/99.1 KB 423.0 MB/s eta 0:00:00
Collecting jinja2
  Downloading jinja2-3.1.6-py3-none-any.whl (134 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 134.9/134.9 KB 451.5 MB/s eta 0:00:00
Collecting triton==3.0.0
  Downloading triton-3.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (209.4 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 209.4/209.4 MB 37.5 MB/s eta 0:00:00
Collecting nvidia-cuda-runtime-cu12==12.1.105
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 823.6/823.6 KB 508.3 MB/s eta 0:00:00
Collecting nvidia-cufft-cu12==11.0.2.54
  Downloading nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 121.6/121.6 MB 180.6 MB/s eta 0:00:00
Collecting nvidia-nvjitlink-cu12
  Downloading nvidia_nvjitlink_cu12-12.9.86-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl (39.7 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 39.7/39.7 MB 233.5 MB/s eta 0:00:00


Collecting MarkupSafe>=2.0
  Downloading MarkupSafe-3.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (23 kB)
Collecting idna<4,>=2.5
  Downloading idna-3.10-py3-none-any.whl (70 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 70.4/70.4 KB 400.1 MB/s eta 0:00:00
Collecting urllib3<3,>=1.21.1
  Downloading urllib3-2.5.0-py3-none-any.whl (129 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 129.8/129.8 KB 477.5 MB/s eta 0:00:00
Collecting charset_normalizer<4,>=2


  Downloading charset_normalizer-3.4.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (150 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 150.3/150.3 KB 492.9 MB/s eta 0:00:00
Collecting certifi>=2017.4.17
  Downloading certifi-2025.8.3-py3-none-any.whl (161 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 161.2/161.2 KB 461.5 MB/s eta 0:00:00
Collecting mpmath<1.4,>=1.1.0
  Downloading mpmath-1.3.0-py3-none-any.whl (536 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 536.2/536.2 KB 519.1 MB/s eta 0:00:00


Installing collected packages: simsimd, pytz, mpmath, urllib3, tzdata, typing-extensions, tqdm, threadpoolctl, sympy, stringzilla, six, safetensors, pyyaml, pillow, packaging, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, numpy, networkx, MarkupSafe, joblib, idna, hf-xet, fsspec, filelock, eval-type-backport, charset_normalizer, certifi, annotated-types, typing-inspection, triton, tifffile, scipy, requests, python-dateutil, pydantic-core, opencv-python-headless, nvidia-cusparse-cu12, nvidia-cudnn-cu12, lazy-loader, jinja2, imageio, scikit-learn, scikit-image, pydantic, pandas, nvidia-cusolver-cu12, huggingface_hub, albucore, torch, iterative-stratification, albumentations, torchvision, timm


Successfully installed MarkupSafe-3.0.2 albucore-0.0.33 albumentations-1.4.14 annotated-types-0.7.0 certifi-2025.8.3 charset_normalizer-3.4.3 eval-type-backport-0.2.2 filelock-3.19.1 fsspec-2025.9.0 hf-xet-1.1.10 huggingface_hub-0.35.1 idna-3.10 imageio-2.37.0 iterative-stratification-0.1.9 jinja2-3.1.6 joblib-1.5.2 lazy-loader-0.4 mpmath-1.3.0 networkx-3.5 numpy-1.26.4 nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.20.5 nvidia-nvjitlink-cu12-12.9.86 nvidia-nvtx-cu12-12.1.105 opencv-python-headless-4.11.0.86 packaging-25.0 pandas-2.3.2 pillow-11.3.0 pydantic-2.11.9 pydantic-core-2.33.2 python-dateutil-2.9.0.post0 pytz-2025.2 pyyaml-6.0.3 requests-2.32.5 safetensors-0.6.2 scikit-image-0.25.2 scikit-learn-1.7.2 scipy-1.16.2 simsimd-6.5.3 six-1.1



torch: 2.4.1+cu121 built CUDA: 12.1
CUDA available: True
GPU: NVIDIA A10-24Q


In [None]:
import os, json, numpy as np, pandas as pd
from collections import Counter
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold

# Reload train to be safe in this cell context
train_df = pd.read_csv('train.csv')

# Parse space-delimited multilabels
label_lists = train_df['labels'].astype(str).str.strip().str.split()
all_labels = [lab for labs in label_lists for lab in labs]
label_counts = Counter(all_labels)
classes = sorted(label_counts.keys())  # preserve alpha order; alternative: sort by freq
print('Num classes:', len(classes))
print('Classes:', classes)
print('Top counts:', label_counts.most_common(10))

# Multi-hot encode
cls2id = {c:i for i,c in enumerate(classes)}
y = np.zeros((len(train_df), len(classes)), dtype=np.uint8)
for i, labs in enumerate(label_lists):
    for lab in labs:
        y[i, cls2id[lab]] = 1

# Save class list for reuse
with open('classes.json','w') as f:
    json.dump({'classes': classes}, f)
print('Saved classes.json')

# 5-fold Multilabel Stratified CV
mskf = MultilabelStratifiedKFold(n_splits=5, shuffle=True, random_state=42)
folds = np.full(len(train_df), -1, dtype=int)
for fold, (_, val_idx) in enumerate(mskf.split(train_df['image'].values, y)):
    folds[val_idx] = fold
assert (folds>=0).all(), 'Fold assignment failed'

train_folds = train_df.copy()
train_folds['fold'] = folds
train_folds.to_csv('train_folds.csv', index=False)
print('Saved train_folds.csv with fold distribution:')
print(train_folds['fold'].value_counts().sort_index())

# Basic sanity: label distribution per fold
for f in range(5):
    idx = (folds==f)
    cnt = y[idx].sum(axis=0)
    print(f'Fold {f}: n={idx.sum()} | positive labels total={int(cnt.sum())}')
print('CV setup complete.')

In [None]:
import os, time, math, json, random, gc
from pathlib import Path
import numpy as np
import pandas as pd
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import f1_score
import timm
from timm.utils import ModelEmaV2
from timm.loss import AsymmetricLossMultiLabel
from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingLR

SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.benchmark = True
if hasattr(torch.backends, 'cuda') and hasattr(torch.backends.cuda, 'matmul'):
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_fp16_reduced_precision_reduction = True
    torch.backends.cudnn.allow_tf32 = True
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = os.environ.get('PYTORCH_CUDA_ALLOC_CONF','') or 'expandable_segments:True'
cv2.setNumThreads(0)

# Config
IMG_SIZE = 512
BATCH_SIZE = 10
EPOCHS = 15  # per expert, train up to 15 with early stopping
LR = 3e-4
WD = 1e-2
NUM_FOLDS = 5  # full CV
MODEL_NAME = 'tf_efficientnetv2_m.in21k'
DROP_PATH = 0.2
DROP_RATE = 0.05
TRAIN_DIR = 'train_images'
TEST_DIR = 'test_images'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DO_TRAIN = True  # train full 5-folds

# Performance/stability toggles
USE_CHANNELS_LAST = True
USE_BF16_AMP = True
USE_GRAD_CKPT = False
USE_EMA = True
EMA_DECAY = 0.9998

# Mixup config (manual for multilabel)
MIXUP_ALPHA = 0.4
MIXUP_PROB = 0.5  # will be turned off for last 2 epochs

# Defaults; will be overridden per model via timm.data.resolve_model_data_config(model)
MEAN = (0.485, 0.456, 0.406)
STD = (0.229, 0.224, 0.225)
INTERP = cv2.INTER_CUBIC

# Load metadata
train_folds = pd.read_csv('train_folds.csv')
train_df = pd.read_csv('train.csv')
with open('classes.json') as f:
    classes = json.load(f)['classes']
C = len(classes)
cls2id = {c:i for i,c in enumerate(classes)}

# Parse labels to multi-hot
def labels_to_multi_hot(s):
    arr = np.zeros(C, dtype=np.float32)
    for lab in str(s).strip().split():
        if lab in cls2id:
            arr[cls2id[lab]] = 1.0
    return arr
y_all = np.stack(train_df.labels.apply(labels_to_multi_hot).values)

# Simple CV2-based transforms with RRC-like crop and Random Erasing
class SimpleTransform:
    def __init__(self, train=True, img_size=448, mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225), interp=cv2.INTER_CUBIC,
                 rrc_scale=(0.8, 1.0), rrc_ratio=(0.75, 1.3333),
                 erase_p=0.2, erase_area=(0.02, 0.2)):
        self.train = train
        self.img_size = img_size
        self.mean = np.array(mean, dtype=np.float32)
        self.std = np.array(std, dtype=np.float32)
        self.interp = interp
        self.rrc_scale = rrc_scale
        self.rrc_ratio = rrc_ratio
        self.erase_p = erase_p
        self.erase_area = erase_area

    def random_resized_crop(self, img):
        h, w = img.shape[:2]
        area = h * w
        for _ in range(10):
            target_area = area * random.uniform(self.rrc_scale[0], self.rrc_scale[1])
            aspect = random.uniform(self.rrc_ratio[0], self.rrc_ratio[1])
            new_w = int(round(math.sqrt(target_area * aspect)))
            new_h = int(round(math.sqrt(target_area / aspect)))
            if new_w <= w and new_h <= h and new_w > 0 and new_h > 0:
                x0 = random.randint(0, w - new_w)
                y0 = random.randint(0, h - new_h)
                return img[y0:y0+new_h, x0:x0+new_w]
        # Fallback to center crop
        min_side = min(h, w)
        y0 = (h - min_side) // 2
        x0 = (w - min_side) // 2
        return img[y0:y0+min_side, x0:x0+min_side]

    def random_erasing(self, img):
        # img is float32 normalized HWC
        if random.random() >= self.erase_p:
            return img
        h, w = img.shape[:2]
        area = h * w
        for _ in range(10):
            erase_area = area * random.uniform(self.erase_area[0], self.erase_area[1])
            aspect = random.uniform(0.3, 3.3)
            eh = int(round(math.sqrt(erase_area / aspect)))
            ew = int(round(math.sqrt(erase_area * aspect)))
            if eh <= h and ew <= w and eh > 0 and ew > 0:
                y0 = random.randint(0, h - eh)
                x0 = random.randint(0, w - ew)
                # Fill with mean color (0 mean in normalized space isn't correct; use dataset mean/std to approximate original mean=0)
                fill = np.zeros((eh, ew, 3), dtype=img.dtype)
                img[y0:y0+eh, x0:x0+ew, :] = fill
                return img
        return img

    def __call__(self, img):
        if self.train:
            img = self.random_resized_crop(img)
            if random.random() < 0.5:
                img = cv2.flip(img, 1)
            if random.random() < 0.2:
                img = cv2.flip(img, 0)
            if random.random() < 0.3:
                angle = random.uniform(-15, 15)
                M = cv2.getRotationMatrix2D((img.shape[1]/2, img.shape[0]/2), angle, 1.0)
                img = cv2.warpAffine(img, M, (img.shape[1], img.shape[0]), flags=self.interp, borderMode=cv2.BORDER_REFLECT_101)
            if random.random() < 0.2:
                hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV).astype(np.int32)
                hsv[...,0] = np.clip(hsv[...,0] + random.randint(-5,5), 0, 179)
                hsv[...,1] = np.clip(hsv[...,1] + random.randint(-10,10), 0, 255)
                hsv[...,2] = np.clip(hsv[...,2] + random.randint(-10,10), 0, 255)
                img = cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2RGB)
            if random.random() < 0.1:
                k = random.choice([3,5])
                img = cv2.GaussianBlur(img, (k,k), 0)
        # resize, normalize
        img = cv2.resize(img, (self.img_size, self.img_size), interpolation=self.interp)
        img = img.astype(np.float32) / 255.0
        img = (img - self.mean) / self.std
        # random erasing after norm
        if self.train:
            img = self.random_erasing(img)
        return img

def get_transforms(train=True):
    return SimpleTransform(train=train, img_size=IMG_SIZE, mean=MEAN, std=STD, interp=INTERP)

class PlantDataset(Dataset):
    def __init__(self, df, labels=None, img_dir=TRAIN_DIR, transform=None):
        self.df = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.transform = transform
        self.labels = labels
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.img_dir, row['image'])
        img = cv2.imread(img_path)
        if img is None:
            raise FileNotFoundError(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        if self.transform:
            img = self.transform(img)
        if isinstance(img, np.ndarray):
            img = torch.from_numpy(img.transpose(2,0,1)).float()
        if self.labels is not None:
            target = self.labels[idx]
            return img, torch.from_numpy(target).float()
        else:
            return img, row['image']

def micro_f1(y_true, y_prob, thresh=0.3):
    y_pred = (y_prob >= thresh).astype(np.int32)
    if (y_pred.sum(axis=1)==0).any():
        for i in np.where(y_pred.sum(axis=1)==0)[0]:
            y_pred[i, y_prob[i].argmax()] = 1
    if 'healthy' in cls2id:
        h = cls2id['healthy']
        disease_idx = [i for i,c in enumerate(classes) if c!='healthy']
        disease_on = (y_pred[:, disease_idx].sum(axis=1) > 0)
        y_pred[disease_on, h] = 0
    return f1_score(y_true.ravel(), y_pred.ravel(), average='micro')

def micro_f1_vec(y_true, y_prob, thrs):
    y_pred = (y_prob >= thrs[None, :]).astype(np.int32)
    if (y_pred.sum(axis=1)==0).any():
        for i in np.where(y_pred.sum(axis=1)==0)[0]:
            y_pred[i, y_prob[i].argmax()] = 1
    if 'healthy' in cls2id:
        h = cls2id['healthy']
        disease_idx = [i for i,c in enumerate(classes) if c!='healthy']
        disease_on = (y_pred[:, disease_idx].sum(axis=1) > 0)
        y_pred[disease_on, h] = 0
    return f1_score(y_true.ravel(), y_pred.ravel(), average='micro')

def tune_global_threshold(y_true, y_prob, grid=None):
    if grid is None:
        grid = np.linspace(0.05, 0.6, 12)
    best_t, best_f1 = 0.3, -1
    for t in grid:
        f1 = micro_f1(y_true, y_prob, t)
        if f1 > best_f1:
            best_f1, best_t = f1, t
    return best_t, best_f1

def tune_thresholds_coordinate_descent(y_true, y_prob, base_thr=0.5, grid=None, iters=2):
    if grid is None:
        grid = np.linspace(0.05, 0.8, 31)  # ~0.025 step
    thrs = np.full(y_prob.shape[1], base_thr, dtype=np.float32)
    best = micro_f1_vec(y_true, y_prob, thrs)
    for _ in range(iters):
        improved = False
        for c in range(y_prob.shape[1]):
            best_c_thr = thrs[c]
            best_c_f1 = best
            for t in grid:
                thrs_try = thrs.copy(); thrs_try[c] = t
                f1 = micro_f1_vec(y_true, y_prob, thrs_try)
                if f1 > best_c_f1:
                    best_c_f1 = f1; best_c_thr = t
            if best_c_thr != thrs[c]:
                thrs[c] = best_c_thr
                best = best_c_f1
                improved = True
        if not improved:
            break
    # clamp to avoid extreme overfit
    thrs = np.clip(thrs, 0.05, 0.80).astype(np.float32)
    return thrs, best

def train_one_fold(fold):
    global MEAN, STD, INTERP
    t0 = time.time()
    print(f'===== Fold {fold} start =====')
    trn_idx = train_folds.index[train_folds['fold'] != fold].values
    val_idx = train_folds.index[train_folds['fold'] == fold].values
    df_trn = train_folds.iloc[trn_idx][['image']].reset_index(drop=True)
    df_val = train_folds.iloc[val_idx][['image']].reset_index(drop=True)
    y_trn = y_all[trn_idx]
    y_val = y_all[val_idx]

    # Model & data config
    model = timm.create_model(MODEL_NAME, pretrained=True, num_classes=C, drop_path_rate=DROP_PATH, drop_rate=DROP_RATE)
    data_cfg = timm.data.resolve_model_data_config(model)
    MEAN, STD = tuple(data_cfg.get('mean', MEAN)), tuple(data_cfg.get('std', STD))
    interp_name = str(data_cfg.get('interpolation', 'bicubic')).lower()
    INTERP = cv2.INTER_CUBIC if 'bicubic' in interp_name else cv2.INTER_LINEAR

    if USE_GRAD_CKPT and hasattr(model, 'set_grad_checkpointing'):
        try:
            model.set_grad_checkpointing(True)
            print('Enabled gradient checkpointing')
        except Exception:
            pass
    model.to(DEVICE)
    if USE_CHANNELS_LAST:
        model.to(memory_format=torch.channels_last)

    ema = ModelEmaV2(model, decay=EMA_DECAY, device=DEVICE) if USE_EMA else None

    # Datasets now that MEAN/STD/INTERP are set
    train_ds = PlantDataset(df_trn, y_trn, img_dir=TRAIN_DIR, transform=get_transforms(True))
    val_ds = PlantDataset(df_val, y_val, img_dir=TRAIN_DIR, transform=get_transforms(False))
    nw = min(8, os.cpu_count() or 4)
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=nw, pin_memory=True, drop_last=True, persistent_workers=True, prefetch_factor=2)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE*2, shuffle=False, num_workers=nw, pin_memory=True, persistent_workers=True, prefetch_factor=2)

    opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WD, eps=1e-8, betas=(0.9,0.999))
    warmup_steps = max(1, len(train_loader))
    total_steps = max(warmup_steps+1, EPOCHS * len(train_loader))
    sched_warm = LinearLR(opt, start_factor=0.01, total_iters=warmup_steps)
    sched_cos = CosineAnnealingLR(opt, T_max=max(1, total_steps - warmup_steps))
    scheduler = SequentialLR(opt, schedulers=[sched_warm, sched_cos], milestones=[warmup_steps])
    criterion = AsymmetricLossMultiLabel(gamma_neg=4.0, gamma_pos=0.0, clip=0.05, eps=1e-8)

    best_f1 = -1.0
    best_path = f'model_fold{fold}.pt'
    patience = 3
    bad_epochs = 0

    for epoch in range(EPOCHS):
        model.train()
        t_ep = time.time()
        # decay Mixup prob to 0 in last 2 epochs
        mixup_prob_now = MIXUP_PROB if epoch < EPOCHS - 2 else 0.0
        for it, (imgs, targets) in enumerate(train_loader):
            if USE_CHANNELS_LAST:
                imgs = imgs.to(DEVICE, non_blocking=True).to(memory_format=torch.channels_last)
            else:
                imgs = imgs.to(DEVICE, non_blocking=True)
            targets = targets.to(DEVICE, non_blocking=True)
            # Manual Mixup for multilabel
            if MIXUP_ALPHA > 0 and random.random() < mixup_prob_now:
                lam = float(np.random.beta(MIXUP_ALPHA, MIXUP_ALPHA))
                idx = torch.randperm(imgs.size(0), device=imgs.device)
                imgs = lam * imgs + (1.0 - lam) * imgs[idx]
                targets = lam * targets + (1.0 - lam) * targets[idx]
            opt.zero_grad(set_to_none=True)
            if USE_BF16_AMP and DEVICE=='cuda':
                with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
                    logits = model(imgs)
                    loss = criterion(logits.float(), targets)
            else:
                logits = model(imgs)
                loss = criterion(logits, targets)

            if not torch.isfinite(loss):
                print(f'Non-finite loss detected at iter {it}: {loss.item()} -> skipping step')
                opt.zero_grad(set_to_none=True)
                continue

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            scheduler.step()
            if ema is not None:
                ema.update(model)
            if it % 50 == 0:
                cur_lr = scheduler.get_last_lr()[0] if hasattr(scheduler, 'get_last_lr') else opt.param_groups[0]['lr']
                print(f'Fold {fold} Epoch {epoch} Iter {it}/{len(train_loader)} loss={loss.item():.4f} lr={cur_lr:.6f}')

        model.eval()
        eval_model = ema.module if ema is not None else model
        preds = []; gts = []
        with torch.no_grad():
            for imgs, targets in val_loader:
                if USE_CHANNELS_LAST:
                    imgs = imgs.to(DEVICE, non_blocking=True).to(memory_format=torch.channels_last)
                else:
                    imgs = imgs.to(DEVICE, non_blocking=True)
                if USE_BF16_AMP and DEVICE=='cuda':
                    with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
                        logits = eval_model(imgs)
                else:
                    logits = eval_model(imgs)
                preds.append(torch.sigmoid(logits.float()).cpu().numpy())
                gts.append(targets.cpu().numpy())
        y_prob = np.concatenate(preds, axis=0)
        y_true = np.concatenate(gts, axis=0)
        t_opt, f1_opt = tune_global_threshold(y_true, y_prob)
        print(f'Epoch {epoch} val micro-F1={f1_opt:.5f} @thr={t_opt:.3f} | time {time.time()-t_ep:.1f}s')
        improved = f1_opt > best_f1 + 1e-5
        if improved and np.isfinite(f1_opt):
            best_f1 = f1_opt
            torch.save({'model': eval_model.state_dict(), 'thr': t_opt}, best_path)
            bad_epochs = 0
        else:
            bad_epochs += 1
            if bad_epochs >= patience:
                print('Early stopping due to no improvement')
                break

    ckpt = torch.load(best_path, map_location=DEVICE)
    (ema.module if ema is not None else model).load_state_dict(ckpt['model'])
    thr = ckpt.get('thr', 0.3)
    (ema.module if ema is not None else model).eval()
    preds = []
    with torch.no_grad():
        for imgs, targets in val_loader:
            if USE_CHANNELS_LAST:
                imgs = imgs.to(DEVICE, non_blocking=True).to(memory_format=torch.channels_last)
            else:
                imgs = imgs.to(DEVICE, non_blocking=True)
            if USE_BF16_AMP and DEVICE=='cuda':
                with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
                    logits = (ema.module if ema is not None else model)(imgs)
            else:
                logits = (ema.module if ema is not None else model)(imgs)
            preds.append(torch.sigmoid(logits.float()).cpu().numpy())
    y_prob = np.concatenate(preds, axis=0)
    print(f'Fold {fold} done in {time.time()-t0:.1f}s, best_f1={best_f1:.5f}, thr={thr:.3f}')
    return y_prob, y_true, thr

def infer_test(models_paths, tta=2):
    global MEAN, STD, INTERP
    # Ensure normalization matches model cfg even in fresh kernels
    tmp_model = timm.create_model(MODEL_NAME, pretrained=False, num_classes=C, drop_path_rate=DROP_PATH, drop_rate=DROP_RATE)
    data_cfg = timm.data.resolve_model_data_config(tmp_model)
    MEAN, STD = tuple(data_cfg.get('mean', MEAN)), tuple(data_cfg.get('std', STD))
    interp_name = str(data_cfg.get('interpolation', 'bicubic')).lower()
    INTERP = cv2.INTER_CUBIC if 'bicubic' in interp_name else cv2.INTER_LINEAR

    test_df = pd.read_csv('sample_submission.csv')[['image']].copy()
    ds = PlantDataset(test_df, labels=None, img_dir=TEST_DIR, transform=get_transforms(False))
    loader = DataLoader(ds, batch_size=BATCH_SIZE*2, shuffle=False, num_workers=min(8, os.cpu_count() or 4), pin_memory=True, persistent_workers=True, prefetch_factor=2)
    model_level_logits = []  # list of (N_test, C)
    for mp in models_paths:
        model = timm.create_model(MODEL_NAME, pretrained=False, num_classes=C, drop_path_rate=DROP_PATH, drop_rate=DROP_RATE)
        ckpt = torch.load(mp, map_location=DEVICE)
        model.load_state_dict(ckpt['model'])
        model.to(DEVICE)
        if USE_CHANNELS_LAST:
            model.to(memory_format=torch.channels_last)
        model.eval()
        view_logits = []  # per-TTA view logits (N_test, C)
        with torch.no_grad():
            # view 1: original
            outs = []
            for imgs, names in loader:
                if USE_CHANNELS_LAST:
                    imgs = imgs.to(DEVICE, non_blocking=True).to(memory_format=torch.channels_last)
                else:
                    imgs = imgs.to(DEVICE, non_blocking=True)
                if USE_BF16_AMP and DEVICE=='cuda':
                    with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
                        logits = model(imgs)
                else:
                    logits = model(imgs)
                outs.append(logits.float().cpu().numpy())
            view_logits.append(np.concatenate(outs, axis=0))
            # view 2: hflip
            if tta >= 2:
                outs = []
                for imgs, names in loader:
                    imgs = imgs.flip(-1)
                    if USE_CHANNELS_LAST:
                        imgs = imgs.to(DEVICE, non_blocking=True).to(memory_format=torch.channels_last)
                    else:
                        imgs = imgs.to(DEVICE, non_blocking=True)
                    if USE_BF16_AMP and DEVICE=='cuda':
                        with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
                            logits = model(imgs)
                    else:
                        logits = model(imgs)
                    outs.append(logits.float().cpu().numpy())
                view_logits.append(np.concatenate(outs, axis=0))
            # view 3: vflip
            if tta >= 3:
                outs = []
                for imgs, names in loader:
                    imgs = imgs.flip(-2)
                    if USE_CHANNELS_LAST:
                        imgs = imgs.to(DEVICE, non_blocking=True).to(memory_format=torch.channels_last)
                    else:
                        imgs = imgs.to(DEVICE, non_blocking=True)
                    if USE_BF16_AMP and DEVICE=='cuda':
                        with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
                            logits = model(imgs)
                    else:
                        logits = model(imgs)
                    outs.append(logits.float().cpu().numpy())
                view_logits.append(np.concatenate(outs, axis=0))
            # view 4: hvflip
            if tta >= 4:
                outs = []
                for imgs, names in loader:
                    imgs = imgs.flip(-1).flip(-2)
                    if USE_CHANNELS_LAST:
                        imgs = imgs.to(DEVICE, non_blocking=True).to(memory_format=torch.channels_last)
                    else:
                        imgs = imgs.to(DEVICE, non_blocking=True)
                    if USE_BF16_AMP and DEVICE=='cuda':
                        with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
                            logits = model(imgs)
                    else:
                        logits = model(imgs)
                    outs.append(logits.float().cpu().numpy())
                view_logits.append(np.concatenate(outs, axis=0))
        logits_avg = np.mean(np.stack(view_logits, axis=0), axis=0)  # (N_test, C)
        model_level_logits.append(logits_avg)
    logits = np.mean(np.stack(model_level_logits, axis=0), axis=0)  # (N_test, C)
    probs = 1/(1+np.exp(-logits))
    return test_df['image'].values, probs

# Orchestrate K-fold training or skip to inference
if DO_TRAIN:
    oof_probs = np.zeros((len(train_df), C), dtype=np.float32)
    oof_targets = y_all.copy()
    fold_thresholds = []
    for fold in range(NUM_FOLDS):
        t_fold = time.time()
        y_prob, y_true, thr = train_one_fold(fold)
        val_idx = train_folds.index[train_folds['fold'] == fold].values
        oof_probs[val_idx] = y_prob
        fold_thresholds.append(thr)
        print(f'Fold {fold} completed in {time.time()-t_fold:.1f}s')
        gc.collect(); torch.cuda.empty_cache()

    np.save('oof_probs.npy', oof_probs)
    np.save('oof_targets.npy', oof_targets)
    print('Saved OOF probs/targets')
    mask = train_folds['fold'].isin(list(range(NUM_FOLDS))).values
    t_best, f1_best = tune_global_threshold(oof_targets[mask], oof_probs[mask])
    # Per-class threshold tuning and save
    thrs_vec, f1_best_vec = tune_thresholds_coordinate_descent(oof_targets[mask], oof_probs[mask], base_thr=t_best, grid=np.linspace(0.05,0.8,31), iters=3)
    thrs_vec = np.clip(thrs_vec, 0.05, 0.80).astype(np.float32)
    np.save('thr_per_class.npy', thrs_vec)
    print(f'OOF (folds< {NUM_FOLDS}) micro-F1={f1_best:.5f} @thr={t_best:.3f}; per-class tuned micro-F1={f1_best_vec:.5f} | n={mask.sum()}')
else:
    if os.path.exists('oof_probs.npy') and os.path.exists('oof_targets.npy'):    
        oof_probs = np.load('oof_probs.npy')
        oof_targets = np.load('oof_targets.npy')
        mask = train_folds['fold'].isin(list(range(NUM_FOLDS))).values
        t_best, f1_best = tune_global_threshold(oof_targets[mask], oof_probs[mask])
        thrs_vec, f1_best_vec = tune_thresholds_coordinate_descent(oof_targets[mask], oof_probs[mask], base_thr=t_best, grid=np.linspace(0.05,0.8,31), iters=3)
        thrs_vec = np.clip(thrs_vec, 0.05, 0.80).astype(np.float32)
        np.save('thr_per_class.npy', thrs_vec)
        print(f'Loaded OOF (folds< {NUM_FOLDS}); tuned global thr={t_best:.3f} (micro-F1={f1_best:.5f}); per-class tuned micro-F1={f1_best_vec:.5f} | n={mask.sum()}')
    else:
        t_best = 0.6
        thrs_vec = np.full(C, t_best, dtype=np.float32)
        print('OOF not found; using default thr=0.6')

# Inference on test using best fold checkpoints
model_paths = [f'model_fold{i}.pt' for i in range(NUM_FOLDS)]
names, test_probs = infer_test(model_paths, tta=4)

def probs_to_labels_row(p, thr_or_vec):
    if np.ndim(thr_or_vec)==0:
        lab_mask = (p >= float(thr_or_vec))
    else:
        lab_mask = (p >= thr_or_vec)
    if lab_mask.sum() == 0:
        lab_mask[p.argmax()] = 1
    if 'healthy' in cls2id:
        h = cls2id['healthy']
        disease_idx = [i for i,c in enumerate(classes) if c!='healthy']
        if lab_mask[disease_idx].sum() > 0:
            lab_mask[h] = 0
    return ' '.join([classes[i] for i in np.where(lab_mask)[0]])

thr_to_use = np.load('thr_per_class.npy') if os.path.exists('thr_per_class.npy') else t_best
labels_out = [probs_to_labels_row(p, thr_to_use) for p in test_probs]
sub = pd.DataFrame({'image': names, 'labels': labels_out})
sub.to_csv('submission.csv', index=False)
print('Wrote submission.csv with shape', sub.shape)

===== Fold 0 start =====


Fold 0 Epoch 0 Iter 0/1192 loss=47.3398 lr=0.000003


In [6]:
import sys, subprocess, os
def pip(*args):
    print('>', *args, flush=True)
    subprocess.run([sys.executable, '-m', 'pip', *args], check=True)

# Fix albumentations/albucore mismatch by pinning albumentations<1.4 (no albucore dependency)
subprocess.run([sys.executable, '-m', 'pip', 'uninstall', '-y', 'albumentations', 'albucore'], check=False)
pip('install', 'albumentations==1.3.1', 'opencv-python-headless', '--upgrade-strategy', 'only-if-needed')
import albumentations as A
print('Albumentations loaded from:', getattr(A, '__file__', 'unknown'))

Found existing installation: albumentations 1.3.1
Uninstalling albumentations-1.3.1:
  Successfully uninstalled albumentations-1.3.1
> install albumentations==1.3.1 opencv-python-headless --upgrade-strategy only-if-needed




Collecting albumentations==1.3.1
  Downloading albumentations-1.3.1-py3-none-any.whl (125 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 125.7/125.7 KB 5.9 MB/s eta 0:00:00
Collecting opencv-python-headless
  Downloading opencv_python_headless-4.12.0.88-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (54.0 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 54.0/54.0 MB 199.3 MB/s eta 0:00:00
Collecting PyYAML
  Downloading pyyaml-6.0.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (806 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 806.6/806.6 KB 504.1 MB/s eta 0:00:00


Collecting numpy>=1.11.1
  Downloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 18.3/18.3 MB 401.0 MB/s eta 0:00:00
Collecting qudida>=0.0.4
  Downloading qudida-0.0.4-py3-none-any.whl (3.5 kB)


Collecting scipy>=1.1.0
  Downloading scipy-1.16.2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (35.9 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 35.9/35.9 MB 138.2 MB/s eta 0:00:00
Collecting scikit-image>=0.16.1
  Downloading scikit_image-0.25.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (14.8 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 14.8/14.8 MB 209.7 MB/s eta 0:00:00
Collecting opencv-python-headless
  Downloading opencv_python_headless-4.11.0.86-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (50.0 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 50.0/50.0 MB 176.7 MB/s eta 0:00:00
Collecting typing-extensions
  Downloading typing_extensions-4.15.0-py3-none-any.whl (44 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 44.6/44.6 KB 300.4 MB/s eta 0:00:00


Collecting scikit-learn>=0.19.1
  Downloading scikit_learn-1.7.2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (9.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 9.7/9.7 MB 190.4 MB/s eta 0:00:00
Collecting tifffile>=2022.8.12
  Downloading tifffile-2025.9.20-py3-none-any.whl (230 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 230.1/230.1 KB 389.1 MB/s eta 0:00:00
Collecting networkx>=3.0
  Downloading networkx-3.5-py3-none-any.whl (2.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.0/2.0 MB 432.7 MB/s eta 0:00:00


Collecting lazy-loader>=0.4
  Downloading lazy_loader-0.4-py3-none-any.whl (12 kB)
Collecting imageio!=2.35.0,>=2.33
  Downloading imageio-2.37.0-py3-none-any.whl (315 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 315.8/315.8 KB 524.2 MB/s eta 0:00:00


Collecting pillow>=10.1
  Downloading pillow-11.3.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (6.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.6/6.6 MB 334.6 MB/s eta 0:00:00
Collecting packaging>=21
  Downloading packaging-25.0-py3-none-any.whl (66 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 66.5/66.5 KB 364.6 MB/s eta 0:00:00


Collecting threadpoolctl>=3.1.0
  Downloading threadpoolctl-3.6.0-py3-none-any.whl (18 kB)
Collecting joblib>=1.2.0
  Downloading joblib-1.5.2-py3-none-any.whl (308 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 308.4/308.4 KB 530.1 MB/s eta 0:00:00


Successfully installed PyYAML-6.0.3 albumentations-1.3.1 imageio-2.37.0 joblib-1.5.2 lazy-loader-0.4 networkx-3.5 numpy-1.26.4 opencv-python-headless-4.11.0.86 packaging-25.0 pillow-11.3.0 qudida-0.0.4 scikit-image-0.25.2 scikit-learn-1.7.2 scipy-1.16.2 threadpoolctl-3.6.0 tifffile-2025.9.20 typing-extensions-4.15.0




Albumentations loaded from: None


In [None]:
# Quick near-duplicate scan via pHash (run separately from training).
import os, math, time, itertools
import numpy as np
import pandas as pd
import cv2

def phash64(img_bgr):
    # Convert to grayscale and compute 8x8 DCT-based perceptual hash (64-bit)
    try:
        img = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
        img = cv2.resize(img, (32, 32), interpolation=cv2.INTER_AREA)
        img = np.float32(img)
        dct = cv2.dct(img)
        dct_low = dct[:8, :8].copy()
        dct_low[0,0] = 0.0  # remove DC
        med = np.median(dct_low)
        bits = (dct_low > med).astype(np.uint8).reshape(-1)
        # pack into 64-bit int
        h = 0
        for b in bits:
            h = (h << 1) | int(b)
        return np.uint64(h)
    except Exception:
        return np.uint64(0)

def hamming64(a, b):
    return int(bin(int(a ^ b)).count('1'))

def run_phash_scan(images_dir='train_images', max_bucket_size=200, prefix_bits=16, ham_thresh=5, sample=None):
    t0 = time.time()
    imgs = sorted(os.listdir(images_dir))
    if sample is not None and sample < len(imgs):
        imgs = imgs[:sample]
    paths = [os.path.join(images_dir, x) for x in imgs]
    hashes = []
    for i, p in enumerate(paths):
        im = cv2.imread(p)
        if im is None:
            hashes.append(np.uint64(0)); continue
        hashes.append(phash64(im))
        if i % 1000 == 0:
            print(f'pHash {i}/{len(paths)} processed')
    hashes = np.array(hashes, dtype=np.uint64)
    print('Computed hashes in', f'{time.time()-t0:.1f}s')

    # Exact duplicates (identical pHash)
    df = pd.DataFrame({'image': imgs, 'phash': hashes})
    dup_groups = df.groupby('phash').filter(lambda x: len(x) > 1)
    if len(dup_groups) > 0:
        print('Exact-duplicate pHash groups:', dup_groups.groupby('phash').size().shape[0])
    else:
        print('No exact-duplicate pHash groups found')

    # Approximate duplicates by prefix bucketing to limit pairwise work
    prefix_shift = 64 - prefix_bits
    prefixes = (hashes >> np.uint64(prefix_shift)).astype(np.uint64)
    buckets = {}
    for idx, pref in enumerate(prefixes):
        buckets.setdefault(int(pref), []).append(idx)
    print('Buckets:', len(buckets))

    pairs = []  # (img_a, img_b, ham)
    checked = 0
    for pref, idxs in buckets.items():
        if len(idxs) <= 1:
            continue
        if len(idxs) > max_bucket_size:
            # skip giant buckets to keep runtime bounded
            continue
        for i, j in itertools.combinations(idxs, 2):
            ham = hamming64(hashes[i], hashes[j])
            if ham <= ham_thresh:
                pairs.append((imgs[i], imgs[j], ham))
        checked += 1
        if checked % 200 == 0:
            print(f'Checked {checked}/{len(buckets)} buckets; pairs so far={len(pairs)}')

    dup_df = pd.DataFrame(pairs, columns=['image_a','image_b','hamming'])
    dup_df.to_csv('near_duplicate_pairs.csv', index=False)
    print('Saved near_duplicate_pairs.csv with', len(dup_df), 'pairs; total time', f'{time.time()-t0:.1f}s')

    # If folds exist, summarize cross-fold duplicates
    if os.path.exists('train_folds.csv'):
        folds = pd.read_csv('train_folds.csv')[['image','fold']]
        m = dup_df.merge(folds.rename(columns={'image':'image_a'}), on='image_a', how='left')
        m = m.merge(folds.rename(columns={'image':'image_b','fold':'fold_b'}), on='image_b', how='left')
        m = m.rename(columns={'fold':'fold_a'})
        m.to_csv('near_duplicate_pairs_with_folds.csv', index=False)
        if len(m):
            cross = (m['fold_a'] != m['fold_b']).mean()
            print(f'Cross-fold duplicate rate: {cross:.3f} over {len(m)} pairs')
        else:
            print('No near-duplicate pairs to summarize')

print('To run: run_phash_scan(images_dir="train_images", sample=None)')