# STATE Transition Model Training on SCP1064

This notebook adapts the official STATE training notebook  for the custom Perturb-CITE-seq dataset (SCP1064).

The workflow is as follows:
1.  **Setup**: Define file paths and import libraries.
2.  **Data Preprocessing**:
    * Load RNA expression and metadata from your CSVs.
    * Create a single `AnnData` object.
    * Apply your custom metadata processing (filling "CTRL", creating 'perturbation' column).
    * Apply the standard preprocessing from the reference: normalize to 10k counts and log-transform .
    * Calculate and store 2000 Highly Variable Genes (HVGs).
    * Save the final processed data as a `.h5ad` file.
3.  **Create TOML Config**:
    * Define a train/val/test split by holding out a subset of perturbations from one cellular context, similar to the reference's 'few-shot' setup[cite: 252, 281].
    * This split is saved to a `.toml` file.
4.  **Install & Train**:
    * Change to your local `state` repo directory.
    * Install dependencies using `uv`.
    * Run the `state tx train` command, pointing to our new data and TOML config.
5.  **Predict**:
    * Run inference and evaluation on the holdout test set[cite: 449].

In [1]:
import os
import pandas as pd
import scanpy as sc
import anndata as ad
import numpy as np
from scipy import sparse
import toml
from sklearn.model_selection import train_test_split

# --- User-defined Paths ---
# Directory for your data
data_dir = "/home/nebius/cellian/data/perturb-cite-seq/SCP1064"

# Input files
meta_path = f"{data_dir}/metadata/RNA_metadata.csv"
rna_csv = f"{data_dir}/expression/RNA_expression_subset1a.csv"
protein_csv = f"{data_dir}/expression/Protein_expression.csv" # Note: The reference ST model uses RNA profiles [cite: 14]

# Path to your local clone of the STATE repo
state_repo_dir = "/home/nebius/ST-Tahoe"

# Output files
processed_data_dir = f"{data_dir}/processed"
os.makedirs(processed_data_dir, exist_ok=True)
output_adata_path = f"{processed_data_dir}/scp1064_processed.h5ad"
output_toml_path = f"{state_repo_dir}/scp1064_split.toml" # Save TOML config in the repo dir for easy access

CELL_TYPE_COLUMN = 'condition' 
PERTURBATION_COLUMN = 'perturbation'
CONTROL_LABEL = 'CTRL'

In [2]:
print(f"Loading metadata from {meta_path}...")
meta_df = pd.read_csv(meta_path, index_col=0)

print(f"Loading RNA expression from {rna_csv}...")
# Genes as index, cells as columns
rna_df = pd.read_csv(rna_csv, index_col=0)

# Transpose RNA data: AnnData expects (observations x variables), i.e., (cells x genes)
print("Transposing RNA matrix to (cells x genes)...")
rna_df_t = rna_df.T

# Align metadata and expression data
print("Aligning metadata and expression data...")
common_cells = rna_df_t.index.intersection(meta_df.index)
if len(common_cells) == 0:
    raise ValueError("No common cells found between RNA expression and metadata. Check cell identifiers.")

print(f"Found {len(common_cells)} common cells.")
rna_df_t = rna_df_t.loc[common_cells]
meta_df = meta_df.loc[common_cells]

# Create AnnData object
adata = ad.AnnData(X=rna_df_t.values, obs=meta_df)
adata.var_names = rna_df.index.tolist()
del rna_df, rna_df_t, meta_df

Loading metadata from /home/nebius/cellian/data/perturb-cite-seq/SCP1064/metadata/RNA_metadata.csv...
Loading RNA expression from /home/nebius/cellian/data/perturb-cite-seq/SCP1064/expression/RNA_expression_subset1a.csv...


  meta_df = pd.read_csv(meta_path, index_col=0)


Transposing RNA matrix to (cells x genes)...
Aligning metadata and expression data...
Found 27291 common cells.


In [3]:
print("Applying custom metadata processing...")
# 1. Fill NA in 'sgRNA' with 'CTRL'
adata.obs['sgRNA'] = adata.obs['sgRNA'].fillna(CONTROL_LABEL)
# 2. Create 'perturbation' column by stripping sgRNA number
adata.obs[PERTURBATION_COLUMN] = adata.obs['sgRNA'].str.replace(r"(_\d+)$", "", regex=True)

print(f"Unique perturbations created. Example: {adata.obs[PERTURBATION_COLUMN].unique()[:5]}")

Applying custom metadata processing...
Unique perturbations created. Example: ['HLA-B' 'CTRL' 'IFNGR1' 'CDKN1A' 'EMP1']


In [6]:
print("Applying standard preprocessing (Normalize total, log1p)...")
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)

# --- Compute and Store Highly Variable Genes (HVGs) ---
# The reference notebook computes 2000 HVGs.
print("Calculating 2000 Highly Variable Genes (HVGs)...")
sc.pp.highly_variable_genes(adata, n_top_genes=2000, flavor='seurat_v3')

# Store HVG data in .obsm as a dense array, just like the reference [cite: 228-230]
print("Storing HVG array in .obsm['X_hvg']...")
hvg_data = adata.X[:, adata.var['highly_variable']]
if sparse.issparse(hvg_data):
    hvg_data = hvg_data.toarray()
adata.obsm['X_hvg'] = hvg_data
print(f"Saving processed AnnData object to {output_adata_path}...")
adata.write(output_adata_path)

Applying standard preprocessing (Normalize total, log1p)...


Calculating 2000 Highly Variable Genes (HVGs)...


ImportError: Please install skmisc package via `pip install --user scikit-misc

In [None]:
all_contexts = adata.obs[CELL_TYPE_COLUMN].unique().tolist()
if not all_contexts:
    raise ValueError(f"No contexts found in column '{CELL_TYPE_COLUMN}'. Please check the column name.")

print(f"Found {len(all_contexts)} cellular contexts: {all_contexts}")

# --- Define the Split ---
# We pick ONE context to split for validation and testing
# All other contexts will be used entirely for training.
context_to_split = all_contexts[0]
print(f"Selected context '{context_to_split}' to create val/test splits from.")

# Get all perturbations present in this context, excluding the control
perts_in_context = adata[adata.obs[CELL_TYPE_COLUMN] == context_to_split].obs[PERTURBATION_COLUMN].unique()
perts_in_context = [p for p in perts_in_context if p != CONTROL_LABEL]
print(f"Found {len(perts_in_context)} perturbations in '{context_to_split}'.")

# Split these perturbations into validation and test sets (e.g., 50/50 split)
val_perts, test_perts = train_test_split(perts_in_context, test_size=0.5, random_state=42)

print(f"Split: {len(val_perts)} validation perts, {len(test_perts)} test perts.")

# --- Build the TOML Config Dictionary ---
# This structure follows the reference notebook [cite: 265-281]
DATASET_NAME = "scp1064_dataset" # An internal name for this dataset

config = {
    # Map the dataset name to the *directory* containing the .h5ad file
    "datasets": {
        DATASET_NAME: processed_data_dir
    },
    
    # Specify which datasets to use for training
    "training": {
        DATASET_NAME: "train"
    },
    
    # Define zero-shot holdouts (none in this setup)
    "zeroshot": {},
    
    # Define few-shot holdouts
    "fewshot": {
        # The key is "dataset_name.context_name"
        f"{DATASET_NAME}.{context_to_split}": {
            "val": list(val_perts),
            "test": list(test_perts)
        }
    }
}

# --- Save the TOML File ---
print(f"Saving TOML config to {output_toml_path}...")
with open(output_toml_path, 'w') as f:
    toml.dump(config, f)

print("\nTOML config generation complete.")
print(f"--- Config Preview (first 5 perts) ---")
print(f"[fewshot.{DATASET_NAME}.{context_to_split}]")
print(f"val = {list(val_perts)[:5]}...")
print(f"test = {list(test_perts)[:5]}...")
print("---------------------------------")

In [None]:
%cd {state_repo_dir}

In [None]:
!uv sync

In [None]:
!uv add "cell-load>=0.7.11" toml pandas scikit-learn scanpy anndata

In [None]:
!uv run state tx train \
    data.kwargs.toml_config_path="{output_toml_path}" \
    data.kwargs.num_workers=4 \
    data.kwargs.output_space="gene" \
    data.kwargs.batch_col="gem_group" \
    data.kwargs.pert_col="{PERTURBATION_COLUMN}" \
    data.kwargs.cell_type_key="{CELL_TYPE_COLUMN}" \
    data.kwargs.control_pert="{CONTROL_LABEL}" \
    training.max_steps=80000 \
    training.ckpt_every_n_steps=2000 \
    training.batch_size=64 \
    training.lr=1e-3 \
    model.kwargs.cell_set_len=64 \
    model.kwargs.hidden_dim=128 \
    model.kwargs.batch_encoder=True \
    model=state \
    wandb.entity="arcinstitute" \
    wandb.tags="[scp1064_run]" \
    output_dir="test_scp1064" \
    name="scp1064_holdout"

In [None]:
!uv run state tx predict \
    --output_dir "test_scp1064/scp1064_holdout" \
    --checkpoint "last.ckpt"