In [1]:
import cfp
import scanpy as sc
import numpy as np
import optax
from omegaconf import DictConfig, OmegaConf
from functools import partial
from cfp.training import Metrics, PCADecoder, PCADecodedMetrics
from cfp.training._callbacks import WandbLogger
from ott.solvers import utils as solver_utils



In [2]:
dataset_path = {
    "train": '/lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/ood_cell_type/IFNG_ct-BXPC3_hvg-2000_pca-30_norm/adata_train_IFNG_BXPC3_embs.h5ad',
    "test": '/lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/ood_cell_type/IFNG_ct-BXPC3_hvg-2000_pca-30_norm/adata_test_IFNG_BXPC3_embs.h5ad',
    "ood": '/lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/ood_cell_type/IFNG_ct-BXPC3_hvg-2000_pca-30_norm/adata_ood_IFNG_BXPC3_embs.h5ad'
}



adata_train = sc.read_h5ad(dataset_path["train"])
adata_test = sc.read_h5ad(dataset_path["test"])
adata_ood = sc.read_h5ad(dataset_path["ood"])




In [3]:
# Prepare model
model_config = {
    "hidden_dims": (1024, 1024, 1024),
    "decoder_dims": (2048, 2048, 2048),
    "condition_embedding_dim": 4096,
    "time_encoder_dims": (512, 512, 512),
    "flow": {"constant_noise": 1.0},
    "hidden_dropout": 0.0,
    "decoder_dropout": 0.0,
    "layers_after_pool": ({"layer_type": "mlp", "dims": (4096, 4096)},),
    "match_kwargs": {"epsilon": 0.01, "tau_a": 1.0, "tau_b": 0.999},
    #"covariates_not_pooled": ['cell_type', 'pathway', 'gene'],
}

train_config = {
    "learning_rate": 5.0e-05,
    "batch_size": 1024,
    "multi_steps": 50,
    "num_iterations": 1000000,
    "valid_freq": 1000,
    "n_test_samples": 2,
    "n_ood_samples": 2,
    "save_model": True,
    "save_model_path": "/home/icb/lea.zimmermann/projects/cell_flow_perturbation/results"
}
optimizer = optax.MultiSteps(optax.adam(learning_rate=train_config["learning_rate"]), train_config["multi_steps"])

match_fn = partial(solver_utils.match_linear, **model_config["match_kwargs"])


# Initiate the model (only otfm works for now)
cf = cfp.model.CellFlow(adata_train, solver='otfm')

# Prepare training data
cf.prepare_data(
    sample_rep="X_pca",  # this is the location of the cell data, can also be "X"
    control_key="control",  # Column to distinguish controls. could wither be a tuple with the column key and the control name or just the name of a boolian column
    sample_covariates=[
        'pathway',
        'cell_type'
    ],  # split the data based on the covariates in obs, such as "cell_type"
    sample_covariate_reps={
        "cell_type": "cell_type_emb",
        "pathway": "pathway_emb",
    },  
    perturbation_covariates={'gene': ('gene',)},
    perturbation_covariate_reps={
        "gene": "gene_emb",
    }, 
)



A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  _covariate_data["cell_index"] = _covariate_data.index
2024-08-22 14:39:28.921586: W external/xla/xla/service/gpu/nvptx_compiler.cc:836] The NVIDIA driver's CUDA version is 12.3 which is older than the PTX compiler version (12.6.20). Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
100%|██████████| 219/219 [00:02<00:00, 97.95it/s] 


In [9]:
cf.dm.sample_covariates

['pathway', 'cell_type']

In [None]:
# Then prepare validation data from separate adata object
cf.prepare_validation_data(
    adata_test,
    name="test",
    n_conditions_on_log_iteration=2,
    n_conditions_on_train_end=2
)
# If you have two, such as test & ood you can run it a second time
cf.prepare_validation_data(
    adata_ood,
    name="ood",
    n_conditions_on_log_iteration=2,
    n_conditions_on_train_end=2
)

cf.prepare_model(
    hidden_dims=model_config["hidden_dims"],
    decoder_dims=model_config["decoder_dims"],
    condition_embedding_dim=model_config["condition_embedding_dim"],
    time_encoder_dims=model_config["time_encoder_dims"],
    flow=model_config["flow"],
    hidden_dropout=model_config["hidden_dropout"],
    decoder_dropout=model_config["decoder_dropout"],
    match_fn=match_fn,
    optimizer=optimizer,
    layers_after_pool=model_config["layers_after_pool"]
)

metrics_callback = Metrics(metrics=["r_squared"])
pca_decoder = PCADecoder(pcs=adata_train.varm['PCs'], means=adata_train.varm["X_train_mean"])
decoded_metrics_callback = PCADecodedMetrics(
    metrics=["r_squared"],
    pca_decoder=pca_decoder
)

logger_config = {
    'project': 'satija_ifng_otfm_test',
    'out_dir': '/lustre/groups/ml01/workspace/ot_perturbation/data/satija/out'
}

config = {
    "model": model_config,
    "train": train_config,
    "logger": logger_config
}

wandb_callback = WandbLogger(project=logger_config['project'], out_dir=logger_config['out_dir'],config=config)

cf.train(
    num_iterations=train_config["num_iterations"], 
    callbacks=[metrics_callback, decoded_metrics_callback, wandb_callback], 
    valid_freq=train_config["valid_freq"], 
    batch_size=train_config["batch_size"]
)
