In [3]:
import scanpy as sc
import functools
import os
import sys
import traceback
from typing import Dict, Literal, Optional, Tuple
import cfp
import scanpy as sc
import numpy as np
import functools
from ott.solvers import utils as solver_utils
import optax
from omegaconf import OmegaConf
from typing import NamedTuple, Any
import hydra
import wandb
import anndata as ad
import pandas as pd
import os
from cfp.training import ComputationCallback
from cfp.preprocessing import transfer_labels, compute_wknn
from cfp.training import ComputationCallback
from numpy.typing import ArrayLike
from cfp.metrics import compute_r_squared, compute_e_distance


In [2]:
adata = sc.read_h5ad("/lustre/groups/ml01/workspace/ot_perturbation/data/zebrafish_new/zebrafish_processed.h5ad")

In [20]:
def compute_metrics(adata_ref: ad.AnnData, adata_pred: ad.AnnData, adata_ood_true: ad.AnnData, cell_types_subsets: dict[str, list] = {}, n_neighbors: int=1, cell_type_col: str = "cell_type_broad", min_cells_for_dist_metrics: int = 50) -> dict:
    dict_to_log = {}
    compute_wknn(ref_adata=adata_ref, query_adata=adata_ood_true, n_neighbors=n_neighbors, ref_rep_key="X_aligned", query_rep_key="X_aligned")
    transfer_labels(query_adata=adata_ood_true, ref_adata=adata_ref, label_key=cell_type_col)
    ct_transferred_true = adata_ood_true.obs["cell_type_broad_transfer"].value_counts().to_frame()
    ct_transferred_true = ct_transferred_true / ct_transferred_true.sum()
    
    compute_wknn(ref_adata=adata_ref, query_adata=adata_pred, n_neighbors=n_neighbors, ref_rep_key="X_aligned", query_rep_key="X_aligned")
    transfer_labels(query_adata=adata_pred, ref_adata=adata_ref, label_key=cell_type_col)
    ct_transferred_pred = adata_pred.obs[f"{cell_type_col}_transfer"].value_counts().to_frame()
    ct_transferred_pred/=ct_transferred_pred.sum()
    shared_cell_types = list(set(ct_transferred_true.index).intersection(set(ct_transferred_pred.index)))
    cell_type_fraction_error = np.abs(ct_transferred_true.loc[shared_cell_types] - ct_transferred_pred.loc[shared_cell_types]).sum()

    cell_types_subsets_fraction_error = {}
    if len(cell_types_subsets) > 0:
        
        for cts_name, cts in cell_types_subsets.items():
            try:
                cell_types = list(set(shared_cell_types).intersection(cts))
                ct_true_normalized = ct_transferred_true.loc[cell_types]/ct_transferred_true.loc[cell_types].sum().values()
                ct_pred_normalized = ct_transferred_pred.loc[cell_types]/ct_transferred_pred.loc[cell_types].sum().values()
                cell_types_subsets_fraction_error[cts_name] = np.abs(ct_true_normalized - ct_pred_normalized).sum()
            except:
                continue
    all_cell_types = list(adata_ood_true.obs["cell_type_broad"].value_counts()[adata_ood_true.obs["cell_type_broad"].value_counts()>min_cells_for_dist_metrics].index)

    e_distance = {}
    r_sq = {}
    n_cell_types_covered = 0
    for cell_type in all_cell_types: 
        dist_true = adata_ood_true[adata_ood_true.obs["cell_type_broad"]==cell_type].obsm["X_aligned"]
        dist_pred = adata_pred[adata_pred.obs[f"{cell_type_col}_transfer"]==cell_type].obsm["X_aligned"]
        if len(dist_pred) == 0:
            continue
        n_cell_types_covered+=1
        r_sq[f"r_squared_{cell_type}"] = compute_r_squared(dist_true, dist_pred)
        e_distance[f"e_distance_{cell_type}"] = compute_e_distance(dist_true, dist_pred)

    fraction_cell_types_covered = n_cell_types_covered/len(all_cell_types)
    
    # metrics to return
    dict_to_log["fraction_cell_types_covered"] = fraction_cell_types_covered
    dict_to_log["cell_type_fraction_error"] = cell_type_fraction_error.values[0]
    dict_to_log["mean_r_sq_per_cell_type"] = np.mean(list(r_sq.values()))
    dict_to_log["mean_e_distance_per_cell_type"] = np.mean(list(e_distance.values()))
    dict_to_log["median_r_sq_per_cell_type"] = np.median(list(r_sq.values()))
    dict_to_log["median_e_distance_per_cell_type"] = np.median(list(e_distance.values()))
    dict_to_log.update(cell_types_subsets_fraction_error)
    dict_to_log.update(r_sq)
    dict_to_log.update(e_distance)
    return dict_to_log
    

In [21]:
ood_conds = ['epha4a_control_36']
cell_types_subsets = {}
ood_cond_results = {}

for ood_cond in ood_conds:
    adata_ood_true = adata[adata.obs["condition"]==ood_cond]
    tp = int(ood_cond.split("_")[-1])
    adata_ood_pred = adata[(adata.obs["is_control"]) & (adata.obs["timepoint"]==tp)]
    ood_cond_results[ood_cond] = compute_metrics(adata_ref=adata, adata_pred=adata_ood_pred, adata_ood_true=adata_ood_true, cell_types_subsets=cell_types_subsets)

  query_adata.obs[f"{label_key}_transfer"] = scores.idxmax(1)
  query_adata.obs[f"{label_key}_transfer"] = scores.idxmax(1)


In [34]:

out_dir = "./"
pd.DataFrame.from_dict(ood_cond_results[ood_cond], columns=[ood_cond], orient="index").to_csv(os.path.join(out_dir, f"{ood_cond}_identity.csv"))
