In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import cfp
import scanpy as sc
import numpy as np

In [None]:
adata_train = sc.read_h5ad("/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_train_300_2.h5ad")
adata_test = sc.read_h5ad("/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_test_300_2.h5ad")
adata_ood = sc.read_h5ad("/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/adata_ood_300_2.h5ad")

In [None]:
adata_train.obs["CTRL"] = adata_train.obs.apply(lambda x: True if x["drug"] == "Vehicle" else False, axis=1)
adata_test.obs["CTRL"] = adata_test.obs.apply(lambda x: True if x["drug"] == "Vehicle" else False, axis=1)
adata_ood.obs["CTRL"] = adata_ood.obs.apply(lambda x: True if x["drug"] == "Vehicle" else False, axis=1)

In [None]:
adata_tmp =  adata_train[adata_train.obs["drug"].drop_duplicates().index]
ecfp_dict = {drug: adata_tmp[adata_tmp.obs["drug"]==drug].obsm["ecfp"] for drug in adata_tmp.obs["drug"]}

adata_tmp =  adata_ood[adata_ood.obs["drug"].drop_duplicates().index]
ecfp_dict.update({drug: adata_tmp[adata_tmp.obs["drug"]==drug].obsm["ecfp"] for drug in adata_tmp.obs["drug"]})

adata_tmp =  adata_ood[adata_ood.obs["cell_line"].drop_duplicates().index]
cell_line_dict = {cell_line: adata_tmp[adata_tmp.obs["cell_line"]==cell_line].obsm["cell_line_emb"] for cell_line in adata_tmp.obs["cell_line"]}

In [None]:
adata_train.uns['ecfp_rep'] = ecfp_dict
adata_test.uns['ecfp_rep'] = ecfp_dict
adata_ood.uns['ecfp_rep'] = ecfp_dict
adata_train.uns['cl_rep'] = cell_line_dict
adata_test.uns['cl_rep'] = cell_line_dict
adata_ood.uns['cl_rep'] = cell_line_dict

In [None]:
prepare_config = {
    "sample_rep": "X_pca",
    "perturbation_covariates": {"drugs": ["drug"],
        "dose": ["logdose"],}
    "perturbation_covariate_reps": {
        "drugs": "ecfp_rep",
    },
    "sample_covariates": ["cell_line"],
    "sample_covariate_reps": {
        "cell_line": "cl_rep",
    } # or both sample_covaraites and sample_covariate_reps None,
    split_covariates=["cell_line"],

    
    
}

In [None]:
# Initialize CellFlow
cf = cfp.model.CellFlow(adata_train, solver="otfm")

# Prepare the training data and perturbation conditions
cf.prepare_data(
    sample_rep=prepare_config["sample_rep"],
    control_key="CTRL",
    perturbation_covariates=prepare_config["perturbation_covariates"],
    perturbation_covariate_reps=prepare_config["perturbation_covariate_reps"],
    sample_covariates=prepare_config["sample_covariates"],
    sample_covariate_reps=prepare_config["sample_covariate_reps"],
    split_covariates=split_covariates["sample_covariates"],
)







In [None]:
model_config = {
    "condition_embedding_dim" : 1024,
    "time_encoder_dims": (1024, 1024, 1024),
    "time_encoder_dropout": 0.0,
    "hidden_dims": (1024, 1024, 1024),
    "hidden_dropout": 0.0,
    "decoder_dims": (1024, 1024, 1024),
    "decoder_dropout": 0.0,
    "pooling": "mean",
    "layers_before_pool": [{"drug": ({"layer_type": "mlp", "dims": (1024, 512)}},), {"dose": ("mlp", {"dims": (256, 512)}),}],
    "layers_after_pool": ({"layer_type": "mlp", "dims": (512, 512)},),
    "cond_output_dropout": 0.0
    "time_freqs": 1024,
    "epsilon": 0.1,
    "tau_a": 1.0,
    "tau_b": 1.0
    
    
    
}

In [77]:
# Prepare the model
cf.prepare_model(
    encode_conditions=True,
    condition_embedding_dim=model_config["condition_embedding_dim"],
    time_encoder_dims = model_config["time_encoder_dims"],
    time_encoder_dropout = model_config["time_encoder_dropout"],
    hidden_dims=model_config["hidden_dims"],
    hidden_dropout = model_config["hidden_dropout"],
    decoder_dims=model_config["decoder_dims"],
    decoder_dropout=model_config["decoder_dropout"],
    pooling=model_config["pooling"],
    layers_before_pool=model_config["layers_before_pool"],
    layers_after_pool=model_config["layers_after_pool"],
    cond_output_dropout=model_config["cond_output_dropout"],
    time_freqs = model_config["time_freqs"],
    epsilon=model_config["epsilon"],
    tau_a=model_config["tau_a"],
    tau_b=model_config["tau_b"],
)



In [83]:
metrics_callback = cfp.training.ComputeMetrics(metrics=["r_squared", "mmd", "e_distance"])
decoded_metrics_callback = cfp.training.callbacks.PCADecodedMetrics(pca_decoder=(adata_train.obsm["X_pca"], adata_train.varm["X_train_mean"]), metrics=["r_squared", "mmd", "e_distance"])
wandb_callback = cfp.training.callbacks.WandbLogger(project="cfp", out_dir="/home/icb/dominik.klein/tmp", config={})


callbacks = [metrics_callback, decoded_metrics_callback, wandb_callback]

AttributeError: module 'cfp.training' has no attribute 'ComputeMetrics'

In [None]:
training_config = {
    "num_iterations": 100000,
    "batch_size": 1024,
    "valid_freq": 10000,
}

In [79]:
# Train the model
cf.train(
    num_iterations=training_config["num_iterations"],
    batch_size=training_config["batch_size"],
    callbacks=callbacks,
    valid_freq=training_config["valid_freq"],
)

100%|██████████| 100/100 [00:15<00:00,  6.49it/s]


In [18]:
condition_ood = adata_ood.obs.drop_duplicates(subset=["drug", "logdose", "cell_line"])

In [19]:
adata_ood_ctrl.obs["CTRL"] = True

  adata_ood_ctrl.obs["CTRL"] = True


In [21]:
adata_tmp =  adata_ood[adata_ood.obs["drug"].drop_duplicates().index]
ecfp_dict = {drug: adata_tmp[adata_tmp.obs["drug"]==drug].obsm["ecfp"] for drug in adata_tmp.obs["drug"]}

adata_tmp =  adata_ood[adata_ood.obs["cell_line"].drop_duplicates().index]
cell_line_dict = {cell_line: adata_tmp[adata_tmp.obs["cell_line"]==cell_line].obsm["cell_line_emb"] for cell_line in adata_tmp.obs["cell_line"]}

In [22]:
adata_ood_ctrl.uns['cl_rep'] = cell_line_dict
adata_ood_ctrl.uns['ecfp_rep'] = ecfp_dict

In [55]:
pred_data = cf.dm.get_prediction_data(
            adata_ood_ctrl,
            sample_rep="X_pca",
            covariate_data=condition_ood,
        )

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  _covariate_data["cell_index"] = _covariate_data.index
100%|██████████| 299/299 [00:00<00:00, 587.86it/s]
100%|██████████| 299/299 [00:00<00:00, 584.59it/s]
100%|██████████| 299/299 [00:00<00:00, 567.76it/s]


In [57]:
# Make predictions
X_pca_pred = cf.predict(
    adata = adata_ood_ctrl,
    sample_rep ="X_pca",
    covariate_data=condition_ood,
)



A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  _covariate_data["cell_index"] = _covariate_data.index
100%|██████████| 299/299 [00:00<00:00, 578.39it/s]
100%|██████████| 299/299 [00:00<00:00, 574.19it/s]
100%|██████████| 299/299 [00:00<00:00, 571.77it/s]


KeyboardInterrupt: 

In [60]:
cond_emb = cf.get_condition_embedding(condition_ood, rep_dict=adata_ood_ctrl.uns)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  _covariate_data["cell_index"] = _covariate_data.index
100%|██████████| 299/299 [00:00<00:00, 554.34it/s]
100%|██████████| 299/299 [00:00<00:00, 572.55it/s]
100%|██████████| 299/299 [00:00<00:00, 569.55it/s]


In [63]:
cond_emb[('Tie2_kinase_inhibitor',
  3.0,
  'A549')].shape

(1, 32)

In [70]:
import pandas as pd
pd.DataFrame.from_dict({k: v[0] for k,v in cond_emb.items()}).T

Unnamed: 0,Unnamed: 1,Unnamed: 2,0,1,2,3,4,5,6,7,8,9,...,22,23,24,25,26,27,28,29,30,31
Tie2_kinase_inhibitor,3.0,A549,0.874050,1.280500,-0.716889,-1.179713,-1.663216,-0.017025,-0.602069,-0.755869,0.589173,-0.060522,...,-1.087466,-0.220028,-0.983849,-2.438949,3.559757,0.796902,0.383854,0.331109,1.099895,0.848264
Alvespimycin_(17-DMAG)_HCl,1.0,A549,0.950771,1.037687,-0.711093,-1.408819,-1.517525,0.334236,0.060676,-0.858752,0.751628,0.097697,...,-1.659082,-0.404462,-1.079366,-2.281739,3.016140,0.877623,0.739793,-0.039125,1.152653,0.635560
JNJ-26854165_(Serdemetan),2.0,A549,0.632765,1.045042,-0.616878,-1.288092,-1.851291,0.161847,-0.355149,-0.940658,0.505082,0.064010,...,-1.349064,-0.308401,-1.294433,-2.862194,3.571977,0.960563,0.434199,-0.087776,1.162492,0.732873
Ruxolitinib_(INCB018424),1.0,A549,0.739389,1.353861,-0.493939,-0.751243,-1.772465,0.030624,-0.420219,-0.855654,0.457615,-0.090353,...,-1.630281,-0.234347,-1.026674,-3.062254,3.278475,0.877921,0.532786,-0.363705,1.203457,0.661306
PJ34,3.0,A549,0.859568,1.167635,-0.432446,-1.463552,-1.947932,0.156941,-0.463507,-0.675633,0.628918,0.122065,...,-1.417213,-0.410866,-1.101759,-2.478868,3.383358,0.713623,0.529841,-0.348254,1.121087,0.387284
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
JNJ-26854165_(Serdemetan),1.0,K562,-4.077663,3.203194,-1.156101,-0.283768,2.327652,2.615307,-3.327886,2.222895,2.281304,0.381888,...,2.625293,2.341361,2.521572,0.924272,-2.724225,1.634744,-1.590197,2.678515,0.446779,0.348494
Capecitabine_,2.0,K562,-4.016849,3.484368,-1.134743,0.012245,2.104794,2.830555,-3.187905,2.427397,2.453857,0.225519,...,2.262397,2.419325,2.571800,1.466806,-3.309803,1.479446,-1.166994,2.643369,0.620547,0.115589
Vehicle,0.0,A549,0.508402,1.289773,-0.452384,-1.043944,-2.052598,0.045271,-0.353646,-0.920775,0.644023,-0.089528,...,-1.771869,-0.469007,-1.372862,-2.400884,3.259788,0.944513,0.827397,-0.233044,1.330682,0.511168
Vehicle,0.0,MCF7,-1.904997,2.214967,0.141451,0.015483,2.739381,-1.437222,0.130280,-0.114577,0.968864,1.409276,...,-0.926935,-0.844184,-0.051661,1.578361,0.978406,0.903609,1.961412,1.814589,0.468807,0.670851


In [71]:
from typing import Any

import anndata as ad
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from adjustText import adjust_text

from cfp import _constants
from cfp.model import CellFlow
from cfp.plotting._utils import _get_colors, _input_to_adata, get_plotting_vars

ModuleNotFoundError: No module named 'adjustText'

In [None]:
df = pd.DataFrame(emb, columns=["dim1", "dim2"])
if labels is not None:
    if labels_name is None:
        labels_name = "labels"
    df[labels_name] = labels

fig = plt.figure(figsize=(fig_width, fig_height))
ax = plt.gca()

sns.despine(left=False, bottom=False, right=True)

if (col_dict is None) and labels is not None:
    col_dict = _get_colors(labels)

sns.scatterplot(
    x="dim1",
    y="dim2",
    hue=labels_name,
    palette=col_dict,
    alpha=circe_transparency,
    edgecolor="none",
    s=circle_size,
    data=df,
    ax=ax,
)

if show_lines:
    for i in range(len(emb)):
        if col_dict is None:
            ax.plot(
                [0, emb[i, 0]],
                [0, emb[i, 1]],
                alpha=line_transparency,
                linewidth=line_width,
                c=None,
            )
        else:
            ax.plot(
                [0, emb[i, 0]],
                [0, emb[i, 1]],
                alpha=line_transparency,
                linewidth=line_width,
                c=col_dict[labels[i]],
            )

if show_text and labels is not None:
    texts = []
    labels = np.array(labels)
    unique_labels = np.unique(labels)
    for label in unique_labels:
        idx_label = np.where(labels == label)[0]
        texts.append(
            ax.text(
                np.mean(emb[idx_label, 0]),
                np.mean(emb[idx_label, 1]),
                label,
                fontsize=fontsize,
            )
        )

    adjust_text(
        texts,
        arrowprops=dict(arrowstyle="-", color="black", lw=0.1),  # noqa: C408
        ax=ax,
    )

if axis_equal:
    ax.axis("equal")
    ax.axis("square")

if title:
    ax.set_title(title, fontsize=fontsize, fontweight="bold")

ax.set_xlabel("dim1", fontsize=fontsize)
ax.set_ylabel("dim2", fontsize=fontsize)
ax.xaxis.set_tick_params(labelsize=fontsize)
ax.yaxis.set_tick_params(labelsize=fontsize)
