# Notebook 1 for in-silico perturbation 

##### Intialisations

In [1]:
# Task 1 — In-Silico Perturbation Workflow (Geneformer V2)

import os, sys, gc, glob, pickle, json
from pathlib import Path
import numpy as np
import pandas as pd
import scanpy as sc

from geneformer import TranscriptomeTokenizer, InSilicoPerturber

# --- project root ---
PROJECT = Path(r"C:\Users\ratne\Downloads\Helical_Challenge")

# --- data paths ---
DATA      = PROJECT / "data"
RAW_H5AD  = DATA / "counts_combined_filtered_BA4_sALS_PN.h5ad"  # raw counts .h5ad

# put the prepped file in its own folder so tokenization only sees the right .h5ad
PREP_DIR  = DATA / "prepped"
PREP_H5AD = PREP_DIR / "ALS_snRNA_raw_prepped.h5ad"

# tokenized dataset + ISP outputs (keep them in your project, not in the Geneformer repo)
TOK       = DATA / "tokenized"          # Geneformer .dataset will be saved here
ISP       = PROJECT / "results" / "isp" # ISP outputs here
DATASET   = TOK / "ALS.dataset"

# Geneformer V2 model path:
# EITHER: local checkpoint folder...
MODEL_DIR = str(PROJECT / "Geneformer" / "Geneformer-V2-104M")
# ...OR, alternatively, the HF hub id:
# MODEL_DIR = "ctheodoris/Geneformer/gf-12L-95M-i4096"

# performance knobs
NPROC     = 4
FWD_BATCH = 16
MODEL_VER = "V2"

# make output dirs
for p in [PREP_DIR, TOK, ISP]:
    p.mkdir(parents=True, exist_ok=True)

print("PROJECT:", PROJECT)
print("RAW_H5AD exists:", RAW_H5AD.exists())
print("PREP_H5AD will be written to:", PREP_H5AD)
print("TOKENIZED dir:", TOK)
print("ISP dir:", ISP)
print("MODEL_DIR:", MODEL_DIR)


  from .autonotebook import tqdm as notebook_tqdm


PROJECT: C:\Users\ratne\Downloads\Helical_Challenge
RAW_H5AD exists: True
PREP_H5AD will be written to: C:\Users\ratne\Downloads\Helical_Challenge\data\prepped\ALS_snRNA_raw_prepped.h5ad
TOKENIZED dir: C:\Users\ratne\Downloads\Helical_Challenge\data\tokenized
ISP dir: C:\Users\ratne\Downloads\Helical_Challenge\results\isp
MODEL_DIR: C:\Users\ratne\Downloads\Helical_Challenge\Geneformer\Geneformer-V2-104M


##### Preprocess

In [2]:
# ---------- 1) Load raw data ----------
assert RAW_H5AD.exists(), f"Raw .h5ad not found at {RAW_H5AD}"
adata = sc.read_h5ad(RAW_H5AD)

In [3]:
# starting from your full adata *before* writing PREP_H5AD
import numpy as np

np.random.seed(0)
n_cells = 1000   # or 2000, something smallish
idx = np.random.choice(adata.obs_names, size=n_cells, replace=False)
adata_small = adata[idx].copy()

print("Subsampled:", adata_small.n_obs, "cells")

# then continue preprocessing *on adata_small* instead of adata:
adata = adata_small


Subsampled: 1000 cells


In [4]:
import numpy as np
import pandas as pd
import re
import pickle

# ---------- 1) sanity check on counts ----------
mtx_max = adata.X.max() if not hasattr(adata.X, "A") else adata.X.A.max()
if float(mtx_max) < 50:
    print("NOTE: Matrix max is <50. Ensure this is RAW counts, not log/CPM.")

# ---------- 2) Ensure Ensembl IDs in adata.var["ensembl_id"] ----------

gene_name_id_path = PROJECT / "Geneformer" / "geneformer" / "gene_name_id_dict_gc104M.pkl"
if not gene_name_id_path.exists():
    raise FileNotFoundError(f"Mapping dict not found at {gene_name_id_path}")

print("Loading gene name→Ensembl mapping:", gene_name_id_path)
with open(gene_name_id_path, "rb") as f:
    gene_name_id_dict = pickle.load(f)

# Start from existing 'ensembl_id' if present, otherwise from gene symbols in var.index
if "ensembl_id" in adata.var.columns:
    ens = adata.var["ensembl_id"].astype(object).to_numpy()
    print("Found existing 'ensembl_id' column; will keep valid ENSG IDs and fill the rest.")
else:
    ens = adata.var.index.astype(str).to_numpy().astype(object)
    print("No 'ensembl_id' column; starting from var index as symbols.")

symbols = adata.var.index.astype(str).to_numpy()

converted = 0
skipped_existing = 0
pattern = re.compile(r"^ENSG\d+")

for i in range(len(ens)):
    val = ens[i]
    s = "" if val is None else str(val)

    # If already a valid Ensembl ID, keep it
    if pattern.match(s):
        skipped_existing += 1
        continue

    # Otherwise, map from gene symbol (var index)
    sym = symbols[i]
    mapped = gene_name_id_dict.get(sym)
    if mapped is not None:
        ens[i] = mapped
        converted += 1
    else:
        ens[i] = None  # mark unmapped

adata.var["ensembl_id"] = ens

n_missing = pd.isna(adata.var["ensembl_id"]).sum()
print(f"-> Kept {skipped_existing} existing Ensembl IDs.")
print(f"-> Converted {converted} symbols to Ensembl IDs via mapping.")
print(f"-> {n_missing} genes still lack Ensembl IDs.")
print("Example ensembl_id values:", adata.var['ensembl_id'].head().tolist())

# ---------- 3) Required obs fields ----------
# Geneformer needs per-cell total counts in 'n_counts'
if "total_counts" in adata.obs.columns:
    adata.obs["n_counts"] = adata.obs["total_counts"].astype(np.float64)
else:
    if hasattr(adata.X, "A"):
        adata.obs["n_counts"] = np.array(adata.X.sum(axis=1)).ravel()
    else:
        adata.obs["n_counts"] = adata.X.sum(axis=1)

adata.obs["filter_pass"] = True

print(f"Genes: {adata.n_vars}, cells: {adata.n_obs}")
print("Has 'ensembl_id' in var:", "ensembl_id" in adata.var.columns)
print("Added 'n_counts' from total_counts" if "total_counts" in adata.obs.columns else "Computed 'n_counts' by summing X")


Loading gene name→Ensembl mapping: C:\Users\ratne\Downloads\Helical_Challenge\Geneformer\geneformer\gene_name_id_dict_gc104M.pkl
No 'ensembl_id' column; starting from var index as symbols.
-> Kept 0 existing Ensembl IDs.
-> Converted 22831 symbols to Ensembl IDs via mapping.
-> 1 genes still lack Ensembl IDs.
Example ensembl_id values: ['ENSG00000000003', 'ENSG00000000005', 'ENSG00000000419', 'ENSG00000000457', 'ENSG00000000938']
Genes: 22832, cells: 1000
Has 'ensembl_id' in var: True
Added 'n_counts' from total_counts


In [5]:
mask = adata.var["ensembl_id"].notna()
print("Dropping", (~mask).sum(), "genes with no Ensembl ID.")
adata = adata[:, mask].copy()

Dropping 1 genes with no Ensembl ID.


In [6]:
# obs metadata to carry into the .dataset file

obs_to_keep = [
    "Sample_ID", "Donor", "Region", "Sex",
    "Condition", "Group", "C9_pos",
    "CellClass", "CellType", "SubType",
    "Cellstates_LVL1", "Cellstates_LVL2", "Cellstates_LVL3",
    "total_counts", "log1p_total_counts",
    "total_counts_mt", "log1p_total_counts_mt", "pct_counts_mt",
    "n_genes",
    "split",
    # you already added:
    "n_counts"
]

present = [k for k in obs_to_keep if k in adata.obs.columns]
missing = [k for k in obs_to_keep if k not in adata.obs.columns]
print("Will keep obs columns:", present)
if missing:
    print("Missing (ignored):", missing)

custom_attr_name_dict = {k: k for k in present}

# cast non-numeric obs to string so they survive round-trips nicely
for k in present:
    if adata.obs[k].dtype.kind not in {"i", "f"}:
        adata.obs[k] = adata.obs[k].astype(str)

# Save prepped AnnData
PREP_H5AD.parent.mkdir(parents=True, exist_ok=True)
adata.write_h5ad(PREP_H5AD)
print("Wrote prepped AnnData to:", PREP_H5AD)


Will keep obs columns: ['Sample_ID', 'Donor', 'Region', 'Sex', 'Condition', 'Group', 'C9_pos', 'CellClass', 'CellType', 'SubType', 'Cellstates_LVL1', 'Cellstates_LVL2', 'Cellstates_LVL3', 'total_counts', 'log1p_total_counts', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'n_genes', 'split', 'n_counts']
Wrote prepped AnnData to: C:\Users\ratne\Downloads\Helical_Challenge\data\prepped\ALS_snRNA_raw_prepped.h5ad


In [None]:
after_adata = sc.read_h5ad(PREP_H5AD)

# checking the ensembl_ids
print("Ensembl IDs in var:", after_adata.var["ensembl_id"].head(100))
print("\n")
# checking the obs
print("Obs fields:", after_adata.obs.head(100))


Ensembl IDs in var: Gene
TSPAN6    ENSG00000000003
TNMD      ENSG00000000005
DPM1      ENSG00000000419
SCYL3     ENSG00000000457
FGR       ENSG00000000938
               ...       
CROT      ENSG00000005469
ABCB4     ENSG00000005471
KMT2E     ENSG00000005483
RHBDD2    ENSG00000005486
SOX8      ENSG00000005513
Name: ensembl_id, Length: 100, dtype: object


Obs fields:                                        Sample_ID Donor Region Sex Condition  \
Barcode                                                                       
GTCGTTCTCTGTGCAA-118MCX  191112_ALS_118_snRNA-C4   118    BA4   M       ALS   
GGGCTCATCTGGGAGA-126MCX  191112_ALS_126_snRNA-A6   126    BA4   M       ALS   
GTAACACTCGTCCATC-303MCX   191114_PN_303_snRNA-E8   303    BA4   F        PN   
AGCCAGCGTCGTTCAA-116MCX  191112_ALS_116_snRNA-C2   116    BA4   M       ALS   
ATTCGTTTCAAGCTGT-309MCX   191114_PN_309_snRNA-F1   309    BA4   F        PN   
...                                          ...   ...    ...  ..       ...  

In [8]:
import torch
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("Device:", torch.cuda.get_device_name(0))


CUDA available: True
Device: NVIDIA GeForce RTX 3050 Ti Laptop GPU


##### Tokenise the data

In [9]:
# Cell 4 — Tokenize with Geneformer V2

TOKENIZED_DIR = DATA / "tokenized"
TOKENIZED_DIR.mkdir(parents=True, exist_ok=True)

print("Tokenizing with Geneformer V2…")
print("Using data directory:", PREP_DIR)

tk = TranscriptomeTokenizer(
    custom_attr_name_dict=custom_attr_name_dict,
    nproc=NPROC,
    model_version=MODEL_VER
)

tk.tokenize_data(
    data_directory=str(PREP_DIR),   # folder containing only the prepped .h5ad
    output_directory=str(TOKENIZED_DIR),
    output_prefix="ALS",
    file_format="h5ad"
)

DATASET = TOKENIZED_DIR / "ALS.dataset"
print("Done. Dataset exists:", DATASET.exists(), "at", DATASET)


Tokenizing with Geneformer V2…
Using data directory: C:\Users\ratne\Downloads\Helical_Challenge\data\prepped
Tokenizing C:\Users\ratne\Downloads\Helical_Challenge\data\prepped\ALS_snRNA_raw_prepped.h5ad


  for i in adata.var["ensembl_id_collapsed"][coding_miRNA_loc]
  coding_miRNA_ids = adata.var["ensembl_id_collapsed"][coding_miRNA_loc]


Creating dataset.
Done. Dataset exists: True at C:\Users\ratne\Downloads\Helical_Challenge\data\tokenized\ALS.dataset


##### Helpers

In [10]:
# Cell 5 — In-silico perturbation helpers (knock-down / knock-up)

import gc

# ISP output directory
ISP = PROJECT / "results" / "isp"
ISP.mkdir(parents=True, exist_ok=True)

# genes present in the (prepped) matrix
genes_in_matrix = set(adata.var["ensembl_id"].astype(str).tolist())

def verify_genes_present(ensg_list):
    """Split requested Ensembl IDs into present vs missing."""
    ensg_list = [str(x) for x in ensg_list]
    present = [g for g in ensg_list if g in genes_in_matrix]
    missing = [g for g in ensg_list if g not in genes_in_matrix]
    return present, missing

def run_isp(
    genes_to_perturb,
    perturb_type,                 # "delete" (KD) or "overexpress" (KU)
    out_prefix,
    model_directory=MODEL_DIR,
    dataset_file=DATASET,
    emb_mode="cls",
    emb_layer=-1,
    forward_batch_size=FWD_BATCH,
    nproc=NPROC,
    model_version=MODEL_VER,
    combos=0,
    anchor_gene=None
):
    """Run Geneformer InSilicoPerturber for a list of Ensembl IDs."""
    present, missing = verify_genes_present(genes_to_perturb)
    if missing:
        print(f"[WARN] {len(missing)} gene(s) not in matrix (skipped):", missing[:10], "...")
    if not present:
        print("[SKIP] no valid genes to perturb in this call.")
        return

    isp = InSilicoPerturber(
        perturb_type=perturb_type,           # "delete" / "overexpress"
        genes_to_perturb=present,            # list of Ensembl IDs
        combos=combos,
        anchor_gene=anchor_gene,
        model_type="Pretrained",
        emb_mode=emb_mode,
        emb_layer=emb_layer,
        forward_batch_size=forward_batch_size,
        nproc=nproc,
        model_version=model_version
    )
    isp.perturb_data(
        model_directory=str(model_directory),
        input_data_file=str(dataset_file),
        output_directory=str(ISP),
        output_prefix=out_prefix
    )
    print(f"[ISP] {perturb_type}: {len(present)} gene(s) -> {ISP} (prefix={out_prefix})")

def chunked(lst, size):
    for i in range(0, len(lst), size):
        yield lst[i:i+size]

def run_isp_batched(genes, perturb_type, prefix, batch_size=25):
    """Scale to many genes by chunking."""
    for i, gbatch in enumerate(chunked(genes, batch_size), 1):
        run_isp(gbatch, perturb_type, f"{prefix}_b{i:03d}")
        gc.collect()
    print(f"[ISP] completed all batches for {prefix}.")


In [11]:
# Cell 6 — Helper to load ISP outputs and summarize per gene

def load_isp_batches(prefix):
    """Read all batch .pkl files starting with prefix into a DataFrame."""
    rows = []
    for f in sorted(ISP.glob(f"{prefix}*batch*.pkl")):
        with open(f, "rb") as fh:
            b = pickle.load(fh)
        n = len(b.get("cell_ids", []))
        gene_ids = b.get("gene_id", b.get("gene_ids", ["?"] * n))
        cos = b.get("cosine_shift", [np.nan] * n)
        for i in range(n):
            rows.append({
                "batch_file": f.name,
                "cell_id": b["cell_ids"][i],
                "gene_id": gene_ids[i],
                "cosine_shift": float(cos[i]) if not isinstance(cos, float) else float(cos),
            })
    return pd.DataFrame(rows)

def summarize_isp(prefix, topn=15):
    df = load_isp_batches(prefix)
    if df.empty:
        print("No ISP batches found for", prefix)
        return df, None
    summary = (df.groupby("gene_id")["cosine_shift"]
                 .median()
                 .sort_values(ascending=False)
                 .rename("median_cosine_shift")
                 .reset_index())
    display(summary.head(topn))
    return df, summary


##### Quick sample test: single-gene KD + KU

In [12]:
# Cell 7 — Smoke test: perturb a single gene

# Pick an Ensembl ID you know is present in your data:
print(adata.var["ensembl_id"].head())

TEST_GENE = adata.var["ensembl_id"].iloc[0]   # quick hack: just take the first one
print("Using test gene:", TEST_GENE)

# Knock-down (delete)
run_isp([TEST_GENE], "delete", "SMOKE_KD")

# Knock-up (overexpress)
run_isp([TEST_GENE], "overexpress", "SMOKE_KU")

# Summaries (if any batches were written)
_, kd_summary = summarize_isp("SMOKE_KD")
_, ku_summary = summarize_isp("SMOKE_KU")


Gene
TSPAN6    ENSG00000000003
TNMD      ENSG00000000005
DPM1      ENSG00000000419
SCYL3     ENSG00000000457
FGR       ENSG00000000938
Name: ensembl_id, dtype: object
Using test gene: ENSG00000000003


100%|██████████| 1/1 [01:56<00:00, 116.38s/it]


[ISP] delete: 1 gene(s) -> C:\Users\ratne\Downloads\Helical_Challenge\results\isp (prefix=SMOKE_KD)


100%|██████████| 63/63 [1:38:01<00:00, 93.36s/it]   

[ISP] overexpress: 1 gene(s) -> C:\Users\ratne\Downloads\Helical_Challenge\results\isp (prefix=SMOKE_KU)
No ISP batches found for SMOKE_KD
No ISP batches found for SMOKE_KU





In [None]:
import pickle
import numpy as np
from pathlib import Path

print("ISP directory:", ISP)
for f in sorted(ISP.glob("*.pickle")):
    print(" -", f.name)

def load_cell_embs(path):
    """Load Geneformer ISP cell embeddings from a cell_embs_dict file."""
    with open(path, "rb") as fh:
        d = pickle.load(fh)

    print(f"\n{Path(path).name}:")
    print("  raw keys:", list(d.keys()))

    # Find the (layer_idx, 'cell_emb') key
    key = None
    for k in d.keys():
        if isinstance(k, tuple) and "cell_emb" in k[1]:
            key = k
            break

    if key is None:
        raise ValueError("Could not find a 'cell_emb' key in this dict.")

    raw = d[key]           # list/array of length N_cells, each element is a vector
    print("  number of cells:", len(raw))
    print("  type of first element:", type(raw[0]))

    # Stack into a 2D array: (n_cells, emb_dim)
    embs = np.stack(raw, axis=0)
    print("  embedding array shape:", embs.shape)

    return embs

# Load both KD and KU results
kd_path = next(ISP.glob("*SMOKE_KD*.pickle"))
ku_path = next(ISP.glob("*SMOKE_KU*.pickle"))

kd_embs = load_cell_embs(kd_path)
ku_embs = load_cell_embs(ku_path)


ISP directory: C:\Users\ratne\Downloads\Helical_Challenge\results\isp
 - in_silico_delete_SMOKE_KD_cell_embs_dict_[4]_raw.pickle
 - in_silico_overexpress_SMOKE_KU_cell_embs_dict_[4]_raw.pickle

in_silico_delete_SMOKE_KD_cell_embs_dict_[4]_raw.pickle:
  raw keys: [(4, 'cell_emb')]
  number of cells: 12
  type of first element: <class 'float'>
  embedding array shape: (12,)

in_silico_overexpress_SMOKE_KU_cell_embs_dict_[4]_raw.pickle:
  raw keys: [(4, 'cell_emb')]
  number of cells: 1000
  type of first element: <class 'float'>
  embedding array shape: (1000,)
