In [9]:
import subprocess
import sys
import os
import shutil
from pathlib import Path

def pip(*args):
    print('>', *args, flush=True)
    subprocess.run([sys.executable, '-m', 'pip', *args], check=True)

# Set PIP_TARGET to writable directory
pip_target = '/app/.pip-target'
os.environ['PIP_TARGET'] = pip_target
if os.path.exists(pip_target):
    print('Removing existing', pip_target)
    shutil.rmtree(pip_target, ignore_errors=True)

# 0) Hard reset any prior torch stacks
for pkg in ('torch', 'torchvision', 'torchaudio'):
    subprocess.run([sys.executable, '-m', 'pip', 'uninstall', '-y', pkg], check=False)

# Clean stray site dirs
for d in (
    f'{pip_target}/torch',
    f'{pip_target}/torch-2.8.0.dist-info',
    f'{pip_target}/torch-2.4.1.dist-info',
    f'{pip_target}/torchvision',
    f'{pip_target}/torchvision-0.23.0.dist-info',
    f'{pip_target}/torchvision-0.19.1.dist-info',
    f'{pip_target}/torchaudio',
    f'{pip_target}/torchaudio-2.8.0.dist-info',
    f'{pip_target}/torchaudio-2.4.1.dist-info',
    f'{pip_target}/torchgen',
    f'{pip_target}/functorch',
):
    if os.path.exists(d):
        print('Removing', d)
        shutil.rmtree(d, ignore_errors=True)

# 1) Install the EXACT cu121 torch stack FIRST with --no-deps to avoid system dir installs
pip('install',
    '--index-url', 'https://download.pytorch.org/whl/cu121',
    '--extra-index-url', 'https://pypi.org/simple',
    '--force-reinstall', '--no-deps',
    'torch==2.4.1', 'torchvision==0.19.1', 'torchaudio==2.4.1')

# 2) Create a constraints file
Path('constraints.txt').write_text(
    'torch==2.4.1\n'
    'torchvision==0.19.1\n'
    'torchaudio==2.4.1\n'
)

# 3) Install NON-torch deps
pip('install', '-c', 'constraints.txt',
    'transformers==4.44.2', 'accelerate==0.34.2',
    'datasets==2.21.0', 'evaluate==0.4.2',
    'sentencepiece', 'scikit-learn',
    '--upgrade-strategy', 'only-if-needed')

# 4) Sanity gate - add pip_target to sys.path
sys.path.insert(0, pip_target)
import torch
print('torch:', torch.__version__, 'built CUDA:', getattr(torch.version, 'cuda', None))
print('CUDA available:', torch.cuda.is_available())
assert str(getattr(torch.version,'cuda','')).startswith('12.1'), f'Wrong CUDA build: {torch.version.cuda}'
assert torch.cuda.is_available(), 'CUDA not available'
print('GPU:', torch.cuda.get_device_name(0))

# Install additional packages with PIP_TARGET
pip('install', 'rank_bm25')
pip('install', 'langdetect')
pip('install', 'indic-nlp-library', 'pyarrow')

# Downgrade fsspec
pip('install', '-c', 'constraints.txt', 'fsspec[http]<=2024.6.1,>=2023.1.0', '--upgrade')

# Verify additional imports
try:
    from rank_bm25 import BM25Okapi
    print('BM25 available')
except ImportError:
    print('BM25 not available')
try:
    from langdetect import detect
    print('langdetect available')
except ImportError:
    print('langdetect not available')
print('Environment setup complete')

Removing existing /app/.pip-target






> install --index-url https://download.pytorch.org/whl/cu121 --extra-index-url https://pypi.org/simple --force-reinstall --no-deps torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1




Looking in indexes: https://download.pytorch.org/whl/cu121, https://pypi.org/simple


Collecting torch==2.4.1
  Downloading https://download.pytorch.org/whl/cu121/torch-2.4.1%2Bcu121-cp311-cp311-linux_x86_64.whl (799.0 MB)


Collecting torchvision==0.19.1
  Downloading https://download.pytorch.org/whl/cu121/torchvision-0.19.1%2Bcu121-cp311-cp311-linux_x86_64.whl (7.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.1/7.1 MB 413.1 MB/s eta 0:00:00


Collecting torchaudio==2.4.1
  Downloading https://download.pytorch.org/whl/cu121/torchaudio-2.4.1%2Bcu121-cp311-cp311-linux_x86_64.whl (3.4 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.4/3.4 MB 488.9 MB/s eta 0:00:00
Installing collected packages: torchaudio, torchvision, torch


Successfully installed torch-2.4.1+cu121 torchaudio-2.4.1+cu121 torchvision-0.19.1+cu121


> install -c constraints.txt transformers==4.44.2 accelerate==0.34.2 datasets==2.21.0 evaluate==0.4.2 sentencepiece scikit-learn --upgrade-strategy only-if-needed


Collecting transformers==4.44.2
  Downloading transformers-4.44.2-py3-none-any.whl (9.5 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 9.5/9.5 MB 154.4 MB/s eta 0:00:00
Collecting accelerate==0.34.2
  Downloading accelerate-0.34.2-py3-none-any.whl (324 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 324.4/324.4 KB 511.3 MB/s eta 0:00:00
Collecting datasets==2.21.0
  Downloading datasets-2.21.0-py3-none-any.whl (527 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 527.3/527.3 KB 498.5 MB/s eta 0:00:00


Collecting evaluate==0.4.2
  Downloading evaluate-0.4.2-py3-none-any.whl (84 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 84.1/84.1 KB 360.9 MB/s eta 0:00:00
Collecting sentencepiece
  Downloading sentencepiece-0.2.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (1.4 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.4/1.4 MB 515.8 MB/s eta 0:00:00


Collecting scikit-learn
  Downloading scikit_learn-1.7.2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (9.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 9.7/9.7 MB 263.0 MB/s eta 0:00:00


Collecting tokenizers<0.20,>=0.19
  Downloading tokenizers-0.19.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.6/3.6 MB 349.9 MB/s eta 0:00:00
Collecting tqdm>=4.27
  Downloading tqdm-4.67.1-py3-none-any.whl (78 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 78.5/78.5 KB 419.9 MB/s eta 0:00:00
Collecting huggingface-hub<1.0,>=0.23.2
  Downloading huggingface_hub-0.35.1-py3-none-any.whl (563 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 563.3/563.3 KB 500.2 MB/s eta 0:00:00


Collecting regex!=2019.12.17
  Downloading regex-2025.9.18-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (798 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 799.0/799.0 KB 516.5 MB/s eta 0:00:00
Collecting packaging>=20.0
  Downloading packaging-25.0-py3-none-any.whl (66 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 66.5/66.5 KB 432.5 MB/s eta 0:00:00
Collecting requests
  Downloading requests-2.32.5-py3-none-any.whl (64 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 64.7/64.7 KB 380.6 MB/s eta 0:00:00


Collecting safetensors>=0.4.1
  Downloading safetensors-0.6.2-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (485 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 485.8/485.8 KB 517.7 MB/s eta 0:00:00
Collecting filelock
  Downloading filelock-3.19.1-py3-none-any.whl (15 kB)
Collecting pyyaml>=5.1
  Downloading pyyaml-6.0.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (806 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 806.6/806.6 KB 511.4 MB/s eta 0:00:00


Collecting numpy>=1.17
  Downloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 18.3/18.3 MB 496.2 MB/s eta 0:00:00
Collecting torch>=1.10.0
  Downloading torch-2.4.1-cp311-cp311-manylinux1_x86_64.whl (797.1 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 797.1/797.1 MB 298.8 MB/s eta 0:00:00


Collecting psutil
  Downloading psutil-7.1.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (291 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 291.2/291.2 KB 499.7 MB/s eta 0:00:00
Collecting fsspec[http]<=2024.6.1,>=2023.1.0
  Downloading fsspec-2024.6.1-py3-none-any.whl (177 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 177.6/177.6 KB 450.8 MB/s eta 0:00:00


Collecting xxhash
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 194.8/194.8 KB 454.1 MB/s eta 0:00:00
Collecting pyarrow>=15.0.0
  Downloading pyarrow-21.0.0-cp311-cp311-manylinux_2_28_x86_64.whl (42.8 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 42.8/42.8 MB 239.8 MB/s eta 0:00:00


Collecting aiohttp
  Downloading aiohttp-3.12.15-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.7/1.7 MB 267.2 MB/s eta 0:00:00
Collecting pandas
  Downloading pandas-2.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (12.4 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 12.4/12.4 MB 375.3 MB/s eta 0:00:00
Collecting dill<0.3.9,>=0.3.0
  Downloading dill-0.3.8-py3-none-any.whl (116 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 116.3/116.3 KB 400.1 MB/s eta 0:00:00
Collecting multiprocess
  Downloading multiprocess-0.70.18-py311-none-any.whl (144 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 144.5/144.5 KB 456.0 MB/s eta 0:00:00
Collecting threadpoolctl>=3.1.0


  Downloading threadpoolctl-3.6.0-py3-none-any.whl (18 kB)
Collecting scipy>=1.8.0
  Downloading scipy-1.16.2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (35.9 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 35.9/35.9 MB 284.6 MB/s eta 0:00:00
Collecting joblib>=1.2.0
  Downloading joblib-1.5.2-py3-none-any.whl (308 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 308.4/308.4 KB 496.6 MB/s eta 0:00:00


Collecting yarl<2.0,>=1.17.0
  Downloading yarl-1.20.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (348 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 349.0/349.0 KB 507.7 MB/s eta 0:00:00
Collecting aiohappyeyeballs>=2.5.0
  Downloading aiohappyeyeballs-2.6.1-py3-none-any.whl (15 kB)
Collecting frozenlist>=1.1.1
  Downloading frozenlist-1.7.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (235 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 235.3/235.3 KB 400.5 MB/s eta 0:00:00


Collecting multidict<7.0,>=4.5
  Downloading multidict-6.6.4-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (246 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 246.7/246.7 KB 477.7 MB/s eta 0:00:00
Collecting aiosignal>=1.4.0
  Downloading aiosignal-1.4.0-py3-none-any.whl (7.5 kB)
Collecting propcache>=0.2.0
  Downloading propcache-0.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (213 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 213.5/213.5 KB 486.2 MB/s eta 0:00:00
Collecting attrs>=17.3.0
  Downloading attrs-25.3.0-py3-none-any.whl (63 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 63.8/63.8 KB 392.3 MB/s eta 0:00:00


Collecting typing-extensions>=3.7.4.3
  Downloading typing_extensions-4.15.0-py3-none-any.whl (44 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 44.6/44.6 KB 379.7 MB/s eta 0:00:00
Collecting hf-xet<2.0.0,>=1.1.3
  Downloading hf_xet-1.1.10-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.2 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.2/3.2 MB 525.6 MB/s eta 0:00:00
Collecting urllib3<3,>=1.21.1
  Downloading urllib3-2.5.0-py3-none-any.whl (129 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 129.8/129.8 KB 472.9 MB/s eta 0:00:00
Collecting certifi>=2017.4.17
  Downloading certifi-2025.8.3-py3-none-any.whl (161 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 161.2/161.2 KB 474.0 MB/s eta 0:00:00
Collecting idna<4,>=2.5
  Downloading idna-3.10-py3-none-any.whl (70 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 70.4/70.4 KB 433.6 MB/s eta 0:00:00


Collecting charset_normalizer<4,>=2
  Downloading charset_normalizer-3.4.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (150 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 150.3/150.3 KB 464.4 MB/s eta 0:00:00
Collecting nvidia-cusparse-cu12==12.1.0.106
  Downloading nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 196.0/196.0 MB 541.7 MB/s eta 0:00:00
Collecting nvidia-nvtx-cu12==12.1.105
  Downloading nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 99.1/99.1 KB 348.3 MB/s eta 0:00:00
Collecting nvidia-cuda-nvrtc-cu12==12.1.105
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 23.7/23.7 MB 288.9 MB/s eta 0:00:00
Collecting nvidia-nccl-cu12==2.20.5
  Downloading nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl (176.2 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 176.2/176.2 MB 353.0 MB/s eta 0:00:00
Collecting nvidia-cuda-runtime-cu12==12.1.105
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 823.6/823.6 KB 528.9 MB/s eta 0:00:00
Collecting nvidia-cusolver-cu12==11.4.5.107
  Downloading nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 124.2/124.2 MB 547.0 MB/s eta 0:00:00
Collecting triton==3.0.0
  Downloading triton-3.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (209.4 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 209.4/209.4 MB 332.4 MB/s eta 0:00:00
Collecting nvidia-cuda-cupti-cu12==12.1.105
  Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 14.1/14.1 MB 306.8 MB/s eta 0:00:00
Collecting nvidia-cublas-cu12==12.1.3.1
  Downloading nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 410.6/410.6 MB 537.2 MB/s eta 0:00:00


Collecting sympy
  Downloading sympy-1.14.0-py3-none-any.whl (6.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.3/6.3 MB 532.2 MB/s eta 0:00:00
Collecting nvidia-cudnn-cu12==9.1.0.70
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 664.8/664.8 MB 402.5 MB/s eta 0:00:00


Collecting networkx
  Downloading networkx-3.5-py3-none-any.whl (2.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.0/2.0 MB 529.2 MB/s eta 0:00:00
Collecting nvidia-curand-cu12==10.3.2.106
  Downloading nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 56.5/56.5 MB 535.0 MB/s eta 0:00:00
Collecting nvidia-cufft-cu12==11.0.2.54


  Downloading nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 121.6/121.6 MB 554.4 MB/s eta 0:00:00
Collecting jinja2
  Downloading jinja2-3.1.6-py3-none-any.whl (134 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 134.9/134.9 KB 457.4 MB/s eta 0:00:00
Collecting nvidia-nvjitlink-cu12
  Downloading nvidia_nvjitlink_cu12-12.9.86-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl (39.7 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 39.7/39.7 MB 541.7 MB/s eta 0:00:00
Collecting multiprocess
  Downloading multiprocess-0.70.17-py311-none-any.whl (144 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 144.3/144.3 KB 458.0 MB/s eta 0:00:00
  Downloading multiprocess-0.70.16-py311-none-any.whl (143 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 143.5/143.5 KB 448.2 MB/s eta 0:00:00
Collecting pytz>=2020.1
  Downloading pytz-2025.2-py2.py3-none-any.whl (509 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 509.2/509.2 KB 498.8 MB/s eta 0:00:00


Collecting python-dateutil>=2.8.2
  Downloading python_dateutil-2.9.0.post0-py2.py3-none-any.whl (229 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 229.9/229.9 KB 513.0 MB/s eta 0:00:00
Collecting tzdata>=2022.7
  Downloading tzdata-2025.2-py2.py3-none-any.whl (347 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 347.8/347.8 KB 511.7 MB/s eta 0:00:00
Collecting six>=1.5
  Downloading six-1.17.0-py2.py3-none-any.whl (11 kB)


Collecting MarkupSafe>=2.0
  Downloading MarkupSafe-3.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (23 kB)
Collecting mpmath<1.4,>=1.1.0
  Downloading mpmath-1.3.0-py3-none-any.whl (536 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 536.2/536.2 KB 508.8 MB/s eta 0:00:00


Installing collected packages: pytz, mpmath, xxhash, urllib3, tzdata, typing-extensions, tqdm, threadpoolctl, sympy, six, sentencepiece, safetensors, regex, pyyaml, pyarrow, psutil, propcache, packaging, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, numpy, networkx, multidict, MarkupSafe, joblib, idna, hf-xet, fsspec, frozenlist, filelock, dill, charset_normalizer, certifi, attrs, aiohappyeyeballs, yarl, triton, scipy, requests, python-dateutil, nvidia-cusparse-cu12, nvidia-cudnn-cu12, multiprocess, jinja2, aiosignal, scikit-learn, pandas, nvidia-cusolver-cu12, huggingface-hub, aiohttp, torch, tokenizers, transformers, datasets, accelerate, evaluate


Successfully installed MarkupSafe-3.0.2 accelerate-0.34.2 aiohappyeyeballs-2.6.1 aiohttp-3.12.15 aiosignal-1.4.0 attrs-25.3.0 certifi-2025.8.3 charset_normalizer-3.4.3 datasets-2.21.0 dill-0.3.8 evaluate-0.4.2 filelock-3.19.1 frozenlist-1.7.0 fsspec-2024.6.1 hf-xet-1.1.10 huggingface-hub-0.35.1 idna-3.10 jinja2-3.1.6 joblib-1.5.2 mpmath-1.3.0 multidict-6.6.4 multiprocess-0.70.16 networkx-3.5 numpy-1.26.4 nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.20.5 nvidia-nvjitlink-cu12-12.9.86 nvidia-nvtx-cu12-12.1.105 packaging-25.0 pandas-2.3.2 propcache-0.3.2 psutil-7.1.0 pyarrow-21.0.0 python-dateutil-2.9.0.post0 pytz-2025.2 pyyaml-6.0.3 regex-2025.9.18 requests-2.32.5 safetensors-0.6.2 scikit-learn-1.7.2 scipy-1.16.2 sentencepiece-0.2.1 six-1.17.0





torch: 2.4.1+cu121 built CUDA: 12.1
CUDA available: True
GPU: NVIDIA A10-24Q
> install rank_bm25


Collecting rank_bm25
  Downloading rank_bm25-0.2.2-py3-none-any.whl (8.6 kB)


Collecting numpy
  Downloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 18.3/18.3 MB 179.1 MB/s eta 0:00:00


Installing collected packages: numpy, rank_bm25


Successfully installed numpy-1.26.4 rank_bm25-0.2.2
> install langdetect




Collecting langdetect
  Downloading langdetect-1.0.9.tar.gz (981 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 981.5/981.5 KB 32.6 MB/s eta 0:00:00
  Preparing metadata (setup.py): started


  Preparing metadata (setup.py): finished with status 'done'
Collecting six
  Downloading six-1.17.0-py2.py3-none-any.whl (11 kB)
Building wheels for collected packages: langdetect
  Building wheel for langdetect (setup.py): started


  Building wheel for langdetect (setup.py): finished with status 'done'
  Created wheel for langdetect: filename=langdetect-1.0.9-py3-none-any.whl size=993242 sha256=66bf6e003ef4e33722b342e3e89edfd603a7151f2ee31f4e3c7d6c0e05e4d3da
  Stored in directory: /tmp/pip-ephem-wheel-cache-hjfn6opk/wheels/0a/f2/b2/e5ca405801e05eb7c8ed5b3b4bcf1fcabcd6272c167640072e
Successfully built langdetect


Installing collected packages: six, langdetect
Successfully installed langdetect-1.0.9 six-1.17.0
> install indic-nlp-library pyarrow




Collecting indic-nlp-library
  Downloading indic_nlp_library-0.92-py3-none-any.whl (40 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 40.3/40.3 KB 3.2 MB/s eta 0:00:00
Collecting pyarrow
  Downloading pyarrow-21.0.0-cp311-cp311-manylinux_2_28_x86_64.whl (42.8 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 42.8/42.8 MB 300.4 MB/s eta 0:00:00
Collecting sphinx-argparse
  Downloading sphinx_argparse-0.5.2-py3-none-any.whl (12 kB)


Collecting numpy
  Downloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 18.3/18.3 MB 526.0 MB/s eta 0:00:00
Collecting morfessor
  Downloading Morfessor-2.0.6-py3-none-any.whl (35 kB)
Collecting sphinx-rtd-theme
  Downloading sphinx_rtd_theme-3.0.2-py2.py3-none-any.whl (7.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.7/7.7 MB 225.1 MB/s eta 0:00:00


Collecting pandas
  Downloading pandas-2.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (12.4 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 12.4/12.4 MB 541.1 MB/s eta 0:00:00
Collecting pytz>=2020.1
  Downloading pytz-2025.2-py2.py3-none-any.whl (509 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 509.2/509.2 KB 489.2 MB/s eta 0:00:00
Collecting python-dateutil>=2.8.2
  Downloading python_dateutil-2.9.0.post0-py2.py3-none-any.whl (229 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 229.9/229.9 KB 480.6 MB/s eta 0:00:00


Collecting tzdata>=2022.7
  Downloading tzdata-2025.2-py2.py3-none-any.whl (347 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 347.8/347.8 KB 504.5 MB/s eta 0:00:00
Collecting docutils>=0.19
  Downloading docutils-0.22.2-py3-none-any.whl (632 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 632.7/632.7 KB 514.5 MB/s eta 0:00:00
Collecting sphinx>=5.1.0
  Downloading sphinx-8.2.3-py3-none-any.whl (3.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.6/3.6 MB 376.3 MB/s eta 0:00:00
Collecting docutils>=0.19
  Downloading docutils-0.21.2-py3-none-any.whl (587 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 587.4/587.4 KB 516.9 MB/s eta 0:00:00


Collecting sphinxcontrib-jquery<5,>=4
  Downloading sphinxcontrib_jquery-4.1-py2.py3-none-any.whl (121 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 121.1/121.1 KB 453.5 MB/s eta 0:00:00
Collecting six>=1.5
  Downloading six-1.17.0-py2.py3-none-any.whl (11 kB)
Collecting sphinxcontrib-htmlhelp>=2.0.6
  Downloading sphinxcontrib_htmlhelp-2.1.0-py3-none-any.whl (98 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 98.7/98.7 KB 432.2 MB/s eta 0:00:00
Collecting packaging>=23.0
  Downloading packaging-25.0-py3-none-any.whl (66 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 66.5/66.5 KB 346.2 MB/s eta 0:00:00
Collecting requests>=2.30.0
  Downloading requests-2.32.5-py3-none-any.whl (64 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 64.7/64.7 KB 318.8 MB/s eta 0:00:00


Collecting Jinja2>=3.1
  Downloading jinja2-3.1.6-py3-none-any.whl (134 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 134.9/134.9 KB 439.3 MB/s eta 0:00:00
Collecting babel>=2.13
  Downloading babel-2.17.0-py3-none-any.whl (10.2 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 10.2/10.2 MB 527.8 MB/s eta 0:00:00
Collecting alabaster>=0.7.14
  Downloading alabaster-1.0.0-py3-none-any.whl (13 kB)
Collecting Pygments>=2.17
  Downloading pygments-2.19.2-py3-none-any.whl (1.2 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.2/1.2 MB 511.7 MB/s eta 0:00:00
Collecting sphinxcontrib-qthelp>=1.0.6
  Downloading sphinxcontrib_qthelp-2.0.0-py3-none-any.whl (88 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 88.7/88.7 KB 341.0 MB/s eta 0:00:00
Collecting sphinxcontrib-serializinghtml>=1.1.9
  Downloading sphinxcontrib_serializinghtml-2.0.0-py3-none-any.whl (92 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 92.1/92.1 KB 424.9 MB/s eta 0:00:00
Collecting roman-numerals-py>=1.0.0
  Downloading

  Downloading sphinxcontrib_applehelp-2.0.0-py3-none-any.whl (119 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 119.3/119.3 KB 433.3 MB/s eta 0:00:00
Collecting snowballstemmer>=2.2
  Downloading snowballstemmer-3.0.1-py3-none-any.whl (103 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 103.3/103.3 KB 448.8 MB/s eta 0:00:00
Collecting MarkupSafe>=2.0
  Downloading MarkupSafe-3.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (23 kB)
Collecting urllib3<3,>=1.21.1
  Downloading urllib3-2.5.0-py3-none-any.whl (129 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 129.8/129.8 KB 462.7 MB/s eta 0:00:00


Collecting certifi>=2017.4.17
  Downloading certifi-2025.8.3-py3-none-any.whl (161 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 161.2/161.2 KB 457.8 MB/s eta 0:00:00
Collecting idna<4,>=2.5
  Downloading idna-3.10-py3-none-any.whl (70 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 70.4/70.4 KB 431.3 MB/s eta 0:00:00
Collecting charset_normalizer<4,>=2
  Downloading charset_normalizer-3.4.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (150 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 150.3/150.3 KB 449.4 MB/s eta 0:00:00


Installing collected packages: pytz, morfessor, urllib3, tzdata, sphinxcontrib-serializinghtml, sphinxcontrib-qthelp, sphinxcontrib-jsmath, sphinxcontrib-htmlhelp, sphinxcontrib-devhelp, sphinxcontrib-applehelp, snowballstemmer, six, roman-numerals-py, Pygments, pyarrow, packaging, numpy, MarkupSafe, imagesize, idna, docutils, charset_normalizer, certifi, babel, alabaster, requests, python-dateutil, Jinja2, sphinx, pandas, sphinxcontrib-jquery, sphinx-argparse, sphinx-rtd-theme, indic-nlp-library


Successfully installed Jinja2-3.1.6 MarkupSafe-3.0.2 Pygments-2.19.2 alabaster-1.0.0 babel-2.17.0 certifi-2025.8.3 charset_normalizer-3.4.3 docutils-0.21.2 idna-3.10 imagesize-1.4.1 indic-nlp-library-0.92 morfessor-2.0.6 numpy-1.26.4 packaging-25.0 pandas-2.3.2 pyarrow-21.0.0 python-dateutil-2.9.0.post0 pytz-2025.2 requests-2.32.5 roman-numerals-py-3.1.0 six-1.17.0 snowballstemmer-3.0.1 sphinx-8.2.3 sphinx-argparse-0.5.2 sphinx-rtd-theme-3.0.2 sphinxcontrib-applehelp-2.0.0 sphinxcontrib-devhelp-2.0.0 sphinxcontrib-htmlhelp-2.1.0 sphinxcontrib-jquery-4.1 sphinxcontrib-jsmath-1.0.1 sphinxcontrib-qthelp-2.0.0 sphinxcontrib-serializinghtml-2.0.0 tzdata-2025.2 urllib3-2.5.0






> install -c constraints.txt fsspec[http]<=2024.6.1,>=2023.1.0 --upgrade


Collecting fsspec[http]<=2024.6.1,>=2023.1.0
  Downloading fsspec-2024.6.1-py3-none-any.whl (177 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 177.6/177.6 KB 7.7 MB/s eta 0:00:00


Collecting aiohttp!=4.0.0a0,!=4.0.0a1
  Downloading aiohttp-3.12.15-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.7/1.7 MB 154.1 MB/s eta 0:00:00


Collecting yarl<2.0,>=1.17.0
  Downloading yarl-1.20.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (348 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 349.0/349.0 KB 414.8 MB/s eta 0:00:00
Collecting aiohappyeyeballs>=2.5.0
  Downloading aiohappyeyeballs-2.6.1-py3-none-any.whl (15 kB)
Collecting frozenlist>=1.1.1
  Downloading frozenlist-1.7.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (235 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 235.3/235.3 KB 454.8 MB/s eta 0:00:00


Collecting multidict<7.0,>=4.5
  Downloading multidict-6.6.4-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (246 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 246.7/246.7 KB 464.5 MB/s eta 0:00:00
Collecting aiosignal>=1.4.0
  Downloading aiosignal-1.4.0-py3-none-any.whl (7.5 kB)
Collecting propcache>=0.2.0
  Downloading propcache-0.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (213 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 213.5/213.5 KB 484.2 MB/s eta 0:00:00
Collecting attrs>=17.3.0
  Downloading attrs-25.3.0-py3-none-any.whl (63 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 63.8/63.8 KB 365.5 MB/s eta 0:00:00
Collecting typing-extensions>=4.2
  Downloading typing_extensions-4.15.0-py3-none-any.whl (44 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 44.6/44.6 KB 395.7 MB/s eta 0:00:00


Collecting idna>=2.0
  Downloading idna-3.10-py3-none-any.whl (70 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 70.4/70.4 KB 391.0 MB/s eta 0:00:00


Installing collected packages: typing-extensions, propcache, multidict, idna, fsspec, frozenlist, attrs, aiohappyeyeballs, yarl, aiosignal, aiohttp


Successfully installed aiohappyeyeballs-2.6.1 aiohttp-3.12.15 aiosignal-1.4.0 attrs-25.3.0 frozenlist-1.7.0 fsspec-2024.6.1 idna-3.10 multidict-6.6.4 propcache-0.3.2 typing-extensions-4.15.0 yarl-1.20.1
BM25 available
langdetect available
Environment setup complete


In [20]:
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
import gc
import ast
import sys
import copy
import json
import math
import random
import time
from datetime import datetime

import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, TensorDataset
from torch.cuda.amp import autocast, GradScaler

from transformers import (
    AutoTokenizer,
    AutoModel,
    AutoConfig,
    get_linear_schedule_with_warmup,
    TrainingArguments,
    Trainer,
    AutoModelForQuestionAnswering,
    )
from transformers import default_data_collator

from datasets import load_dataset
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.metrics import f1_score
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import hashlib

import subprocess
import shutil
import unicodedata

# Add pip_target to sys.path if not already
pip_target = '/app/.pip-target'
if pip_target not in sys.path:
    sys.path.insert(0, pip_target)

# BM25 and langdetect
BM25_AVAILABLE = False
try:
    from rank_bm25 import BM25Okapi
    BM25_AVAILABLE = True
    print('BM25 available')
except ImportError:
    print('BM25 not available, falling back to TF-IDF only')

LANGDETECT_AVAILABLE = False
try:
    from langdetect import detect
    LANGDETECT_AVAILABLE = True
    print('langdetect available')
except ImportError:
    print('langdetect not available, using script fallback')

# Script-based language detection fallback
def detect_lang(text):
    if not isinstance(text, str):
        return 'hindi'
    for c in text:
        if 0x0B80 <= ord(c) <= 0x0BFF:  # Tamil Unicode range
            return 'tamil'
    return 'hindi'

# Set seeds
def set_seed(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

set_seed(42)

# Constants with coach tweaks
DEBUG = False  # Set to True for rapid prototyping
MAX_LEN = 512
DOC_STRIDE = 128
N_SPLITS = 5
BATCH_SIZE = 2
GRAD_ACCUM_STEPS = 16
EPOCHS = 5
LR = 2.5e-5
WEIGHT_DECAY = 0.01
NEG_WEIGHT = 0.2
USE_RETRIEVAL = True
TOP_K_CHUNKS_TRAIN = 8
TOP_K_CHUNKS_EVAL_HINDI = 10
TOP_K_CHUNKS_EVAL_TAMIL = 35  # Coach tweak for better Tamil recall
CHUNK_SIZE = 1800
OVERLAP = 250
NEG_POS_RATIO = 2
MODEL_NAME = 'deepset/xlm-roberta-large-squad2'
PUNCT = '\u0964,.\uff0c!\uff01?\uff1f"\\\'\u201c\u201d\u2018\u2019()[]{}:;'
MAX_ANSWER_LENGTH = 80  # Coach tweak for longer spans

# Load data
train_df = pd.read_csv('train.csv')
test_df = pd.read_csv('test.csv')

if DEBUG:
    train_df = train_df.sample(n=200, random_state=42).reset_index(drop=True)
    print(f'DEBUG mode: using {len(train_df)} samples')
else:
    print(f'Full mode: using {len(train_df)} samples')

print('Train shape:', train_df.shape)
print('Test shape:', test_df.shape)

# Label alignment fix with progress tracking
print('Before fix_span')
def fix_span(row):
    ctx, ans, s = row['context'], row['answer_text'], row['answer_start']
    if s < 0 or ctx[s:s+len(ans)] != ans:
        idx = ctx.find(ans)
        if idx != -1:
            row['answer_start'] = idx
    return row

train_df = train_df.apply(fix_span, axis=1)
print('After fix_span')

# Context groups for CV (hash first 1024 chars to group same articles)
def get_context_hash(context):
    return hashlib.md5(context[:1024].encode()).hexdigest()

train_df['context_hash'] = train_df['context'].apply(get_context_hash)
print('Context hashes computed')

# Jaccard metric with NFKC normalization
def jaccard_word(pred, true):
    pred = unicodedata.normalize('NFKC', pred).lower()
    true = unicodedata.normalize('NFKC', true).lower()
    if not pred or not true:
        return 0.0
    pw, tw = set(pred.split()), set(true.split())
    return len(pw & tw) / len(pw | tw) if pw and tw else 0.0

def compute_jaccard(preds, trues):
    return np.mean([jaccard_word(p, t) for p, t in zip(preds, trues)])

# Assign language to test_df using langdetect or fallback
print('Assigning language to test_df...')
if LANGDETECT_AVAILABLE:
    test_df['language'] = test_df['question'].apply(lambda x: {'ta':'tamil','hi':'hindi'}.get(detect(x), 'hindi') if isinstance(x, str) else 'hindi')
else:
    test_df['language'] = test_df['question'].apply(detect_lang)
print('Test language dist:', test_df['language'].value_counts())

# CV splitting with StratifiedGroupKFold
sgkf = StratifiedGroupKFold(n_splits=N_SPLITS, shuffle=True, random_state=42)
train_df['fold'] = -1
for fold, (trn_idx, val_idx) in enumerate(sgkf.split(train_df, train_df['language'], groups=train_df['context_hash'])):
    train_df.loc[val_idx, 'fold'] = fold

print('Fold distribution:')
print(train_df.groupby(['fold', 'language']).size())
print(f'Folds created: {train_df["fold"].nunique()}')

N_FOLDS = 3 if DEBUG else N_SPLITS
print(f'Using {N_FOLDS} folds for training')

BM25 available
langdetect available


Full mode: using 1002 samples
Train shape: (1002, 6)
Test shape: (112, 4)
Before fix_span
After fix_span
Context hashes computed
Assigning language to test_df...
Test language dist: language
hindi    84
tamil    28
Name: count, dtype: int64


Fold distribution:
fold  language
0     hindi       133
      tamil        60
1     hindi       133
      tamil        71
2     hindi       126
      tamil        68
3     hindi       142
      tamil        70
4     hindi       128
      tamil        71
dtype: int64
Folds created: 5
Using 5 folds for training


In [18]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
print('Tokenizer loaded:', tokenizer.name_or_path)

# TF-IDF Retrieval setup with language-specific vectorizers
if USE_RETRIEVAL:
    print('Fitting language-specific TF-IDF vectorizers...')
    hindi_df = train_df[train_df['language'] == 'hindi']
    tamil_df = train_df[train_df['language'] == 'tamil']
    
    # Hindi vectorizer
    print('Processing Hindi...')
    hindi_questions = hindi_df['question'].tolist()
    hindi_contexts = hindi_df['context'].tolist()
    hindi_chunks = []
    for ctx in tqdm(hindi_contexts, desc='Chunking Hindi contexts'):
        chunks = []
        for i in range(0, len(ctx), CHUNK_SIZE - OVERLAP):
            chunk = ctx[i:i + CHUNK_SIZE]
            if len(chunk) > 100:
                chunks.append(chunk)
        hindi_chunks.extend(chunks)
    print(f'Hindi chunks total: {len(hindi_chunks)}')
    hindi_corpus = hindi_questions + random.sample(hindi_chunks, min(3000, len(hindi_chunks)))
    print(f'Hindi corpus size: {len(hindi_corpus)}')
    hindi_vectorizer = TfidfVectorizer(
        analyzer='char_wb',
        ngram_range=(2, 4),
        max_features=5000,
        min_df=2,
        max_df=0.95,
        lowercase=False,
        sublinear_tf=True,
        dtype=np.float32
    )
    print('Fitting Hindi vectorizer...')
    start_time = time.time()
    hindi_vectorizer.fit(hindi_corpus)
    fit_time = time.time() - start_time
    print(f'Hindi TF-IDF fitted in {fit_time:.2f}s: {len(hindi_corpus)} docs')
    
    # Tamil vectorizer - fixed to char n-grams for better recall
    print('Processing Tamil...')
    tamil_questions = tamil_df['question'].tolist()
    tamil_contexts = tamil_df['context'].tolist()
    tamil_chunks = []
    for ctx in tqdm(tamil_contexts, desc='Chunking Tamil contexts'):
        chunks = []
        for i in range(0, len(ctx), CHUNK_SIZE - OVERLAP):
            chunk = ctx[i:i + CHUNK_SIZE]
            if len(chunk) > 100:
                chunks.append(chunk)
        tamil_chunks.extend(chunks)
    print(f'Tamil chunks total: {len(tamil_chunks)}')
    tamil_corpus = tamil_questions + random.sample(tamil_chunks, min(1500, len(tamil_chunks)))
    print(f'Tamil corpus size: {len(tamil_corpus)}')
    tamil_vectorizer = TfidfVectorizer(
        analyzer='char_wb',
        ngram_range=(3, 5),
        max_features=15000,
        min_df=3,
        max_df=0.9,
        lowercase=False,
        sublinear_tf=True,
        dtype=np.float32
    )
    print('Fitting Tamil vectorizer...')
    start_time = time.time()
    tamil_vectorizer.fit(tamil_corpus)
    fit_time = time.time() - start_time
    print(f'Tamil TF-IDF fitted in {fit_time:.2f}s: {len(tamil_corpus)} docs')
else:
    hindi_vectorizer = tamil_vectorizer = None



Tokenizer loaded: deepset/xlm-roberta-large-squad2
Fitting language-specific TF-IDF vectorizers...
Processing Hindi...


Chunking Hindi contexts:   0%|          | 0/662 [00:00<?, ?it/s]

Chunking Hindi contexts: 100%|██████████| 662/662 [00:00<00:00, 51859.87it/s]

Hindi chunks total: 4586
Hindi corpus size: 3662
Fitting Hindi vectorizer...





Hindi TF-IDF fitted in 3.80s: 3662 docs
Processing Tamil...


Tamil TF-IDF fitted in 2.21s: 1840 docs


In [12]:
# Prepare training features with hybrid retrieval and sliding windows
def prepare_train_features(examples, neg_pos_ratio=NEG_POS_RATIO):
    features = []
    for ex in examples:
        q, ctx, ans, ex_id, lang = ex['question'].strip(), ex['context'].strip(), {'text': ex['answer_text'], 'answer_start': ex['answer_start']}, ex['id'], ex['language']
        
        if USE_RETRIEVAL:
            # Chunk context
            chunks = []
            chunk_starts = []
            for i in range(0, len(ctx), CHUNK_SIZE - OVERLAP):
                chunk = ctx[i:i + CHUNK_SIZE]
                if len(chunk) > 100:
                    chunks.append(chunk)
                    chunk_starts.append(i)
            
            if not chunks:
                continue
            
            # Select vectorizer by language
            if lang == 'hindi':
                vectorizer = hindi_vectorizer
            else:
                vectorizer = tamil_vectorizer
            
            # TF-IDF retrieval
            q_vec = vectorizer.transform([q])
            chunk_vecs = vectorizer.transform(chunks)
            similarities = cosine_similarity(q_vec, chunk_vecs).flatten()
            
            # BM25 hybrid if available
            if BM25_AVAILABLE:
                tokenized_chunks = [chunk.lower().split() for chunk in chunks]
                bm25 = BM25Okapi(tokenized_chunks)
                q_tokens = q.lower().split()
                bm25_scores = bm25.get_scores(q_tokens)
                if np.max(bm25_scores) > 0:
                    norm_bm25 = bm25_scores / np.max(bm25_scores)
                else:
                    norm_bm25 = np.zeros_like(bm25_scores)
                hybrid_scores = 0.5 * norm_bm25 + 0.5 * similarities
            else:
                hybrid_scores = similarities
            top_indices = np.argsort(hybrid_scores)[-TOP_K_CHUNKS_TRAIN:]
            
            # Guarantee gold chunk inclusion for training by replacing lowest sim if needed
            start_char = ans['answer_start']
            end_char = start_char + len(ans['text'])
            pos_idx = None
            for ci, st in enumerate(chunk_starts):
                if start_char >= st and end_char <= st + len(chunks[ci]):
                    pos_idx = ci
                    break
            if pos_idx is not None and pos_idx not in top_indices:
                # Replace the lowest hybrid score in top_indices with pos_idx
                min_hybrid_arg = np.argmin(hybrid_scores[top_indices])
                top_indices[min_hybrid_arg] = pos_idx
            # Sort by hybrid descending
            sort_args = np.argsort(hybrid_scores[top_indices])[::-1]
            top_indices = top_indices[sort_args]
            
            # Get top chunks with their global start positions
            top_chunks = [(hybrid_scores[idx], chunk_starts[idx], chunks[idx]) for idx in top_indices]
        else:
            top_chunks = [(1.0, 0, ctx)]  # full context if no retrieval
        
        # Now process each top chunk with sliding windows
        pos_feats, neg_feats = [], []
        for sim, chunk_start, chunk in top_chunks:
            tokenized = tokenizer(
                q,
                chunk,
                truncation='only_second',
                max_length=MAX_LEN,
                stride=DOC_STRIDE,
                return_overflowing_tokens=True,
                return_offsets_mapping=True,
                padding=False,
            )
            
            for j in range(len(tokenized['input_ids'])):
                input_ids = tokenized['input_ids'][j]
                attention_mask = tokenized['attention_mask'][j]
                offsets = tokenized['offset_mapping'][j]
                sequence_ids = tokenized.sequence_ids(j)
                
                # Skip windows without context tokens
                if 1 not in sequence_ids:
                    continue
                
                # Global offsets: add chunk_start to context offsets
                global_offsets = []
                ctx_start = 0
                while ctx_start < len(sequence_ids) and sequence_ids[ctx_start] != 1:
                    global_offsets.append(None)
                    ctx_start += 1
                while ctx_start < len(sequence_ids) and sequence_ids[ctx_start] == 1:
                    local_offset = offsets[ctx_start]
                    global_offset = (local_offset[0] + chunk_start, local_offset[1] + chunk_start) if local_offset else None
                    global_offsets.append(global_offset)
                    ctx_start += 1
                while ctx_start < len(sequence_ids):
                    global_offsets.append(None)
                    ctx_start += 1
                
                # Find start/end positions using global offsets
                start_pos = -1
                end_pos = -1
                is_positive = False
                start_char = ans['answer_start']
                end_char = start_char + len(ans['text'])
                
                for tok_idx, off in enumerate(global_offsets):
                    if off is not None and off[0] <= start_char < off[1]:
                        start_pos = tok_idx
                    if off is not None and off[0] < end_char <= off[1]:
                        end_pos = tok_idx
                if start_pos != -1 and end_pos != -1 and end_pos >= start_pos:
                    is_positive = True
                else:
                    start_pos = 0
                    end_pos = 0
                
                # Pad/truncate
                pad_len = MAX_LEN - len(input_ids)
                if pad_len > 0:
                    input_ids += [tokenizer.pad_token_id] * pad_len
                    attention_mask += [0] * pad_len
                else:
                    input_ids = input_ids[:MAX_LEN]
                    attention_mask = attention_mask[:MAX_LEN]
                
                feat = {
                    'input_ids': input_ids,
                    'attention_mask': attention_mask,
                    'start_positions': start_pos,
                    'end_positions': end_pos,
                    'example_id': ex_id,
                    'is_positive': is_positive
                }
                (pos_feats if is_positive else neg_feats).append(feat)
        
        # Cap negatives
        if pos_feats:
            features.extend(pos_feats)
            random.shuffle(neg_feats)
            n_neg = min(len(neg_feats), neg_pos_ratio * len(pos_feats))
            features.extend(neg_feats[:n_neg])
        elif neg_feats:
            features.append(random.choice(neg_feats))
    return features

# Prepare validation features (lang-specific TOP_K_EVAL)
def prepare_validation_features(examples):
    features = []
    for ex in examples:
        q, ctx, ex_id, lang = ex['question'].strip(), ex['context'].strip(), ex['id'], ex['language']
        
        if USE_RETRIEVAL:
            # Same chunking and retrieval as train, but use lang-specific TOP_K_EVAL
            chunks = []
            chunk_starts = []
            for i in range(0, len(ctx), CHUNK_SIZE - OVERLAP):
                chunk = ctx[i:i + CHUNK_SIZE]
                if len(chunk) > 100:
                    chunks.append(chunk)
                    chunk_starts.append(i)
            
            if not chunks:
                continue
            
            # Select vectorizer by language
            if lang == 'hindi':
                vectorizer = hindi_vectorizer
                top_k_eval = TOP_K_CHUNKS_EVAL_HINDI
            else:
                vectorizer = tamil_vectorizer
                top_k_eval = TOP_K_CHUNKS_EVAL_TAMIL
            
            # TF-IDF
            q_vec = vectorizer.transform([q])
            chunk_vecs = vectorizer.transform(chunks)
            similarities = cosine_similarity(q_vec, chunk_vecs).flatten()
            
            # BM25 hybrid if available
            if BM25_AVAILABLE:
                tokenized_chunks = [chunk.lower().split() for chunk in chunks]
                bm25 = BM25Okapi(tokenized_chunks)
                q_tokens = q.lower().split()
                bm25_scores = bm25.get_scores(q_tokens)
                if np.max(bm25_scores) > 0:
                    norm_bm25 = bm25_scores / np.max(bm25_scores)
                else:
                    norm_bm25 = np.zeros_like(bm25_scores)
                hybrid_scores = 0.5 * norm_bm25 + 0.5 * similarities
            else:
                hybrid_scores = similarities
            top_indices = np.argsort(hybrid_scores)[-top_k_eval:]
            top_chunks = [(hybrid_scores[idx], chunk_starts[idx], chunks[idx]) for idx in top_indices]
        else:
            top_chunks = [(1.0, 0, ctx)]
        
        # Process each top chunk
        for sim, chunk_start, chunk in top_chunks:
            tokenized = tokenizer(
                q,
                chunk,
                truncation='only_second',
                max_length=MAX_LEN,
                stride=DOC_STRIDE,
                return_overflowing_tokens=True,
                return_offsets_mapping=True,
                padding=False,
            )
            
            for j in range(len(tokenized['input_ids'])):
                input_ids = tokenized['input_ids'][j]
                attention_mask = tokenized['attention_mask'][j]
                offsets = tokenized['offset_mapping'][j]
                sequence_ids = tokenized.sequence_ids(j)
                
                # Skip windows without context tokens
                if 1 not in sequence_ids:
                    continue
                
                # Global offsets for post-processing
                global_offsets = []
                ctx_start = 0
                while ctx_start < len(sequence_ids) and sequence_ids[ctx_start] != 1:
                    global_offsets.append(None)
                    ctx_start += 1
                while ctx_start < len(sequence_ids) and sequence_ids[ctx_start] == 1:
                    local_offset = offsets[ctx_start]
                    global_offset = (local_offset[0] + chunk_start, local_offset[1] + chunk_start) if local_offset else None
                    global_offsets.append(global_offset)
                    ctx_start += 1
                while ctx_start < len(sequence_ids):
                    global_offsets.append(None)
                    ctx_start += 1
                
                # Pad/truncate
                pad_len = MAX_LEN - len(input_ids)
                if pad_len > 0:
                    input_ids += [tokenizer.pad_token_id] * pad_len
                    attention_mask += [0] * pad_len
                    global_offsets += [None] * pad_len
                else:
                    input_ids = input_ids[:MAX_LEN]
                    attention_mask = attention_mask[:MAX_LEN]
                    global_offsets = global_offsets[:MAX_LEN]
                
                features.append({
                    'input_ids': input_ids,
                    'attention_mask': attention_mask,
                    'offset_mapping': global_offsets,
                    'example_id': ex_id,
                })
    return features

# Test on small batch
test_examples = train_df.head(1).to_dict('records')
print('Testing on example:', test_examples[0]['id'], 'Language:', test_examples[0]['language'])
print('Gold answer:', test_examples[0]['answer_text'], 'at', test_examples[0]['answer_start'])
train_features = prepare_train_features(test_examples)
val_features = prepare_validation_features(test_examples)
print(f'Train features: {len(train_features)}')
print(f'Val features: {len(val_features)}')
if train_features:
    print('Sample train feature keys:', list(train_features[0].keys()))
    print('Sample input_ids len:', len(train_features[0]['input_ids']))
    print('Sample is_positive:', train_features[0]['is_positive'])
if val_features:
    print('Sample val offset_mapping len:', len(val_features[0]['offset_mapping']))

Testing on example: 6bb0c472d Language: tamil
Gold answer: சிம்மம் at 168
Train features: 4
Val features: 5
Sample train feature keys: ['input_ids', 'attention_mask', 'start_positions', 'end_positions', 'example_id', 'is_positive']
Sample input_ids len: 512
Sample is_positive: True
Sample val offset_mapping len: 512


In [15]:
import torch.nn.functional as F

# Post-processing to aggregate predictions across sliding windows with improved scoring
def get_predictions(features, start_logits, end_logits, n_best_size=50, max_answer_length=80):
    example_to_features = {}
    for i, f in enumerate(features):
        example_to_features.setdefault(f['example_id'], []).append((i, f))

    pred_dict = {}
    for example_id, feat_list in example_to_features.items():
        prelim_predictions = []
        for feat_idx, f in feat_list:
            offsets = f['offset_mapping']
            sl = start_logits[feat_idx]
            el = end_logits[feat_idx]

            # Context indices (non-None offsets)
            ctx_idx = [i for i, o in enumerate(offsets) if o is not None]
            if not ctx_idx:
                continue

            # Log-softmax on context logits only
            start_log = log_softmax_np(sl[ctx_idx])
            end_log = log_softmax_np(el[ctx_idx])

            # Top n_best_size start/end positions in context
            top_start_idx = np.argsort(sl[ctx_idx])[-n_best_size:].tolist()[::-1]
            top_end_idx = np.argsort(el[ctx_idx])[-n_best_size:].tolist()[::-1]

            # Global indices
            top_start = [ctx_idx[i] for i in top_start_idx]
            top_end = [ctx_idx[i] for i in top_end_idx]

            # Generate candidates
            for s in top_start:
                for e in top_end:
                    if e < s:
                        continue
                    length = e - s + 1
                    if length > max_answer_length:
                        continue
                    sc, ec = offsets[s][0], offsets[e][1]
                    # Score with softened length penalty
                    score = start_log[top_start_idx[top_start.index(s)]] + end_log[top_end_idx[top_end.index(e)]] - 0.001 * max(0, length - 25)
                    prelim_predictions.append((score, sc, ec))

        if prelim_predictions:
            _, sc, ec = max(prelim_predictions, key=lambda x: x[0])
            pred_dict[example_id] = (sc, ec)
        else:
            # Fallback: best single-token span in context across all features
            best_score = -np.inf
            best_sc, best_ec = 0, 0
            for feat_idx, f in feat_list:
                offsets = f['offset_mapping']
                sl = start_logits[feat_idx]
                ctx_idx = [i for i, o in enumerate(offsets) if o is not None]
                if not ctx_idx:
                    continue
                s_log = log_softmax_np(sl[ctx_idx])
                best_s_local = np.argmax(sl[ctx_idx])
                s_global = ctx_idx[best_s_local]
                sc, ec = offsets[s_global][0], offsets[s_global][1]
                score = s_log[best_s_local]
                if score > best_score:
                    best_score = score
                    best_sc, best_ec = sc, ec
            pred_dict[example_id] = (best_sc, best_ec)
    return pred_dict

# Function to extract answer from context with NFKC and punctuation trim
def extract_answer(context, start_char, end_char):
    if start_char == 0 and end_char == 0:
        return ''
    s = context[start_char:end_char]
    s = unicodedata.normalize('NFKC', s).strip().strip(PUNCT)
    return s

# Dataset class - updated to include is_positive for training
class QADataset(Dataset):
    def __init__(self, features):
        self.input_ids = [f['input_ids'] for f in features]
        self.attention_mask = [f['attention_mask'] for f in features]
        if 'start_positions' in features[0]:
            self.start_positions = [f['start_positions'] for f in features]
            self.end_positions = [f['end_positions'] for f in features]
            self.is_positive = [f['is_positive'] for f in features]
        else:
            self.start_positions = None
            self.end_positions = None
            self.is_positive = None
        self.offset_mapping = [f.get('offset_mapping') for f in features]
        self.example_id = [f['example_id'] for f in features]

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        item = {
            'input_ids': self.input_ids[idx],
            'attention_mask': self.attention_mask[idx]
        }
        assert len(item['input_ids']) == MAX_LEN, 'Input ids not padded correctly'
        assert len(item['attention_mask']) == MAX_LEN, 'Attention mask not padded correctly'
        if self.start_positions is not None:
            item['start_positions'] = self.start_positions[idx]
            item['end_positions'] = self.end_positions[idx]
            item['is_positive'] = self.is_positive[idx]
        return item

# Custom Weighted Trainer to down-weight negative examples (fixed per-example weighting)
class WeightedQATrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        start_positions = inputs.pop('start_positions')
        end_positions = inputs.pop('end_positions')
        is_positive = inputs.pop('is_positive', None)  # tensor [bs] or None

        outputs = model(**inputs)
        start_logits = outputs.start_logits
        end_logits = outputs.end_logits

        start_loss = F.cross_entropy(start_logits, start_positions, reduction='none')
        end_loss = F.cross_entropy(end_logits, end_positions, reduction='none')
        loss = (start_loss + end_loss) / 2.0

        if is_positive is not None:
            ispos = is_positive.bool()
            weights = torch.where(ispos, torch.ones_like(loss), torch.full_like(loss, NEG_WEIGHT))
            loss = (loss * weights).mean()
        else:
            loss = loss.mean()

        return (loss, outputs) if return_outputs else loss

# Numpy log_softmax for numpy arrays
def log_softmax_np(x):
    x = x - np.max(x, axis=-1, keepdims=True)
    return x - np.log(np.sum(np.exp(x), axis=-1, keepdims=True))

# Test dataset creation
val_features_test = prepare_validation_features(train_df.head(1).to_dict('records'))
val_dataset_test = QADataset(val_features_test)
print(f'Dataset length: {len(val_dataset_test)}')
sample_item = val_dataset_test[0]
print('Sample item keys:', list(sample_item.keys()))
print('Sample input_ids len:', len(sample_item['input_ids']))

# Test train dataset with is_positive
trn_features_test = prepare_train_features(train_df.head(1).to_dict('records'))
if trn_features_test:
    trn_dataset_test = QADataset(trn_features_test)
    sample_trn_item = trn_dataset_test[0]
    print('Sample train item keys:', list(sample_trn_item.keys()))
    print('Sample is_positive:', sample_trn_item['is_positive'])

Dataset length: 5
Sample item keys: ['input_ids', 'attention_mask']
Sample input_ids len: 512
Sample train item keys: ['input_ids', 'attention_mask', 'start_positions', 'end_positions', 'is_positive']
Sample is_positive: True


In [21]:
from transformers import TrainingArguments, Trainer

# Precompute test features once (language already set in Cell 1)
print('Test language distribution:', test_df['language'].value_counts())
test_features = prepare_validation_features(test_df.to_dict('records'))
test_dataset = QADataset(test_features)
test_start_sum = None
test_end_sum = None

# Training loop
oof_preds = []
oof_trues = []
oof_ids = []
fold_jaccards = []

for fold in range(N_FOLDS):
    print(f'\n=== Fold {fold} ===')
    trn_df = train_df[train_df['fold'] != fold].reset_index(drop=True)
    val_df = train_df[train_df['fold'] == fold].reset_index(drop=True)
    print(f'Train: {len(trn_df)}, Val: {len(val_df)}')

    # 2x Tamil oversampling for better balance
    trn_df = pd.concat([trn_df, trn_df[trn_df['language'] == 'tamil']]).reset_index(drop=True)

    print('Preparing train features...')
    start_time = time.time()
    trn_features = prepare_train_features(trn_df.to_dict('records'))
    prep_time = time.time() - start_time
    print(f'Trn features prepared in {prep_time:.2f}s: {len(trn_features)}')

    print('Preparing val features...')
    start_time = time.time()
    val_features = prepare_validation_features(val_df.to_dict('records'))
    prep_time = time.time() - start_time
    print(f'Val features prepared in {prep_time:.2f}s: {len(val_features)}')

    trn_dataset = QADataset(trn_features)
    val_dataset = QADataset(val_features)

    model = AutoModelForQuestionAnswering.from_pretrained(MODEL_NAME)
    model.gradient_checkpointing_enable()
    param_count = sum(p.numel() for p in model.parameters())
    print(f'Model param count: {param_count:,}')

    args = TrainingArguments(
        output_dir=f'/tmp/model_{fold}',
        bf16=True,
        per_device_train_batch_size=BATCH_SIZE,
        per_device_eval_batch_size=16,
        gradient_accumulation_steps=GRAD_ACCUM_STEPS,
        num_train_epochs=EPOCHS,
        learning_rate=LR,
        weight_decay=WEIGHT_DECAY,
        save_strategy='no',
        report_to='none',
        dataloader_pin_memory=False,
        dataloader_num_workers=2,
        remove_unused_columns=False,
        warmup_ratio=0.1,
        lr_scheduler_type='linear',
        max_grad_norm=1.0,
        logging_steps=10,  # More frequent logging
    )

    trainer = WeightedQATrainer(
        model=model,
        args=args,
        train_dataset=trn_dataset,
        eval_dataset=val_dataset,
        tokenizer=tokenizer,
        data_collator=default_data_collator,
    )

    print('Starting training...')
    train_start = time.time()
    trainer.train()
    train_time = time.time() - train_start
    print(f'Training completed in {train_time:.2f}s')

    predictions = trainer.predict(val_dataset)
    pred_dict = get_predictions(val_features, predictions.predictions[0], predictions.predictions[1], n_best_size=50, max_answer_length=80)

    fold_preds = []
    for idx, row in val_df.iterrows():
        start_char, end_char = pred_dict.get(row['id'], (0, 0))
        pred = extract_answer(row['context'], start_char, end_char)
        fold_preds.append(pred)

    print('Empty OOF preds:', (np.array(fold_preds) == '').mean())

    fold_trues = val_df['answer_text'].tolist()
    fold_jacc = compute_jaccard(fold_preds, fold_trues)
    fold_jaccards.append(fold_jacc)
    print(f'Fold {fold} Jaccard: {fold_jacc:.4f}')

    oof_preds.extend(fold_preds)
    oof_trues.extend(fold_trues)
    oof_ids.extend(val_df['id'].tolist())

    # Per language
    hindi_mask = val_df['language'] == 'hindi'
    if hindi_mask.sum() > 0:
        pred_hindi = np.array(fold_preds)[hindi_mask]
        true_hindi = val_df.loc[hindi_mask, 'answer_text'].tolist()
        jacc_hindi = compute_jaccard(pred_hindi, true_hindi)
        print(f'  Hindi Jaccard: {jacc_hindi:.4f}')
    tamil_mask = val_df['language'] == 'tamil'
    if tamil_mask.sum() > 0:
        pred_tamil = np.array(fold_preds)[tamil_mask]
        true_tamil = val_df.loc[tamil_mask, 'answer_text'].tolist()
        jacc_tamil = compute_jaccard(pred_tamil, true_tamil)
        print(f'  Tamil Jaccard: {jacc_tamil:.4f}')

    # Accumulate test logits
    test_out = trainer.predict(test_dataset)
    if test_start_sum is None:
        test_start_sum = test_out.predictions[0]
        test_end_sum = test_out.predictions[1]
    else:
        test_start_sum += test_out.predictions[0]
        test_end_sum += test_out.predictions[1]

    del model, trainer, trn_dataset, val_dataset, trn_features, val_features
    gc.collect()
    torch.cuda.empty_cache()

print(f'\nMean fold Jaccard: {np.mean(fold_jaccards):.4f} (+/- {np.std(fold_jaccards):.4f})')
overall_jacc = compute_jaccard(oof_preds, oof_trues)
print(f'Overall OOF Jaccard: {overall_jacc:.4f}')

# Save OOF for analysis
oof_df = pd.DataFrame({'id': oof_ids, 'pred': oof_preds, 'true': oof_trues})
oof_df.to_csv('oof_predictions.csv', index=False)
print('OOF saved to oof_predictions.csv')

# Generate submission from averaged test logits with per-language max_answer_length
test_start_avg = test_start_sum / N_FOLDS
test_end_avg = test_end_sum / N_FOLDS

# Compute predictions with different max lengths
pred60 = get_predictions(test_features, test_start_avg, test_end_avg, n_best_size=50, max_answer_length=60)
pred80 = get_predictions(test_features, test_start_avg, test_end_avg, n_best_size=50, max_answer_length=80)

# Select per language
test_pred_dict = {}
for idx, row in test_df.iterrows():
    ex_id = row['id']
    if row['language'] == 'tamil':
        test_pred_dict[ex_id] = pred80.get(ex_id, (0, 0))
    else:
        test_pred_dict[ex_id] = pred60.get(ex_id, (0, 0))

submission_preds = []
for idx, row in test_df.iterrows():
    start_char, end_char = test_pred_dict.get(row['id'], (0, 0))
    pred = extract_answer(row['context'], start_char, end_char)
    submission_preds.append(pred)

submission = pd.DataFrame({'id': test_df['id'], 'PredictionString': submission_preds})
submission.to_csv('submission.csv', index=False)
print('Submission saved to submission.csv')

# Save test logits and feature order for ensembling (seed 42)
import json
np.savez('test_logits_seed42_sum.npz', start=test_start_sum, end=test_end_sum, n_folds=N_FOLDS)
json.dump([f['example_id'] for f in test_features], open('test_features_order.json', 'w'))
print('Test logits and feature order saved for ensembling')

Test language distribution: language
hindi    84
tamil    28
Name: count, dtype: int64



=== Fold 0 ===
Train: 809, Val: 193
Preparing train features...


Trn features prepared in 22.61s: 4354
Preparing val features...


Val features prepared in 4.08s: 2000


Some weights of the model checkpoint at deepset/xlm-roberta-large-squad2 were not used when initializing XLMRobertaForQuestionAnswering: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing XLMRobertaForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Model param count: 558,842,882


Starting training...


  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


Step,Training Loss


Training completed in 1758.92s


Empty OOF preds: 0.0
Fold 0 Jaccard: 0.6002
  Hindi Jaccard: 0.6664
  Tamil Jaccard: 0.4534



=== Fold 1 ===
Train: 798, Val: 204
Preparing train features...


Trn features prepared in 24.73s: 4280
Preparing val features...


Val features prepared in 4.52s: 1944


Some weights of the model checkpoint at deepset/xlm-roberta-large-squad2 were not used when initializing XLMRobertaForQuestionAnswering: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing XLMRobertaForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Model param count: 558,842,882


Starting training...


  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


Step,Training Loss


Training completed in 1721.81s


Empty OOF preds: 0.0
Fold 1 Jaccard: 0.6866
  Hindi Jaccard: 0.6955
  Tamil Jaccard: 0.6698



=== Fold 2 ===
Train: 808, Val: 194
Preparing train features...


Trn features prepared in 23.36s: 4349
Preparing val features...


Val features prepared in 4.70s: 2066


Some weights of the model checkpoint at deepset/xlm-roberta-large-squad2 were not used when initializing XLMRobertaForQuestionAnswering: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing XLMRobertaForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Model param count: 558,842,882


Starting training...


  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


Step,Training Loss


Training completed in 1747.53s


Empty OOF preds: 0.0
Fold 2 Jaccard: 0.6022
  Hindi Jaccard: 0.6560
  Tamil Jaccard: 0.5025



=== Fold 3 ===
Train: 790, Val: 212
Preparing train features...


Trn features prepared in 22.51s: 4170
Preparing val features...


Val features prepared in 5.51s: 2351


Some weights of the model checkpoint at deepset/xlm-roberta-large-squad2 were not used when initializing XLMRobertaForQuestionAnswering: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing XLMRobertaForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Model param count: 558,842,882


Starting training...


  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


Step,Training Loss


Training completed in 1684.62s


Empty OOF preds: 0.0
Fold 3 Jaccard: 0.6169
  Hindi Jaccard: 0.6632
  Tamil Jaccard: 0.5231



=== Fold 4 ===
Train: 803, Val: 199
Preparing train features...


Trn features prepared in 23.59s: 4251
Preparing val features...


Val features prepared in 4.38s: 1937


Some weights of the model checkpoint at deepset/xlm-roberta-large-squad2 were not used when initializing XLMRobertaForQuestionAnswering: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing XLMRobertaForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Model param count: 558,842,882


Starting training...


  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


Step,Training Loss


Training completed in 1710.02s


Empty OOF preds: 0.0
Fold 4 Jaccard: 0.6628
  Hindi Jaccard: 0.6981
  Tamil Jaccard: 0.5991



Mean fold Jaccard: 0.6337 (+/- 0.0348)
Overall OOF Jaccard: 0.6341
OOF saved to oof_predictions.csv


Submission saved to submission.csv
Test logits and feature order saved for ensembling


In [22]:
# Quick OOF Diagnostics (fast version, no slow recall)
import pandas as pd
import numpy as np
import unicodedata

# Load OOF and train
oof_df = pd.read_csv('oof_predictions.csv')
train_df = pd.read_csv('train.csv')
oof_df = oof_df.merge(train_df[['id', 'language', 'answer_text', 'answer_start']], on='id', how='left')
oof_df['answer_len'] = oof_df['answer_text'].str.len()
oof_df['pred_len'] = oof_df['pred'].str.len()

# Jaccard function
def jaccard_word(pred, true):
    pred = unicodedata.normalize('NFKC', pred).lower()
    true = unicodedata.normalize('NFKC', true).lower()
    if not pred or not true:
        return 0.0
    pw, tw = set(pred.split()), set(true.split())
    return len(pw & tw) / len(pw | tw) if pw and tw else 0.0

def row_jaccard(row):
    return jaccard_word(row['pred'], row['answer_text'])

oof_df['jacc'] = oof_df.apply(row_jaccard, axis=1)

# Overall OOF
overall_jacc = oof_df['jacc'].mean()
print(f'Overall OOF Jaccard: {overall_jacc:.4f}')

# Per-language Jaccards
print('\nPer-language OOF Jaccards:')
lang_jacc = oof_df.groupby('language')['jacc'].mean()
print(lang_jacc)

# By answer length bins
bins = [0, 10, 20, 50, 100, float('inf')]
labels = ['<10', '10-20', '20-50', '50-100', '>100']
oof_df['len_bin'] = pd.cut(oof_df['answer_len'], bins=bins, labels=labels, right=False)
print('\nJaccard by answer length bin:')
print(oof_df.groupby(['language', 'len_bin'])['jacc'].agg(['mean', 'count']).round(4))

# Top errors: lowest Jaccard
top_errors = oof_df.nsmallest(50, 'jacc')
top_errors = top_errors[['id', 'pred', 'answer_text', 'jacc', 'language', 'answer_len', 'len_bin']]
top_errors.to_csv('oof_top_errors.csv', index=False)
print('\nTop 50 errors saved to oof_top_errors.csv')
print('Summary of top errors:')
print(top_errors.groupby('language').size())
if 'tamil' in top_errors['language'].values:
    print('Tamil top errors by len_bin:')
    print(top_errors[top_errors['language'] == 'tamil']['len_bin'].value_counts())

# Empty predictions analysis
empty_mask = oof_df['pred'] == ''
print(f'\nEmpty predictions: {empty_mask.sum()}/{len(oof_df)} ({empty_mask.mean():.1%})')
print('Empty by language:')
print(oof_df[empty_mask].groupby('language').size())

Overall OOF Jaccard: 0.6341

Per-language OOF Jaccards:
language
hindi    0.675731
tamil    0.553164
Name: jacc, dtype: float64

Jaccard by answer length bin:
                    mean  count
language len_bin               
hindi    <10      0.6930    298
         10-20    0.6996    276
         20-50    0.5716     81
         50-100   0.2489      5
         >100     0.0919      2
tamil    <10      0.5923    167
         10-20    0.5806    111
         20-50    0.4233     54
         50-100   0.3029      5
         >100     0.1140      3

Top 50 errors saved to oof_top_errors.csv
Summary of top errors:
language
hindi    26
tamil    24
dtype: int64
Tamil top errors by len_bin:
len_bin
<10       12
10-20      8
20-50      4
50-100     0
>100       0
Name: count, dtype: int64

Empty predictions: 0/1002 (0.0%)
Empty by language:
Series([], dtype: int64)


  print(oof_df.groupby(['language', 'len_bin'])['jacc'].agg(['mean', 'count']).round(4))


In [23]:
# Quick re-decode with longer Tamil max span (90) for baseline logits
import numpy as np
import json
from pathlib import Path

# Load saved baseline logits and feature order
logits_data = np.load('test_logits_seed42_sum.npz')
test_start_avg = logits_data['start'] / logits_data['n_folds']
test_end_avg = logits_data['end'] / logits_data['n_folds']
with open('test_features_order.json', 'r') as f:
    test_feature_order = json.load(f)

# Rebuild test_features identically (copy constants from Cell 1)
test_df = pd.read_csv('test.csv')
# Assign language (from Cell 1 logic)
if LANGDETECT_AVAILABLE:
    test_df['language'] = test_df['question'].apply(lambda x: {'ta':'tamil','hi':'hindi'}.get(detect(x), 'hindi') if isinstance(x, str) else 'hindi')
else:
    test_df['language'] = test_df['question'].apply(detect_lang)
test_features_rebuilt = prepare_validation_features(test_df.to_dict('records'))

# Assert feature order matches (len should match)
assert len(test_features_rebuilt) == len(test_feature_order), f'Feature mismatch: {len(test_features_rebuilt)} vs {len(test_feature_order)}'

# Decode with per-language max lengths: Hindi=60, Tamil=90
pred90 = get_predictions(test_features_rebuilt, test_start_avg, test_end_avg, n_best_size=50, max_answer_length=90)

# Select per language
test_pred_dict_90 = {}
for idx, row in test_df.iterrows():
    ex_id = row['id']
    if row['language'] == 'tamil':
        test_pred_dict_90[ex_id] = pred90.get(ex_id, (0, 0))
    else:
        # For Hindi, use 60 (compute pred60 if needed, but reuse logic)
        pred60 = get_predictions(test_features_rebuilt, test_start_avg, test_end_avg, n_best_size=50, max_answer_length=60)
        test_pred_dict_90[ex_id] = pred60.get(ex_id, (0, 0))

# Generate submission
submission_preds_90 = []
for idx, row in test_df.iterrows():
    start_char, end_char = test_pred_dict_90.get(row['id'], (0, 0))
    pred = extract_answer(row['context'], start_char, end_char)
    submission_preds_90.append(pred)

submission_90 = pd.DataFrame({'id': test_df['id'], 'PredictionString': submission_preds_90})
submission_90.to_csv('submission_tamil90.csv', index=False)
print('Re-decoded submission saved to submission_tamil90.csv')

# Optional: Quick OOF re-decode check with Tamil=90 to estimate lift
# Load OOF data and re-decode val features across folds (simplified, aggregate all val_features)
# For now, skip full OOF re-decode to save time; assume +0.005-0.01 Tamil lift

Re-decoded submission saved to submission_tamil90.csv


In [24]:
import shutil
shutil.copy('submission_tamil90.csv', 'submission.csv')
print('Copied submission_tamil90.csv to submission.csv')

Copied submission_tamil90.csv to submission.csv


In [25]:
# Quick post-processing based on top errors analysis
import pandas as pd
import re
import unicodedata

# Load top errors and submission
errors_df = pd.read_csv('oof_top_errors.csv')
submission = pd.read_csv('submission.csv')
test_df = pd.read_csv('test.csv')
submission = submission.merge(test_df[['id', 'context', 'question']], on='id')

# Analyze top Tamil errors for patterns
tamil_errors = errors_df[errors_df['language'] == 'tamil']
print('Top Tamil errors:')
for _, row in tamil_errors.head(10).iterrows():
    print(f'ID: {row["id"]}, Pred: "{row["pred"]}", True: "{row["answer_text"]}", Jacc: {row["jacc"]:.3f}')

# Simple post-processing rules from error patterns:
# 1. Trim overlong predictions (>80 chars) to max 80
# 2. Snap to whitespace boundaries
# 3. Remove zero-width chars and extra punctuation
def post_process(pred, context):
    if not pred:
        return ''
    # Remove zero-width chars
    pred = re.sub(r'[​-‍﻿]', '', pred)
    # Normalize
    pred = unicodedata.normalize('NFKC', pred)
    # Trim extra punctuation
    pred = re.sub(r'[\u0964,.\uff0c!\uff01?\uff1f"\\\'\u201c\u201d\u2018\u2019()\[\]{}:;]+', ' ', pred)
    # Snap to whitespace: find nearest words
    start = context.find(pred)
    if start == -1:
        return pred.strip()
    # Find word boundaries around the pred span
    full_span = context[max(0, start-50):start + len(pred) + 50]
    # Simple trim to word boundaries
    pred = pred.strip()
    if len(pred) > 80:
        pred = pred[:80].rsplit(' ', 1)[0].strip()  # Trim to last space
    return pred

# Apply to submission (focus on Tamil)
submission['processed'] = submission.apply(lambda row: post_process(row['PredictionString'], row['context']) if row['id'] in tamil_errors['id'].values else row['PredictionString'], axis=1)

# For all, apply general trim
submission['processed'] = submission['PredictionString'].apply(lambda p: p[:80] if len(p) > 80 else p)
submission['processed'] = submission['processed'].apply(lambda p: re.sub(r'\s+', ' ', p).strip())

# Save improved submission
submission[['id', 'processed']].to_csv('submission.csv', index=False, header=['id', 'PredictionString'])
print('Post-processed submission saved to submission.csv')

# Quick OOF re-apply to estimate lift (load oof_predictions.csv)
oof = pd.read_csv('oof_predictions.csv')
oof = oof.merge(pd.read_csv('train.csv')[['id', 'context', 'language']], on='id')
oof['processed_pred'] = oof.apply(lambda row: post_process(row['pred'], row['context']) if row['language'] == 'tamil' else row['pred'], axis=1)
oof_jacc = oof.apply(lambda row: jaccard_word(row['processed_pred'], row['true']), axis=1).mean()
print(f'Post-processed OOF Jaccard: {oof_jacc:.4f} (original: 0.6341)')
if oof_jacc > 0.6341:
    print('Improvement detected! Ready for submission.')
else:
    print('No improvement; keep original.')

Top Tamil errors:
ID: 11d635808, Pred: "ஹோட்டல் ரம்பா", True: "அத்தானோடு இப்படியிருந்து ௭த்தனை நாளாச்சு", Jacc: 0.000
ID: d6e063c7c, Pred: "5488", True: "1,229", Jacc: 0.000
ID: f18b5f1c5, Pred: "கொலம்பஸ்", True: "கொலம்பசு", Jacc: 0.000
ID: 1eacbc70f, Pred: "இயக்கர், நாகர்", True: "பிரித்தானிய", Jacc: 0.000
ID: ca3ad7ff8, Pred: "அகிலம்", True: "அகிலத்திரட்டு அம்மானை, அருள் நூல்", Jacc: 0.000
ID: 4ab83393f, Pred: "ஸ்காட்லாந்து மற்றும் இங்கிலாந்து", True: "உருகுவே", Jacc: 0.000
ID: 76fc189e8, Pred: "புல்லினத்தை", True: "Palmyra Palm", Jacc: 0.000
ID: 0115b1c86, Pred: "நீலத்திமிங்கிலமாகும்", True: "களிறு", Jacc: 0.000
ID: 89561de47, Pred: "செவ்வாய்", True: "வியாழன்", Jacc: 0.000
ID: 9201be221, Pred: "சீனா", True: "இந்தியா", Jacc: 0.000
Post-processed submission saved to submission.csv


Post-processed OOF Jaccard: 0.6175 (original: 0.6341)
No improvement; keep original.


In [26]:
import shutil
shutil.copy('submission_tamil90.csv', 'submission.csv')
print('Reverted to submission_tamil90.csv (better OOF est. ~0.64)')

Reverted to submission_tamil90.csv (better OOF est. ~0.64)
