In [1]:
%load_ext autoreload
%autoreload 2

**Adapted from** [https://github.com/theislab/ot_pert_reproducibility/blob/4944509fc4820bde50d0c1da19c33a337ae0df0d/runs_otfm/train_sciplex.py](https://github.com/theislab/ot_pert_reproducibility/blob/4944509fc4820bde50d0c1da19c33a337ae0df0d/runs_otfm/train_sciplex.py)\
The following dependencies need to be installed first:
```bash
pip install hydra-core --upgrade
pip install wandb
pip install hydra-submitit-launcher --upgrade
```

In [2]:
import functools
import os
import sys
import traceback
from typing import Dict, Literal, Optional, Tuple

import cfp
import anndata
import scanpy as sc
import numpy as np
import functools
from ott.solvers import utils as solver_utils
import optax
from omegaconf import OmegaConf
from typing import NamedTuple, Any
import hydra
import wandb

In [3]:
from hydra import compose, initialize

with initialize(version_base=None, config_path='conf'):
    config = compose(
        config_name="train",
        overrides=[
            "dataset=norman",
            "model=norman",
            "training=norman",
            "logger=norman",
        ]
    )
config_dict  = OmegaConf.to_container(config, resolve=True)
display(config_dict)

{'dataset': {'split': 0,
  'sample_rep': 'X_pca',
  'perturbation_covariates': {'target_gene': ['gene_1', 'gene_2']},
  'perturbation_covariate_reps': {'target_gene': 'esm2'},
  'wandb_project': 'otfm_norman'},
 'model': {'condition_embedding_dim': 1024,
  'time_encoder_dims': [2048, 2048, 2048],
  'time_encoder_dropout': 0.0,
  'hidden_dims': [4096, 4096, 4096],
  'hidden_dropout': 0.0,
  'decoder_dims': [4096, 4096, 4096],
  'decoder_dropout': 0.2,
  'pooling': 'attention_token',
  'layers_before_pool': {'target_gene': {'layer_type': 'mlp',
    'dims': [1024, 1024],
    'dropout_rate': 0.5}},
  'layers_after_pool': {'layer_type': 'mlp',
   'dims': [1024, 1024],
   'dropout_rate': 0.2},
  'cond_output_dropout': 0.9,
  'time_freqs': 1024,
  'flow_noise': 1.0,
  'learning_rate': 5e-05,
  'multi_steps': 50,
  'epsilon': 0.1,
  'tau_a': 1.0,
  'tau_b': 1.0,
  'flow_type': 'constant_noise',
  'linear_projection_before_concatenation': False,
  'layer_norm_before_concatenation': False},
 'lo

In [4]:
split = config_dict["dataset"]["split"]
adata_train_path = f"/home/haicu/soeren.becker/repos/ot_pert_reproducibility/norman2019/norman_preprocessed_adata/adata_train_split_{split}.h5ad"
adata_test_path = f"/home/haicu/soeren.becker/repos/ot_pert_reproducibility/norman2019/norman_preprocessed_adata/adata_val_split_{split}.h5ad"
adata_ood_path = f"/home/haicu/soeren.becker/repos/ot_pert_reproducibility/norman2019/norman_preprocessed_adata/adata_test_split_{split}.h5ad"
adata_train = sc.read_h5ad(adata_train_path)
adata_test = sc.read_h5ad(adata_test_path)
adata_ood = sc.read_h5ad(adata_ood_path)

In [5]:
cf = cfp.model.CellFlow(adata_train, solver="otfm")

In [6]:
# Prepare the training data and perturbation conditions
perturbation_covariates = {k: tuple(v) for k, v in config_dict["dataset"]["perturbation_covariates"].items()}

In [7]:
cf.prepare_data(
    sample_rep="X_pca",
    control_key="control",
    perturbation_covariates=perturbation_covariates,
    perturbation_covariate_reps=dict(config_dict["dataset"]["perturbation_covariate_reps"]),
    sample_covariates=None,
    sample_covariate_reps=None,
    split_covariates=None
)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  _covariate_data["cell_index"] = _covariate_data.index
100%|██████████| 139/139 [00:00<00:00, 231.87it/s]


In [8]:
match_fn = functools.partial(
    solver_utils.match_linear,
    epsilon=config_dict["model"]["epsilon"],
    scale_cost="mean",
    tau_a=config_dict["model"]["tau_a"],
    tau_b=config_dict["model"]["tau_b"]
)

In [9]:
optimizer = optax.MultiSteps(optax.adam(config_dict["model"]["learning_rate"]), config_dict["model"]["multi_steps"])
flow = {config_dict["model"]["flow_type"]: config_dict["model"]["flow_noise"]}

layers_before_pool = config_dict["model"]["layers_before_pool"]
layers_after_pool = config_dict["model"]["layers_after_pool"]

In [10]:
# Prepare the model
cf.prepare_model(
    encode_conditions=True,
    condition_embedding_dim=config_dict["model"]["condition_embedding_dim"],
    pooling=config_dict["model"]["pooling"],
    time_encoder_dims=config_dict["model"]["time_encoder_dims"],
    time_encoder_dropout=config_dict["model"]["time_encoder_dropout"],
    hidden_dims=config_dict["model"]["hidden_dims"],
    hidden_dropout=config_dict["model"]["hidden_dropout"],
    decoder_dims=config_dict["model"]["decoder_dims"],
    decoder_dropout=config_dict["model"]["decoder_dropout"],
    layers_before_pool=layers_before_pool,
    layers_after_pool=layers_after_pool,
    cond_output_dropout=config_dict["model"]["cond_output_dropout"],
    time_freqs=config_dict["model"]["time_freqs"],
    match_fn=match_fn,
    optimizer=optimizer,
    flow=flow,
    layer_norm_before_concatenation=config_dict["model"]["layer_norm_before_concatenation"],
    linear_projection_before_concatenation=config_dict["model"]["linear_projection_before_concatenation"],
)

In [11]:
cf.prepare_validation_data(
    adata_test,
    name="test",
    n_conditions_on_log_iteration=config_dict["training"]["test_n_conditions_on_log_iteration"],
    n_conditions_on_train_end=config_dict["training"]["test_n_conditions_on_log_iteration"],
)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  _covariate_data["cell_index"] = _covariate_data.index
100%|██████████| 32/32 [00:00<00:00, 599.00it/s]


In [12]:
cf.prepare_validation_data(
    adata_ood,
    name="ood",
    n_conditions_on_log_iteration=2,#config_dict["training"]["ood_n_conditions_on_log_iteration"],
    n_conditions_on_train_end=2,#config_dict["training"]["ood_n_conditions_on_log_iteration"],
)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  _covariate_data["cell_index"] = _covariate_data.index
100%|██████████| 108/108 [00:00<00:00, 426.50it/s]


In [13]:
metrics_callback = cfp.training.Metrics(metrics=["r_squared", "mmd", "e_distance"])
decoded_metrics_callback = cfp.training.PCADecodedMetrics(ref_adata=adata_train, metrics=["r_squared", "mmd", "e_distance"])
wandb_callback = cfp.training.WandbLogger(project="cfp_otfm_norman", out_dir="/home/icb/dominik.klein/tmp", config=config_dict)

callbacks = [metrics_callback, decoded_metrics_callback, wandb_callback]

In [14]:
cf.train(
    num_iterations=config_dict["training"]["num_iterations"],
    batch_size=config_dict["training"]["batch_size"],
    callbacks=callbacks,
    valid_freq=config_dict["training"]["valid_freq"],
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33msab[0m. Use [1m`wandb login --relogin`[0m to force relogin


100%|██████████| 10000/10000 [22:11<00:00,  7.51it/s, loss=0.406]  
