# ATTR-SSM Colab Quickstart

- Use Runtime → Change runtime type → GPU for best speed.
- The script now supports an optional `--device` flag.
  - `--device auto` prefers CUDA → MPS → CPU automatically.


In [None]:
import torch, platform
print('Python:', platform.python_version())
print('PyTorch:', torch.__version__)
print('CUDA available:', torch.cuda.is_available())
if torch.cuda.is_available():
    print('CUDA device:', torch.cuda.get_device_name(0))
else:
    mps_ok = getattr(torch.backends, 'mps', None)
    print('MPS available:', bool(mps_ok and mps_ok.is_available()))


In [None]:
# Clean Torch triple and install CUDA 12.1 wheels
%pip -q uninstall -y torch torchvision torchaudio
%pip -q install -U pip setuptools wheel
%pip -q install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu121

# Install project dependencies with wheel-friendly versions for Colab (avoid building from source)
%pip -q install "numpy>=1.26,<2.0" "scipy>=1.11,<2.0" "pandas>=2.1,<2.3" "matplotlib>=3.8,<3.9" \
  "scikit-learn>=1.3,<1.5" seaborn==0.13.2 tqdm==4.67.1 einops==0.8.1 \
  transformers==4.46.3 huggingface-hub==0.34.4 tokenizers==0.20.3 safetensors==0.5.3 \
  PyYAML==6.0.2 requests==2.32.4 fsspec==2024.6.1 typing_extensions==4.12.2 joblib==1.4.2 \
  regex==2024.11.6 packaging==25.0 filelock==3.13.1 sympy==1.13.3 "networkx>=3,<4"

# Clone the repo and prepare folders
import os
if not os.path.exists('ATTR-SSM'):
    !git clone https://github.com/tripmat/ATTR-SSM.git
%cd ATTR-SSM
!mkdir -p results experiments


In [None]:
# Quick sanity run (uses sample data to generate figures)
!python main.py --device auto --plot-only


In [None]:
# Full training (can take time)
!python main.py --device auto --train-only


### Fallbacks
- If CUDA 12.1 wheels are unavailable, try CUDA 11.8:
  ```bash
  %pip -q install torch==2.2.2+cu118 torchvision==0.17.2+cu118 -f https://download.pytorch.org/whl/torch_stable.html
  ```
- CPU-only:
  ```bash
  %pip -q install torch==2.2.2+cpu torchvision==0.17.2+cpu -f https://download.pytorch.org/whl/torch_stable.html
  !python main.py --device cpu --plot-only
  ```


## Optional: Train on GPU
Run the full training (can be time-consuming):

```bash
python main.py --device auto --train-only
```
