# Multi-Dataset Splitting with ScaleFlow

This notebook demonstrates how to prepare and split multiple AnnData datasets using ScaleFlow's batch utilities.

In [1]:
import anndata as ad
import numpy as np
import pandas as pd

from scaleflow.data import (
    AnnDataLocation,
    DataManager,
    GroupedDistribution,
    prepare_multiple_datasets,
    split_multiple_datasets,
    prepare_and_split_multiple_datasets,
)

  from .autonotebook import tqdm as notebook_tqdm


## 1. Create Sample Datasets

First, let's create some synthetic AnnData objects that mimic perturbation experiments.

In [2]:
def create_synthetic_adata(
    n_obs: int = 1000,
    n_vars: int = 50,
    n_pca: int = 20,
    n_drugs: int = 5,
    n_genes: int = 3,
    n_cell_lines: int = 3,
    seed: int = 42,
) -> ad.AnnData:
    """Create a synthetic AnnData for demonstration."""
    np.random.seed(seed)
    
    # Define categories
    drugs = ["control"] + [f"drug_{i}" for i in range(n_drugs)]
    genes = ["control"] + [f"gene_{i}" for i in range(n_genes)]
    cell_lines = [f"cell_line_{i}" for i in range(n_cell_lines)]
    
    # Generate obs data
    obs = pd.DataFrame({
        "drug": np.random.choice(drugs, n_obs),
        "gene": np.random.choice(genes, n_obs),
        "cell_line": np.random.choice(cell_lines, n_obs),
    })
    
    # Mark controls (both drug and gene are "control")
    obs["control"] = (obs["drug"] == "control") & (obs["gene"] == "control")
    
    # Convert to categorical
    for col in ["drug", "gene", "cell_line"]:
        obs[col] = obs[col].astype("category")
    
    # Generate expression data
    X = np.random.randn(n_obs, n_vars).astype(np.float32)
    X_pca = np.random.randn(n_obs, n_pca).astype(np.float32)
    
    # Create AnnData
    adata = ad.AnnData(X=X, obs=obs)
    adata.obsm["X_pca"] = X_pca
    
    # Create embeddings for conditions
    adata.uns["cell_line_emb"] = {
        cl: np.random.randn(10).astype(np.float32) for cl in cell_lines
    }
    adata.uns["drug_emb"] = {
        d: np.random.randn(10).astype(np.float32) for d in drugs
    }
    adata.uns["gene_emb"] = {
        g: np.random.randn(10).astype(np.float32) for g in genes
    }
    
    return adata

In [3]:
# Create multiple synthetic datasets
datasets = {
    "pbmc": create_synthetic_adata(n_obs=2000, n_drugs=8, n_genes=4, seed=42),
    "zebrafish": create_synthetic_adata(n_obs=1500, n_drugs=6, n_genes=3, seed=123),
    "ineuron": create_synthetic_adata(n_obs=1000, n_drugs=5, n_genes=2, seed=456),
}

for name, adata in datasets.items():
    print(f"{name}: {adata.n_obs} cells, {adata.n_vars} genes")
    print(f"  Drugs: {adata.obs['drug'].nunique()}, Genes: {adata.obs['gene'].nunique()}")
    print(f"  Controls: {adata.obs['control'].sum()}")
    print()

pbmc: 2000 cells, 50 genes
  Drugs: 9, Genes: 5
  Controls: 51

zebrafish: 1500 cells, 50 genes
  Drugs: 7, Genes: 4
  Controls: 58

ineuron: 1000 cells, 50 genes
  Drugs: 6, Genes: 3
  Controls: 50



  return dispatch(args[0].__class__)(*args, **kw)
  return dispatch(args[0].__class__)(*args, **kw)
  return dispatch(args[0].__class__)(*args, **kw)


## 2. Configure the DataManager

Set up a single `DataManager` that will be used across all datasets.

In [4]:
# Define the data location (where the cell features are stored)
adl = AnnDataLocation()

# Create a DataManager with shared configuration
data_manager = DataManager(
    dist_flag_key="control",  # Boolean column marking control cells
    src_dist_keys=["cell_line"],  # Source distribution keys
    tgt_dist_keys=["drug", "gene"],  # Target distribution keys
    rep_keys={
        "cell_line": "cell_line_emb",
        "drug": "drug_emb",
        "gene": "gene_emb",
    },
    data_location=adl.obsm["X_pca"],  # Use PCA embeddings
)

print("DataManager configured!")

DataManager configured!


## 3. Prepare Multiple Datasets

Use `prepare_multiple_datasets` to convert all AnnData objects to `GroupedDistribution` objects.

In [5]:
# Prepare all datasets at once
grouped_distributions = prepare_multiple_datasets(
    datasets=datasets,
    data_manager=data_manager,
    verbose=True,
)

for name, gd in grouped_distributions.items():
    print(f"\n{name}:")
    print(f"  Source distributions: {len(gd.data.src_data)}")
    print(f"  Target distributions: {len(gd.data.tgt_data)}")
    print(f"  Conditions: {len(gd.data.conditions)}")

Sorting values...
Sorting values took 0.00 seconds.
Getting conditions...
Getting conditions took 0.00 seconds.
Getting source and target distribution data...
Getting source and target distribution data took 0.00 seconds.
Sorting values...
Sorting values took 0.00 seconds.
Getting conditions...
Getting conditions took 0.00 seconds.
Getting source and target distribution data...
Getting source and target distribution data took 0.00 seconds.
Sorting values...
Sorting values took 0.00 seconds.
Getting conditions...
Getting conditions took 0.00 seconds.
Getting source and target distribution data...
Getting source and target distribution data took 0.00 seconds.

pbmc:
  Source distributions: 3
  Target distributions: 132
  Conditions: 132

zebrafish:
  Source distributions: 3
  Target distributions: 81
  Conditions: 81

ineuron:
  Source distributions: 3
  Target distributions: 51
  Conditions: 51


## 4. Split Multiple Datasets

Use `split_multiple_datasets` to create train/val/test splits for each dataset.

In [7]:
# Split all datasets with the same configuration
all_splits = split_multiple_datasets(
    grouped_distributions=grouped_distributions,
    holdout_combinations=False,
    split_by=["drug", "gene"],  # Split by drug-gene combinations
    split_key="split",
    force_training_values={},  # No forced training values
    ratios=[0.6, 0.2, 0.2],  # 60% train, 20% val, 20% test
    random_state=42,
)

# Examine the splits
for dataset_name, splits in all_splits.items():
    print(f"\n{dataset_name}:")
    for split_name, gd in splits.items():
        n_src = len(gd.data.src_data)
        n_tgt = len(gd.data.tgt_data)
        print(f"  {split_name}: {n_src} source dists, {n_tgt} target dists")


pbmc:
  train: 3 source dists, 78 target dists
  val: 3 source dists, 27 target dists
  test: 3 source dists, 27 target dists

zebrafish:
  train: 3 source dists, 48 target dists
  val: 3 source dists, 15 target dists
  test: 3 source dists, 18 target dists

ineuron:
  train: 3 source dists, 30 target dists
  val: 3 source dists, 9 target dists
  test: 3 source dists, 12 target dists


## 5. One-Step: Prepare and Split Together

Use `prepare_and_split_multiple_datasets` for a more convenient workflow.

In [None]:
# Do everything in one step
all_splits = prepare_and_split_multiple_datasets(
    datasets=datasets, # dict like {"dataset_name": AnnData}
    data_manager=data_manager,
    holdout_combinations=False,
    split_by=["drug", "gene"],
    ratios=[0.6, 0.2, 0.2],
    random_state=42,
    verbose=True,
)

print("\nDatasets prepared and split:")
for dataset_name in all_splits:
    print(f"  - {dataset_name}: train/val/test")

Sorting values...
Sorting values took 0.00 seconds.
Getting conditions...
Getting conditions took 0.00 seconds.
Getting source and target distribution data...
Getting source and target distribution data took 0.00 seconds.
Sorting values...
Sorting values took 0.00 seconds.
Getting conditions...
Getting conditions took 0.00 seconds.
Getting source and target distribution data...
Getting source and target distribution data took 0.00 seconds.
Sorting values...
Sorting values took 0.00 seconds.
Getting conditions...
Getting conditions took 0.00 seconds.
Getting source and target distribution data...
Getting source and target distribution data took 0.00 seconds.

Datasets prepared and split:
  - pbmc: train/val/test
  - zebrafish: train/val/test
  - ineuron: train/val/test


## 6. Access Individual Splits

In [9]:
# Access a specific dataset and split
pbmc_train = all_splits["pbmc"]["train"]
zebrafish_val = all_splits["zebrafish"]["val"]

print("PBMC Training Data:")
print(f"  Source distributions: {len(pbmc_train.data.src_data)}")
print(f"  Target distributions: {len(pbmc_train.data.tgt_data)}")
print(f"  Distribution DataFrame shape: {pbmc_train.annotation.src_tgt_dist_df.shape}")

print("\nZebrafish Validation Data:")
print(f"  Source distributions: {len(zebrafish_val.data.src_data)}")
print(f"  Target distributions: {len(zebrafish_val.data.tgt_data)}")

PBMC Training Data:
  Source distributions: 3
  Target distributions: 78
  Distribution DataFrame shape: (78, 5)

Zebrafish Validation Data:
  Source distributions: 3
  Target distributions: 15


## 7. Use with ReservoirSampler

Create samplers for training on the split data.

In [10]:
from scaleflow.data import ReservoirSampler

# Create a sampler for each training dataset
train_samplers = {}
rng = np.random.default_rng(42)

for dataset_name in all_splits:
    train_data = all_splits[dataset_name]["train"]
    sampler = ReservoirSampler(
        data=train_data,
        batch_size=128,
    )
    sampler.init_sampler(rng)
    train_samplers[dataset_name] = sampler
    print(f"{dataset_name} sampler ready")

pbmc sampler ready
zebrafish sampler ready
ineuron sampler ready


In [11]:
# Sample from each dataset
for dataset_name, sampler in train_samplers.items():
    batch = sampler.sample(rng)
    print(f"\n{dataset_name} batch:")
    print(f"  Source cells: {batch['src_cell_data'].shape}")
    print(f"  Target cells: {batch['tgt_cell_data'].shape}")
    print(f"  Condition: {batch['condition'].shape}")

sampled source dist idx: 0 and target dist idx: 34
sampled source batch: (128, 20)
sampled target batch: (128, 20)

pbmc batch:
  Source cells: (128, 20)
  Target cells: (128, 20)
  Condition: (30,)
sampled source dist idx: 0 and target dist idx: 13
sampled source batch: (128, 20)
sampled target batch: (128, 20)

zebrafish batch:
  Source cells: (128, 20)
  Target cells: (128, 20)
  Condition: (30,)
sampled source dist idx: 2 and target dist idx: 34
sampled source batch: (128, 20)
sampled target batch: (128, 20)

ineuron batch:
  Source cells: (128, 20)
  Target cells: (128, 20)
  Condition: (30,)


## 8. Summary Statistics

In [12]:
# Create a summary table
summary_data = []

for dataset_name, splits in all_splits.items():
    for split_name, gd in splits.items():
        summary_data.append({
            "Dataset": dataset_name,
            "Split": split_name,
            "Source Dists": len(gd.data.src_data),
            "Target Dists": len(gd.data.tgt_data),
            "Unique Drug-Gene Combos": len(gd.annotation.src_tgt_dist_df.drop_duplicates(
                subset=["drug", "gene"]
            )),
        })

summary_df = pd.DataFrame(summary_data)
print(summary_df.to_string(index=False))

  Dataset Split  Source Dists  Target Dists  Unique Drug-Gene Combos
     pbmc train             3            78                       26
     pbmc   val             3            27                        9
     pbmc  test             3            27                        9
zebrafish train             3            48                       16
zebrafish   val             3            15                        5
zebrafish  test             3            18                        6
  ineuron train             3            30                       10
  ineuron   val             3             9                        3
  ineuron  test             3            12                        4


## 9. Force Training Values Example

Ensure specific drug or gene combinations always appear in training.

In [13]:
# Force drug_0 to always be in training
all_splits_forced = prepare_and_split_multiple_datasets(
    datasets=datasets,
    data_manager=data_manager,
    holdout_combinations=False,
    split_by=["drug"],
    force_training_values={"drug": "drug_0"},  # Ensure drug_0 is in training
    ratios=[0.6, 0.2, 0.2],
    random_state=42,
)

# Verify drug_0 is only in training
for dataset_name, splits in all_splits_forced.items():
    train_drugs = set(splits["train"].annotation.src_tgt_dist_df["drug"].unique())
    val_drugs = set(splits["val"].annotation.src_tgt_dist_df["drug"].unique())
    test_drugs = set(splits["test"].annotation.src_tgt_dist_df["drug"].unique())
    
    print(f"\n{dataset_name}:")
    print(f"  drug_0 in train: {'drug_0' in train_drugs}")
    print(f"  drug_0 in val: {'drug_0' in val_drugs}")
    print(f"  drug_0 in test: {'drug_0' in test_drugs}")


pbmc:
  drug_0 in train: True
  drug_0 in val: False
  drug_0 in test: False

zebrafish:
  drug_0 in train: True
  drug_0 in val: False
  drug_0 in test: False

ineuron:
  drug_0 in train: True
  drug_0 in val: False
  drug_0 in test: False
