In [37]:
import cfp
import scanpy as sc
import numpy as np
import pandas as pd
import anndata as ad
import os 
import pickle
import yaml
import functools
from ott.solvers import utils as solver_utils
import optax

In [29]:
with open('/home/icb/alejandro.tejada/ot_pert_reproducibility/runs_otfm/conf/model/zebrafish.yaml', 'r') as file:
    model = yaml.safe_load(file)

with open('/home/icb/alejandro.tejada/ot_pert_reproducibility/runs_otfm/conf/dataset/zebrafish.yaml', 'r') as file:
    dataset = yaml.safe_load(file)

In [2]:
def prepare_data(adata_train, adata_test, adata_ood, path_dict):

    adata_train.obs['control'] = False
    adata_test.obs['control'] = False
    adata_ood.obs['control'] = False

    adata_train.obs['logtimepoint'] = np.log(pd.to_numeric(adata_train.obs['timepoint']))
    adata_test.obs['logtimepoint'] = np.log(pd.to_numeric(adata_test.obs['timepoint']))
    adata_ood.obs['logtimepoint'] = np.log(pd.to_numeric(adata_ood.obs['timepoint']))

    adata_train.obs.loc[(adata_train.obs['gene1+gene2'] == 'negative+negative') & (adata_train.obs['timepoint'] == '18'), 'control'] = True 

    adata_test.obs.loc[(adata_test.obs['gene1+gene2'] == 'negative+negative') & (adata_test.obs['timepoint'] == '18'), 'control'] = True 

    adata_ood.obs.loc[(adata_ood.obs['gene1+gene2'] == 'negative+negative') & (adata_ood.obs['timepoint'] == '18'), 'control'] = True 
    
    adata_ood = ad.concat((adata_ood, adata_test[adata_test.obs.control == True])) # add controls to ood

    file_path = os.path.join(path_dict)

    with open(file_path, 'rb') as file:
        gene_dict = pickle.load(file) 

    adata_train.uns['crispr_rep'] = gene_dict
    adata_test.uns['crispr_rep'] = gene_dict
    adata_ood.uns['crispr_rep'] = gene_dict

    return adata_train, adata_test, adata_ood

In [3]:
adata_train_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/zebrafish/adata_train.h5ad"
adata_test_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/zebrafish/adata_test.h5ad"
adata_ood_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/zebrafish/adata_ood.h5ad"
embedding_dict = "/lustre/groups/ml01/workspace/ot_perturbation/data/zebrafish/ESM2_embeddings.pkl"

In [None]:
adata_train = sc.read_h5ad(adata_train_path)
adata_test = sc.read_h5ad(adata_test_path)
adata_ood = sc.read_h5ad(adata_ood_path)

del adata_train.obsm['emb_1'], adata_train.obsm['emb_2'], adata_test.obsm['emb_1'], adata_test.obsm['emb_2'], adata_ood.obsm['emb_1'], adata_ood.obsm['emb_2']

adata_train.varm["X_mean"] = adata_train.varm["X_train_mean"]
adata_test.varm["X_mean"] = adata_test.varm["X_train_mean"]
adata_ood.varm["X_mean"] = adata_ood.varm["X_train_mean"]

del adata_train.varm["X_train_mean"], adata_test.varm["X_train_mean"], adata_ood.varm["X_train_mean"]
    

In [9]:
adata_train, adata_test, adata_ood = prepare_data(adata_train, adata_test, adata_ood, embedding_dict)

adata_train.obsm["X_pca_use"] = adata_train.obsm["X_pca"]
adata_test.obsm["X_pca_use"] = adata_test.obsm["X_pca"]
adata_ood.obsm["X_pca_use"] = adata_ood.obsm["X_pca"]


  utils.warn_names_duplicates("obs")


In [21]:
cf = cfp.model.CellFlow(adata_train, solver="otfm")


In [30]:
# Prepare the training data and perturbation conditions
perturbation_covariates = {k: tuple(v) for k, v in dataset["perturbation_covariates"].items()} # gene1, gene2
cf.prepare_data(
    sample_rep="X_pca_use",
    control_key="control",
    perturbation_covariates=perturbation_covariates,
    perturbation_covariate_reps=dict(dataset["perturbation_covariate_reps"]),    )

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  _covariate_data["cell_index"] = _covariate_data.index
100%|██████████| 21/21 [00:03<00:00,  5.93it/s]


In [38]:
match_fn = functools.partial(
        solver_utils.match_linear,
        epsilon=model["epsilon"],
        scale_cost="mean",
        tau_a=model["tau_a"],
        tau_b=model["tau_b"]
    )

In [39]:
optimizer = optax.MultiSteps(optax.adam(model["learning_rate"]), model["multi_steps"])
flow = {model["flow_type"]: model["flow_noise"]}

layers_before_pool = model["layers_before_pool"]
layers_after_pool = model["layers_after_pool"]

In [41]:
# Prepare the model
cf.prepare_model(
    encode_conditions=True,
    condition_embedding_dim=model["condition_embedding_dim"],
    time_encoder_dims=model["time_encoder_dims"],
    time_encoder_dropout=model["time_encoder_dropout"],
    hidden_dims=model["hidden_dims"],
    hidden_dropout=model["hidden_dropout"],
    decoder_dims=model["decoder_dims"],
    decoder_dropout=model["decoder_dropout"],
    pooling=model["pooling"],
    layers_before_pool=layers_before_pool,
    layers_after_pool=layers_after_pool,
    cond_output_dropout=model["cond_output_dropout"],
    time_freqs=model["time_freqs"],
    match_fn=match_fn,
    optimizer=optimizer,
    flow=flow,
)

In [46]:
ckpt = "/lustre/groups/ml01/workspace/cell_flow_perturbation/zebrafish/CellFlow.pkl"

In [73]:
cf = cf.load(ckpt)

In [63]:
adata_ctrl = adata_train[(adata_train.obs['gene1+gene2'] == 'negative+negative') & (adata_train.obs.timepoint=='18')]

In [84]:
m = adata_ctrl.shape[0]

test_condition_df = pd.DataFrame({
    'gene1': ['negative'] * m,
    'gene2': ['negative'] * m,
    'logtimepoint': [3.871201] * m
})


In [85]:
# Make predictions
X_pca_pred = cf.predict(
    adata_ctrl,
    covariate_data=test_condition_df,
    sample_rep="X_pca_use",
)

100%|██████████| 1/1 [00:00<00:00, 232.00it/s]


In [86]:
for key, value in X_pca_pred.items():
    print(key)

('negative', 'negative')


In [82]:
X_pca_pred[('negative', 'negative')]

array([[-1.38737932e-02, -6.97984397e-02, -6.22924685e-01, ...,
         3.24704319e-01, -2.05856562e-01,  4.68166731e-02],
       [ 2.96986639e-01, -1.32533297e-01, -7.01308668e-01, ...,
         1.09518096e-01,  3.01962703e-01, -3.95097733e-01],
       [ 9.71315444e-01, -3.21484655e-01, -4.55359876e-01, ...,
         2.24381387e-01, -2.26696627e-03,  1.69605568e-01],
       ...,
       [ 1.03912401e+00, -5.24812639e-01,  2.18347326e-01, ...,
        -1.00334153e-01, -5.24305217e-02, -7.42635578e-02],
       [-7.29464531e-01, -1.31363618e+00,  7.38438249e-01, ...,
         1.37648329e-01,  1.00623734e-01, -2.88038373e-01],
       [-2.12585020e+00, -2.45961952e+00,  3.79014421e+00, ...,
         2.32681543e-01, -6.30738679e-03,  8.80552903e-02]], dtype=float32)

In [87]:
X_pca_pred[('negative', 'negative')]

array([[-1.38737932e-02, -6.97984397e-02, -6.22924685e-01, ...,
         3.24704319e-01, -2.05856562e-01,  4.68166731e-02],
       [ 2.96986639e-01, -1.32533297e-01, -7.01308668e-01, ...,
         1.09518096e-01,  3.01962703e-01, -3.95097733e-01],
       [ 9.71315444e-01, -3.21484655e-01, -4.55359876e-01, ...,
         2.24381387e-01, -2.26696627e-03,  1.69605568e-01],
       ...,
       [ 1.03912401e+00, -5.24812639e-01,  2.18347326e-01, ...,
        -1.00334153e-01, -5.24305217e-02, -7.42635578e-02],
       [-7.29464531e-01, -1.31363618e+00,  7.38438249e-01, ...,
         1.37648329e-01,  1.00623734e-01, -2.88038373e-01],
       [-2.12585020e+00, -2.45961952e+00,  3.79014421e+00, ...,
         2.32681543e-01, -6.30738679e-03,  8.80552903e-02]], dtype=float32)

In [3]:
pq = pd.read_parquet('/lustre/groups/ml01/projects/2023_nicheformer/attention/male/train/tokens-0.parquet')

In [2]:
import pandas as pd

In [4]:
pq

Unnamed: 0,assay,specie,modality,CCF_acronym,X
0,7,6,4,LPO,"[12262, 11729, 10645, 13432, 8170, 8616, 3194,..."
1,7,6,4,OT,"[12262, 5037, 13961, 3483, 13262, 16091, 8286,..."
2,7,6,4,,"[630, 12434, 16016, 7885, 7554, 4601, 13961, 1..."
3,7,6,4,root,"[10073, 45, 12262, 7915, 394, 6009, 9627, 1234..."
4,7,6,4,MPO,"[13961, 9627, 10073, 6009, 394, 10146, 12738, ..."
...,...,...,...,...,...
8316,7,6,4,ADP,"[9213, 1669, 5610, 10889, 686, 14879, 5543, 48..."
8317,7,6,4,root,"[5543, 1669, 12464, 10573, 3457, 8532, 9239, 1..."
8318,7,6,4,aco,"[12262, 5543, 394, 6009, 362, 13944, 15011, 67..."
8319,7,6,4,SI,"[5543, 9213, 12235, 3194, 1332, 13111, 12738, ..."
