In [1]:
1

1

In [2]:
import functools
import os
import sys
import traceback
from typing import Dict, Literal, Optional, Tuple

import cfp
import scanpy as sc
import numpy as np
import functools
from ott.solvers import utils as solver_utils
import optax
from omegaconf import OmegaConf
from typing import NamedTuple, Any
import hydra
import wandb

In [2]:
import yaml

with open("/home/icb/dominik.klein/git_repos/ot_pert_new/runs_otfm/conf/model/combosciplex.yaml") as stream:

    model_config = yaml.safe_load(stream)

with open("/home/icb/dominik.klein/git_repos/ot_pert_new/runs_otfm/conf/dataset/combosciplex.yaml") as stream:

    dataset_config = yaml.safe_load(stream)

with open("/home/icb/dominik.klein/git_repos/ot_pert_new/runs_otfm/conf/training/combosciplex.yaml") as stream:

    training_config = yaml.safe_load(stream)

In [3]:
config_dict = {}
config_dict["model"] = model_config
config_dict["dataset"] = dataset_config
config_dict["training"] = training_config

In [3]:
def prepare_data(adata_train, adata_test, adata_ood):
    
    adata_tmp =  adata_train[adata_train.obs["Drug1"].drop_duplicates().index]
    ecfp_dict = {drug: adata_tmp[adata_tmp.obs["Drug1"]==drug].obsm["ecfp_drug_1"] for drug in adata_tmp.obs["Drug1"]}

    adata_tmp =  adata_train[adata_train.obs["Drug2"].drop_duplicates().index]
    ecfp_dict.update({drug: adata_tmp[adata_tmp.obs["Drug2"]==drug].obsm["ecfp_drug_2"] for drug in adata_tmp.obs["Drug2"]})

    adata_tmp =  adata_ood[adata_ood.obs["Drug1"].drop_duplicates().index]
    ecfp_dict.update({drug: adata_tmp[adata_tmp.obs["Drug1"]==drug].obsm["ecfp_drug_1"] for drug in adata_tmp.obs["Drug1"]})

    adata_tmp =  adata_ood[adata_ood.obs["Drug2"].drop_duplicates().index]
    ecfp_dict.update({drug: adata_tmp[adata_tmp.obs["Drug2"]==drug].obsm["ecfp_drug_2"] for drug in adata_tmp.obs["Drug2"]})

        
    adata_train.uns['ecfp_rep'] = ecfp_dict
    adata_test.uns['ecfp_rep'] = ecfp_dict
    adata_ood.uns['ecfp_rep'] = ecfp_dict
    return adata_train, adata_test, adata_ood


In [4]:


def run(config):
    config_dict  = config#OmegaConf.to_container(config, resolve=True)
    #split = config_dict["dataset"]["split"]
    split = 2
    adata_train_path = f"/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/adata_train_{split}.h5ad"
    adata_test_path = f"/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/adata_test_{split}.h5ad"
    adata_ood_path = f"/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/adata_ood_{split}.h5ad"
    adata_train = sc.read_h5ad(adata_train_path)
    adata_test = sc.read_h5ad(adata_test_path)
    adata_ood = sc.read_h5ad(adata_ood_path)
    adata_train, adata_test, adata_ood = prepare_data(adata_train, adata_test, adata_ood)

    
    cf = cfp.model.CellFlow(adata_train, solver="otfm")

    # Prepare the training data and perturbation conditions
    perturbation_covariates = {k: tuple(v) for k, v in config_dict["dataset"]["perturbation_covariates"].items()}
    cf.prepare_data(
        sample_rep="X_pca",
        control_key="control",
        perturbation_covariates=perturbation_covariates,
        perturbation_covariate_reps=dict(config_dict["dataset"]["perturbation_covariate_reps"]),
    )

    match_fn = functools.partial(
        solver_utils.match_linear,
        epsilon=config_dict["model"]["epsilon"],
        scale_cost="mean",
        tau_a=config_dict["model"]["tau_a"],
        tau_b=config_dict["model"]["tau_b"]
    )
    optimizer = optax.MultiSteps(optax.adam(config_dict["model"]["learning_rate"]), config_dict["model"]["multi_steps"])
    flow = {config_dict["model"]["flow_type"]: config_dict["model"]["flow_noise"]}

    layers_before_pool = config_dict["model"]["layers_before_pool"]
    layers_after_pool = config_dict["model"]["layers_after_pool"]


    # Prepare the model
    cf.prepare_model(
        encode_conditions=True,
        condition_embedding_dim=config_dict["model"]["condition_embedding_dim"],
        time_encoder_dims=config_dict["model"]["time_encoder_dims"],
        time_encoder_dropout=config_dict["model"]["time_encoder_dropout"],
        hidden_dims=config_dict["model"]["hidden_dims"],
        hidden_dropout=config_dict["model"]["hidden_dropout"],
        decoder_dims=config_dict["model"]["decoder_dims"],
        decoder_dropout=config_dict["model"]["decoder_dropout"],
        pooling=config_dict["model"]["pooling"],
        layers_before_pool=layers_before_pool,
        layers_after_pool=layers_after_pool,
        cond_output_dropout=config_dict["model"]["cond_output_dropout"],
        time_freqs=config_dict["model"]["time_freqs"],
        match_fn=match_fn,
        optimizer=optimizer,
        flow=flow,
    )

    cf.prepare_validation_data(
        adata_test,
        name="test",
        n_conditions_on_log_iteration=config_dict["training"]["test_n_conditions_on_log_iteration"],
        n_conditions_on_train_end=config_dict["training"]["test_n_conditions_on_log_iteration"],
    )

    cf.prepare_validation_data(
        adata_ood,
        name="ood",
        n_conditions_on_log_iteration=config_dict["training"]["ood_n_conditions_on_log_iteration"],
        n_conditions_on_train_end=config_dict["training"]["ood_n_conditions_on_log_iteration"],
    )

    metrics_callback = cfp.training.Metrics(metrics=["r_squared", "mmd", "e_distance"])
    decoded_metrics_callback = cfp.training.PCADecodedMetrics(ref_adata=adata_train, metrics=["r_squared", "mmd", "e_distance"])
    wandb_callback = cfp.training.WandbLogger(project="cfp_combosciplex_otfm", out_dir="/home/icb/dominik.klein/tmp", config=config_dict)

    callbacks = [metrics_callback, decoded_metrics_callback, wandb_callback]

    cf.train(
        num_iterations=240_000,#num_iterations=config_dict["training"]["num_iterations"],
        batch_size=config_dict["training"]["batch_size"],
        callbacks=callbacks,
        valid_freq=config_dict["training"]["valid_freq"],
    )

    if config_dict["training"]["save_model"]:
        cf.save(config_dict["training"]["out_dir"], file_prefix=wandb.run.name)



In [5]:
run(config_dict)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  _covariate_data["cell_index"] = _covariate_data.index
100%|████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:00<00:00, 48.10it/s]
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  _covariate_data["cell_index"] = _covariate_data.index
100%|███████████████████████████████████████████████████████████████████████████████████| 25/25 [00:00<00:00, 222.90it/s]
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value ins

100%|██████████████████████████████████████████████████████████████| 240000/240000 [1:56:52<00:00, 34.22it/s, loss=0.275]


In [9]:
model_name = wandb.run.name

In [4]:
model_name = 'swept-bush-84'

In [5]:
cf = cfp.model.CellFlow.load(f"/lustre/groups/ml01/workspace/ot_perturbation/models/otfm/combosciplex/{model_name}_CellFlow.pkl")

In [6]:
cf.trainer.training_logs.keys()

dict_keys(['loss', 'test_r_squared_mean', 'ood_r_squared_mean', 'test_mmd_mean', 'ood_mmd_mean', 'test_e_distance_mean', 'ood_e_distance_mean', 'pca_decoded_ood_r_squared_mean', 'pca_decoded_test_r_squared_mean', 'pca_decoded_ood_mmd_mean', 'pca_decoded_test_mmd_mean', 'pca_decoded_ood_e_distance_mean', 'pca_decoded_test_e_distance_mean'])

In [7]:
cf.trainer.training_logs['pca_decoded_ood_r_squared_mean']

[0.8854240694360506, 0.9720600279184204, 0.9884586794405363]

In [8]:
split = 2
adata_train_path = f"/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/adata_train_{split}.h5ad"
adata_test_path = f"/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/adata_test_{split}.h5ad"
adata_ood_path = f"/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex/adata_train_{split}.h5ad"
adata_train = sc.read_h5ad(adata_train_path)
adata_test = sc.read_h5ad(adata_test_path)
adata_ood = sc.read_h5ad(adata_ood_path)
adata_train, adata_test, adata_ood = prepare_data(adata_train, adata_test, adata_ood)

In [9]:
adata_ood_ctrl = adata_ood[adata_ood.obs["condition"]=="control"]
adata_ood_ctrl.obs["control"] = True
covariate_data_ood = adata_ood[adata_ood.obs["condition"]!="control"].obs.drop_duplicates(subset=["condition"])

  adata_ood_ctrl.obs["control"] = True


In [10]:
covariate_data_ood 

Unnamed: 0_level_0,sample,Size_Factor,n.umi,RT_well,Drug1,Drug2,Well,n_genes,n_genes_by_counts,total_counts,...,split,control,cell_type,cell_line,smiles_drug_1,smiles_drug_2,ood_1,ood_2,ood_3,ood_4
Cell,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
A01_A02_RT_BC_13_Lig_BC_1,sciPlex_theis,1.083277,2908,RT_13,Cediranib,PCI-34051,B1,1896,1895,2907.0,...,train,0,A549,A549,CC1=CC2=C(N1)C=CC(=C2F)OC3=NC=NC4=CC(=C(C=C43)...,COC1=CC=C(C=C1)CN2C=CC3=C2C=C(C=C3)C(=O)NO,not ood,not ood,not ood,Cediranib+PCI-34051
A01_A02_RT_BC_16_Lig_BC_2,sciPlex_theis,1.298964,3487,RT_16,Givinostat,SRT1720,B4,2154,2152,3485.0,...,train,0,A549,A549,CCN(CC)CC1=CC2=C(C=C1)C=C(C=C2)COC(=O)NC3=CC=C...,C1CN(CCN1)CC2=CSC3=NC(=CN23)C4=CC=CC=C4NC(=O)C...,Givinostat+SRT1720,not ood,not ood,not ood
A01_A02_RT_BC_1_Lig_BC_3,sciPlex_theis,2.151281,5775,RT_1,Panobinostat,PCI-34051,A1,3139,3138,5774.0,...,train,0,A549,A549,CC1=C(C2=CC=CC=C2N1)CCNCC3=CC=C(C=C3)/C=C/C(=O)NO,COC1=CC=C(C=C1)CN2C=CC3=C2C=C(C=C3)C(=O)NO,Panobinostat+PCI-34051,not ood,not ood,not ood
A01_A02_RT_BC_22_Lig_BC_17,sciPlex_theis,0.899999,2416,RT_22,control,Dacinostat,B10,1572,1569,2412.0,...,train,0,A549,A549,,C1=CC=C2C(=C1)C(=CN2)CCN(CCO)CC3=CC=C(C=C3)/C=...,not ood,not ood,control+Dacinostat,not ood
A01_A02_RT_BC_28_Lig_BC_12,sciPlex_theis,1.136175,3050,RT_28,Panobinostat,Sorafenib,C4,2015,2015,3050.0,...,train,0,A549,A549,CC1=C(C2=CC=CC=C2N1)CCNCC3=CC=C(C=C3)/C=C/C(=O)NO,CNC(=O)C1=NC=CC(=C1)OC2=CC=C(C=C2)NC(=O)NC3=CC...,not ood,not ood,Panobinostat+Sorafenib,not ood
A01_A02_RT_BC_34_Lig_BC_21,sciPlex_theis,1.1548,3100,RT_34,control,Givinostat,C10,1987,1985,3098.0,...,train,0,A549,A549,,CCN(CC)CC1=CC2=C(C=C1)C=C(C=C2)COC(=O)NC3=CC=C...,not ood,not ood,not ood,control+Givinostat
A01_A02_RT_BC_37_Lig_BC_14,sciPlex_theis,0.921233,2473,RT_37,Dacinostat,Danusertib,D1,1657,1656,2472.0,...,train,0,A549,A549,C1=CC=C2C(=C1)C(=CN2)CCN(CCO)CC3=CC=C(C=C3)/C=...,CN1CCN(CC1)C2=CC=C(C=C2)C(=O)NC3=NNC4=C3CN(C4)...,not ood,not ood,not ood,not ood
A01_A02_RT_BC_40_Lig_BC_18,sciPlex_theis,2.64263,7094,RT_40,Panobinostat,Dasatinib,D4,3767,3760,7087.0,...,train,0,A549,A549,CC1=C(C2=CC=CC=C2N1)CCNCC3=CC=C(C=C3)/C=C/C(=O)NO,CC1=C(C(=CC=C1)Cl)NC(=O)C2=CN=C(S2)NC3=CC(=NC(...,Panobinostat+Dasatinib,not ood,not ood,not ood
A01_A02_RT_BC_43_Lig_BC_4,sciPlex_theis,1.156663,3105,RT_43,Givinostat,Tanespimycin,D7,2013,2011,3103.0,...,train,0,A549,A549,CCN(CC)CC1=CC2=C(C=C1)C=C(C=C2)COC(=O)NC3=CC=C...,C[C@H]1C[C@@H]([C@@H]([C@H](/C=C(/[C@@H]([C@H]...,not ood,not ood,not ood,not ood
A01_A02_RT_BC_46_Lig_BC_22,sciPlex_theis,1.234519,3314,RT_46,Givinostat,Carmofur,D10,2123,2123,3314.0,...,train,0,A549,A549,CCN(CC)CC1=CC2=C(C=C1)C=C(C=C2)COC(=O)NC3=CC=C...,CCCCCCNC(=O)N1C=C(C(=O)NC1=O)F,not ood,not ood,not ood,not ood


In [5]:
preds_ood = cf.predict(adata=adata_ood_ctrl, sample_rep="X_pca", condition_id_key="condition", covariate_data=covariate_data_ood)

NameError: name 'cf' is not defined