In [2]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
!nvidia-smi

Sat Nov 30 12:02:38 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03              Driver Version: 560.35.03      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla V100-SXM3-32GB           Off |   00000000:4C:00.0 Off |                    0 |
| N/A   29C    P0             48W /  350W |       1MiB /  32768MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
import os
import ast
import jax
import numpy as np
import pandas as pd
import anndata as ad
import scanpy as sc

from tqdm.auto import tqdm

import cfp.preprocessing as cfpp
from cfp.metrics import compute_metrics, compute_mean_metrics, compute_metrics_fast

In [None]:
def get_mask(x, y):
    return x[:, [gene in y for gene in adata_train.var_names]]

In [None]:
def extract_metainfo(file_path):
    """
    Extracts the configuration dictionary, results file path, split index, 
    and wandb run name from the log file and returns them in a dictionary.

    Args:
        file_path (str): Path to the log file.

    Returns:
        dict: A dictionary containing the extracted information:
            - 'config': The configuration dictionary parsed from the log file.
            - 'results_path': The path of the results file.
            - 'split_index': The split index as an integer.
            - 'wandb_run_name': The wandb run name as a string.
    """
    with open(file_path, "r") as file:
        lines = file.readlines()

    # Initialize the metainfo dictionary
    metainfo = {
        "config": None,
        "path_predictions": None,
        "split_index": None,
        "wandb_run_name": None
    }

    # Extract the third line containing the config dictionary
    config_line = lines[2].strip()

    # Parse the configuration dictionary
    try:
        metainfo["config"] = ast.literal_eval(config_line)
    except (SyntaxError, ValueError) as e:
        raise ValueError("Failed to parse configuration dictionary.") from e

    # Extract the results file path and wandb run name
    for line in lines:
        if "Saving results at:" in line:
            metainfo["path_predictions"] = line.split("Saving results at:")[-1].strip()
        # if "🚀 View run" in line:
        #     metainfo["wandb_run_name"] = line.split("View run")[1].split("at:")[0].strip()
        if "🚀 View run" in line:
            # Extract the run name and remove any icons or extra spaces
            raw_run_name = line.split("View run")[1].split("at:")[0].strip()
            metainfo["wandb_run_name"] = raw_run_name.replace("\x1b[33m", "").replace("\x1b[0m", "").strip()

    if not metainfo["path_predictions"]:
        raise ValueError("Results path not found in the log file.")
    if not metainfo["wandb_run_name"]:
        raise ValueError("wandb run name not found in the log file.")

    # Extract the split index from the first line
    for line in lines:
        if line.startswith("split:"):
            try:
                metainfo["split_index"] = int(line.split(":")[-1].strip())
            except ValueError:
                raise ValueError("Failed to parse the split index.")
            break

    if metainfo["split_index"] is None:
        raise ValueError("Split index not found in the log file.")

    return metainfo

# # Example usage
# log_file_path = "path_to_your_log_file.txt"
# metainfo = extract_metainfo(log_file_path)

# # Print the results
# print("Metainfo:", metainfo)


In [None]:
path_log_file = "/home/haicu/soeren.becker/repos/ot_pert_reproducibility/runs_otfm/bash_scripts/h-otfm-norman_29797689.out"

In [None]:
metainfo = extract_metainfo(path_log_file)
config = metainfo["config"]
path_predictions = metainfo["path_predictions"]
split = metainfo["split_index"]
wandb_run_name = metainfo["wandb_run_name"]
print("wandb_run_name", wandb_run_name)
print("split", split)
print("path_predictions", path_predictions)
print("config", config)
assert split == config["dataset"]["split"]


In [None]:
wandb_run_name

In [None]:
DATA_DIR = "/home/haicu/soeren.becker/repos/ot_pert_reproducibility/norman2019/norman_preprocessed_adata"

adata_train_path = os.path.join(DATA_DIR, f"adata_train_pca_50_split_{split}.h5ad")
adata_test_path = os.path.join(DATA_DIR, f"adata_val_pca_50_split_{split}.h5ad")
adata_ood_path = os.path.join(DATA_DIR, f"adata_test_pca_50_split_{split}.h5ad")

# load data splits
adata_train = sc.read(adata_train_path)
adata_test = sc.read(adata_test_path)
adata_ood = sc.read(adata_ood_path)

In [None]:
# adata_pred_ood = sc.read_h5ad(f"/lustre/groups/ml01/workspace/ot_perturbation/data/norman_soren/cellflow/out/solar-pine-515_adata_test_with_predictions_0.h5ad")
# path_predictions = f"/lustre/groups/ml01/workspace/ot_perturbation/data/norman_soren/cellflow/out/astral-water-224_adata_test_with_predictions_0.h5ad"
adata_pred_ood = sc.read_h5ad(path_predictions)
adata_pred_ood.obs.loc[:, ["gene_1", "gene_2"]] = adata_pred_ood.obs.condition.str.split("+", expand=True).rename({0: "gene_1", 1: "gene_2"}, axis=1).values
adata_pred_ood.X = adata_pred_ood.layers['X_recon_pred']

In [None]:
adata_pred_ood.X.max(),  adata_ood.X.max(), adata_train.X.max(), adata_test.X.max()

In [None]:
# compute pca on full dataset
adata_all = ad.concat((adata_train, adata_test, adata_ood))
cfpp.centered_pca(adata_all, n_comps=10)

In [None]:
cfpp.project_pca(query_adata=adata_pred_ood, ref_adata=adata_all)
cfpp.project_pca(query_adata=adata_ood, ref_adata=adata_all)

In [None]:
def add_subgroup_annotations(adata_train, adata): 

    train_conditions = adata_train.obs.condition.str.replace("+ctrl", "").str.replace("ctrl+", "").unique()

    assert not adata[adata.obs.condition != "ctrl"].obs.condition.isin(train_conditions).any()

    mask_single_perturbation = adata.obs.condition.str.contains("ctrl")
    mask_double_perturbation_seen_0 = (
        ~adata.obs.condition.str.contains("ctrl") & 
        ~adata.obs.gene_1.isin(train_conditions) & 
        ~adata.obs.gene_2.isin(train_conditions)
    )
    mask_double_perturbation_seen_1 = (
        ~adata.obs.condition.str.contains("ctrl") & 
        (
            (adata.obs.gene_1.isin(train_conditions) & ~adata.obs.gene_2.isin(train_conditions)) | 
            (~adata.obs.gene_1.isin(train_conditions) & adata.obs.gene_2.isin(train_conditions))
        )
    )
    mask_double_perturbation_seen_2 = (
        ~adata.obs.condition.str.contains("ctrl") & 
        adata.obs.gene_1.isin(train_conditions) & 
        adata.obs.gene_2.isin(train_conditions)
    )
    adata.obs.loc[mask_single_perturbation, "subgroup"] = "single"
    adata.obs.loc[mask_double_perturbation_seen_0, "subgroup"] = "double_seen_0"
    adata.obs.loc[mask_double_perturbation_seen_1, "subgroup"] = "double_seen_1"
    adata.obs.loc[mask_double_perturbation_seen_2, "subgroup"] = "double_seen_2"

add_subgroup_annotations(adata_train, adata_ood)
add_subgroup_annotations(adata_train, adata_pred_ood)

display(adata_ood.obs.subgroup.value_counts())
display(adata_pred_ood.obs.subgroup.value_counts())

In [None]:
ood_data_target_encoded, ood_data_target_decoded = {}, {}
ood_data_target_encoded_predicted, ood_data_target_decoded_predicted = {}, {}

subgroups = ["single", "double_seen_0", "double_seen_1", "double_seen_2"]

for subgroup in tqdm(subgroups):

    ood_data_target_encoded[subgroup] = {}
    ood_data_target_decoded[subgroup] = {}
    ood_data_target_encoded_predicted[subgroup] = {}
    ood_data_target_decoded_predicted[subgroup] = {}
    
    for cond in adata_ood.obs["condition"].cat.categories:
        if cond == "ctrl":
            continue
        
        select = adata_ood.obs["condition"] == cond
        select_pred = adata_pred_ood.obs["condition"] == cond

        if subgroup != "all":
            select = select & (adata_ood.obs.subgroup == subgroup)
            select_pred = select_pred & (adata_pred_ood.obs.subgroup == subgroup)

        if not any(select):
            # the condition is not part of this subgroup
            continue
        
        # pca space
        ood_data_target_encoded[subgroup][cond] = adata_ood[select].obsm["X_pca"]
        ood_data_target_encoded_predicted[subgroup][cond] = adata_pred_ood[select_pred].obsm["X_pca"]
    
        # gene space
        ood_data_target_decoded[subgroup][cond] = np.asarray(adata_ood[select].X.todense())
        ood_data_target_decoded_predicted[subgroup][cond] = adata_pred_ood[select_pred].X

In [None]:
ood_metrics_encoded, mean_ood_metrics_encoded = {}, {}
ood_metrics_decoded, mean_ood_metrics_decoded = {}, {}
deg_ood_metrics, deg_mean_ood_metrics = {}, {}
ood_deg_dict = {}
ood_deg_target_decoded_predicted, ood_deg_target_decoded = {}, {}

for subgroup in tqdm(subgroups[::-1]):

    print(f"subgroup: {subgroup}")
    print("Computing ood_metrics_encoded")
    # ood set: evaluation in encoded (=pca) space
    ood_metrics_encoded[subgroup] = jax.tree_util.tree_map(
        compute_metrics, 
        # compute_metrics_fast, 
        ood_data_target_encoded[subgroup], 
        ood_data_target_encoded_predicted[subgroup]
    )
    mean_ood_metrics_encoded[subgroup] = compute_mean_metrics(
        ood_metrics_encoded[subgroup], 
        prefix="encoded_ood_",
    )

    print("Computing ood_metrics_decoded")
    # ood set: evaluation in decoded (=gene) space
    ood_metrics_decoded[subgroup] = jax.tree_util.tree_map(
        # compute_metrics, 
        compute_metrics_fast, 
        ood_data_target_decoded[subgroup], 
        ood_data_target_decoded_predicted[subgroup]
    )
    mean_ood_metrics_decoded[subgroup] = compute_mean_metrics(
        ood_metrics_decoded[subgroup], 
        prefix="decoded_ood_",
    )

    # ood set
    ood_deg_dict[subgroup] = {
        k: v
        for k, v in adata_train.uns['rank_genes_groups_cov_all'].items() 
        if k in ood_data_target_decoded_predicted[subgroup].keys()
    }

    print("Apply DEG mask")
    # ood set
    ood_deg_target_decoded_predicted[subgroup] = jax.tree_util.tree_map(
        get_mask, 
        ood_data_target_decoded_predicted[subgroup], 
        ood_deg_dict[subgroup]
    )
    
    ood_deg_target_decoded[subgroup] = jax.tree_util.tree_map(
        get_mask, 
        ood_data_target_decoded[subgroup], 
        ood_deg_dict[subgroup]
    )

    print("Compute metrics on DEG subsetted decoded")
    deg_ood_metrics[subgroup] = jax.tree_util.tree_map(
        compute_metrics, 
        # compute_metrics_fast, 
        ood_deg_target_decoded[subgroup], 
        ood_deg_target_decoded_predicted[subgroup]
    )
    deg_mean_ood_metrics[subgroup] = compute_mean_metrics(
        deg_ood_metrics[subgroup], 
        prefix="deg_ood_"
    )

In [None]:
collected_results = {
    # ood
    "ood_metrics_encoded": ood_metrics_encoded,
    "mean_ood_metrics_encoded": mean_ood_metrics_encoded,
    "ood_metrics_decoded": ood_metrics_decoded,
    "mean_ood_metrics_decoded": mean_ood_metrics_decoded,
    "deg_ood_metrics": deg_ood_metrics,
    "deg_mean_ood_metrics": deg_mean_ood_metrics,
    "ood_deg_dict": ood_deg_dict,
    "ood_deg_target_decoded_predicted": ood_deg_target_decoded_predicted,
    "ood_deg_target_decoded": ood_deg_target_decoded,
    "path_predictions": path_predictions,
    "wandb_run_name": wandb_run_name,
}

In [None]:
collected_results["deg_mean_ood_metrics"]

In [None]:
OUT_DIR = "/lustre/groups/ml01/workspace/ot_perturbation/data/norman_soren/cellflow"
result_filename = os.path.join(OUT_DIR, f"cellflow_split_{split}_collected_results_{wandb_run_name}.pkl")
print(f"Saving results at: {result_filename}")
pd.to_pickle(collected_results, result_filename)

In [None]:
wandb_run_name