In [1]:
%load_ext autoreload
%autoreload 2


import numpy as np
import pandas as pd
import anndata as ad
import numpy as np
import pandas as pd
import jax.numpy as jnp
import jax
jax.config.update("jax_default_device", jax.devices()[0])

In [None]:
def create_synthetic_data(
    n_genes=50,
    n_control_cells=40,
    n_drugs=6,
    dosages_per_drug=3,
    cells_per_condition=50,
    n_cell_types=30,
    cell_type_embed_dim=50,
    drug_embed_dim=50,
):
    """
    Creates a synthetic AnnData object with multiple dosages per drug.

    Parameters
    ----------
    n_genes : int
        Number of genes to simulate
    n_control_cells : int
        Number of control cells per cell type
    n_drugs : int
        Total number of distinct drugs
    dosages_per_drug : int
        Number of different dosages for each drug
    cells_per_condition : int
        Number of cells per drug-dosage condition
    n_cell_types : int
        Number of cell types
    cell_type_embed_dim : int
        Embedding dimension for cell types
    drug_embed_dim : int
        Embedding dimension for drugs

    Returns
    -------
    dict
        Dictionary containing all DataManager parameters
    """

    # Create cell type names
    n_batches = n_cell_types
    cell_type_names = [f"cell_line_{chr(97 + i)}" for i in range(n_cell_types)]
    batch_names = [f"batch_{i+1}" for i in range(n_batches)]  # New: create batch names

    # Calculate total cells
    total_conditions = n_drugs * dosages_per_drug  # Total conditions excluding control
    total_cells_per_type = n_control_cells + (total_conditions * cells_per_condition)
    n_cells = n_cell_types * total_cells_per_type

    # Initialize lists for observation data
    cell_type_list = []
    control_list = []
    drug_list = []
    dosage_list = []
    batch_list = []

    # Generate data for each cell type
    # shuffle batch_names
    batch_names = np.random.permutation(batch_names)

    for i,cell_type in enumerate(cell_type_names):
        # Add control cells for this cell type
        cell_type_list.extend([cell_type] * total_cells_per_type)
        # Controls
        control_list.extend([True] * n_control_cells)
        drug_list.extend(["control"] * n_control_cells)
        dosage_list.extend([0.0] * n_control_cells)
        # batch_assignments = np.random.choice(batch_names[i], size=total_cells_per_type)
        batch_list.extend([batch_names[i]] * total_cells_per_type)
        # ensure that at lease one type is present 


        # Add perturbed cells for each drug-dosage combination
        control_list.extend([False] * (total_conditions * cells_per_condition))

        # Add drug-dosage combinations
        for drug_idx in range(1, n_drugs + 1):
            for dosage_idx in range(1, dosages_per_drug + 1):
                # Calculate dosage value (e.g., 0.1, 0.5, 1.0)
                dosage_value = dosage_idx / dosages_per_drug  # Normalize to [0,1] range

                # Add this drug-dosage combination
                drug_list.extend([f"drug{drug_idx}"] * cells_per_condition)
                dosage_list.extend([dosage_value] * cells_per_condition)

    # Generate random expression data
    X = np.random.normal(size=(n_cells, n_genes))

    # Create observation DataFrame
    obs = pd.DataFrame(
        {
            "control": control_list,
            "cell_type": pd.Categorical(cell_type_list),
            "drug": pd.Categorical(drug_list),
            "dosage": dosage_list,
            "batch": pd.Categorical(batch_list),
        }
    )



    # Create AnnData object
    adata = ad.AnnData(X, obs=obs)

    # Add representations to uns (for covariate embeddings)
    adata.uns["drug"] = {
        "control": np.zeros(drug_embed_dim),
    }

    # Add drug embeddings
    for i in range(1, n_drugs + 1):
        adata.uns["drug"][f"drug{i}"] = np.random.normal(size=(drug_embed_dim,))

    # Add cell type embeddings
    adata.uns["cell_type"] = {}
    adata.uns["batch"] = {}
    for cell_type in cell_type_names:
        adata.uns["cell_type"][cell_type] = np.random.normal(
            size=(cell_type_embed_dim,)
        )
    for batch in batch_names:
        adata.uns["batch"][batch] = np.random.normal(
            size=(cell_type_embed_dim,)
        )

    # Define parameters for DataManager
    sample_rep = "X"
    control_key = "control"
    split_covariates = ["cell_type","batch"]

    # Here we use a simpler structure with just one drug and dosage column
    # perturbation_covariates = {"drug": ["drug"], "dosage": ["dosage"]}
    # perturbation_covariate_reps = {"drug": "drug"}
    # sample_covariates = ["cell_type", "batch"]
    # sample_covariate_reps = {"cell_type": "cell_type", "batch": "batch"}
    perturbation_covariates = {"drug": ["drug"], "treatment": ["treatment"]}
    


    # Return a dictionary with all required parameters
    return {
        "adata": adata,
        "sample_rep": sample_rep,
        "control_key": control_key,
        "split_covariates": split_covariates,
        "perturbation_covariates": perturbation_covariates,
        "perturbation_covariate_reps": perturbation_covariate_reps,
        "sample_covariates": sample_covariates,
        "sample_covariate_reps": sample_covariate_reps,
    }


In [3]:
# Now you can initialize the DataManager
from cellflow.data._datamanager import DataManager

dm_args = create_synthetic_data()

dm = DataManager(
    **dm_args    
)



In [4]:
from cellflow.model._cellflow import CellFlow

cf = CellFlow(adata=dm_args["adata"],) 

In [5]:
import functools

import cellflow
import scanpy as sc
import numpy as np
import functools
from ott.solvers import utils as solver_utils
import optax
import anndata as ad

In [6]:
dm_args

{'adata': AnnData object with n_obs × n_vars = 28200 × 50
     obs: 'control', 'cell_type', 'drug', 'dosage', 'batch'
     uns: 'drug', 'cell_type', 'batch',
 'sample_rep': 'X',
 'control_key': 'control',
 'split_covariates': ['cell_type', 'batch'],
 'perturbation_covariates': {'drug': ['drug'], 'dosage': ['dosage']},
 'perturbation_covariate_reps': {'drug': 'drug'},
 'sample_covariates': ['cell_type', 'batch'],
 'sample_covariate_reps': {'cell_type': 'cell_type', 'batch': 'batch'}}

In [7]:
perturbation_covariates = dm_args["perturbation_covariates"]
perturbation_covariate_reps = dm_args["perturbation_covariate_reps"]
sample_covariates = dm_args["sample_covariates"]
sample_covariate_reps = dm_args["sample_covariate_reps"]


In [8]:
adata = dm_args["adata"]

In [9]:
cf.prepare_data(
    sample_rep="X",
    control_key="control",
    perturbation_covariates=perturbation_covariates,
    perturbation_covariate_reps=perturbation_covariate_reps,
    sample_covariates=sample_covariates,
    sample_covariate_reps=sample_covariate_reps,
    split_covariates=dm_args["split_covariates"],
)
print("Finished preparing data")

[########################################] | 100% Completed | 108.06 ms
[########################################] | 100% Completed | 104.61 ms
[########################################] | 100% Completed | 1.37 sms
Finished preparing data


In [10]:
# condition_embedding_dim: 256
# time_encoder_dims: [2048, 2048, 2048]
# time_encoder_dropout: 0.0
# hidden_dims: [2048, 2048, 2048]
# hidden_dropout: 0.0
# decoder_dims: [4096, 4096, 4096]
# decoder_dropout: 0.2
# pooling: "mean"
# layers_before_pool: 
#   drugs:
#     layer_type: mlp
#     dims: [1024, 1024]
#     dropout_rate: 0.5
#   dose:
#     layer_type: mlp
#     dims: [256, 256]
#     dropout_rate: 0.0
#   cell_line:
#     layer_type: mlp
#     dims: [1024, 1024]
#     dropout_rate: 0.2
# layers_after_pool:
#   layer_type: mlp
#   dims: [1024, 1024]
#   dropout_rate: 0.2
# cond_output_dropout: 0.9
# time_freqs: 1024
# flow_noise: 1.0
# learning_rate: 0.00005
# multi_steps: 20
# epsilon: 1.0
# tau_a: 1.0
# tau_b: 1.0
# flow_type: "constant_noise"
# linear_projection_before_concatenation: False
# layer_norm_before_concatenation: False

In [11]:



match_fn = functools.partial(
    solver_utils.match_linear,
    epsilon=1.0,
    scale_cost="mean",
    tau_a=1.0,
    tau_b=1.0,
)
optimizer = optax.MultiSteps(optax.adam(0.00005), 20)
flow = {
        "constant_noise": 1.0}
layers_before_pool = {
    "drug": {
        "layer_type": "mlp",
        "dims": [1024, 1024],
        "dropout_rate": 0.5,
    },
    "dosage": {
        "layer_type": "mlp",
        "dims": [256, 256],
        "dropout_rate": 0.0,
    },
    "cell_type": {
        "layer_type": "mlp",
        "dims": [1024, 1024],
        "dropout_rate": 0.2,
    },
    "batch": {
        "layer_type": "mlp",
        "dims": [1024, 1024],
        "dropout_rate": 0.2,
    },
}
layers_after_pool = {
    "layer_type": "mlp",
    "dims": [1024, 1024],
    "dropout_rate": 0.2,
}
condition_embedding_dim = 256
pooling = "mean"
time_encoder_dims = [2048, 2048, 2048]
time_encoder_dropout = 0.0
hidden_dims = [2048, 2048, 2048]
hidden_dropout = 0.0
decoder_dims = [4096, 4096, 4096]
decoder_dropout = 0.2
cond_output_dropout = 0.9
time_freqs = 1024
layer_norm_before_concatenation = False
linear_projection_before_concatenation = False
# Prepare the model
print("Preparing model...")
cf.prepare_model(
    encode_conditions=True,
    condition_embedding_dim=condition_embedding_dim,
    pooling=pooling,
    time_encoder_dims=time_encoder_dims,
    time_encoder_dropout=time_encoder_dropout,
    hidden_dims=hidden_dims,
    hidden_dropout=hidden_dropout,
    decoder_dims=decoder_dims,
    decoder_dropout=decoder_dropout,
    layers_before_pool=layers_before_pool,
    layers_after_pool=layers_after_pool,
    cond_output_dropout=cond_output_dropout,
    time_freqs=time_freqs,
    match_fn=match_fn,
    optimizer=optimizer,
    flow=flow,
    layer_norm_before_concatenation=False,
    linear_projection_before_concatenation=False,
)
print("Begin training")


Preparing model...
Begin training


In [12]:

cf.prepare_validation_data(
    adata,
    name="train",
    n_conditions_on_log_iteration=10,
    n_conditions_on_train_end=10,
)

cf.prepare_validation_data(
    adata,
    name="test",
    n_conditions_on_log_iteration=None,
    n_conditions_on_train_end=None,
)

[########################################] | 100% Completed | 110.13 ms
[########################################] | 100% Completed | 1.27 sms
[########################################] | 100% Completed | 103.69 ms
[########################################] | 100% Completed | 1.16 sms


In [13]:
import cellflow

In [14]:
metrics_callback = cellflow.training.Metrics(metrics=["r_squared", "mmd", "e_distance"])

# we don't pass the wandb_callback as it requires a user-specific account, but recommend setting it up
callbacks = [metrics_callback]

In [None]:
cf.train(
    num_iterations=1,
    batch_size=100,
    callbacks=callbacks,
    valid_freq=100,
    prefetch_factor=3,
    num_workers=8,
)

100%|██████████| 1/1 [00:03<00:00,  3.45s/it]


In [None]:
self = cf.train_data.data_manager

In [None]:
sorted(set(self._split_covariates + self._perturb_covar_keys))

['batch', 'cell_type', 'dosage', 'drug']