In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import anndata as ad
import numpy as np
import pandas as pd

from scaleflow.data import DataManager
from scaleflow.data._anndata_location import AnnDataLocation


def create_test_adata() -> ad.AnnData:
    """
    Create test AnnData with simple, traceable values for testing data splitting.

    Key design:
    - X_pca: cell index embedded at position [idx, 0] for easy tracing (cell 0 has value 0, cell 1 has value 1, etc.)
    - Simple names: cell_line_A, drug_A, gene_A, etc.
    - Multiple metadata columns (batch, plate, day) for testing different split strategies
    - Known perturbation combinations
    """
    # Define explicit test cases
    data = [
        # Controls - cell_line_A
        {
            "control": True,
            "cell_line": "cell_line_A",
            "drug": "control",
            "gene": "control",
            "dose": 0.0,
            "batch": "batch_1",
            "plate": "plate_1",
            "day": "day_1",
        },
        {
            "control": True,
            "cell_line": "cell_line_A",
            "drug": "control",
            "gene": "control",
            "dose": 0.0,
            "batch": "batch_1",
            "plate": "plate_1",
            "day": "day_1",
        },
        {
            "control": True,
            "cell_line": "cell_line_A",
            "drug": "control",
            "gene": "control",
            "dose": 0.0,
            "batch": "batch_2",
            "plate": "plate_2",
            "day": "day_2",
        },
        # Controls - cell_line_B
        {
            "control": True,
            "cell_line": "cell_line_B",
            "drug": "control",
            "gene": "control",
            "dose": 0.0,
            "batch": "batch_1",
            "plate": "plate_2",
            "day": "day_1",
        },
        {
            "control": True,
            "cell_line": "cell_line_B",
            "drug": "control",
            "gene": "control",
            "dose": 0.0,
            "batch": "batch_2",
            "plate": "plate_2",
            "day": "day_3",
        },
        # cell_line_A + drug_A, low dose
        {
            "control": False,
            "cell_line": "cell_line_A",
            "drug": "drug_A",
            "gene": "control",
            "dose": 1.0,
            "batch": "batch_1",
            "plate": "plate_1",
            "day": "day_1",
        },
        {
            "control": False,
            "cell_line": "cell_line_A",
            "drug": "drug_A",
            "gene": "control",
            "dose": 1.0,
            "batch": "batch_1",
            "plate": "plate_1",
            "day": "day_2",
        },
        {
            "control": False,
            "cell_line": "cell_line_A",
            "drug": "drug_A",
            "gene": "control",
            "dose": 1.0,
            "batch": "batch_2",
            "plate": "plate_2",
            "day": "day_1",
        },
        # cell_line_A + drug_A, high dose
        {
            "control": False,
            "cell_line": "cell_line_A",
            "drug": "drug_A",
            "gene": "control",
            "dose": 100.0,
            "batch": "batch_1",
            "plate": "plate_1",
            "day": "day_1",
        },
        {
            "control": False,
            "cell_line": "cell_line_A",
            "drug": "drug_A",
            "gene": "control",
            "dose": 100.0,
            "batch": "batch_2",
            "plate": "plate_3",
            "day": "day_2",
        },
        # cell_line_A + gene_A knockout
        {
            "control": False,
            "cell_line": "cell_line_A",
            "drug": "control",
            "gene": "gene_A",
            "dose": 0.0,
            "batch": "batch_1",
            "plate": "plate_1",
            "day": "day_1",
        },
        {
            "control": False,
            "cell_line": "cell_line_A",
            "drug": "control",
            "gene": "gene_A",
            "dose": 0.0,
            "batch": "batch_2",
            "plate": "plate_2",
            "day": "day_3",
        },
        # cell_line_B + drug_B, mid dose
        {
            "control": False,
            "cell_line": "cell_line_B",
            "drug": "drug_B",
            "gene": "control",
            "dose": 10.0,
            "batch": "batch_1",
            "plate": "plate_2",
            "day": "day_1",
        },
        {
            "control": False,
            "cell_line": "cell_line_B",
            "drug": "drug_B",
            "gene": "control",
            "dose": 10.0,
            "batch": "batch_1",
            "plate": "plate_2",
            "day": "day_2",
        },
        {
            "control": False,
            "cell_line": "cell_line_B",
            "drug": "drug_B",
            "gene": "control",
            "dose": 10.0,
            "batch": "batch_3",
            "plate": "plate_3",
            "day": "day_1",
        },
        # cell_line_B + gene_B knockout
        {
            "control": False,
            "cell_line": "cell_line_B",
            "drug": "control",
            "gene": "gene_B",
            "dose": 0.0,
            "batch": "batch_2",
            "plate": "plate_2",
            "day": "day_1",
        },
        # Combination: cell_line_A + drug_A + gene_A
        {
            "control": False,
            "cell_line": "cell_line_A",
            "drug": "drug_A",
            "gene": "gene_A",
            "dose": 10.0,
            "batch": "batch_1",
            "plate": "plate_1",
            "day": "day_1",
        },
        {
            "control": False,
            "cell_line": "cell_line_A",
            "drug": "drug_A",
            "gene": "gene_A",
            "dose": 10.0,
            "batch": "batch_2",
            "plate": "plate_2",
            "day": "day_2",
        },
        # Combination: cell_line_B + drug_B + gene_B
        {
            "control": False,
            "cell_line": "cell_line_B",
            "drug": "drug_B",
            "gene": "gene_B",
            "dose": 10.0,
            "batch": "batch_3",
            "plate": "plate_3",
            "day": "day_1",
        },
    ]

    n_obs = len(data)
    n_vars = 20
    n_pca = 10

    obs = pd.DataFrame(data)

    # Convert to categorical
    for col in ["cell_line", "drug", "gene", "batch", "plate", "day"]:
        obs[col] = obs[col].astype("category")

    # Simple X matrix (not really used in tests, just needs to exist)
    X = np.random.randn(n_obs, n_vars).astype(np.float32)

    # X_pca: Put cell index at position [idx, 0] for easy tracing
    X_pca = np.zeros((n_obs, n_pca), dtype=np.float32)
    for i in range(n_obs):
        X_pca[i, 0] = float(i)  # Cell 0 has value 0, cell 1 has value 1, etc.

    # Create AnnData
    adata = ad.AnnData(X=X, obs=obs)
    adata.obsm["X_pca"] = X_pca

    # Simple embeddings
    adata.uns["cell_line_embeddings"] = {
        "cell_line_A": np.array([1.0, 0.0], dtype=np.float32),
        "cell_line_B": np.array([0.0, 1.0], dtype=np.float32),
    }

    adata.uns["drug_embeddings"] = {
        "drug_A": np.array([1.0, 0.0, 0.0], dtype=np.float32),
        "drug_B": np.array([0.0, 1.0, 0.0], dtype=np.float32),
        "control": np.array([0.0, 0.0, 0.0], dtype=np.float32),
    }

    adata.uns["gene_embeddings"] = {
        "gene_A": np.array([1.0, 0.0], dtype=np.float32),
        "gene_B": np.array([0.0, 1.0], dtype=np.float32),
        "control": np.array([0.0, 0.0], dtype=np.float32),
    }

    return adata


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
adata = create_test_adata()

  return dispatch(args[0].__class__)(*args, **kw)


In [4]:
adata.obs

Unnamed: 0,control,cell_line,drug,gene,dose,batch,plate,day
0,True,cell_line_A,control,control,0.0,batch_1,plate_1,day_1
1,True,cell_line_A,control,control,0.0,batch_1,plate_1,day_1
2,True,cell_line_A,control,control,0.0,batch_2,plate_2,day_2
3,True,cell_line_B,control,control,0.0,batch_1,plate_2,day_1
4,True,cell_line_B,control,control,0.0,batch_2,plate_2,day_3
5,False,cell_line_A,drug_A,control,1.0,batch_1,plate_1,day_1
6,False,cell_line_A,drug_A,control,1.0,batch_1,plate_1,day_2
7,False,cell_line_A,drug_A,control,1.0,batch_2,plate_2,day_1
8,False,cell_line_A,drug_A,control,100.0,batch_1,plate_1,day_1
9,False,cell_line_A,drug_A,control,100.0,batch_2,plate_3,day_2


In [5]:
"""Test that prepare_data works and returns correct structure."""
adl = AnnDataLocation()

dm = DataManager(
    dist_flag_key="control",
    src_dist_keys=["cell_line"],
    tgt_dist_keys=["drug", "gene"],
    rep_keys={
        "cell_line": "cell_line_embeddings",
        "drug": "drug_embeddings",
        "gene": "gene_embeddings",
    },
    data_location=adl.obsm["X_pca"],
)

In [6]:
gd = dm.prepare_data(adata,verbose=False)

    control    cell_line     drug     gene  src_dist_idx  tgt_dist_idx
10    False  cell_line_A  control   gene_A             0             1
11    False  cell_line_A  control   gene_A             0             1
5     False  cell_line_A   drug_A  control             0             2
6     False  cell_line_A   drug_A  control             0             2
7     False  cell_line_A   drug_A  control             0             2
8     False  cell_line_A   drug_A  control             0             2
9     False  cell_line_A   drug_A  control             0             2
16    False  cell_line_A   drug_A   gene_A             0             3
17    False  cell_line_A   drug_A   gene_A             0             3
15    False  cell_line_B  control   gene_B             1             5
12    False  cell_line_B   drug_B  control             1             6
13    False  cell_line_B   drug_B  control             1             6
14    False  cell_line_B   drug_B  control             1             6
18    

In [48]:
adata.obs.groupby(["drug", "gene"], observed=True, sort=True).ngroup()

0     0
1     0
2     0
3     0
4     0
5     3
6     3
7     3
8     3
9     3
10    1
11    1
12    5
13    5
14    5
15    2
16    4
17    4
18    6
dtype: int64

In [37]:
gd.data.tgt_data

{1: array([[10.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [11.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]], dtype=float32),
 2: array([[5., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [6., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [7., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [8., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [9., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32),
 3: array([[16.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [17.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]], dtype=float32),
 5: array([[15.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]], dtype=float32),
 6: array([[12.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [13.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [14.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]], dtype=float32),
 7: array([[18.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]], dtype=float32)}

In [38]:
gd.annotation.src_tgt_dist_df

Unnamed: 0,src_dist_idx,tgt_dist_idx,cell_line,drug,gene
10,0,1,cell_line_A,control,gene_A
5,0,2,cell_line_A,drug_A,control
16,0,3,cell_line_A,drug_A,gene_A
15,1,5,cell_line_B,control,gene_B
12,1,6,cell_line_B,drug_B,control
18,1,7,cell_line_B,drug_B,gene_B
