In [1]:
%load_ext autoreload
%autoreload 2

In [17]:
import anndata as ad
import numpy as np
import pandas as pd

from scaleflow.data import DataManager
from scaleflow.data._anndata_location import AnnDataLocation

def adata_test():
    drugs = ["control", "drug_A", "drug_B"]
    genes = ["control", "gene_A", "gene_B"]
    cell_lines = ["cell_line_A", "cell_line_B"]
    batches = ["batch_1", "batch_2", "batch_3"]
    plates = ["plate_1", "plate_2", "plate_3"]
    days = ["day_1", "day_2", "day_3"]
    doses = [1.0, 10.0, 100.0]

    rows = []
    for drug in drugs:
        for gene in genes:
            for cell_line in cell_lines:
                for batch in batches:
                    for plate in plates:
                        for day in days:
                            if drug != "control":
                                for dose in doses:
                                    rows.append(
                                        {
                                            "drug": drug,
                                            "gene": gene,
                                            "cell_line": cell_line,
                                            "batch": batch,
                                            "plate": plate,
                                            "day": day,
                                            "dose": dose,
                                            "control": False,
                                        }
                                    )
                            else:
                                rows.append(
                                    {
                                        "drug": drug,
                                        "gene": gene,
                                        "cell_line": cell_line,
                                        "batch": batch,
                                        "plate": plate,
                                        "day": day,
                                        "dose": 0.0,
                                        "control": gene == "control" and drug == "control",
                                    }
                                )

    n_obs = len(rows)
    n_vars = 20
    n_pca = 10

    obs = pd.DataFrame(rows)

    # Convert to categorical
    for col in ["cell_line", "drug", "gene", "batch", "plate", "day"]:
        obs[col] = obs[col].astype("category")

    # Simple X matrix (not really used in tests, just needs to exist)
    X = np.random.randn(n_obs, n_vars).astype(np.float32)

    # X_pca: Put cell index at position [idx, 0] for easy tracing
    X_pca = np.zeros((n_obs, n_pca), dtype=np.float32)
    for i in range(n_obs):
        X_pca[i, 0] = float(i)  # Cell 0 has value 0, cell 1 has value 1, etc.

    # Create AnnData
    adata = ad.AnnData(X=X, obs=obs)
    adata.obsm["X_pca"] = X_pca

    # Simple embeddings
    adata.uns["cell_line_embeddings"] = {
        "cell_line_A": np.array([1.0, 0.0], dtype=np.float32),
        "cell_line_B": np.array([0.0, 1.0], dtype=np.float32),
    }

    adata.uns["drug_embeddings"] = {
        "drug_A": np.array([1.0, 0.0, 0.0], dtype=np.float32),
        "drug_B": np.array([0.0, 1.0, 0.0], dtype=np.float32),
        "control": np.array([0.0, 0.0, 0.0], dtype=np.float32),
    }

    adata.uns["gene_embeddings"] = {
        "gene_A": np.array([1.0, 0.0], dtype=np.float32),
        "gene_B": np.array([0.0, 1.0], dtype=np.float32),
        "control": np.array([0.0, 0.0], dtype=np.float32),
    }

    return adata

In [18]:
adata = adata_test()

  return dispatch(args[0].__class__)(*args, **kw)


In [19]:
"""Test that prepare_data works and returns correct structure."""
adl = AnnDataLocation()

dm = DataManager(
    dist_flag_key="control",
    src_dist_keys=["cell_line"],
    tgt_dist_keys=["drug", "gene"],
    rep_keys={
        "cell_line": "cell_line_embeddings",
        "drug": "drug_embeddings",
        "gene": "gene_embeddings",
    },
    data_location=adl.obsm["X_pca"],
)

In [20]:
gd = dm.prepare_data(adata,verbose=False)

In [21]:
from scaleflow.data._data_splitter import DataSplitter

splitter = DataSplitter(
    annotations=[gd.annotation],
    dataset_names=["test"],
    split_ratios=[[0.7, 0.15, 0.15]],
)

In [22]:
res = splitter.split_all()

In [30]:
res['test']['metadata']

{'split_type': 'random',
 'split_key': None,
 'split_ratios': [0.7, 0.15, 0.15],
 'random_state': 42,
 'test_random_state': 42,
 'val_random_state': 42,
 'hard_test_split': True,
 'train_distributions': 11,
 'val_distributions': 2,
 'test_distributions': 3}

In [8]:
gd.annotation.src_tgt_dist_df

Unnamed: 0,src_dist_idx,tgt_dist_idx,cell_line,drug,gene
10,0,1,cell_line_A,control,gene_A
5,0,2,cell_line_A,drug_A,control
16,0,3,cell_line_A,drug_A,gene_A
15,1,5,cell_line_B,control,gene_B
12,1,6,cell_line_B,drug_B,control
18,1,7,cell_line_B,drug_B,gene_B
