In [1]:
%load_ext autoreload
%autoreload 2


import numpy as np
import pandas as pd
import anndata as ad
import numpy as np
import pandas as pd
import jax.numpy as jnp
import jax
jax.config.update("jax_default_device", jax.devices()[0])

In [2]:
def create_synthetic_data(
    n_genes=50,
    n_control_cells=40,
    n_drugs=6,
    dosages_per_drug=3,
    cells_per_condition=50,
    n_cell_types=30,
    cell_type_embed_dim=50,
    drug_embed_dim=50,
):
    """
    Creates a synthetic AnnData object with multiple dosages per drug.

    Parameters
    ----------
    n_genes : int
        Number of genes to simulate
    n_control_cells : int
        Number of control cells per cell type
    n_drugs : int
        Total number of distinct drugs
    dosages_per_drug : int
        Number of different dosages for each drug
    cells_per_condition : int
        Number of cells per drug-dosage condition
    n_cell_types : int
        Number of cell types
    cell_type_embed_dim : int
        Embedding dimension for cell types
    drug_embed_dim : int
        Embedding dimension for drugs

    Returns
    -------
    dict
        Dictionary containing all DataManager parameters
    """

    # Create cell type names
    n_batches = n_cell_types
    cell_type_names = [f"cell_line_{chr(97 + i)}" for i in range(n_cell_types)]
    batch_names = [f"batch_{i+1}" for i in range(n_batches)]  # New: create batch names

    # Calculate total cells
    total_conditions = n_drugs * dosages_per_drug  # Total conditions excluding control
    total_cells_per_type = n_control_cells + (total_conditions * cells_per_condition)
    n_cells = n_cell_types * total_cells_per_type

    # Initialize lists for observation data
    cell_type_list = []
    control_list = []
    drug_list = []
    dosage_list = []
    batch_list = []

    # Generate data for each cell type
    # shuffle batch_names
    batch_names = np.random.permutation(batch_names)

    for i,cell_type in enumerate(cell_type_names):
        # Add control cells for this cell type
        cell_type_list.extend([cell_type] * total_cells_per_type)
        # Controls
        control_list.extend([True] * n_control_cells)
        drug_list.extend(["control"] * n_control_cells)
        dosage_list.extend([0.0] * n_control_cells)
        # batch_assignments = np.random.choice(batch_names[i], size=total_cells_per_type)
        batch_list.extend([batch_names[i]] * total_cells_per_type)
        # ensure that at lease one type is present 


        # Add perturbed cells for each drug-dosage combination
        control_list.extend([False] * (total_conditions * cells_per_condition))

        # Add drug-dosage combinations
        for drug_idx in range(1, n_drugs + 1):
            for dosage_idx in range(1, dosages_per_drug + 1):
                # Calculate dosage value (e.g., 0.1, 0.5, 1.0)
                dosage_value = dosage_idx / dosages_per_drug  # Normalize to [0,1] range

                # Add this drug-dosage combination
                drug_list.extend([f"drug{drug_idx}"] * cells_per_condition)
                dosage_list.extend([dosage_value] * cells_per_condition)

    # Generate random expression data
    X = np.random.normal(size=(n_cells, n_genes))

    # Create observation DataFrame
    obs = pd.DataFrame(
        {
            "control": control_list,
            "cell_type": pd.Categorical(cell_type_list),
            "drug": pd.Categorical(drug_list),
            "dosage": dosage_list,
            # "batch": pd.Categorical(batch_list),
        }
    )



    # Create AnnData object
    adata = ad.AnnData(X, obs=obs)

    # Add representations to uns (for covariate embeddings)
    adata.uns["drug"] = {
        "control": np.zeros(drug_embed_dim),
    }

    # Add drug embeddings
    for i in range(1, n_drugs + 1):
        adata.uns["drug"][f"drug{i}"] = np.random.normal(size=(drug_embed_dim,))

    # Add cell type embeddings
    adata.uns["cell_type"] = {}
    # adata.uns["batch"] = {}
    for cell_type in cell_type_names:
        adata.uns["cell_type"][cell_type] = np.random.normal(
            size=(cell_type_embed_dim,)
        )
    # for batch in batch_names:
    #     adata.uns["batch"][batch] = np.random.normal(
    #         size=(cell_type_embed_dim,)
    #     )

    # Define parameters for DataManager
    sample_rep = "X"
    control_key = "control"
    split_covariates = ["cell_type",]

    # Here we use a simpler structure with just one drug and dosage column
    perturbation_covariates = {"drug": ["drug"], "dosage": ["dosage"]}
    perturbation_covariate_reps = {"drug": "drug"}
    sample_covariates = ["cell_type", ]
    sample_covariate_reps = {"cell_type": "cell_type", }



    # Return a dictionary with all required parameters
    return {
        "adata": adata,
        "sample_rep": sample_rep,
        "control_key": control_key,
        "split_covariates": split_covariates,
        "perturbation_covariates": perturbation_covariates,
        "perturbation_covariate_reps": perturbation_covariate_reps,
        "sample_covariates": sample_covariates,
        "sample_covariate_reps": sample_covariate_reps,
    }


In [3]:
# Now you can initialize the DataManager
from cellflow.data._datamanager import DataManager

dm_args = create_synthetic_data()

dm = DataManager(
    **dm_args    
)



In [4]:
from cellflow.model._cellflow import CellFlow

cf = CellFlow(adata=dm_args["adata"],) 

In [6]:
import functools

import cellflow
import scanpy as sc
import numpy as np
import functools
from ott.solvers import utils as solver_utils
import optax
import anndata as ad

In [7]:
dm_args

{'adata': AnnData object with n_obs × n_vars = 28200 × 50
     obs: 'control', 'cell_type', 'drug', 'dosage'
     uns: 'drug', 'cell_type',
 'sample_rep': 'X',
 'control_key': 'control',
 'split_covariates': ['cell_type'],
 'perturbation_covariates': {'drug': ['drug'], 'dosage': ['dosage']},
 'perturbation_covariate_reps': {'drug': 'drug'},
 'sample_covariates': ['cell_type'],
 'sample_covariate_reps': {'cell_type': 'cell_type'}}

In [8]:
perturbation_covariates = dm_args["perturbation_covariates"]
perturbation_covariate_reps = dm_args["perturbation_covariate_reps"]
sample_covariates = dm_args["sample_covariates"]
sample_covariate_reps = dm_args["sample_covariate_reps"]


In [9]:
adata = dm_args["adata"]

In [10]:
cf.prepare_data(
    sample_rep="X",
    control_key="control",
    perturbation_covariates=perturbation_covariates,
    perturbation_covariate_reps=perturbation_covariate_reps,
    sample_covariates=sample_covariates,
    sample_covariate_reps=sample_covariate_reps,
    split_covariates=dm_args["split_covariates"],
)
print("Finished preparing data")

[########################################] | 100% Completed | 105.84 ms
[########################################] | 100% Completed | 105.85 ms
[########################################] | 100% Completed | 412.98 ms
Finished preparing data


In [11]:
# condition_embedding_dim: 256
# time_encoder_dims: [2048, 2048, 2048]
# time_encoder_dropout: 0.0
# hidden_dims: [2048, 2048, 2048]
# hidden_dropout: 0.0
# decoder_dims: [4096, 4096, 4096]
# decoder_dropout: 0.2
# pooling: "mean"
# layers_before_pool: 
#   drugs:
#     layer_type: mlp
#     dims: [1024, 1024]
#     dropout_rate: 0.5
#   dose:
#     layer_type: mlp
#     dims: [256, 256]
#     dropout_rate: 0.0
#   cell_line:
#     layer_type: mlp
#     dims: [1024, 1024]
#     dropout_rate: 0.2
# layers_after_pool:
#   layer_type: mlp
#   dims: [1024, 1024]
#   dropout_rate: 0.2
# cond_output_dropout: 0.9
# time_freqs: 1024
# flow_noise: 1.0
# learning_rate: 0.00005
# multi_steps: 20
# epsilon: 1.0
# tau_a: 1.0
# tau_b: 1.0
# flow_type: "constant_noise"
# linear_projection_before_concatenation: False
# layer_norm_before_concatenation: False

In [12]:



match_fn = functools.partial(
    solver_utils.match_linear,
    epsilon=1.0,
    scale_cost="mean",
    tau_a=1.0,
    tau_b=1.0,
)
optimizer = optax.MultiSteps(optax.adam(0.00005), 20)
flow = {
        "constant_noise": 1.0}
layers_before_pool = {
    "drug": {
        "layer_type": "mlp",
        "dims": [1024, 1024],
        "dropout_rate": 0.5,
    },
    "dosage": {
        "layer_type": "mlp",
        "dims": [256, 256],
        "dropout_rate": 0.0,
    },
    "cell_type": {
        "layer_type": "mlp",
        "dims": [1024, 1024],
        "dropout_rate": 0.2,
    },
}
layers_after_pool = {
    "layer_type": "mlp",
    "dims": [1024, 1024],
    "dropout_rate": 0.2,
}
condition_embedding_dim = 256
pooling = "mean"
time_encoder_dims = [2048, 2048, 2048]
time_encoder_dropout = 0.0
hidden_dims = [2048, 2048, 2048]
hidden_dropout = 0.0
decoder_dims = [4096, 4096, 4096]
decoder_dropout = 0.2
cond_output_dropout = 0.9
time_freqs = 1024
layer_norm_before_concatenation = False
linear_projection_before_concatenation = False
# Prepare the model
print("Preparing model...")
cf.prepare_model(
    encode_conditions=True,
    condition_embedding_dim=condition_embedding_dim,
    pooling=pooling,
    time_encoder_dims=time_encoder_dims,
    time_encoder_dropout=time_encoder_dropout,
    hidden_dims=hidden_dims,
    hidden_dropout=hidden_dropout,
    decoder_dims=decoder_dims,
    decoder_dropout=decoder_dropout,
    layers_before_pool=layers_before_pool,
    layers_after_pool=layers_after_pool,
    cond_output_dropout=cond_output_dropout,
    time_freqs=time_freqs,
    match_fn=match_fn,
    optimizer=optimizer,
    flow=flow,
    layer_norm_before_concatenation=False,
    linear_projection_before_concatenation=False,
)
print("Begin training")


Preparing model...
Begin training


In [13]:

cf.prepare_validation_data(
    adata,
    name="train",
    n_conditions_on_log_iteration=10,
    n_conditions_on_train_end=10,
)

cf.prepare_validation_data(
    adata,
    name="test",
    n_conditions_on_log_iteration=None,
    n_conditions_on_train_end=None,
)

[########################################] | 100% Completed | 105.77 ms
[########################################] | 100% Completed | 309.39 ms
[########################################] | 100% Completed | 104.98 ms
[########################################] | 100% Completed | 399.56 ms


In [14]:
len(cf._validation_data["test"].perturbation_idx_to_covariates)

540

In [15]:
import cellflow

In [16]:
metrics_callback = cellflow.training.Metrics(metrics=["r_squared", "mmd", "e_distance"])

# we don't pass the wandb_callback as it requires a user-specific account, but recommend setting it up
callbacks = [metrics_callback]

In [17]:
cf.solver.vf_state.params['condition_encoder']['after_pool_modules_mean_0']['kernel'].devices()

{CpuDevice(id=0)}

In [23]:
from cellflow.data._dataloader import CpuTrainSampler
dataloader = CpuTrainSampler(data=cf.train_data, batch_size=100)

In [65]:
import queue
import threading
import time  # Add this import at the top of the file
from collections.abc import Sequence
from typing import Any, Literal
import jax.numpy as jnp
import jax
import numpy as np
from numpy.typing import ArrayLike
from tqdm import tqdm

from cellflow.data._dataloader import CpuTrainSampler, ValidationSampler
from cellflow.solvers import _genot, _otfm
from cellflow.training._callbacks import BaseCallback, CallbackRunner


def prefetch_to_device(sampler, num_iterations, prefetch_factor=2, num_workers=4):
    seed = 42  # Set a fixed seed for reproducibility
    seq = np.random.SeedSequence(seed)
    random_generators = [np.random.default_rng(s) for s in seq.spawn(num_workers)]

    q = queue.Queue(maxsize=prefetch_factor*num_workers)
    sem = threading.Semaphore(num_iterations)
    stop_event = threading.Event()
    def worker(rng):
        while not stop_event.is_set() and sem.acquire(blocking=False):
            batch = sampler.sample(rng)
            batch = jax.device_put(batch, jax.devices()[0], donate=True)
            jax.block_until_ready(batch)
            while not stop_event.is_set():
                try:
                    q.put(batch, timeout=1.0)
                    break  # Batch successfully put into the queue; break out of retry loop
                except queue.Full:
                    continue

        return

    # Start multiple worker threads
    ts = []
    for i in range(num_workers):
        t = threading.Thread(target=worker, daemon=True, name=f"worker-{i}", args=(random_generators[i], ))
        t.start()
        ts.append(t)

    try:
        for _ in range(num_iterations):
            # Yield batches from the queue; will block waiting for available batch
            yield q.get()
    finally:
        # When the generator is closed or garbage collected, clean up the worker threads
        stop_event.set()  # Signal all workers to exit
        for t in ts:
            t.join()  # Wait for all worker threads to finish




In [None]:
prefetch_to_device(dataloader, num_iterations=100, prefetch_factor=2, num_workers=32)

In [38]:
seed = 1
rng = np.random.default_rng(np.random.SeedSequence(seed + 10000))

In [64]:
dataloader.sample(rng)

{'src_cell_data': array([[ 2.1629744 ,  0.2553238 , -0.37207523, ..., -0.5949909 ,
          1.7239684 ,  0.58037424],
        [-0.21306144,  1.1288933 , -0.28751445, ..., -0.435637  ,
         -1.5042096 ,  1.4034455 ],
        [-0.21306144,  1.1288933 , -0.28751445, ..., -0.435637  ,
         -1.5042096 ,  1.4034455 ],
        ...,
        [-0.41723475, -0.1516465 , -1.0242379 , ..., -0.55732745,
          0.2977195 ,  0.6932218 ],
        [-0.2725472 ,  0.29985598,  0.515528  , ..., -0.830595  ,
         -0.44936737, -0.1041475 ],
        [-0.6362564 , -0.7889936 ,  0.09297998, ..., -0.0373694 ,
         -0.7823192 , -0.21801941]], shape=(100, 50), dtype=float32),
 'tgt_cell_data': array([[ 0.25597256, -0.09510171, -0.51105815, ...,  0.09576876,
          1.027608  , -1.3337334 ],
        [ 0.70467055, -0.7783663 ,  1.755893  , ...,  0.26792753,
          2.07906   ,  0.3890477 ],
        [ 1.0986998 , -1.2486377 , -0.17665577, ...,  0.3352405 ,
          0.36951387, -0.71817696],
 

In [20]:
from cellflow.data._dataloader import ValidationSampler
validation_loaders = {k: ValidationSampler(v) for k, v in cf.validation_data.items()}

In [22]:
validation_loaders['train']

<cellflow.data._dataloader.ValidationSampler at 0x33f4254d0>

In [49]:
cf.train(
    num_iterations=10,
    batch_size=1024,
    callbacks=callbacks,
    valid_freq=100,
    prefetch_factor=3,
    num_workers=8,
)

100%|██████████| 10/10 [00:01<00:00,  5.57it/s]


HERE IT IS DONE
dict_keys([('drug1', 0.3333333333333333, 'cell_line_i'), ('drug1', 0.6666666666666666, 'cell_line_z'), ('drug1', 1.0, 'cell_line_j'), ('drug1', 1.0, 'cell_line_p'), ('drug2', 0.3333333333333333, 'cell_line_b'), ('drug2', 0.6666666666666666, 'cell_line_c'), ('drug2', 0.6666666666666666, 'cell_line_f'), ('drug3', 0.6666666666666666, 'cell_line_y'), ('drug3', 1.0, 'cell_line_a'), ('drug5', 1.0, 'cell_line_s')])
dict_keys([('drug1', 0.3333333333333333, 'cell_line_i'), ('drug1', 0.6666666666666666, 'cell_line_z'), ('drug1', 1.0, 'cell_line_j'), ('drug1', 1.0, 'cell_line_p'), ('drug2', 0.3333333333333333, 'cell_line_b'), ('drug2', 0.6666666666666666, 'cell_line_c'), ('drug2', 0.6666666666666666, 'cell_line_f'), ('drug3', 0.6666666666666666, 'cell_line_y'), ('drug3', 1.0, 'cell_line_a'), ('drug5', 1.0, 'cell_line_s')])
(400, 50)
(1, 1010)
dict_keys([('drug1', 0.3333333333333333, 'cell_line_a'), ('drug1', 0.3333333333333333, 'cell_line_b'), ('drug1', 0.3333333333333333, 'cell_l