In [3]:
import anndata as ad
from pathlib import Path

ew: 
but in any case de_df[["gene_col", "de_embedding"]].to_csv()

In [8]:
import numpy as np
import pandas as pd
from anndata import AnnData
from pdex import parallel_differential_expression
from tqdm import tqdm


def compute_de_labels(adata: AnnData,
                           perturb_col: str = "target_gene",
                           cell_type_col: str = "cell_type",
                           control_var: str = "non-targeting",
                           alpha: float = 0.05) -> None:
    """
    Precompute DE gene labels (+1, -1, 0) using pdex.
    Stores results in adata.uns['de_labels'].

    Parameters
    ----------
    adata : AnnData
        Input AnnData object.
    perturb_col : str
        Column in obs indicating perturbation/target gene.
    cell_type_col : str
        Column in obs indicating cell type.
    control_var : str
        Name of the control group in perturb_col.
    alpha : float
        FDR threshold for significance.
    """



    cell_types = adata.obs[cell_type_col].unique()
    target_genes = adata.obs[perturb_col].unique()

    result = {}

    # iterate over cell types
    for ct in tqdm(cell_types, desc="Computing DE labels with pdex"):
        # subset AnnData to current cell type
        adata_ct = adata[adata.obs[cell_type_col] == ct].copy()
        
        # Run pdex differential expression
        de_df = parallel_differential_expression(
            adata_ct,
            reference=control_var,
            groupby_key=perturb_col,
            metric="wilcoxon",  # default
            num_workers=20,
            batch_size=100,
        )

        # if "fdr" < 0.05 and "percent_change" > 0 set to 1, if "fdr" < 0.05 and "percent_change" < 0 set to -1, else 0
        de_df["de_embedding"] = de_df.apply(
            lambda row: 1 if row["fdr"] < alpha and row["percent_change"] > 0 else (-1 if row["fdr"] < alpha and row["percent_change"] < 0 else 0),
            axis=1
        )

        # Create a DataFrame from the DE results

        
        for target in target_genes:
            # filter for current target gene
            de_target = de_df[de_df["target"] == target]
            if not de_target.empty:
                result[f"{ct}_{target}"] = de_target["de_embedding"].values
            else:
                result[f"{ct}_{target}"] = 0

        
        #print(de_df[["gene_col", "de_embedding"]])
        

    adata.uns['de_labels'] = result
    print(f"Stored DE results in adata.uns['de_labels'] with {len(adata.uns['de_labels'])} entries.")

    

          


In [None]:
path="/buffer/ag_bsc/pmsb_workflows_2025/team4_ensemble_assembly/robin/arc-state/state/datasets/base_dataset/hepg2.h5"
adata = ad.read_h5ad(path)#we should be able to load this faster, by using the .csv instead! 

#if 'de_labels' not in adata.uns:
#    compute_de_labels(adata, perturb_col="target_gene", cell_type_col="cell_type", control_var="non-targeting", alpha=0.05)
#print(adata.uns[["gene_col", "de_embedding"]])

de_df = pd.DataFrame(adata.uns['de_labels']).T.reset_index()
de_df.columns = ['cell_type_target_gene', 'de_embedding']
print(de_df.head())

TypeError: unhashable type: 'list'

In [15]:

path_to_test="/buffer/ag_bsc/pmsb_workflows_2025/team4_ensemble_assembly/robin/arc-state/state/datasets/base_dataset/DES_test/"
path_to_compare="/buffer/ag_bsc/pmsb_workflows_2025/team4_ensemble_assembly/robin/arc-state/state/datasets/base_dataset/DES/"

# collect csv files from both directories
test_dir = Path(path_to_test)
compare_dir = Path(path_to_compare)

test_files = {p.name: p for p in test_dir.glob("*.csv")}
compare_files = {p.name: p for p in compare_dir.glob("*.csv")}

print(f"Test dir: {test_dir} -> {len(test_files)} CSV files")
print(f"Compare dir: {compare_dir} -> {len(compare_files)} CSV files")

test_names = set(test_files.keys())
compare_names = set(compare_files.keys())

def _normalize(fname: str) -> str:
    # remove leading "hepg2_" or "hepg_" if present
    if fname.startswith("hepg2_"):
        return fname[len("hepg2_"):]
    if fname.startswith("hepg_"):
        return fname[len("hepg_"):]
    return fname

# rebuild mappings keyed by normalized names (e.g. "ARPC2.csv")
_new_test_files = {}
for name, p in test_files.items():
    nk = _normalize(name)
    if nk in _new_test_files:
        print(f"Warning: duplicate normalized test name {nk}; keeping first ({_new_test_files[nk].name})")
    else:
        _new_test_files[nk] = p

_new_compare_files = {}
for name, p in compare_files.items():
    nk = _normalize(name)
    if nk in _new_compare_files:
        print(f"Warning: duplicate normalized compare name {nk}; keeping first ({_new_compare_files[nk].name})")
    else:
        _new_compare_files[nk] = p

# replace original dicts and name sets with normalized-key versions
test_files = _new_test_files
compare_files = _new_compare_files
test_names = set(test_files.keys())
compare_names = set(compare_files.keys())

print(f"After normalizing prefixes -> test files: {len(test_files)}, compare files: {len(compare_files)}")

common_names = sorted(test_names & compare_names)
only_in_test = sorted(test_names - compare_names)
only_in_compare = sorted(compare_names - test_names)

print(f"Common files: {len(common_names)}")
if only_in_test:
    print(f"Files only in test ({len(only_in_test)}): {only_in_test}")
if only_in_compare:
    print(f"Files only in compare ({len(only_in_compare)}): {only_in_compare}")

# compare contents for common files
for name in common_names:
    p_test = test_files[name]
    p_cmp = compare_files[name]

    # fast byte-wise check first
    try:
        if p_test.read_bytes() == p_cmp.read_bytes():
            print(f"{name}: IDENTICAL (byte-wise)")
            continue
    except Exception as e:
        print(f"{name}: error reading bytes: {e}")

    # fall back to reading as DataFrame and compare
    try:
        df_test = pd.read_csv(p_test)
        df_cmp = pd.read_csv(p_cmp)
    except Exception as e:
        print(f"{name}: ERROR reading CSVs as DataFrame: {e}")
        continue

    if df_test.equals(df_cmp):
        print(f"{name}: IDENTICAL (DataFrame.equals)")
        continue

    # summarize differences
    diffs = []
    if df_test.shape != df_cmp.shape:
        diffs.append(f"shape differs: test {df_test.shape} vs compare {df_cmp.shape}")
    if list(df_test.columns) != list(df_cmp.columns):
        diffs.append("columns differ")

    # compute a simple cell-wise diff for common columns and rows
    common_cols = [c for c in df_test.columns if c in df_cmp.columns]
    if common_cols:
        a = df_test[common_cols].reset_index(drop=True).fillna("__NA__").astype(str)
        b = df_cmp[common_cols].reset_index(drop=True).fillna("__NA__").astype(str)
        rows = min(len(a), len(b))
        if rows > 0:
            comp = (a.iloc[:rows] != b.iloc[:rows])
            n_diff_cells = int(comp.values.sum())
            diffs.append(f"{n_diff_cells} differing cells in first {rows} rows across {len(common_cols)} common columns")
        else:
            diffs.append("no overlapping rows to compare")
    else:
        diffs.append("no common columns to compare")

    print(f"{name}: DIFFER ({'; '.join(diffs)})")




Test dir: /buffer/ag_bsc/pmsb_workflows_2025/team4_ensemble_assembly/robin/arc-state/state/datasets/base_dataset/DES_test -> 70 CSV files
Compare dir: /buffer/ag_bsc/pmsb_workflows_2025/team4_ensemble_assembly/robin/arc-state/state/datasets/base_dataset/DES -> 68 CSV files
After normalizing prefixes -> test files: 70, compare files: 68
Common files: 68
Files only in test (2): ['non-targeting.csv', '{perties}.csv']
ARPC2.csv: DIFFER (shape differs: test (5, 12) vs compare (0, 6311); columns differ; no common columns to compare)
ATP6V0B.csv: DIFFER (shape differs: test (261, 12) vs compare (0, 5417); columns differ; no common columns to compare)
ATP6V0C.csv: DIFFER (shape differs: test (448, 12) vs compare (0, 5638); columns differ; no common columns to compare)
C1QBP.csv: DIFFER (shape differs: test (18, 12) vs compare (0, 4779); columns differ; no common columns to compare)
CAST.csv: DIFFER (shape differs: test (14, 12) vs compare (0, 4779); columns differ; no common columns to compare

In [13]:
import numpy as np
import pandas as pd
from anndata import AnnData
from tqdm import tqdm

def compute_expression_ranks(
    adata: AnnData,
    groupby_key: str = "target_gene",
    celltype_key: str = "cell_type",
) -> None:
    """
    Precompute gene expression ranks for each (cell_type, target_gene).
    Stores results in adata.uns['expr_ranks'].

    Each entry is a vector of length n_vars where
        rank[i] = rank of gene i's mean expression within that subset.
    """
    results = {}

    cell_types = adata.obs[celltype_key].unique()
    target_genes = adata.obs[groupby_key].unique()

    # single progress bar over all combinations
    combos = [(ct, tg) for ct in cell_types for tg in target_genes]
    for ct, tg in tqdm(combos, desc="Computing expression ranks", total=len(combos)):
        # subset to matching cells
        mask = (adata.obs[celltype_key] == ct) & (adata.obs[groupby_key] == tg)
        if mask.sum() == 0:
            continue

        subset = adata[mask]

        # mean expression across cells for each gene
        mean_expr = np.asarray(subset.X.mean(axis=0)).ravel()

        # compute ranks (highest expression = rank 1)
        ranks = pd.Series(mean_expr).rank(method="first", ascending=False).to_numpy()

        results[f"{ct}_{tg}"] = ranks

    adata.uns["rank_embedding"] = results
    print(f"Stored rank vectors in adata.uns['rank_embedding'] with {len(results)} entries.")


In [10]:
data_dir = Path("/raid/kreid/v_cell/competition_support_set")

In [11]:
anndata_paths = [file for file in data_dir.glob("*.h5")]

In [12]:
for path in anndata_paths:
    print(f"Processing {path.name}")
    adata = ad.read_h5ad(path)

    # Compute DE labels
    compute_de_labels(adata, perturb_col="target_gene", cell_type_col="cell_type", control_var="non-targeting", alpha=0.05)

    # Compute expression ranks
    compute_expression_ranks(adata, groupby_key="target_gene", celltype_key="cell_type")

    # Save updated AnnData
    adata.write_h5ad(path)  # Overwrite original file or save to a new path if needed

Processing hepg2.h5


Computing DE labels with pdex:   0%|          | 0/1 [00:00<?, ?it/s]INFO:pdex._single_cell:Precomputing masks for each target gene
Identifying target masks: 100%|██████████| 69/69 [00:00<00:00, 50044.44it/s]
INFO:pdex._single_cell:Precomputing variable indices for each feature
Identifying variable indices: 100%|██████████| 18080/18080 [00:00<00:00, 8518649.33it/s]
INFO:pdex._single_cell:Creating shared memory memory matrix for parallel computing
INFO:pdex._single_cell:Creating generator of all combinations: N=1247520
INFO:pdex._single_cell:Creating generator of all batches: N=12476
INFO:pdex._single_cell:Initializing parallel processing pool
INFO:pdex._single_cell:Processing batches
Processing batches: 100%|██████████| 12476/12476 [00:39<00:00, 318.27it/s]
INFO:pdex._single_cell:Flattening results
INFO:pdex._single_cell:Closing shared memory pool
Computing DE labels with pdex: 100%|██████████| 1/1 [00:44<00:00, 44.71s/it]


Stored DE results in adata.uns['de_labels'] with 69 entries.


Computing expression ranks: 100%|██████████| 69/69 [00:00<00:00, 377.24it/s]


Stored rank vectors in adata.uns['rank_embedding'] with 69 entries.
Processing k562.h5


Computing DE labels with pdex:   0%|          | 0/1 [00:00<?, ?it/s]INFO:pdex._single_cell:Precomputing masks for each target gene
Identifying target masks: 100%|██████████| 54/54 [00:00<00:00, 59028.52it/s]
INFO:pdex._single_cell:Precomputing variable indices for each feature
Identifying variable indices: 100%|██████████| 18080/18080 [00:00<00:00, 7914111.49it/s]
INFO:pdex._single_cell:Creating shared memory memory matrix for parallel computing
INFO:pdex._single_cell:Creating generator of all combinations: N=976320
INFO:pdex._single_cell:Creating generator of all batches: N=9764
INFO:pdex._single_cell:Initializing parallel processing pool
INFO:pdex._single_cell:Processing batches
Processing batches: 100%|██████████| 9764/9764 [00:53<00:00, 183.44it/s]
INFO:pdex._single_cell:Flattening results
INFO:pdex._single_cell:Closing shared memory pool
Computing DE labels with pdex: 100%|██████████| 1/1 [00:57<00:00, 57.99s/it]


Stored DE results in adata.uns['de_labels'] with 54 entries.


Computing expression ranks: 100%|██████████| 54/54 [00:00<00:00, 237.65it/s]


Stored rank vectors in adata.uns['rank_embedding'] with 54 entries.
Processing competition_train.h5


Computing DE labels with pdex:   0%|          | 0/1 [00:00<?, ?it/s]INFO:pdex._single_cell:Precomputing masks for each target gene
Identifying target masks: 100%|██████████| 151/151 [00:00<00:00, 36957.45it/s]
INFO:pdex._single_cell:Precomputing variable indices for each feature
Identifying variable indices: 100%|██████████| 18080/18080 [00:00<00:00, 6746709.64it/s]
INFO:pdex._single_cell:Creating shared memory memory matrix for parallel computing
INFO:pdex._single_cell:Creating generator of all combinations: N=2730080
INFO:pdex._single_cell:Creating generator of all batches: N=27301
INFO:pdex._single_cell:Initializing parallel processing pool
INFO:pdex._single_cell:Processing batches
Processing batches: 100%|██████████| 27301/27301 [28:05<00:00, 16.20it/s]
INFO:pdex._single_cell:Flattening results
INFO:pdex._single_cell:Closing shared memory pool
Computing DE labels with pdex: 100%|██████████| 1/1 [28:31<00:00, 1711.98s/it]


Stored DE results in adata.uns['de_labels'] with 151 entries.


Computing expression ranks: 100%|██████████| 151/151 [00:06<00:00, 25.14it/s]


Stored rank vectors in adata.uns['rank_embedding'] with 151 entries.
Processing k562_gwps.h5


Computing DE labels with pdex:   0%|          | 0/1 [00:00<?, ?it/s]INFO:pdex._single_cell:Precomputing masks for each target gene
Identifying target masks: 100%|██████████| 184/184 [00:00<00:00, 39028.62it/s]
INFO:pdex._single_cell:Precomputing variable indices for each feature
Identifying variable indices: 100%|██████████| 18080/18080 [00:00<00:00, 5793644.76it/s]
INFO:pdex._single_cell:Creating shared memory memory matrix for parallel computing
INFO:pdex._single_cell:Creating generator of all combinations: N=3326720
INFO:pdex._single_cell:Creating generator of all batches: N=33268
INFO:pdex._single_cell:Initializing parallel processing pool
INFO:pdex._single_cell:Processing batches
Processing batches: 100%|██████████| 33268/33268 [34:42<00:00, 15.97it/s]
INFO:pdex._single_cell:Flattening results
INFO:pdex._single_cell:Closing shared memory pool
Computing DE labels with pdex: 100%|██████████| 1/1 [35:06<00:00, 2106.88s/it]


Stored DE results in adata.uns['de_labels'] with 184 entries.


Computing expression ranks: 100%|██████████| 184/184 [00:01<00:00, 151.55it/s]


Stored rank vectors in adata.uns['rank_embedding'] with 184 entries.
Processing jurkat.h5


Computing DE labels with pdex:   0%|          | 0/1 [00:00<?, ?it/s]INFO:pdex._single_cell:Precomputing masks for each target gene
Identifying target masks: 100%|██████████| 69/69 [00:00<00:00, 52173.60it/s]
INFO:pdex._single_cell:Precomputing variable indices for each feature
Identifying variable indices: 100%|██████████| 18080/18080 [00:00<00:00, 5110730.31it/s]
INFO:pdex._single_cell:Creating shared memory memory matrix for parallel computing
INFO:pdex._single_cell:Creating generator of all combinations: N=1247520
INFO:pdex._single_cell:Creating generator of all batches: N=12476
INFO:pdex._single_cell:Initializing parallel processing pool
INFO:pdex._single_cell:Processing batches
Processing batches: 100%|██████████| 12476/12476 [01:18<00:00, 158.66it/s]
INFO:pdex._single_cell:Flattening results
INFO:pdex._single_cell:Closing shared memory pool
Computing DE labels with pdex: 100%|██████████| 1/1 [01:25<00:00, 85.09s/it]


Stored DE results in adata.uns['de_labels'] with 69 entries.


Computing expression ranks: 100%|██████████| 69/69 [00:00<00:00, 245.05it/s]


Stored rank vectors in adata.uns['rank_embedding'] with 69 entries.
Processing rpe1.h5


Computing DE labels with pdex:   0%|          | 0/1 [00:00<?, ?it/s]INFO:pdex._single_cell:Precomputing masks for each target gene
Identifying target masks: 100%|██████████| 69/69 [00:00<00:00, 44897.14it/s]
INFO:pdex._single_cell:Precomputing variable indices for each feature
Identifying variable indices: 100%|██████████| 18080/18080 [00:00<00:00, 4839993.38it/s]
INFO:pdex._single_cell:Creating shared memory memory matrix for parallel computing
INFO:pdex._single_cell:Creating generator of all combinations: N=1247520
INFO:pdex._single_cell:Creating generator of all batches: N=12476
INFO:pdex._single_cell:Initializing parallel processing pool
INFO:pdex._single_cell:Processing batches
Processing batches: 100%|██████████| 12476/12476 [01:18<00:00, 158.08it/s]
INFO:pdex._single_cell:Flattening results
INFO:pdex._single_cell:Closing shared memory pool
Computing DE labels with pdex: 100%|██████████| 1/1 [01:25<00:00, 85.50s/it]


Stored DE results in adata.uns['de_labels'] with 69 entries.


Computing expression ranks: 100%|██████████| 69/69 [00:00<00:00, 236.53it/s]


Stored rank vectors in adata.uns['rank_embedding'] with 69 entries.
