In [1]:
# Check GPU availability
import subprocess
import sys
result = subprocess.run(['nvidia-smi'], capture_output=True, text=True)
print(result.stdout if result.returncode == 0 else 'GPU not available')

Sun Sep 28 06:47:03 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.144.06             Driver Version: 550.144.06     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A10-24Q                 On  |   00000002:00:00.0 Off |                    0 |
| N/A   N/A    P0             N/A /  N/A  |     182MiB /  24512MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [2]:
import subprocess
import sys
import os
import shutil
from pathlib import Path

def pip(*args):
    print('>', *args, flush=True)
    subprocess.run([sys.executable, '-m', 'pip', *args], check=True)

# 0) Hard reset any prior torch stacks
for pkg in ('torch', 'torchvision', 'torchaudio'):
    subprocess.run([sys.executable, '-m', 'pip', 'uninstall', '-y', pkg], check=False)

# Clean stray site dirs
for d in (
    '/app/.pip-target/torch',
    '/app/.pip-target/torch-2.8.0.dist-info',
    '/app/.pip-target/torch-2.4.1.dist-info',
    '/app/.pip-target/torchvision',
    '/app/.pip-target/torchvision-0.23.0.dist-info',
    '/app/.pip-target/torchvision-0.19.1.dist-info',
    '/app/.pip-target/torchaudio',
    '/app/.pip-target/torchaudio-2.8.0.dist-info',
    '/app/.pip-target/torchaudio-2.4.1.dist-info',
    '/app/.pip-target/torchgen',
    '/app/.pip-target/functorch',
):
    if os.path.exists(d):
        print('Removing', d)
        shutil.rmtree(d, ignore_errors=True)

# 1) Install the EXACT cu121 torch stack FIRST
pip('install',
    '--index-url', 'https://download.pytorch.org/whl/cu121',
    '--extra-index-url', 'https://pypi.org/simple',
    'torch==2.4.1', 'torchvision==0.19.1', 'torchaudio==2.4.1')

# 2) Create a constraints file
Path('constraints.txt').write_text(
    'torch==2.4.1\n'
    'torchvision==0.19.1\n'
    'torchaudio==2.4.1\n'
)

# 3) Install NON-torch deps
pip('install', '-c', 'constraints.txt',
    'transformers==4.44.2', 'accelerate==0.34.2',
    'datasets==2.21.0', 'evaluate==0.4.2',
    'sentencepiece', 'scikit-learn',
    '--upgrade-strategy', 'only-if-needed')

# 4) Sanity check
import torch
print('torch:', torch.__version__, 'built CUDA:', getattr(torch.version, 'cuda', None))
print('CUDA available:', torch.cuda.is_available())
assert str(getattr(torch.version, 'cuda', '')).startswith('12.1'), f'Wrong CUDA build: {torch.version.cuda}'
assert torch.cuda.is_available(), 'CUDA not available'
print('GPU:', torch.cuda.get_device_name(0))





> install --index-url https://download.pytorch.org/whl/cu121 --extra-index-url https://pypi.org/simple torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1




Looking in indexes: https://download.pytorch.org/whl/cu121, https://pypi.org/simple


Collecting torch==2.4.1
  Downloading https://download.pytorch.org/whl/cu121/torch-2.4.1%2Bcu121-cp311-cp311-linux_x86_64.whl (799.0 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 799.0/799.0 MB 537.3 MB/s eta 0:00:00


Collecting torchvision==0.19.1
  Downloading https://download.pytorch.org/whl/cu121/torchvision-0.19.1%2Bcu121-cp311-cp311-linux_x86_64.whl (7.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.1/7.1 MB 339.3 MB/s eta 0:00:00


Collecting torchaudio==2.4.1
  Downloading https://download.pytorch.org/whl/cu121/torchaudio-2.4.1%2Bcu121-cp311-cp311-linux_x86_64.whl (3.4 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.4/3.4 MB 377.0 MB/s eta 0:00:00


Collecting nvidia-cusparse-cu12==12.1.0.106
  Downloading nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 196.0/196.0 MB 200.2 MB/s eta 0:00:00


Collecting nvidia-nccl-cu12==2.20.5
  Downloading nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl (176.2 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 176.2/176.2 MB 242.6 MB/s eta 0:00:00


Collecting sympy
  Downloading sympy-1.14.0-py3-none-any.whl (6.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.3/6.3 MB 142.4 MB/s eta 0:00:00
Collecting filelock
  Downloading filelock-3.19.1-py3-none-any.whl (15 kB)


Collecting triton==3.0.0
  Downloading triton-3.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (209.4 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 209.4/209.4 MB 391.4 MB/s eta 0:00:00


Collecting nvidia-cuda-runtime-cu12==12.1.105
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 823.6/823.6 KB 470.6 MB/s eta 0:00:00


Collecting nvidia-cuda-cupti-cu12==12.1.105
  Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 14.1/14.1 MB 242.7 MB/s eta 0:00:00


Collecting typing-extensions>=4.8.0
  Downloading typing_extensions-4.15.0-py3-none-any.whl (44 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 44.6/44.6 KB 417.3 MB/s eta 0:00:00


Collecting nvidia-cuda-nvrtc-cu12==12.1.105
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 23.7/23.7 MB 261.8 MB/s eta 0:00:00


Collecting nvidia-cudnn-cu12==9.1.0.70
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 664.8/664.8 MB 256.6 MB/s eta 0:00:00


Collecting nvidia-nvtx-cu12==12.1.105
  Downloading nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 99.1/99.1 KB 456.6 MB/s eta 0:00:00


Collecting networkx
  Downloading networkx-3.5-py3-none-any.whl (2.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.0/2.0 MB 344.5 MB/s eta 0:00:00


Collecting nvidia-cufft-cu12==11.0.2.54
  Downloading nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 121.6/121.6 MB 162.0 MB/s eta 0:00:00


Collecting fsspec
  Downloading fsspec-2025.9.0-py3-none-any.whl (199 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 199.3/199.3 KB 499.0 MB/s eta 0:00:00
Collecting nvidia-cublas-cu12==12.1.3.1
  Downloading nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 410.6/410.6 MB 229.2 MB/s eta 0:00:00


Collecting nvidia-curand-cu12==10.3.2.106
  Downloading nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 56.5/56.5 MB 225.9 MB/s eta 0:00:00


Collecting nvidia-cusolver-cu12==11.4.5.107
  Downloading nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 124.2/124.2 MB 207.7 MB/s eta 0:00:00


Collecting jinja2
  Downloading jinja2-3.1.6-py3-none-any.whl (134 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 134.9/134.9 KB 437.5 MB/s eta 0:00:00


Collecting numpy
  Downloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 18.3/18.3 MB 238.1 MB/s eta 0:00:00


Collecting pillow!=8.3.*,>=5.3.0
  Downloading pillow-11.3.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (6.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.6/6.6 MB 234.5 MB/s eta 0:00:00


Collecting nvidia-nvjitlink-cu12
  Downloading nvidia_nvjitlink_cu12-12.9.86-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl (39.7 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 39.7/39.7 MB 255.2 MB/s eta 0:00:00


Collecting MarkupSafe>=2.0
  Downloading markupsafe-3.0.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (22 kB)


Collecting mpmath<1.4,>=1.1.0
  Downloading mpmath-1.3.0-py3-none-any.whl (536 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 536.2/536.2 KB 490.1 MB/s eta 0:00:00


Installing collected packages: mpmath, typing-extensions, sympy, pillow, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, numpy, networkx, MarkupSafe, fsspec, filelock, triton, nvidia-cusparse-cu12, nvidia-cudnn-cu12, jinja2, nvidia-cusolver-cu12, torch, torchvision, torchaudio


Successfully installed MarkupSafe-3.0.3 filelock-3.19.1 fsspec-2025.9.0 jinja2-3.1.6 mpmath-1.3.0 networkx-3.5 numpy-1.26.4 nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.20.5 nvidia-nvjitlink-cu12-12.9.86 nvidia-nvtx-cu12-12.1.105 pillow-11.3.0 sympy-1.14.0 torch-2.4.1+cu121 torchaudio-2.4.1+cu121 torchvision-0.19.1+cu121 triton-3.0.0 typing-extensions-4.15.0


> install -c constraints.txt transformers==4.44.2 accelerate==0.34.2 datasets==2.21.0 evaluate==0.4.2 sentencepiece scikit-learn --upgrade-strategy only-if-needed


Collecting transformers==4.44.2
  Downloading transformers-4.44.2-py3-none-any.whl (9.5 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 9.5/9.5 MB 141.4 MB/s eta 0:00:00
Collecting accelerate==0.34.2
  Downloading accelerate-0.34.2-py3-none-any.whl (324 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 324.4/324.4 KB 506.8 MB/s eta 0:00:00
Collecting datasets==2.21.0


  Downloading datasets-2.21.0-py3-none-any.whl (527 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 527.3/527.3 KB 492.0 MB/s eta 0:00:00
Collecting evaluate==0.4.2
  Downloading evaluate-0.4.2-py3-none-any.whl (84 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 84.1/84.1 KB 423.7 MB/s eta 0:00:00
Collecting sentencepiece
  Downloading sentencepiece-0.2.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (1.4 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.4/1.4 MB 271.6 MB/s eta 0:00:00


Collecting scikit-learn
  Downloading scikit_learn-1.7.2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (9.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 9.7/9.7 MB 329.1 MB/s eta 0:00:00


Collecting pyyaml>=5.1
  Downloading pyyaml-6.0.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (806 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 806.6/806.6 KB 523.4 MB/s eta 0:00:00
Collecting requests
  Downloading requests-2.32.5-py3-none-any.whl (64 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 64.7/64.7 KB 461.9 MB/s eta 0:00:00


Collecting safetensors>=0.4.1
  Downloading safetensors-0.6.2-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (485 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 485.8/485.8 KB 543.0 MB/s eta 0:00:00
Collecting packaging>=20.0
  Downloading packaging-25.0-py3-none-any.whl (66 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 66.5/66.5 KB 442.0 MB/s eta 0:00:00
Collecting tqdm>=4.27
  Downloading tqdm-4.67.1-py3-none-any.whl (78 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 78.5/78.5 KB 473.9 MB/s eta 0:00:00
Collecting filelock
  Downloading filelock-3.19.1-py3-none-any.whl (15 kB)


Collecting regex!=2019.12.17
  Downloading regex-2025.9.18-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (798 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 799.0/799.0 KB 304.8 MB/s eta 0:00:00


Collecting numpy>=1.17
  Downloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 18.3/18.3 MB 238.0 MB/s eta 0:00:00


Collecting tokenizers<0.20,>=0.19
  Downloading tokenizers-0.19.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.6/3.6 MB 251.0 MB/s eta 0:00:00
Collecting huggingface-hub<1.0,>=0.23.2
  Downloading huggingface_hub-0.35.1-py3-none-any.whl (563 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 563.3/563.3 KB 527.6 MB/s eta 0:00:00


Collecting torch>=1.10.0
  Downloading torch-2.4.1-cp311-cp311-manylinux1_x86_64.whl (797.1 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 797.1/797.1 MB 115.3 MB/s eta 0:00:00


Collecting psutil
  Downloading psutil-7.1.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (291 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 291.2/291.2 KB 554.4 MB/s eta 0:00:00


Collecting aiohttp
  Downloading aiohttp-3.12.15-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.7/1.7 MB 363.0 MB/s eta 0:00:00
Collecting fsspec[http]<=2024.6.1,>=2023.1.0
  Downloading fsspec-2024.6.1-py3-none-any.whl (177 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 177.6/177.6 KB 477.5 MB/s eta 0:00:00


Collecting pandas
  Downloading pandas-2.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (12.4 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 12.4/12.4 MB 308.0 MB/s eta 0:00:00


Collecting xxhash
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 194.8/194.8 KB 468.1 MB/s eta 0:00:00
Collecting pyarrow>=15.0.0
  Downloading pyarrow-21.0.0-cp311-cp311-manylinux_2_28_x86_64.whl (42.8 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 42.8/42.8 MB 304.5 MB/s eta 0:00:00
Collecting dill<0.3.9,>=0.3.0
  Downloading dill-0.3.8-py3-none-any.whl (116 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 116.3/116.3 KB 498.2 MB/s eta 0:00:00
Collecting multiprocess
  Downloading multiprocess-0.70.18-py311-none-any.whl (144 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 144.5/144.5 KB 462.1 MB/s eta 0:00:00


Collecting joblib>=1.2.0
  Downloading joblib-1.5.2-py3-none-any.whl (308 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 308.4/308.4 KB 527.1 MB/s eta 0:00:00
Collecting threadpoolctl>=3.1.0
  Downloading threadpoolctl-3.6.0-py3-none-any.whl (18 kB)


Collecting scipy>=1.8.0
  Downloading scipy-1.16.2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (35.9 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 35.9/35.9 MB 233.3 MB/s eta 0:00:00


Collecting aiohappyeyeballs>=2.5.0
  Downloading aiohappyeyeballs-2.6.1-py3-none-any.whl (15 kB)
Collecting propcache>=0.2.0
  Downloading propcache-0.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (213 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 213.5/213.5 KB 486.2 MB/s eta 0:00:00
Collecting aiosignal>=1.4.0
  Downloading aiosignal-1.4.0-py3-none-any.whl (7.5 kB)
Collecting attrs>=17.3.0
  Downloading attrs-25.3.0-py3-none-any.whl (63 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 63.8/63.8 KB 393.9 MB/s eta 0:00:00


Collecting frozenlist>=1.1.1
  Downloading frozenlist-1.7.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (235 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 235.3/235.3 KB 504.7 MB/s eta 0:00:00


Collecting yarl<2.0,>=1.17.0
  Downloading yarl-1.20.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (348 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 349.0/349.0 KB 512.6 MB/s eta 0:00:00


Collecting multidict<7.0,>=4.5
  Downloading multidict-6.6.4-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (246 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 246.7/246.7 KB 529.8 MB/s eta 0:00:00
Collecting hf-xet<2.0.0,>=1.1.3
  Downloading hf_xet-1.1.10-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.2 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.2/3.2 MB 123.7 MB/s eta 0:00:00
Collecting typing-extensions>=3.7.4.3
  Downloading typing_extensions-4.15.0-py3-none-any.whl (44 kB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 44.6/44.6 KB 401.8 MB/s eta 0:00:00
Collecting idna<4,>=2.5
  Downloading idna-3.10-py3-none-any.whl (70 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 70.4/70.4 KB 445.1 MB/s eta 0:00:00
Collecting charset_normalizer<4,>=2
  Downloading charset_normalizer-3.4.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (150 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 150.3/150.3 KB 459.9 MB/s eta 0:00:00
Collecting certifi>=2017.4.17
  Downloading certifi-2025.8.3-py3-none-any.whl (161 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 161.2/161.2 KB 502.0 MB/s eta 0:00:00


Collecting urllib3<3,>=1.21.1
  Downloading urllib3-2.5.0-py3-none-any.whl (129 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 129.8/129.8 KB 482.7 MB/s eta 0:00:00
Collecting nvidia-cusparse-cu12==12.1.0.106
  Downloading nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 196.0/196.0 MB 251.2 MB/s eta 0:00:00
Collecting nvidia-cuda-cupti-cu12==12.1.105
  Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 14.1/14.1 MB 237.9 MB/s eta 0:00:00
Collecting nvidia-curand-cu12==10.3.2.106
  Downloading nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 56.5/56.5 MB 244.6 MB/s eta 0:00:00
Collecting nvidia-cuda-nvrtc-cu12==12.1.105
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 23.7/23.7 MB 263.6 MB/s eta 0:00:00
Collecting nvidia-nccl-cu12==2.20.5


  Downloading nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl (176.2 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 176.2/176.2 MB 234.1 MB/s eta 0:00:00
Collecting nvidia-cufft-cu12==11.0.2.54
  Downloading nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 121.6/121.6 MB 253.3 MB/s eta 0:00:00
Collecting nvidia-nvtx-cu12==12.1.105
  Downloading nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 99.1/99.1 KB 473.0 MB/s eta 0:00:00
Collecting sympy
  Downloading sympy-1.14.0-py3-none-any.whl (6.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.3/6.3 MB 463.4 MB/s eta 0:00:00
Collecting jinja2
  Downloading jinja2-3.1.6-py3-none-any.whl (134 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 134.9/134.9 KB 510.0 MB/s eta 0:00:00
Collecting nvidia-cusolver-cu12==11.4.5.107


  Downloading nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 124.2/124.2 MB 245.0 MB/s eta 0:00:00
Collecting nvidia-cudnn-cu12==9.1.0.70
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 664.8/664.8 MB 218.3 MB/s eta 0:00:00


Collecting triton==3.0.0
  Downloading triton-3.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (209.4 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 209.4/209.4 MB 238.6 MB/s eta 0:00:00
Collecting nvidia-cuda-runtime-cu12==12.1.105
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 823.6/823.6 KB 298.3 MB/s eta 0:00:00
Collecting nvidia-cublas-cu12==12.1.3.1
  Downloading nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 410.6/410.6 MB 243.3 MB/s eta 0:00:00


Collecting networkx
  Downloading networkx-3.5-py3-none-any.whl (2.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.0/2.0 MB 549.6 MB/s eta 0:00:00
Collecting nvidia-nvjitlink-cu12
  Downloading nvidia_nvjitlink_cu12-12.9.86-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl (39.7 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 39.7/39.7 MB 241.5 MB/s eta 0:00:00
Collecting multiprocess
  Downloading multiprocess-0.70.17-py311-none-any.whl (144 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 144.3/144.3 KB 432.4 MB/s eta 0:00:00
  Downloading multiprocess-0.70.16-py311-none-any.whl (143 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 143.5/143.5 KB 439.8 MB/s eta 0:00:00
Collecting tzdata>=2022.7
  Downloading tzdata-2025.2-py2.py3-none-any.whl (347 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 347.8/347.8 KB 532.0 MB/s eta 0:00:00


Collecting pytz>=2020.1
  Downloading pytz-2025.2-py2.py3-none-any.whl (509 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 509.2/509.2 KB 533.8 MB/s eta 0:00:00
Collecting python-dateutil>=2.8.2
  Downloading python_dateutil-2.9.0.post0-py2.py3-none-any.whl (229 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 229.9/229.9 KB 508.6 MB/s eta 0:00:00


Collecting six>=1.5
  Downloading six-1.17.0-py2.py3-none-any.whl (11 kB)
Collecting MarkupSafe>=2.0
  Downloading markupsafe-3.0.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (22 kB)
Collecting mpmath<1.4,>=1.1.0
  Downloading mpmath-1.3.0-py3-none-any.whl (536 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 536.2/536.2 KB 534.6 MB/s eta 0:00:00


Installing collected packages: pytz, mpmath, xxhash, urllib3, tzdata, typing-extensions, tqdm, threadpoolctl, sympy, six, sentencepiece, safetensors, regex, pyyaml, pyarrow, psutil, propcache, packaging, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, numpy, networkx, multidict, MarkupSafe, joblib, idna, hf-xet, fsspec, frozenlist, filelock, dill, charset_normalizer, certifi, attrs, aiohappyeyeballs, yarl, triton, scipy, requests, python-dateutil, nvidia-cusparse-cu12, nvidia-cudnn-cu12, multiprocess, jinja2, aiosignal, scikit-learn, pandas, nvidia-cusolver-cu12, huggingface-hub, aiohttp, torch, tokenizers, transformers, datasets, accelerate, evaluate


Successfully installed MarkupSafe-3.0.3 accelerate-0.34.2 aiohappyeyeballs-2.6.1 aiohttp-3.12.15 aiosignal-1.4.0 attrs-25.3.0 certifi-2025.8.3 charset_normalizer-3.4.3 datasets-2.21.0 dill-0.3.8 evaluate-0.4.2 filelock-3.19.1 frozenlist-1.7.0 fsspec-2024.6.1 hf-xet-1.1.10 huggingface-hub-0.35.1 idna-3.10 jinja2-3.1.6 joblib-1.5.2 mpmath-1.3.0 multidict-6.6.4 multiprocess-0.70.16 networkx-3.5 numpy-1.26.4 nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.20.5 nvidia-nvjitlink-cu12-12.9.86 nvidia-nvtx-cu12-12.1.105 packaging-25.0 pandas-2.3.2 propcache-0.3.2 psutil-7.1.0 pyarrow-21.0.0 python-dateutil-2.9.0.post0 pytz-2025.2 pyyaml-6.0.3 regex-2025.9.18 requests-2.32.5 safetensors-0.6.2 scikit-learn-1.7.2 scipy-1.16.2 sentencepiece-0.2.1 six-1.17.0









torch: 2.4.1+cu121 built CUDA: 12.1
CUDA available: True
GPU: NVIDIA A10-24Q


In [5]:
import json

# Load and inspect train metadata structure
with open('train_metadata.json', 'r') as f:
    train_data = json.load(f)

print(f'train_data type: {type(train_data)}')
print(f'train_data keys: {list(train_data.keys())}')

if 'images' in train_data:
    print(f'Number of images: {len(train_data["images"])}')
    print('Sample image:', json.dumps(train_data['images'][0], indent=2))

if 'annotations' in train_data:
    print(f'Number of annotations: {len(train_data["annotations"])}')
    print('Sample annotation:', json.dumps(train_data['annotations'][0], indent=2))

if 'categories' in train_data:
    print(f'Number of categories: {len(train_data["categories"])}')
    print('Sample category:', json.dumps(train_data['categories'][0], indent=2))

train_data type: <class 'dict'>
train_data keys: ['annotations', 'categories', 'distances', 'genera', 'images', 'institutions', 'license']
Number of images: 665720
Sample image: {
  "file_name": "000/00/00000__001.jpg",
  "image_id": "00000__001",
  "license": 0
}
Number of annotations: 665720
Sample annotation: {
  "category_id": 0,
  "genus_id": 1,
  "image_id": "00000__001",
  "institution_id": 52
}
Number of categories: 15501
Sample category: {
  "authors": "(Douglas ex Loudon) J.Forbes",
  "category_id": 0,
  "family": "Pinaceae",
  "genus": "Abies",
  "scientificName": "Abies amabilis (Douglas ex Loudon) J.Forbes",
  "species": "amabilis"
}


In [7]:
import json
import pandas as pd
import numpy as np
from collections import Counter

# Load train metadata
with open('train_metadata.json', 'r') as f:
    train_data = json.load(f)

images_df = pd.DataFrame(train_data['images'])
annotations_df = pd.DataFrame(train_data['annotations'])
train_df = images_df.merge(annotations_df, on='image_id')

print(f'Train samples: {len(train_df)}')
print(f'Train columns: {train_df.columns.tolist()}')
print(f'Unique classes: {train_df["category_id"].nunique()}')

# Class distribution
class_counts = Counter(train_df['category_id'])
print(f'Min images per class: {min(class_counts.values())}')
print(f'Max images per class: {max(class_counts.values())}')
print(f'Mean images per class: {np.mean(list(class_counts.values())):.2f}')

# Top 10 classes
top_classes = class_counts.most_common(10)
print('Top 10 classes:', top_classes)

# Load test metadata
with open('test_metadata.json', 'r') as f:
    test_data = json.load(f)

test_df = pd.DataFrame(test_data)
print(f'Test samples: {len(test_df)}')
print(f'Test columns: {test_df.columns.tolist()}')

Train samples: 665720
Train columns: ['file_name', 'image_id', 'license', 'category_id', 'genus_id', 'institution_id']
Unique classes: 15501
Min images per class: 4
Max images per class: 64
Mean images per class: 42.95
Top 10 classes: [(19, 64), (123, 64), (207, 64), (230, 64), (231, 64), (297, 64), (364, 64), (572, 64), (573, 64), (578, 64)]


Test samples: 174052
Test columns: ['file_name', 'image_id', 'license']


## Initial Plan for Herbarium 2022 Competition

### Dataset Summary
- Train: 665,720 images, 15,501 classes (long-tail: 4-64 images/class, mean ~43)
- Test: 174,052 images
- Metric: Macro F1-score (emphasizes balanced performance across classes)
- Goal: Gold medal ≥0.846, Silver ≥0.754, Bronze ≥0.597

### Key Challenges
- Extreme class imbalance: Many rare classes with few images
- Large dataset: Need efficient training with GPU
- Image classification: Botanical images may require specific augmentations (e.g., rotations for leaves/flowers)

### Strategy Outline
1. **Cross-Validation**: Use StratifiedKFold(n_splits=5) to ensure balanced class representation in folds. Save fold indices to disk for reproducibility.
2. **Baseline Model**: EfficientNet-B3 from timm library (pre-trained on ImageNet). Use PyTorch for GPU acceleration.
3. **Loss Function**: Cross-entropy with class weights (inverse frequency) to handle imbalance. Consider focal loss later for better rare class handling.
4. **Data Loading & Augmentation**:
   - Use PyTorch DataLoader with num_workers=8.
   - Augmentations via Albumentations: RandomResizedCrop, HorizontalFlip, VerticalFlip, Rotate, ColorJitter, HueSaturationValue.
   - Image size: 300x300 (EfficientNet-B3 input).
   - Subsample for quick baselines: Stratified sample ~50k images (ensuring min 1 per class).
5. **Training Setup**:
   - Optimizer: AdamW, lr=1e-3, weight_decay=1e-4.
   - Scheduler: CosineAnnealingLR.
   - Early stopping on val macro F1.
   - Batch size: 32-64 depending on GPU memory.
   - Train on full dataset after baseline validation.
6. **Evaluation**: Compute macro F1 on OOF predictions. Use sklearn.metrics.f1_score(average='macro').
7. **Advanced Steps (post-baseline)**:
   - Ensemble: Blend EfficientNet variants (B0, B3, B5) + ResNet50.
   - Test-Time Augmentation (TTA).
   - Pseudo-labeling on high-confidence test predictions.
   - Feature engineering: Use genus_id or institution_id as additional features if applicable.

### Next Immediate Steps
- Install additional packages: timm, albumentations, scikit-learn (already partially installed).
- Create '01_baseline_modeling.ipynb' for data loading, CV setup, and training a quick subsample baseline.
- Request expert review on this plan, focusing on medal-winning strategies for imbalanced botanical image classification (e.g., specific augmentations, loss functions, ensemble techniques used in top solutions).

In [8]:
# Install additional packages for modeling
import subprocess
import sys
from pathlib import Path

def pip(*args):
    print('>', *args, flush=True)
    subprocess.run([sys.executable, '-m', 'pip', *args], check=True)

# Install timm, albumentations, opencv-python-headless
pip('install', '-c', 'constraints.txt',
    'timm', 'albumentations', 'opencv-python-headless',
    '--upgrade-strategy', 'only-if-needed')

> install -c constraints.txt timm albumentations opencv-python-headless --upgrade-strategy only-if-needed


Collecting timm
  Downloading timm-1.0.20-py3-none-any.whl (2.5 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.5/2.5 MB 56.9 MB/s eta 0:00:00
Collecting albumentations
  Downloading albumentations-2.0.8-py3-none-any.whl (369 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 369.4/369.4 KB 356.4 MB/s eta 0:00:00


Collecting opencv-python-headless
  Downloading opencv_python_headless-4.12.0.88-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (54.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 54.0/54.0 MB 309.6 MB/s eta 0:00:00


Collecting torch
  Downloading torch-2.4.1-cp311-cp311-manylinux1_x86_64.whl (797.1 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 797.1/797.1 MB 420.3 MB/s eta 0:00:00


Collecting safetensors
  Downloading safetensors-0.6.2-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (485 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 485.8/485.8 KB 514.2 MB/s eta 0:00:00
Collecting huggingface_hub
  Downloading huggingface_hub-0.35.1-py3-none-any.whl (563 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 563.3/563.3 KB 509.6 MB/s eta 0:00:00
Collecting torchvision
  Downloading torchvision-0.19.1-cp311-cp311-manylinux1_x86_64.whl (7.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.0/7.0 MB 401.5 MB/s eta 0:00:00


Collecting pyyaml
  Downloading pyyaml-6.0.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (806 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 806.6/806.6 KB 416.5 MB/s eta 0:00:00


Collecting scipy>=1.10.0
  Downloading scipy-1.16.2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (35.9 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 35.9/35.9 MB 282.9 MB/s eta 0:00:00


Collecting pydantic>=2.9.2
  Downloading pydantic-2.11.9-py3-none-any.whl (444 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 444.9/444.9 KB 503.0 MB/s eta 0:00:00
Collecting albucore==0.0.24
  Downloading albucore-0.0.24-py3-none-any.whl (15 kB)


Collecting numpy>=1.24.4
  Downloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 18.3/18.3 MB 295.3 MB/s eta 0:00:00


Collecting stringzilla>=3.10.4
  Downloading stringzilla-4.0.14-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl (496 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 496.5/496.5 KB 372.2 MB/s eta 0:00:00


Collecting simsimd>=5.9.2
  Downloading simsimd-6.5.3-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (1.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.1/1.1 MB 222.8 MB/s eta 0:00:00
Collecting opencv-python-headless
  Downloading opencv_python_headless-4.11.0.86-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (50.0 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 50.0/50.0 MB 174.8 MB/s eta 0:00:00
Collecting annotated-types>=0.6.0
  Downloading annotated_types-0.7.0-py3-none-any.whl (13 kB)
Collecting typing-inspection>=0.4.0
  Downloading typing_inspection-0.4.1-py3-none-any.whl (14 kB)


Collecting pydantic-core==2.33.2
  Downloading pydantic_core-2.33.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.0/2.0 MB 499.8 MB/s eta 0:00:00
Collecting typing-extensions>=4.12.2
  Downloading typing_extensions-4.15.0-py3-none-any.whl (44 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 44.6/44.6 KB 391.5 MB/s eta 0:00:00


Collecting hf-xet<2.0.0,>=1.1.3
  Downloading hf_xet-1.1.10-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.2 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.2/3.2 MB 490.0 MB/s eta 0:00:00
Collecting packaging>=20.9
  Downloading packaging-25.0-py3-none-any.whl (66 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 66.5/66.5 KB 400.1 MB/s eta 0:00:00
Collecting tqdm>=4.42.1
  Downloading tqdm-4.67.1-py3-none-any.whl (78 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 78.5/78.5 KB 426.6 MB/s eta 0:00:00
Collecting fsspec>=2023.5.0
  Downloading fsspec-2025.9.0-py3-none-any.whl (199 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 199.3/199.3 KB 452.0 MB/s eta 0:00:00
Collecting filelock
  Downloading filelock-3.19.1-py3-none-any.whl (15 kB)
Collecting requests
  Downloading requests-2.32.5-py3-none-any.whl (64 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 64.7/64.7 KB 428.0 MB/s eta 0:00:00


Collecting sympy
  Downloading sympy-1.14.0-py3-none-any.whl (6.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.3/6.3 MB 273.8 MB/s eta 0:00:00
Collecting nvidia-cusparse-cu12==12.1.0.106
  Downloading nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 196.0/196.0 MB 235.7 MB/s eta 0:00:00
Collecting nvidia-curand-cu12==10.3.2.106
  Downloading nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 56.5/56.5 MB 250.6 MB/s eta 0:00:00
Collecting nvidia-cublas-cu12==12.1.3.1
  Downloading nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 410.6/410.6 MB 183.1 MB/s eta 0:00:00


Collecting nvidia-cuda-runtime-cu12==12.1.105
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 823.6/823.6 KB 502.1 MB/s eta 0:00:00
Collecting nvidia-cufft-cu12==11.0.2.54
  Downloading nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 121.6/121.6 MB 296.5 MB/s eta 0:00:00
Collecting nvidia-cudnn-cu12==9.1.0.70
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 664.8/664.8 MB 364.6 MB/s eta 0:00:00


Collecting nvidia-cusolver-cu12==11.4.5.107
  Downloading nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 124.2/124.2 MB 255.1 MB/s eta 0:00:00
Collecting triton==3.0.0
  Downloading triton-3.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (209.4 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 209.4/209.4 MB 244.5 MB/s eta 0:00:00
Collecting nvidia-nvtx-cu12==12.1.105
  Downloading nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 99.1/99.1 KB 430.0 MB/s eta 0:00:00


Collecting networkx
  Downloading networkx-3.5-py3-none-any.whl (2.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.0/2.0 MB 528.3 MB/s eta 0:00:00
Collecting nvidia-cuda-nvrtc-cu12==12.1.105
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 23.7/23.7 MB 252.2 MB/s eta 0:00:00
Collecting jinja2
  Downloading jinja2-3.1.6-py3-none-any.whl (134 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 134.9/134.9 KB 433.6 MB/s eta 0:00:00
Collecting nvidia-cuda-cupti-cu12==12.1.105
  Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 14.1/14.1 MB 252.9 MB/s eta 0:00:00
Collecting nvidia-nccl-cu12==2.20.5
  Downloading nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl (176.2 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 176.2/176.2 MB 247.5 MB/s eta 0:00:00
Collecting nvidia-nvjitlink-cu12
  Downloading nvidia_nvjitlink_cu12-12.9.86-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl (39.7 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 39.7/39.7 MB 201.7 MB/s eta 0:00:00


Collecting pillow!=8.3.*,>=5.3.0
  Downloading pillow-11.3.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (6.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.6/6.6 MB 265.2 MB/s eta 0:00:00


Collecting MarkupSafe>=2.0
  Downloading markupsafe-3.0.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (22 kB)
Collecting urllib3<3,>=1.21.1
  Downloading urllib3-2.5.0-py3-none-any.whl (129 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 129.8/129.8 KB 457.0 MB/s eta 0:00:00
Collecting certifi>=2017.4.17
  Downloading certifi-2025.8.3-py3-none-any.whl (161 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 161.2/161.2 KB 449.2 MB/s eta 0:00:00


Collecting charset_normalizer<4,>=2
  Downloading charset_normalizer-3.4.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (150 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 150.3/150.3 KB 455.6 MB/s eta 0:00:00
Collecting idna<4,>=2.5
  Downloading idna-3.10-py3-none-any.whl (70 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 70.4/70.4 KB 372.2 MB/s eta 0:00:00
Collecting mpmath<1.4,>=1.1.0
  Downloading mpmath-1.3.0-py3-none-any.whl (536 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 536.2/536.2 KB 367.0 MB/s eta 0:00:00


Installing collected packages: simsimd, mpmath, urllib3, typing-extensions, tqdm, sympy, stringzilla, safetensors, pyyaml, pillow, packaging, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, numpy, networkx, MarkupSafe, idna, hf-xet, fsspec, filelock, charset_normalizer, certifi, annotated-types, typing-inspection, triton, scipy, requests, pydantic-core, opencv-python-headless, nvidia-cusparse-cu12, nvidia-cudnn-cu12, jinja2, pydantic, nvidia-cusolver-cu12, huggingface_hub, albucore, torch, albumentations, torchvision, timm


ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
datasets 2.21.0 requires fsspec[http]<=2024.6.1,>=2023.1.0, but you have fsspec 2025.9.0 which is incompatible.


Successfully installed MarkupSafe-3.0.3 albucore-0.0.24 albumentations-2.0.8 annotated-types-0.7.0 certifi-2025.8.3 charset_normalizer-3.4.3 filelock-3.19.1 fsspec-2025.9.0 hf-xet-1.1.10 huggingface_hub-0.35.1 idna-3.10 jinja2-3.1.6 mpmath-1.3.0 networkx-3.5 numpy-1.26.4 nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.20.5 nvidia-nvjitlink-cu12-12.9.86 nvidia-nvtx-cu12-12.1.105 opencv-python-headless-4.11.0.86 packaging-25.0 pillow-11.3.0 pydantic-2.11.9 pydantic-core-2.33.2 pyyaml-6.0.3 requests-2.32.5 safetensors-0.6.2 scipy-1.16.2 simsimd-6.5.3 stringzilla-4.0.14 sympy-1.14.0 timm-1.0.20 torch-2.4.1 torchvision-0.19.1 tqdm-4.67.1 triton-3.0.0 typing-extensions-4.15.0 typing-inspection-0.4.1 urllib3-2.5.0


In [9]:
# Reinstall CUDA-enabled PyTorch (previous install used CPU version)
import subprocess
import sys

def pip(*args):
    print('>', *args, flush=True)
    subprocess.run([sys.executable, '-m', 'pip', *args], check=True)

# Uninstall current torch packages
pip('uninstall', '-y', 'torch', 'torchvision', 'torchaudio')

# Install CUDA 12.1 versions
pip('install',
    '--index-url', 'https://download.pytorch.org/whl/cu121',
    '--extra-index-url', 'https://pypi.org/simple',
    'torch==2.4.1+cu121', 'torchvision==0.19.1+cu121', 'torchaudio==2.4.1+cu121')

# Update constraints.txt with +cu121
from pathlib import Path
Path('constraints.txt').write_text(
    'torch==2.4.1+cu121\n'
    'torchvision==0.19.1+cu121\n'
    'torchaudio==2.4.1+cu121\n'
)

# Sanity check
import torch
print('torch:', torch.__version__, 'built CUDA:', getattr(torch.version, 'cuda', None))
print('CUDA available:', torch.cuda.is_available())
assert str(getattr(torch.version, 'cuda', '')).startswith('12.1'), f'Wrong CUDA build: {torch.version.cuda}'
assert torch.cuda.is_available(), 'CUDA not available'
print('GPU:', torch.cuda.get_device_name(0))

> uninstall -y torch torchvision torchaudio


Found existing installation: torch 2.4.1+cu121


Uninstalling torch-2.4.1+cu121:
  Successfully uninstalled torch-2.4.1+cu121
Found existing installation: torchvision 0.19.1+cu121
Uninstalling torchvision-0.19.1+cu121:
  Successfully uninstalled torchvision-0.19.1+cu121


Found existing installation: torchaudio 2.4.1+cu121
Uninstalling torchaudio-2.4.1+cu121:
  Successfully uninstalled torchaudio-2.4.1+cu121
> install --index-url https://download.pytorch.org/whl/cu121 --extra-index-url https://pypi.org/simple torch==2.4.1+cu121 torchvision==0.19.1+cu121 torchaudio==2.4.1+cu121


Looking in indexes: https://download.pytorch.org/whl/cu121, https://pypi.org/simple


Collecting torch==2.4.1+cu121
  Downloading https://download.pytorch.org/whl/cu121/torch-2.4.1%2Bcu121-cp311-cp311-linux_x86_64.whl (799.0 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 799.0/799.0 MB 490.1 MB/s eta 0:00:00


Collecting torchvision==0.19.1+cu121
  Downloading https://download.pytorch.org/whl/cu121/torchvision-0.19.1%2Bcu121-cp311-cp311-linux_x86_64.whl (7.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.1/7.1 MB 460.0 MB/s eta 0:00:00


Collecting torchaudio==2.4.1+cu121
  Downloading https://download.pytorch.org/whl/cu121/torchaudio-2.4.1%2Bcu121-cp311-cp311-linux_x86_64.whl (3.4 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.4/3.4 MB 491.1 MB/s eta 0:00:00


Collecting typing-extensions>=4.8.0
  Downloading typing_extensions-4.15.0-py3-none-any.whl (44 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 44.6/44.6 KB 3.0 MB/s eta 0:00:00


Collecting nvidia-cuda-runtime-cu12==12.1.105
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 823.6/823.6 KB 53.3 MB/s eta 0:00:00
Collecting nvidia-cusolver-cu12==11.4.5.107
  Downloading nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 124.2/124.2 MB 539.1 MB/s eta 0:00:00


Collecting nvidia-cuda-cupti-cu12==12.1.105
  Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 14.1/14.1 MB 566.5 MB/s eta 0:00:00


Collecting nvidia-cudnn-cu12==9.1.0.70
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 664.8/664.8 MB 532.3 MB/s eta 0:00:00


Collecting triton==3.0.0
  Downloading triton-3.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (209.4 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 209.4/209.4 MB 540.1 MB/s eta 0:00:00


Collecting nvidia-cusparse-cu12==12.1.0.106
  Downloading nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 196.0/196.0 MB 562.7 MB/s eta 0:00:00


Collecting nvidia-curand-cu12==10.3.2.106
  Downloading nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 56.5/56.5 MB 383.9 MB/s eta 0:00:00


Collecting nvidia-nvtx-cu12==12.1.105
  Downloading nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 99.1/99.1 KB 444.0 MB/s eta 0:00:00
Collecting networkx
  Downloading networkx-3.5-py3-none-any.whl (2.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.0/2.0 MB 506.1 MB/s eta 0:00:00


Collecting fsspec
  Downloading fsspec-2025.9.0-py3-none-any.whl (199 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 199.3/199.3 KB 476.2 MB/s eta 0:00:00


Collecting nvidia-cublas-cu12==12.1.3.1
  Downloading nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 410.6/410.6 MB 520.7 MB/s eta 0:00:00


Collecting nvidia-cuda-nvrtc-cu12==12.1.105
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 23.7/23.7 MB 532.7 MB/s eta 0:00:00
Collecting filelock
  Downloading filelock-3.19.1-py3-none-any.whl (15 kB)


Collecting nvidia-nccl-cu12==2.20.5
  Downloading nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl (176.2 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 176.2/176.2 MB 540.2 MB/s eta 0:00:00


Collecting sympy
  Downloading sympy-1.14.0-py3-none-any.whl (6.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.3/6.3 MB 538.0 MB/s eta 0:00:00


Collecting jinja2
  Downloading jinja2-3.1.6-py3-none-any.whl (134 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 134.9/134.9 KB 452.1 MB/s eta 0:00:00


Collecting nvidia-cufft-cu12==11.0.2.54
  Downloading nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 121.6/121.6 MB 547.6 MB/s eta 0:00:00


Collecting numpy
  Downloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 18.3/18.3 MB 524.1 MB/s eta 0:00:00


Collecting pillow!=8.3.*,>=5.3.0
  Downloading pillow-11.3.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (6.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.6/6.6 MB 507.5 MB/s eta 0:00:00
Collecting nvidia-nvjitlink-cu12
  Downloading nvidia_nvjitlink_cu12-12.9.86-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl (39.7 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 39.7/39.7 MB 533.2 MB/s eta 0:00:00


Collecting MarkupSafe>=2.0
  Downloading markupsafe-3.0.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (22 kB)


Collecting mpmath<1.4,>=1.1.0
  Downloading mpmath-1.3.0-py3-none-any.whl (536 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 536.2/536.2 KB 516.9 MB/s eta 0:00:00


Installing collected packages: mpmath, typing-extensions, sympy, pillow, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, numpy, networkx, MarkupSafe, fsspec, filelock, triton, nvidia-cusparse-cu12, nvidia-cudnn-cu12, jinja2, nvidia-cusolver-cu12, torch, torchvision, torchaudio


ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
datasets 2.21.0 requires fsspec[http]<=2024.6.1,>=2023.1.0, but you have fsspec 2025.9.0 which is incompatible.


Successfully installed MarkupSafe-3.0.3 filelock-3.19.1 fsspec-2025.9.0 jinja2-3.1.6 mpmath-1.3.0 networkx-3.5 numpy-1.26.4 nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.20.5 nvidia-nvjitlink-cu12-12.9.86 nvidia-nvtx-cu12-12.1.105 pillow-11.3.0 sympy-1.14.0 torch-2.4.1+cu121 torchaudio-2.4.1+cu121 torchvision-0.19.1+cu121 triton-3.0.0 typing-extensions-4.15.0




torch: 2.4.1+cu121 built CUDA: 12.1
CUDA available: True
GPU: NVIDIA A10-24Q
