# Task 2 — Apply perturbations to ALS genes + embed in latent space

#### Initialisation

In [2]:
# imports
import os, gc, pickle
from pathlib import Path

import numpy as np
import pandas as pd
from datasets import load_from_disk

from geneformer import InSilicoPerturber, EmbExtractor

# --- project root must match Notebook 1 ---
PROJECT = Path(r"C:\Users\ratne\Downloads\Helical_Challenge")

DATA           = PROJECT / "data"
PREP_DIR       = DATA / "prepped"
TOKENIZED_DIR  = DATA / "tokenized"
DATASET        = TOKENIZED_DIR / "ALS.dataset"

RESULTS_DIR    = PROJECT / "results"
ISP_DIR        = RESULTS_DIR / "isp_als"        # separate from smoke-test ISP
EMB_DIR        = RESULTS_DIR / "embeddings"     # where EmbExtractor will write

MODEL_DIR      = str(PROJECT / "Geneformer" / "Geneformer-V2-104M")
MODEL_VER      = "V2"
NPROC          = 4
FWD_BATCH      = 16  # adjust if you have GPU + memory

for p in [ISP_DIR, EMB_DIR]:
    p.mkdir(parents=True, exist_ok=True)

print("DATASET exists:", DATASET.exists(), "->", DATASET)
print("MODEL_DIR:", MODEL_DIR)
print("ISP_DIR:", ISP_DIR)
print("EMB_DIR:", EMB_DIR)


  from .autonotebook import tqdm as notebook_tqdm


DATASET exists: True -> C:\Users\ratne\Downloads\Helical_Challenge\data\tokenized\ALS.dataset
MODEL_DIR: C:\Users\ratne\Downloads\Helical_Challenge\Geneformer\Geneformer-V2-104M
ISP_DIR: C:\Users\ratne\Downloads\Helical_Challenge\results\isp_als
EMB_DIR: C:\Users\ratne\Downloads\Helical_Challenge\results\embeddings


#### Load AnnData metadata + genes present in dataset

In [2]:
import scanpy as sc

PREP_H5AD = PREP_DIR / "ALS_snRNA_raw_prepped.h5ad"  # same name as Notebook 1
assert PREP_H5AD.exists(), f"Prepped AnnData not found at {PREP_H5AD}"

adata = sc.read_h5ad(PREP_H5AD)
print(adata)

# Make sure ensembl_id exists
assert "ensembl_id" in adata.var.columns, "No 'ensembl_id' in adata.var; check Notebook 1."

genes_in_matrix = set(adata.var["ensembl_id"].astype(str).tolist())
print("Number of unique Ensembl IDs in matrix:", len(genes_in_matrix))

# Quick peek at obs metadata columns
print("obs columns:", list(adata.obs.columns))


AnnData object with n_obs × n_vars = 1000 × 22831
    obs: 'Sample_ID', 'Donor', 'Region', 'Sex', 'Condition', 'Group', 'C9_pos', 'CellClass', 'CellType', 'SubType', 'full_label', 'DGE_Group', 'Bakken_M1', 'data_merge_id', 'data_sample_id', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'Cellstates_LVL1', 'Cellstates_LVL2', 'Cellstates_LVL3', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'n_genes', 'split', 'n_counts', 'filter_pass'
    var: 'Biotype', 'Chromosome', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'ENSID', 'mt', 'n_cells', 'biotype', 'ensembl_id'
Number of unique Ensembl IDs in matrix: 22831
obs columns: ['Sample_ID', 'Donor', 'Region', 'Sex', 'Condition', 'Group', 'C9_pos', 'CellClass', 'CellType', 'SubType', 'full_label', 'DGE_Group', 'Bakken_M1', 'data_merge_id', 'data_sample_id', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts',

In [3]:
# Task 2: ALS disease genes (Gene symbol -> Ensembl ID, human)
# These are well-established ALS genes.
ALS_GENES_ENSEMBL = [
    "ENSG00000142168",  # SOD1 
    "ENSG00000120948",  # TARDBP (TDP-43) 
    "ENSG00000089280",  # FUS 
    "ENSG00000147894",  # C9orf72 
    "ENSG00000123240",  # OPTN 
    "ENSG00000183735",  # TBK1 
]

# sanity: allow empty first, but warn (now it shouldn't be empty)
if not ALS_GENES_ENSEMBL:
    print("⚠️ ALS_GENES_ENSEMBL is empty. Fill this with Ensembl IDs for ALS genes you want to perturb.")

def verify_genes_present(ensg_list):
    """Split requested Ensembl IDs into present vs missing."""
    ensg_list = [str(x) for x in ensg_list]
    present = [g for g in ensg_list if g in genes_in_matrix]
    missing = [g for g in ensg_list if g not in genes_in_matrix]
    return present, missing

present, missing = verify_genes_present(ALS_GENES_ENSEMBL)
print("ALS genes present in matrix:", present)
print("ALS genes missing from matrix:", missing)


ALS genes present in matrix: ['ENSG00000142168', 'ENSG00000120948', 'ENSG00000089280', 'ENSG00000147894', 'ENSG00000123240', 'ENSG00000183735']
ALS genes missing from matrix: []


#### ISP helpers for ALS genes (Same idea as notebook 1)

In [5]:
# ISP helpers for ALS genes 
def run_isp(
    genes_to_perturb,
    perturb_type,                 # "delete" (KD) or "overexpress" (KU)
    out_prefix,
    model_directory=MODEL_DIR,
    dataset_file=DATASET,
    emb_mode="cls",
    emb_layer=-1,
    forward_batch_size=FWD_BATCH,
    nproc=NPROC,
    model_version=MODEL_VER,
    combos=0,
    anchor_gene=None
):
    """Run Geneformer InSilicoPerturber for a list of Ensembl IDs."""
    present, missing = verify_genes_present(genes_to_perturb)
    if missing:
        print(f"[WARN] {len(missing)} gene(s) not in matrix (skipped):", missing[:10], "...")
    if not present:
        print("[SKIP] no valid genes to perturb in this call.")
        return

    isp = InSilicoPerturber(
        perturb_type=perturb_type,           # "delete" / "overexpress"
        genes_to_perturb=present,
        combos=combos,
        anchor_gene=anchor_gene,
        model_type="Pretrained",
        emb_mode=emb_mode,
        emb_layer=emb_layer,
        forward_batch_size=forward_batch_size,
        nproc=nproc,
        model_version=model_version,
    )

    isp.perturb_data(
        model_directory=str(model_directory),
        input_data_file=str(dataset_file),
        output_directory=str(ISP_DIR),
        output_prefix=out_prefix,
    )
    print(f"[ISP] {perturb_type}: {len(present)} gene(s) -> {ISP_DIR} (prefix={out_prefix})")


def chunked(lst, size):
    for i in range(0, len(lst), size):
        yield lst[i:i+size]


def run_isp_batched(genes, perturb_type, prefix, batch_size=5):
    """Scale to many ALS genes by chunking."""
    for i, gbatch in enumerate(chunked(genes, batch_size), 1):
        run_isp(gbatch, perturb_type, f"{prefix}_b{i:03d}")
        gc.collect()
    print(f"[ISP] completed all batches for {prefix}.")


#### Apply Perturbation

In [6]:
# Apply KD (delete) to ALS disease genes
if ALS_GENES_ENSEMBL:
    run_isp_batched(ALS_GENES_ENSEMBL, perturb_type="delete",      prefix="ALS_KD")
    run_isp_batched(ALS_GENES_ENSEMBL, perturb_type="overexpress", prefix="ALS_KU")
else:
    print("⚠️ ALS_GENES_ENSEMBL is empty. Fill it with ALS Ensembl IDs before running KD/KU.")

100%|██████████| 1/1 [00:55<00:00, 55.90s/it]


[ISP] delete: 5 gene(s) -> C:\Users\ratne\Downloads\Helical_Challenge\results\isp_als (prefix=ALS_KD_b001)


100%|██████████| 23/23 [41:11<00:00, 107.44s/it]


[ISP] delete: 1 gene(s) -> C:\Users\ratne\Downloads\Helical_Challenge\results\isp_als (prefix=ALS_KD_b002)
[ISP] completed all batches for ALS_KD.


100%|██████████| 63/63 [1:46:08<00:00, 101.08s/it]  


[ISP] overexpress: 5 gene(s) -> C:\Users\ratne\Downloads\Helical_Challenge\results\isp_als (prefix=ALS_KU_b001)


100%|██████████| 63/63 [2:03:49<00:00, 117.92s/it]  


[ISP] overexpress: 1 gene(s) -> C:\Users\ratne\Downloads\Helical_Challenge\results\isp_als (prefix=ALS_KU_b002)
[ISP] completed all batches for ALS_KU.


#### Extract Embeddings

In [8]:
# Extract cell embeddings for the unperturbed ALS dataset using Geneformer V2

emb_prefix = "ALS_unperturbed"

emb_extractor = EmbExtractor(
    model_type="Pretrained",
    num_classes=1,          # <- was None; must be an int
    emb_mode="cls",         # CLS token embeddings
    emb_layer=-1,           # last layer
    forward_batch_size=FWD_BATCH,
    nproc=NPROC,
    model_version=MODEL_VER,
)

print("Running EmbExtractor on ALS.dataset…")

emb_extractor.extract_embs(
    model_directory=MODEL_DIR,
    input_data_file=str(DATASET),
    output_directory=str(EMB_DIR),
    output_prefix=emb_prefix,
)

print("Done. Check EMB_DIR for output files.")


Running EmbExtractor on ALS.dataset…


100%|██████████| 63/63 [49:27<00:00, 47.11s/it]


Done. Check EMB_DIR for output files.


#### Load unperturbed embeddings into a NumPy array + optional DataFrame

##### Method 1 - Need to fix

In [9]:
# Load unperturbed embeddings into a NumPy array + optional DataFrame

def load_emb_extractor_output(prefix, emb_dir=EMB_DIR):
    files = sorted(emb_dir.glob(f"{prefix}*cell_embs_dict*pickle"))
    if not files:
        print(f"No embedding files found for prefix '{prefix}' in {emb_dir}")
        return None

    print("Found embedding files:")
    for f in files:
        print(" -", f.name)

    # Use the first file (adjust if you know there are multiple splits)
    path = files[0]
    with open(path, "rb") as fh:
        d = pickle.load(fh)

    print("\nLoading from:", path.name)
    print("Raw keys:", list(d.keys()))

    key = None
    for k in d.keys():
        if isinstance(k, tuple) and "cell_emb" in k[1]:
            key = k
            break

    if key is None:
        raise ValueError("Could not find a 'cell_emb' key in this dict.")

    raw = d[key]  # list/array of per-cell embeddings OR cosine shifts
    print("Number of elements:", len(raw))
    print("Type of first element:", type(raw[0]))

    # If each element is a vector -> stack into (n_cells, emb_dim)
    if hasattr(raw[0], "__len__") and not isinstance(raw[0], (float, int)):
        embs = np.stack(raw, axis=0)
        print("Embedding matrix shape:", embs.shape)
        return embs
    else:
        # It’s a scalar per cell (e.g. cosine shift), not full emb vector
        arr = np.array(raw)
        print("Scalar per cell; array shape:", arr.shape)
        return arr

ALS_EMBS = load_emb_extractor_output("ALS_unperturbed")


No embedding files found for prefix 'ALS_unperturbed' in C:\Users\ratne\Downloads\Helical_Challenge\results\embeddings


##### Working alternat method

In [10]:
import pandas as pd

csv_path = EMB_DIR / "ALS_unperturbed.csv"
print("CSV exists:", csv_path.exists(), "->", csv_path)

ALS_EMBS_DF = None
if csv_path.exists():
    ALS_EMBS_DF = pd.read_csv(csv_path, index_col=0)
    print("ALS_unperturbed.csv shape:", ALS_EMBS_DF.shape)
    display(ALS_EMBS_DF.head())
else:
    print("No ALS_unperturbed.csv found in", EMB_DIR)


CSV exists: True -> C:\Users\ratne\Downloads\Helical_Challenge\results\embeddings\ALS_unperturbed.csv
ALS_unperturbed.csv shape: (1000, 768)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,758,759,760,761,762,763,764,765,766,767
0,0.024868,-0.118322,-0.104733,-0.228727,0.196272,-0.019281,-0.123582,0.453015,-0.310267,0.132588,...,-0.528719,-0.137115,-0.114696,0.110082,-0.07881,-0.241094,0.440791,-0.215314,-0.055313,0.032509
1,0.04346,-0.142153,-0.17509,-0.25471,0.208086,-0.042378,-0.075468,0.449917,-0.334752,0.167347,...,-0.470818,-0.211254,-0.121203,0.100071,-0.06472,-0.207777,0.470039,-0.184062,-0.06832,0.024948
2,0.038716,-0.10206,-0.108188,-0.231725,0.19146,-0.006341,-0.122659,0.433794,-0.295939,0.164062,...,-0.489459,-0.151722,-0.086471,0.100692,-0.091844,-0.229847,0.489524,-0.214918,-0.053846,0.075414
3,0.034691,-0.128403,-0.140141,-0.236448,0.208852,-0.02278,-0.093658,0.4521,-0.331103,0.16072,...,-0.470175,-0.187985,-0.10419,0.077863,-0.083059,-0.215005,0.463036,-0.218762,-0.046391,0.048188
4,0.059822,-0.109191,-0.136661,-0.223937,0.191714,-0.019595,-0.089921,0.465615,-0.279892,0.133219,...,-0.439917,-0.166989,-0.095249,0.100612,-0.098669,-0.227109,0.474205,-0.201083,-0.051931,0.044647


#### Build a simple DataFrame linking cell embeddings to metadata (for Task 3)

In [11]:
# link embeddings to metadata in a DataFrame
if isinstance(ALS_EMBS, np.ndarray) and ALS_EMBS.ndim == 2:
    n_cells, emb_dim = ALS_EMBS.shape
    print("Embeddings:", n_cells, "cells,", emb_dim, "dims")

    # Make sure adata cells are in the same order as the dataset; if not, you might
    # need to align using cell_ids from the HF dataset. For now we assume same order.
    meta = adata.obs.copy()
    if meta.shape[0] != n_cells:
        print("⚠️ mismatch between adata.obs rows and embedding rows!")
    else:
        emb_df = pd.DataFrame(ALS_EMBS, index=meta.index)
        emb_df.columns = [f"emb_{i}" for i in range(emb_dim)]

        # Concatenate some key metadata
        meta_cols = ["Condition", "Group", "CellClass", "CellType", "SubType", "split"]
        meta_cols = [c for c in meta_cols if c in meta.columns]

        combined = pd.concat([meta[meta_cols], emb_df], axis=1)
        print("Combined DF shape:", combined.shape)
        display(combined.head())
else:
    print("ALS_EMBS is not a 2D embedding matrix; adjust loader or check EmbExtractor output.")


ALS_EMBS is not a 2D embedding matrix; adjust loader or check EmbExtractor output.


In [12]:
import numpy as np
import pandas as pd

# Make sure we have the embeddings DataFrame from the CSV
csv_path = EMB_DIR / "ALS_unperturbed.csv"
ALS_EMBS_DF = pd.read_csv(csv_path, index_col=0)
print("ALS_EMBS_DF shape:", ALS_EMBS_DF.shape)

# Turn it into a NumPy array
ALS_EMBS = ALS_EMBS_DF.values
n_cells, emb_dim = ALS_EMBS.shape
print("Embeddings:", n_cells, "cells,", emb_dim, "dims")

# Sanity check: should match number of cells in your AnnData
print("AnnData cells:", adata.n_obs)
if n_cells != adata.n_obs:
    print("⚠️ mismatch between embeddings rows and AnnData cells!")

# Build a DataFrame that combines metadata + embeddings
meta = adata.obs.copy()

emb_cols = [f"emb_{i}" for i in range(emb_dim)]
emb_df = pd.DataFrame(ALS_EMBS, index=meta.index, columns=emb_cols)

# pick some metadata columns you care about
meta_cols = [
    "Condition", "Group", "CellClass", "CellType", "SubType", "split"
]
meta_cols = [c for c in meta_cols if c in meta.columns]

combined = pd.concat([meta[meta_cols], emb_df], axis=1)
print("Combined DF shape:", combined.shape)
display(combined.head())


ALS_EMBS_DF shape: (1000, 768)
Embeddings: 1000 cells, 768 dims
AnnData cells: 1000
Combined DF shape: (1000, 774)


Unnamed: 0_level_0,Condition,Group,CellClass,CellType,SubType,split,emb_0,emb_1,emb_2,emb_3,...,emb_758,emb_759,emb_760,emb_761,emb_762,emb_763,emb_764,emb_765,emb_766,emb_767
Barcode,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
GTCGTTCTCTGTGCAA-118MCX,ALS,SALS,Ex,L6,TLE4_SEMA3D,train,0.024868,-0.118322,-0.104733,-0.228727,...,-0.528719,-0.137115,-0.114696,0.110082,-0.07881,-0.241094,0.440791,-0.215314,-0.055313,0.032509
GGGCTCATCTGGGAGA-126MCX,ALS,SALS,In,PV,PVALB_CEMIP,train,0.04346,-0.142153,-0.17509,-0.25471,...,-0.470818,-0.211254,-0.121203,0.100071,-0.06472,-0.207777,0.470039,-0.184062,-0.06832,0.024948
GTAACACTCGTCCATC-303MCX,PN,PN,Ex,L5_L6,THEMIS_TMEM233,train,0.038716,-0.10206,-0.108188,-0.231725,...,-0.489459,-0.151722,-0.086471,0.100692,-0.091844,-0.229847,0.489524,-0.214918,-0.053846,0.075414
AGCCAGCGTCGTTCAA-116MCX,ALS,SALS,Ex,L6,TLE4_MEGF11,train,0.034691,-0.128403,-0.140141,-0.236448,...,-0.470175,-0.187985,-0.10419,0.077863,-0.083059,-0.215005,0.463036,-0.218762,-0.046391,0.048188
ATTCGTTTCAAGCTGT-309MCX,PN,PN,Ex,L2_L3,CUX2_RASGRF2,train,0.059822,-0.109191,-0.136661,-0.223937,...,-0.439917,-0.166989,-0.095249,0.100612,-0.098669,-0.227109,0.474205,-0.201083,-0.051931,0.044647


In [13]:
#save combined DataFrame for later use
out_csv = EMB_DIR / "ALS_unperturbed_with_meta.csv"
combined.to_csv(out_csv)
print("Saved:", out_csv)


Saved: C:\Users\ratne\Downloads\Helical_Challenge\results\embeddings\ALS_unperturbed_with_meta.csv


##### Debugging

In [3]:
from pathlib import Path
import pickle

kd_file = next(ISP_DIR.glob("in_silico_delete_ALS_KD_b001_cell_embs_dict_[8090, 5072, 1825, 8907, 5307]_raw.pickle"))
with open(kd_file, "rb") as fh:
    kd_dict = pickle.load(fh)

print("KD keys:", kd_dict.keys())
for k in kd_dict.keys():
    print(k, type(kd_dict[k]), getattr(kd_dict[k], "shape", None))


StopIteration: 