# Task 3 (ISP): Interpret Cosine Shifts from Geneformer InSilicoPerturber

Aggregates cosine-similarity outputs written by ISP and ranks perturbations.



In [None]:
# Locate ISP outputs and list files
from pathlib import Path
import sys, os, pickle
import numpy as np
import pandas as pd
repo_root = Path.cwd().parent
DATA_DIR = repo_root / 'als-perturb-geneformer' / 'als-perturb-geneformer' / 'data'
ISP_DIR = DATA_DIR / 'isp'
print('ISP dir:', ISP_DIR)
print('ALS_down files:', sorted((ISP_DIR/'als_down').glob('*'))[:5], '...')
print('Healthy_up files:', sorted((ISP_DIR/'healthy_up').glob('*'))[:5], '...')



In [None]:
# Aggregate cosine similarity dictionaries written by ISP
import re
from collections import defaultdict

def read_pickles(dir_path: Path):
    vals = []
    for f in sorted(dir_path.glob('*.pkl')):
        try:
            with open(f, 'rb') as fh:
                vals.append(pickle.load(fh))
        except Exception:
            pass
    return vals

# Heuristic: accumulate mean of values under keys that end with ('cell_emb')
# ISP writes dicts keyed like (gene(s), 'cell_emb') -> list[float]

def aggregate_cell_emb_cosines(pkls):
    agg = defaultdict(list)
    for d in pkls:
        for k, v in d.items():
            if isinstance(k, tuple) and len(k) == 2 and k[1] == 'cell_emb':
                arr = np.array(v, dtype=float)
                if arr.size > 0:
                    agg[k[0]].append(arr.mean())
    out = []
    for k, arrs in agg.items():
        out.append({'perturbed': str(k), 'mean_cosine': float(np.mean(arrs))})
    return pd.DataFrame(out)

als_pkls = read_pickles(ISP_DIR/'als_down')
h_pkls = read_pickles(ISP_DIR/'healthy_up')
als_df = aggregate_cell_emb_cosines(als_pkls)
h_df = aggregate_cell_emb_cosines(h_pkls)

# Lower cosine (further from original) indicates larger shift; sign is directional-less here
als_df = als_df.sort_values('mean_cosine')
h_df = h_df.sort_values('mean_cosine')

print('ALS_down (lower cosine = bigger change):')
display(als_df.head(10))
print('\nHealthy_up (lower cosine = bigger change):')
display(h_df.head(10))

