# This notebook covers how to do inference with a State Transition model pretrained on the Tahoe-100M dataset.

Due to storage and ram limits, it is highly recommended to run this with colab pro, or download the notebook to run it locally.

# Installation

In [None]:
! pip install -q anndata

import os, sys, pickle
import anndata as ad

os.environ['MPLBACKEND'] = 'agg'

# Download a pre-trained ST-Tahoe checkpoint

In [None]:
# 1) Install dependency
%pip install -q --upgrade huggingface_hub

# 2) Download the entire repository snapshot
from huggingface_hub import snapshot_download, HfApi
import re, os

repo_url = "https://huggingface.co/arcinstitute/ST-Tahoe/tree/main"

# Parse owner/name from the URL
m = re.match(r"https?://huggingface\.co/([^/]+)/([^/]+)", repo_url)
if not m:
    raise ValueError("Could not parse repo URL")
owner, name = m.group(1), m.group(2)
repo_id = f"{owner}/{name}"

# Detect repo type (model/dataset/space)
api = HfApi()
repo_type = None
for kind, probe in [("model", api.model_info), ("dataset", api.dataset_info), ("space", api.space_info)]:
    try:
        probe(repo_id)
        repo_type = kind
        break
    except Exception:
        pass
if repo_type is None:
    raise ValueError(f"Could not determine repo type for {repo_id}. Is it public?")

# Where to put the files
target_dir = name  # e.g., "ST-Tahoe"
os.makedirs(target_dir, exist_ok=True)
exclude = ["final_from_preprint.ckpt"]

# Download everything at the specified revision
local_dir = snapshot_download(
    repo_id=repo_id,
    repo_type=repo_type,
    revision="main",
    local_dir=target_dir,
    local_dir_use_symlinks=False,  # copy real files instead of symlinks
    max_workers=8,                 # adjust for more/less parallelism
    ignore_patterns=exclude
)

print(f"Downloaded to: {local_dir}")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
For more details, check out https://huggingface.co/docs/huggingface_hub/main/en/guides/download#download-files-to-local-folder.


Fetching 13 files:   0%|          | 0/13 [00:00<?, ?it/s]

README.md: 0.00B [00:00, ?B/s]

config.yaml: 0.00B [00:00, ?B/s]

batch_onehot_map.pkl:   0%|          | 0.00/16.0k [00:00<?, ?B/s]

MODEL_LICENSE.md: 0.00B [00:00, ?B/s]

.gitattributes: 0.00B [00:00, ?B/s]

cell_type_onehot_map.pkl:   0%|          | 0.00/518k [00:00<?, ?B/s]

final.ckpt:   0%|          | 0.00/3.01G [00:00<?, ?B/s]

var_dims.pkl:   0%|          | 0.00/206k [00:00<?, ?B/s]

pert_onehot_map.pt:   0%|          | 0.00/5.50M [00:00<?, ?B/s]

data_module.torch:   0%|          | 0.00/1.90k [00:00<?, ?B/s]

wandb_path.txt:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

MODEL_ACCEPTABLE_USE_POLICY.md: 0.00B [00:00, ?B/s]

LICENSE.md: 0.00B [00:00, ?B/s]

Downloaded to: /content/ST-Tahoe


# Fetch a small file from HuggingFace

In [None]:
# Download a cell line that was heldout from training the model
from huggingface_hub import hf_hub_download

file_path = hf_hub_download(
    repo_id="arcinstitute/State-Tahoe-Filtered",
    repo_type="dataset",
    filename="c37.h5ad",
    local_dir=".",  # downloads to current directory
    local_dir_use_symlinks=False
)

print(f"Downloaded to: {file_path}")

For more details, check out https://huggingface.co/docs/huggingface_hub/main/en/guides/download#download-files-to-local-folder.


c37.h5ad:   0%|          | 0.00/49.7G [00:00<?, ?B/s]

Downloaded to: c37.h5ad


# Prepare the file for inference and downstream cell-eval usage

In [None]:
# Copy over the values in .obsm['X_hvg'] to .X, as cell-eval expects the data in .X

adata_holdout = ad.read_h5ad("c37.h5ad")
hvg_names = pickle.load(open('ST-Tahoe/var_dims.pkl', 'rb'))['gene_names']
adata_holdout.var.index = hvg_names
adata_holdout.X = adata_holdout.obsm['X_hvg']

# Save it back out
adata_holdout.write_h5ad("c37_real.h5ad")

In [None]:
# Clean the old file and free memory

del adata_holdout
! rm c37.h5ad

# Now we are ready to run inference on the file:

In [None]:
! uvx -q --from git+https://github.com/ArcInstitute/state@main state tx infer \
    --model-dir ST-Tahoe \
    --checkpoint ST-Tahoe/final.ckpt \
    --pert-col drugname_drugconc \
    --batch-col plate \
    --control-pert "[('DMSO_TF', 0.0, 'uM')]" \
    --adata c37_real.h5ad \
    --output c37_simulated.h5ad

==> STATE: tx infer (virtual experiment)
Loaded config: ST-Tahoe/config.yaml
Control perturbation: [('DMSO_TF', 0.0, 'uM')]
Grouping by cell type column: cell_name
StateTransitionPerturbationModel(
  (loss_fn): SamplesLoss()
  (gene_decoder): LatentToGeneDecoder(
    (decoder): Sequential(
      (0): Linear(in_features=2000, out_features=1024, bias=True)
      (1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (2): GELU(approximate='none')
      (3): Dropout(p=0.1, inplace=False)
      (4): Linear(in_features=1024, out_features=1024, bias=True)
      (5): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (6): GELU(approximate='none')
      (7): Dropout(p=0.1, inplace=False)
      (8): Linear(in_features=1024, out_features=512, bias=True)
      (9): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (10): GELU(approximate='none')
      (11): Dropout(p=0.1, inplace=False)
      (12): Linear(in_features=512, out_features=2000, bias=True)
      (13): ReLU

# Run Cell-Eval to compare observed data vs State simulated data

In [None]:
# Run cell-eval to compare the simulated anndata vs the observed anndata.

! uvx -q --from git+https://github.com/ArcInstitute/cell-eval@v0.5.42 cell-eval run \
    -ap c37_simulated.h5ad \
    -ar c37_real.h5ad \
    -o . \
    --control-pert "[('DMSO_TF', 0.0, 'uM')]" \
    --pert-col drugname_drugconc \
    --profile vcc \
    --celltype-col cell_name \
    --skip-metrics clustering_agreement,pearson_edistance \
    --batch-size 1024 \
    --num-threads 64

INFO:cell_eval._evaluator:Input is found to be log-normalized already - skipping transformation.
INFO:cell_eval._evaluator:Input is found to be log-normalized already - skipping transformation.
INFO:cell_eval._evaluator:Computing DE for real data
INFO:pdex._single_cell:Precomputing masks for each target gene
Identifying target masks: 100% 1137/1137 [00:01<00:00, 938.88it/s]
INFO:pdex._single_cell:Precomputing variable indices for each feature
Identifying variable indices: 100% 2000/2000 [00:00<00:00, 3133585.36it/s]
INFO:pdex._single_cell:Creating shared memory memory matrix for parallel computing


In [None]:
import pandas as pd

results = pd.read_csv('/content/NCI-H596_agg_NCI-H596_results.csv')
results[results.statistic == 'mean']

Unnamed: 0,statistic,overlap_at_N,mae,discrimination_score_l1
2,mean,0.000178,0.150929,0.501992
