In [87]:
import anndata
import numpy as np
import pandas as pd

n_obs = 500
n_vars = 50
n_pca = 10

X_data = np.random.rand(n_obs, n_vars)

my_counts = np.random.rand(n_obs, n_vars)

X_pca = np.random.rand(n_obs, n_pca)

cell_lines = np.random.choice(["cell_line_a", "cell_line_b", "cell_line_c"], n_obs)
dosages = np.random.choice([10.0, 100.0, 1000.0], n_obs)
drugs = ["drug_a", "drug_b", "drug_c"]
drug1 = np.random.choice(drugs, n_obs)
drug2 = np.random.choice(drugs, n_obs)
drug3 = np.random.choice(drugs, n_obs)


obs_data = pd.DataFrame({"cell_type": cell_lines, "dosage": dosages, "drug1": drug1, "drug2": drug2, "drug3": drug3})

# Create an AnnData object
adata = anndata.AnnData(X=X_data, obs=obs_data)

adata.uns["cell_flow_conditions"] = {}

# Add the random data to .layers and .obsm
adata.layers["my_counts"] = my_counts
adata.obsm["X_pca"] = X_pca


control_idcs = np.random.choice(n_obs, n_obs // 10, replace=False)
for col in ["drug1", "drug2", "drug3"]:
    adata.obs.loc[[str(idx) for idx in control_idcs], col] = "control"


for col in adata.obs.columns:
    adata.obs[col] = adata.obs[col].astype("category")

drug_emb = {}
for drug in adata.obs["drug1"].cat.categories:
    drug_emb[drug] = np.random.randn(5, 1)
adata.uns[UNS_KEY_CONDITIONS]["drug"] = drug_emb

for drug in adata.obs["cell_type"].cat.categories:
    drug_emb[drug] = np.random.randn(3, 1)
adata.uns[UNS_KEY_CONDITIONS]["cell_type"] = drug_emb



In [94]:
import itertools
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from typing import Any, Literal

import anndata
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
import scipy.sparse as sp
from tqdm import tqdm

from cfp._constants import CONTROL_HELPER, UNS_KEY_CONDITIONS


@dataclass
class PerturbationData:
    cell_data: jax.Array  # (n_cells, n_features)
    split_covariates_mask: jax.Array  # (n_cells,), which cell assigned to which source distribution
    split_covariates_to_idx: dict[str, int]  # (n_sources,) dictionary explaining split_covariates_mask
    perturbation_covariates_mask: jax.Array  # (n_cells,), which cell assigned to which target distribution
    perturbation_covariates_to_idx: dict[str, int]  # (n_targets,), dictionary explaining perturbation_covariates_mask
    condition_data: jax.Array  # (n_targets,) all embeddings for conditions
    control_to_perturbation: dict[int, jax.Array]  # mapping from control idx to target distribution idcs

    @property
    def n_controls(self) -> int:
        """Returns the number of control covariate values."""
        return len(self.split_covariates_to_idx)

    @property
    def n_perturbed(self) -> int:
        """Returns the number of perturbation covariate combinations."""
        return len(self.perturbation_covariates_to_idx)

    def _format_params(self, fmt: Callable[[Any], str]) -> str:
        params = {"n_controls": self.n_controls, "n_perturbed": self.n_perturbed}
        return ", ".join(f"{name}={fmt(val)}" for name, val in params.items())

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}[{self._format_params(repr)}]"


def _get_cell_data(adata: anndata.AnnData, cell_data: Literal["X"] | dict[str, str]) -> jax.Array:
    if cell_data == "X":
        cell_data = adata.X
        if isinstance(cell_data, sp.csr_matrix):
            cell_data = jnp.asarray(cell_data.toarray())
        else:
            cell_data = jnp.asarray(cell_data)
    else:
        assert isinstance(cell_data, dict)
        assert "attr" in cell_data
        assert "key" in cell_data
        cell_data = jnp.asarray(getattr(adata, cell_data["attr"])[cell_data["key"]])
    return cell_data


def _verify_control_data(adata: anndata.AnnData, data: tuple[str, Any]):
    assert isinstance(data, tuple)
    assert len(data) == 2
    if data[0] not in adata.obs:
        raise ValueError(f"Control column {data[0]} not found in adata.obs.")
    assert data[0] in adata.obs
    assert isinstance(adata.obs[data[0]].dtype, pd.CategoricalDtype)
    if data[1] not in adata.obs[data[0]].cat.categories:
        raise ValueError(f"Control value {data[1]} not found in {data[0]}.")


def _check_shape(arr: float | np.ndarray) -> np.ndarray:
    if not hasattr(arr, "shape") or len(arr.shape) == 0:
        return np.ones((1, 1)) * arr

    if arr.ndim == 1:
        return arr[:, None]
    elif arr.ndim == 2:
        if arr.shape[0] == 1:
            return arr
        if arr.shape[1] == 1:
            return np.transpose(arr)
        raise ValueError("TODO, wrong shape.")
    elif arr.ndim > 2:
        raise ValueError("TODO. Too many dimensions.")

    raise ValueError("TODO. wrong data for embedding.")


def _get_perturbation_covariates(
    adata: anndata.AnnData,
    embedding_dict: dict[str, dict[str, np.ndarray]],
    obs_perturbation_covariates: Any,
    uns_perturbation_covariates: Any,
    max_combination_length: int,
) -> jax.Array:
    embeddings_no_combination = []
    embeddings_combinations = []
    for obs_group in obs_perturbation_covariates:
        obs_group_emb = []
        for obs_col in obs_group:
            values = list(adata.obs[obs_col].unique())
            if len(values) != 1:
                raise ValueError("Too many categories within distribution found")
            arr = jnp.asarray(adata.obs[obs_col].values[0])
            arr = _check_shape(arr)
            obs_group_emb.append(arr)
        if len(obs_group) == 1:
            embeddings_no_combination.append(obs_group_emb[0])
        else:
            embeddings_combinations.append(jnp.concatenate(obs_group_emb, axis=0))

    for uns_key, uns_group in uns_perturbation_covariates.items():
        uns_group_emb = []
        for obs_col in uns_group:
            values = list(adata.obs[obs_col].unique())
            if len(values) != 1:
                raise ValueError("Too many categories within distribution found")
            assert uns_key in embedding_dict
            assert isinstance(adata.uns[UNS_KEY_CONDITIONS][uns_key], dict)
            assert values[0] in embedding_dict[uns_key]
            arr = jnp.asarray(embedding_dict[uns_key][values[0]])
            arr = _check_shape(arr)
            uns_group_emb.append(arr)
        if len(uns_group) == 1:
            embeddings_no_combination.append(uns_group_emb[0])
        else:
            embeddings_combinations.append(jnp.concatenate(uns_group_emb, axis=0))

    conds_no_combination = jnp.tile(jnp.concatenate(embeddings_no_combination, axis=-1), (1, max_combination_length, 1))

    to_concat = []
    if len(conds_no_combination) > 0:
        to_concat.append(conds_no_combination)
    if len(embeddings_combinations) > 0:
        to_concat.append(jnp.array(embeddings_combinations))
    conds = jnp.concatenate(to_concat, axis=-1)
    return conds


def load_from_adata(
    adata: anndata.AnnData,
    cell_data: Literal["X"] | dict[str, str],
    control_data: tuple[str, Any],
    split_covariates: Sequence[str],
    obs_perturbation_covariates: Sequence[tuple[str, ...]],
    uns_perturbation_covariates: Sequence[dict[str, tuple[str, ...]]],
) -> PerturbationData:
    """Load cell data from an AnnData object.

    Args:
        adata: An :class:`~anndata.AnnData` object.
        cell_data: Where to read the cell data from. If of type :class:`dict`, the key
            "attr" should be present and the value should be an attribute of :class:`~anndata.AnnData`.
            The key `key` should be present and the value should be the key in the respective attribute
        control_data: Tuple of length 2 with first element defining the column in :class:`~anndata.AnnData`
          and second element defining the value in `adata.obs[control_data[0]]` used to define all control cells.
        split_covariates: Covariates in adata.obs to split all control cells into different control populations.
          The perturbed cells are also split according to these columns, but if an embedding for these covariates
          should be encoded in the model, the corresponding column should also be used in `obs_perturbation_covariates`
          or `uns_perturbation_covariates`.
        obs_perturbation_covariates: Tuples of covariates in adata.obs characterizing the perturbed cells (together
          with `split_covariates` and `uns_perturbation_covariates`) and encoded by the values as found in `adata.obs`. If a tuple contains more than
          one element, this is interpreted as a combination of covariates that should be treated as an unordered set.
        uns_perturbation_covariates: Dictionaries with keys in adata.uns[`UNS_KEY_CONDITION`] and values columns in adata.obs which characterize the perturbed cells (together
            with `split_covariates` and `obs_perturbation_covariates`) and encoded by the values as found in `adata.uns[`UNS_KEY_CONDITION`][uns_perturbation_covariates.keys()]`.
            If a value of the dictionary is a tuple with more than one element, this is interpreted as a combination of covariates that should be treated as an unordered set.

    Returns
    -------
        PerturbationData: Data container for the perturbation data.
    """
    # TODO(@MUCDK): add device to possibly only load to cpu
    if split_covariates is None or len(split_covariates) == 0:
        adata.obs[CONTROL_HELPER] = True
        adata.obs[CONTROL_HELPER] = adata.obs[CONTROL_HELPER].astype("category")
        split_covariates = [CONTROL_HELPER]
    _verify_control_data(adata, control_data)

    obs_combination_length = max(len(comb) for comb in obs_perturbation_covariates)
    uns_combination_length = max(len(comb) for comb in uns_perturbation_covariates.values())
    max_combination_length = max(obs_combination_length, uns_combination_length)

    if UNS_KEY_CONDITIONS not in adata.uns:
        adata.uns[UNS_KEY_CONDITIONS] = {}

    for covariate in split_covariates:
        assert covariate in adata.obs
        assert adata.obs[covariate].dtype.name == "category"

    src_dist = {covariate: adata.obs[covariate].cat.categories for covariate in split_covariates}
    tgt_dist_obs = {
        covariate: adata.obs[covariate].cat.categories for group in obs_perturbation_covariates for covariate in group
    }
    tgt_dist_uns = {
        covariate: adata.obs[covariate].cat.categories
        for emb_covariates in uns_perturbation_covariates.values()
        for covariate in emb_covariates
    }
    tgt_dist_obs.update(tgt_dist_uns)
    src_counter = 0
    tgt_counter = 0
    src_dists = list(itertools.product(*src_dist.values()))

    control_to_perturbation = {}
    cell_data = _get_cell_data(adata, cell_data)
    split_covariates_mask = np.full((adata.n_obs,), -1, dtype=jnp.int32)
    split_covariates_to_idx = {}
    perturbation_covariates_mask = np.full((adata.n_obs,), -1, dtype=jnp.int32)
    perturbation_covariates_to_idx = {}
    condition_data = []

    control_mask = (adata.obs[control_data[0]] == control_data[1]) == 1
    for src_combination in tqdm(src_dists):
        filter_dict = dict(zip(split_covariates, src_combination, strict=False))
        split_cov_mask = (adata.obs[list(filter_dict.keys())] == list(filter_dict.values())).all(axis=1)
        mask = split_cov_mask * control_mask
        if mask.sum() == 0:
            continue
        control_to_perturbation[src_counter] = []
        split_covariates_mask[mask] = src_counter
        split_covariates_to_idx[src_counter] = src_combination

        conditional_distributions = []
        for tgt_combination in itertools.product(*tgt_dist_obs.values()):
            mask = (
                (adata.obs[list(tgt_dist_obs.keys())] == list(tgt_combination)).all(axis=1)
                * (1 - control_mask)
                * split_cov_mask
            ) == 1
            if mask.sum() == 0:
                continue
            conditional_distributions.append(tgt_counter)
            perturbation_covariates_mask[mask] = tgt_counter
            perturbation_covariates_to_idx[tgt_counter] = tgt_combination
            control_to_perturbation[src_counter] = tgt_counter
            embedding = _get_perturbation_covariates(
                adata[mask],
                adata.uns[UNS_KEY_CONDITIONS],
                obs_perturbation_covariates,
                uns_perturbation_covariates,
                max_combination_length,
            )
            condition_data.append(embedding)
            tgt_counter += 1
        control_to_perturbation[src_counter] = np.array(conditional_distributions)
        src_counter += 1
    condition_data = jnp.array(condition_data)

    return PerturbationData(
        cell_data=cell_data,
        split_covariates_mask=split_covariates_mask,
        split_covariates_to_idx=split_covariates_to_idx,
        perturbation_covariates_mask=perturbation_covariates_mask,
        perturbation_covariates_to_idx=perturbation_covariates_to_idx,
        condition_data=condition_data,
        control_to_perturbation=control_to_perturbation,
    )

In [95]:
split_covariates = ["cell_type"]
control_data = ("drug1", "control")
obs_perturbation_covariates = [("dosage",)]
uns_perturbation_covariates = {"drug": ("drug1", "drug2")}

In [96]:
pdata = load_from_adata(
    adata,
    cell_data="X",
    control_data=control_data,
    split_covariates=split_covariates,
    obs_perturbation_covariates=obs_perturbation_covariates,
    uns_perturbation_covariates=uns_perturbation_covariates,
)

100%|██████████| 3/3 [00:00<00:00, 19.28it/s]


In [97]:
adata.obs["drug1"].cat.categories

Index(['control', 'drug_a', 'drug_b', 'drug_c'], dtype='object')

In [98]:
pdata

PerturbationData[n_controls=3, n_perturbed=81]

In [106]:
pdata.condition_data[0].shape

(1, 2, 6)

In [99]:
np.unique(pdata.perturbation_covariates_mask)

array([-1,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15,
       16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
       33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49,
       50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66,
       67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80],
      dtype=int32)

In [100]:
pdata.perturbation_covariates_to_idx

{0: (10.0, 'drug_a', 'drug_a'),
 1: (10.0, 'drug_a', 'drug_b'),
 2: (10.0, 'drug_a', 'drug_c'),
 3: (10.0, 'drug_b', 'drug_a'),
 4: (10.0, 'drug_b', 'drug_b'),
 5: (10.0, 'drug_b', 'drug_c'),
 6: (10.0, 'drug_c', 'drug_a'),
 7: (10.0, 'drug_c', 'drug_b'),
 8: (10.0, 'drug_c', 'drug_c'),
 9: (100.0, 'drug_a', 'drug_a'),
 10: (100.0, 'drug_a', 'drug_b'),
 11: (100.0, 'drug_a', 'drug_c'),
 12: (100.0, 'drug_b', 'drug_a'),
 13: (100.0, 'drug_b', 'drug_b'),
 14: (100.0, 'drug_b', 'drug_c'),
 15: (100.0, 'drug_c', 'drug_a'),
 16: (100.0, 'drug_c', 'drug_b'),
 17: (100.0, 'drug_c', 'drug_c'),
 18: (1000.0, 'drug_a', 'drug_a'),
 19: (1000.0, 'drug_a', 'drug_b'),
 20: (1000.0, 'drug_a', 'drug_c'),
 21: (1000.0, 'drug_b', 'drug_a'),
 22: (1000.0, 'drug_b', 'drug_b'),
 23: (1000.0, 'drug_b', 'drug_c'),
 24: (1000.0, 'drug_c', 'drug_a'),
 25: (1000.0, 'drug_c', 'drug_b'),
 26: (1000.0, 'drug_c', 'drug_c'),
 27: (10.0, 'drug_a', 'drug_a'),
 28: (10.0, 'drug_a', 'drug_b'),
 29: (10.0, 'drug_a', 'dru

In [111]:
class CFSampler:
    """Data sampler for :class:`~cfp.data.data.PerturbationData`.

    Parameters
    ----------
    data : PerturbationData
        The data object to sample from.
    batch_size : int
        The batch size.
    """

    def __init__(self, data: PerturbationData, batch_size: int = 64):
        self.data = data
        self.batch_size = batch_size
        self.n_source_dists = data.n_controls
        self.n_target_dists = data.n_perturbed

        def _sample(rng: jax.Array) -> Any:
            rng_1, rng_2, rng_3, rng_4 = jax.random.split(rng, 4)
            source_dist_idx = jax.random.randint(rng_1, [1], 0, self.n_source_dists).item()
            source_cells = self.data.cell_data[self.data.split_covariates_mask == source_dist_idx]
            source_batch = jax.random.choice(rng_2, source_cells, (self.batch_size,), replace=True)
            target_dist_idx = jax.random.randint(
                rng_3, [1], 0, self.data.control_to_perturbation[source_dist_idx].shape[0]
            ).item()
            target_cells = self.data.cell_data[self.data.perturbation_covariates_mask == target_dist_idx]
            target_batch = jax.random.choice(rng_4, target_cells, (self.batch_size,), replace=True)
            condition_batch = jnp.tile(self.data.condition_data[target_dist_idx], (self.batch_size, 1, 1))

            return {
                "src_lin": source_batch,
                "tgt_lin": target_batch,
                "src_condition": condition_batch,
            }

        self.sample = _sample

In [112]:
s = CFSampler(pdata, batch_size=64)

In [113]:
out = s.sample(jax.random.PRNGKey(42))

(1, 2, 6)


In [114]:
out["src_lin"].shape, out["tgt_lin"].shape, out["src_condition"].shape

((64, 50), (64, 50), (64, 2, 6))

Now with combinations of drugs

In [93]:
control_covariates = ["cell_type"]
control_data = ("drug1", "Vehicle")
obs_perturbation_covariates = [("drug1", "drug"), ("drug2", "drug"), ("cell_type", "cell_type")]
uns_perturbation_covariates = [["drug1", "drug2"]]

In [94]:
pdata = load_from_adata(
    adata,
    cell_data="X",
    control_data=control_data,
    split_covariates=control_covariates,
    obs_perturbation_covariates=obs_perturbation_covariates,
    uns_perturbation_covariates=uns_perturbation_covariates,
)

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]


TypeError: Value 'drug_a' with dtype <U6 is not a valid JAX array type. Only arrays of numeric types are supported by JAX.

pdata = load_from_adata(adata, cell_data="X", control_data=control_data, split_covariates = control_covariates, perturbation_covariates=perturbation_covariates, perturbation_covariate_combinations=perturbation_covariate_combinations)