In [1]:
%load_ext autoreload
%autoreload 2

In [8]:
import anndata as ad
import numpy as np
import pandas as pd
import random

from scaleflow.data import DataManager, AnnDataLocation


In [9]:
def adata_test():
    drugs = ['control', 'drug_A', 'drug_B']
    genes = ['control', 'gene_A', 'gene_B']
    cell_lines = ['cell_line_A', 'cell_line_B']
    batches = ['batch_1', 'batch_2', 'batch_3']
    plates = ['plate_1', 'plate_2', 'plate_3']
    days = ['day_1', 'day_2', 'day_3']
    doses = [1.0, 10.0, 100.0]

    rows = []
    for drug in drugs:
        for gene in genes:
            for cell_line in cell_lines:
                for batch in batches:
                    for plate in plates:
                        for day in days:
                            if drug != 'control':
                                for dose in doses:
                                    rows.append({
                                        'drug': drug,
                                        'gene': gene,
                                        'cell_line': cell_line,
                                        'batch': batch,
                                        'plate': plate,
                                        'day': day,
                                        'dose': dose,
                                        'control': False
                                    })
                            else:
                                rows.append({
                                    'drug': drug,
                                    'gene': gene,
                                    'cell_line': cell_line,
                                    'batch': batch,
                                    'plate': plate,
                                    'day': day,
                                    'dose': 0.0,
                                    'control': gene == 'control' and drug == 'control'
                                })

    n_obs = len(rows)
    n_vars = 20
    n_pca = 10

    # shuffle the rows
    # rows = np.random.permutation(rows)
    # print(rows)
    random.shuffle(rows)
    obs = pd.DataFrame(rows, columns=rows[0].keys())

    # Convert to categorical
    for col in ["cell_line", "drug", "gene", "batch", "plate", "day"]:
        obs[col] = obs[col].astype("category")

    # Simple X matrix (not really used in tests, just needs to exist)
    X = np.random.randn(n_obs, n_vars).astype(np.float32)

    # X_pca: Put cell index at position [idx, 0] for easy tracing
    X_pca = np.zeros((n_obs, n_pca), dtype=np.float32)
    for i in range(n_obs):
        X_pca[i, 0] = float(i)  # Cell 0 has value 0, cell 1 has value 1, etc.

    # Create AnnData
    adata = ad.AnnData(X=X, obs=obs)
    adata.obsm["X_pca"] = X_pca

    # Simple embeddings
    adata.uns["cell_line_embeddings"] = {
        "cell_line_A": np.array([1.0, 0.0], dtype=np.float32),
        "cell_line_B": np.array([0.0, 1.0], dtype=np.float32),
    }

    adata.uns["drug_embeddings"] = {
        "drug_A": np.array([1.0, 0.0, 0.0], dtype=np.float32),
        "drug_B": np.array([0.0, 1.0, 0.0], dtype=np.float32),
        "control": np.array([0.0, 0.0, 0.0], dtype=np.float32),
    }

    adata.uns["gene_embeddings"] = {
        "gene_A": np.array([1.0, 0.0], dtype=np.float32),
        "gene_B": np.array([0.0, 1.0], dtype=np.float32),
        "control": np.array([0.0, 0.0], dtype=np.float32),
    }

    return adata


In [10]:
adata = adata_test()

  return dispatch(args[0].__class__)(*args, **kw)


In [11]:
adata

AnnData object with n_obs × n_vars = 1134 × 20
    obs: 'drug', 'gene', 'cell_line', 'batch', 'plate', 'day', 'dose', 'control'
    uns: 'cell_line_embeddings', 'drug_embeddings', 'gene_embeddings'
    obsm: 'X_pca'

In [12]:
"""Test that prepare_data works and returns correct structure."""
adl = AnnDataLocation()

dm = DataManager(
    dist_flag_key="control",
    src_dist_keys=["cell_line"],
    tgt_dist_keys=["drug", "gene"],
    rep_keys={
        "cell_line": "cell_line_embeddings",
        "drug": "drug_embeddings",
        "gene": "gene_embeddings",
    },
    data_location=adl.obsm["X_pca"],
)

In [13]:
gd = dm.prepare_data(adata,verbose=False)

In [16]:
adata.obs

Unnamed: 0,drug,gene,cell_line,batch,plate,day,dose,control
0,control,gene_B,cell_line_A,batch_1,plate_1,day_1,0.0,False
1,drug_A,gene_A,cell_line_B,batch_2,plate_1,day_1,100.0,False
2,drug_A,gene_B,cell_line_B,batch_3,plate_2,day_1,1.0,False
3,control,control,cell_line_A,batch_2,plate_1,day_2,0.0,True
4,drug_B,gene_A,cell_line_B,batch_1,plate_3,day_3,10.0,False
...,...,...,...,...,...,...,...,...
1129,drug_B,control,cell_line_A,batch_3,plate_1,day_1,100.0,False
1130,control,gene_A,cell_line_A,batch_3,plate_1,day_2,0.0,False
1131,drug_B,gene_B,cell_line_A,batch_3,plate_3,day_2,1.0,False
1132,control,gene_B,cell_line_B,batch_1,plate_1,day_1,0.0,False


In [17]:
def test_ordering_reconstruction_after_shuffle(adata):
    """Test that we can reconstruct original ordering after shuffling."""
    # Store original order information
    original_index = adata.obs.index.to_numpy().copy()
    original_X_pca = adata.obsm["X_pca"].copy()
    
    # Shuffle the adata
    shuffle_idx = np.random.permutation(len(adata))
    adata_shuffled = adata[shuffle_idx].copy()
    
    # Verify it's actually shuffled (should not be identical for reasonable dataset sizes)
    assert not np.array_equal(
        adata_shuffled.obs.index.to_numpy(), 
        original_index
    ), "Data should be shuffled"
    
    # Create DataManager and prepare data
    adl = AnnDataLocation()
    dm = DataManager(
        dist_flag_key="control",
        src_dist_keys=["cell_line"],
        tgt_dist_keys=["drug", "gene"],
        rep_keys={
            "cell_line": "cell_line_embeddings",
            "drug": "drug_embeddings",
            "gene": "gene_embeddings",
        },
        data_location=adl.obsm["X_pca"],
    )
    
    gd = dm.prepare_data(adata_shuffled)
    
    # Test 1: old_obs_index should map from sorted order back to shuffled AnnData index
    assert len(gd.annotation.old_obs_index) == len(adata_shuffled)
    assert np.all(
        np.isin(gd.annotation.old_obs_index, adata_shuffled.obs.index.to_numpy())
    ), "old_obs_index should contain valid indices from shuffled adata"
    
    # Test 2: Reconstruct the data and verify we can map back
    # The GroupedDistribution internally sorts the data
    # We should be able to use old_obs_index to map back to the shuffled adata
    
    # Collect all cells from src_data and tgt_data in the order they appear in GroupedDistribution
    all_cells_from_gd = []
    all_indices_from_gd = []
    
    # Get cells from source distributions in order
    for src_idx in sorted(gd.data.src_data.keys()):
        src_cells = gd.data.src_data[src_idx]
        all_cells_from_gd.append(src_cells)
        
    # Get cells from target distributions in order
    for tgt_idx in sorted(gd.data.tgt_data.keys()):
        tgt_cells = gd.data.tgt_data[tgt_idx]
        all_cells_from_gd.append(tgt_cells)
    
    # The cells in the sorted order (as stored in obs after sorting)
    # should correspond to the old_obs_index mapping
    
    # Verify: For each position in the sorted data, the old_obs_index tells us
    # which original (shuffled adata) index it came from
    for i, old_idx in enumerate(gd.annotation.old_obs_index):
        # Find the corresponding cell in the shuffled adata
        shuffled_pos = np.where(adata_shuffled.obs.index == old_idx)[0][0]
        
        # The first element of X_pca contains the original cell index from adata_test
        # (this was set up in the fixture)
        original_cell_id = adata_shuffled.obsm["X_pca"][shuffled_pos, 0]
        
        # This should match the original unshuffled data
        assert original_cell_id == original_X_pca[np.where(original_index == old_idx)[0][0], 0]
    
    # Test 3: Verify we can fully reconstruct the mapping
    # Create inverse mapping: from old_obs_index position -> shuffled adata position
    old_idx_to_shuffled_pos = {
        old_idx: np.where(adata_shuffled.obs.index == old_idx)[0][0]
        for old_idx in gd.annotation.old_obs_index
    }
    
    # This should cover all cells in the shuffled adata
    assert len(old_idx_to_shuffled_pos) == len(adata_shuffled)
    
    print("✓ Ordering preservation test passed!")
    print(f"  - Successfully shuffled {len(adata)} cells")
    print(f"  - Created {len(gd.data.src_data)} source distributions")
    print(f"  - Created {len(gd.data.tgt_data)} target distributions")
    print(f"  - Successfully reconstructed original ordering via old_obs_index")

In [18]:
test_ordering_reconstruction_after_shuffle(adata)

✓ Ordering preservation test passed!
  - Successfully shuffled 1134 cells
  - Created 2 source distributions
  - Created 16 target distributions
  - Successfully reconstructed original ordering via old_obs_index


In [19]:
adata.obs[['drug', 'gene']].drop_duplicates()

Unnamed: 0,drug,gene
0,control,control
54,control,gene_A
108,control,gene_B
162,drug_A,control
324,drug_A,gene_A
486,drug_A,gene_B
648,drug_B,control
810,drug_B,gene_A
972,drug_B,gene_B


In [24]:
gd.annotation.old_obs_index

array(['0', '1', '2', ..., '1131', '1132', '1133'],
      shape=(1134,), dtype=object)

In [21]:
gd.annotation.src_tgt_dist_df

Unnamed: 0,src_dist_idx,tgt_dist_idx,cell_line,drug,gene
54,0,0,cell_line_A,control,gene_A
108,0,1,cell_line_A,control,gene_B
162,0,2,cell_line_A,drug_A,control
324,0,3,cell_line_A,drug_A,gene_A
486,0,4,cell_line_A,drug_A,gene_B
648,0,5,cell_line_A,drug_B,control
810,0,6,cell_line_A,drug_B,gene_A
972,0,7,cell_line_A,drug_B,gene_B
81,1,8,cell_line_B,control,gene_A
135,1,9,cell_line_B,control,gene_B
