In [None]:
%load_ext autoreload
%autoreload 2


import numpy as np
import pandas as pd
import anndata as ad
import numpy as np
import pandas as pd
import jax.numpy as jnp
import jax
jax.config.update("jax_default_device", jax.devices()[0])

In [2]:
def create_synthetic_data(
    n_genes=50,
    n_control_cells=40,
    n_drugs=6,
    dosages_per_drug=3,
    cells_per_condition=50,
    n_cell_types=30,
    cell_type_embed_dim=50,
    drug_embed_dim=50,
):
    """
    Creates a synthetic AnnData object with multiple dosages per drug.

    Parameters
    ----------
    n_genes : int
        Number of genes to simulate
    n_control_cells : int
        Number of control cells per cell type
    n_drugs : int
        Total number of distinct drugs
    dosages_per_drug : int
        Number of different dosages for each drug
    cells_per_condition : int
        Number of cells per drug-dosage condition
    n_cell_types : int
        Number of cell types
    cell_type_embed_dim : int
        Embedding dimension for cell types
    drug_embed_dim : int
        Embedding dimension for drugs

    Returns
    -------
    dict
        Dictionary containing all DataManager parameters
    """

    # Create cell type names
    n_batches = n_cell_types
    cell_type_names = [f"cell_line_{chr(97 + i)}" for i in range(n_cell_types)]
    batch_names = [f"batch_{i+1}" for i in range(n_batches)]  # New: create batch names

    # Calculate total cells
    total_conditions = n_drugs * dosages_per_drug  # Total conditions excluding control
    total_cells_per_type = n_control_cells + (total_conditions * cells_per_condition)
    n_cells = n_cell_types * total_cells_per_type

    # Initialize lists for observation data
    cell_type_list = []
    control_list = []
    drug_list = []
    dosage_list = []
    batch_list = []

    # Generate data for each cell type
    # shuffle batch_names
    batch_names = np.random.permutation(batch_names)

    for i,cell_type in enumerate(cell_type_names):
        # Add control cells for this cell type
        cell_type_list.extend([cell_type] * total_cells_per_type)
        # Controls
        control_list.extend([True] * n_control_cells)
        drug_list.extend(["control"] * n_control_cells)
        dosage_list.extend([0.0] * n_control_cells)
        # batch_assignments = np.random.choice(batch_names[i], size=total_cells_per_type)
        batch_list.extend([batch_names[i]] * total_cells_per_type)
        # ensure that at lease one type is present 


        # Add perturbed cells for each drug-dosage combination
        control_list.extend([False] * (total_conditions * cells_per_condition))

        # Add drug-dosage combinations
        for drug_idx in range(1, n_drugs + 1):
            for dosage_idx in range(1, dosages_per_drug + 1):
                # Calculate dosage value (e.g., 0.1, 0.5, 1.0)
                dosage_value = dosage_idx / dosages_per_drug  # Normalize to [0,1] range

                # Add this drug-dosage combination
                drug_list.extend([f"drug{drug_idx}"] * cells_per_condition)
                dosage_list.extend([dosage_value] * cells_per_condition)

    # Generate random expression data
    X = np.random.normal(size=(n_cells, n_genes))

    # Create observation DataFrame
    obs = pd.DataFrame(
        {
            "control": control_list,
            "cell_type": pd.Categorical(cell_type_list),
            "drug": pd.Categorical(drug_list),
            "dosage": dosage_list,
            # "batch": pd.Categorical(batch_list),
        }
    )



    # Create AnnData object
    adata = ad.AnnData(X, obs=obs)

    # Add representations to uns (for covariate embeddings)
    adata.uns["drug"] = {
        "control": np.zeros(drug_embed_dim),
    }

    # Add drug embeddings
    for i in range(1, n_drugs + 1):
        adata.uns["drug"][f"drug{i}"] = np.random.normal(size=(drug_embed_dim,))

    # Add cell type embeddings
    adata.uns["cell_type"] = {}
    # adata.uns["batch"] = {}
    for cell_type in cell_type_names:
        adata.uns["cell_type"][cell_type] = np.random.normal(
            size=(cell_type_embed_dim,)
        )
    # for batch in batch_names:
    #     adata.uns["batch"][batch] = np.random.normal(
    #         size=(cell_type_embed_dim,)
    #     )

    # Define parameters for DataManager
    sample_rep = "X"
    control_key = "control"
    split_covariates = ["cell_type",]

    # Here we use a simpler structure with just one drug and dosage column
    perturbation_covariates = {"drug": ["drug"], "dosage": ["dosage"]}
    perturbation_covariate_reps = {"drug": "drug"}
    sample_covariates = ["cell_type", ]
    sample_covariate_reps = {"cell_type": "cell_type", }



    # Return a dictionary with all required parameters
    return {
        "adata": adata,
        "sample_rep": sample_rep,
        "control_key": control_key,
        "split_covariates": split_covariates,
        "perturbation_covariates": perturbation_covariates,
        "perturbation_covariate_reps": perturbation_covariate_reps,
        "sample_covariates": sample_covariates,
        "sample_covariate_reps": sample_covariate_reps,
    }


In [3]:
# Now you can initialize the DataManager
from cfp.data._datamanager import DataManager

dm_args = create_synthetic_data()

dm = DataManager(
    **dm_args    
)

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
jax.config.update("jax_default_device", jax.devices()[0])


In [5]:
from cfp.model._cellflow import CellFlow

cf = CellFlow(adata=dm_args["adata"],) 

In [6]:
import functools

import cfp
import scanpy as sc
import numpy as np
import functools
from ott.solvers import utils as solver_utils
import optax
import anndata as ad

In [7]:
dm_args

{'adata': AnnData object with n_obs × n_vars = 28200 × 50
     obs: 'control', 'cell_type', 'drug', 'dosage'
     uns: 'drug', 'cell_type',
 'sample_rep': 'X',
 'control_key': 'control',
 'split_covariates': ['cell_type'],
 'perturbation_covariates': {'drug': ['drug'], 'dosage': ['dosage']},
 'perturbation_covariate_reps': {'drug': 'drug'},
 'sample_covariates': ['cell_type'],
 'sample_covariate_reps': {'cell_type': 'cell_type'}}

In [8]:
perturbation_covariates = dm_args["perturbation_covariates"]
perturbation_covariate_reps = dm_args["perturbation_covariate_reps"]
sample_covariates = dm_args["sample_covariates"]
sample_covariate_reps = dm_args["sample_covariate_reps"]


In [9]:
adata = dm_args["adata"]

In [10]:
cf.prepare_data(
    sample_rep="X",
    control_key="control",
    perturbation_covariates=perturbation_covariates,
    perturbation_covariate_reps=perturbation_covariate_reps,
    sample_covariates=sample_covariates,
    sample_covariate_reps=sample_covariate_reps,
    split_covariates=dm_args["split_covariates"],
)
print("Finished preparing data")

[########################################] | 100% Completed | 111.19 ms
[########################################] | 100% Completed | 102.89 ms
[##########                              ] | 26% Completed | 635.01 ms

IOStream.flush timed out


[########################################] | 100% Completed | 1.90 sms
Finished preparing data


In [11]:
# condition_embedding_dim: 256
# time_encoder_dims: [2048, 2048, 2048]
# time_encoder_dropout: 0.0
# hidden_dims: [2048, 2048, 2048]
# hidden_dropout: 0.0
# decoder_dims: [4096, 4096, 4096]
# decoder_dropout: 0.2
# pooling: "mean"
# layers_before_pool: 
#   drugs:
#     layer_type: mlp
#     dims: [1024, 1024]
#     dropout_rate: 0.5
#   dose:
#     layer_type: mlp
#     dims: [256, 256]
#     dropout_rate: 0.0
#   cell_line:
#     layer_type: mlp
#     dims: [1024, 1024]
#     dropout_rate: 0.2
# layers_after_pool:
#   layer_type: mlp
#   dims: [1024, 1024]
#   dropout_rate: 0.2
# cond_output_dropout: 0.9
# time_freqs: 1024
# flow_noise: 1.0
# learning_rate: 0.00005
# multi_steps: 20
# epsilon: 1.0
# tau_a: 1.0
# tau_b: 1.0
# flow_type: "constant_noise"
# linear_projection_before_concatenation: False
# layer_norm_before_concatenation: False

In [12]:



match_fn = functools.partial(
    solver_utils.match_linear,
    epsilon=1.0,
    scale_cost="mean",
    tau_a=1.0,
    tau_b=1.0,
)
optimizer = optax.MultiSteps(optax.adam(0.00005), 20)
flow = {
        "constant_noise": 1.0}
layers_before_pool = {
    "drug": {
        "layer_type": "mlp",
        "dims": [1024, 1024],
        "dropout_rate": 0.5,
    },
    "dosage": {
        "layer_type": "mlp",
        "dims": [256, 256],
        "dropout_rate": 0.0,
    },
    "cell_type": {
        "layer_type": "mlp",
        "dims": [1024, 1024],
        "dropout_rate": 0.2,
    },
}
layers_after_pool = {
    "layer_type": "mlp",
    "dims": [1024, 1024],
    "dropout_rate": 0.2,
}
condition_embedding_dim = 256
pooling = "mean"
time_encoder_dims = [2048, 2048, 2048]
time_encoder_dropout = 0.0
hidden_dims = [2048, 2048, 2048]
hidden_dropout = 0.0
decoder_dims = [4096, 4096, 4096]
decoder_dropout = 0.2
cond_output_dropout = 0.9
time_freqs = 1024
layer_norm_before_concatenation = False
linear_projection_before_concatenation = False
# Prepare the model
print("Preparing model...")
cf.prepare_model(
    encode_conditions=True,
    condition_embedding_dim=condition_embedding_dim,
    pooling=pooling,
    time_encoder_dims=time_encoder_dims,
    time_encoder_dropout=time_encoder_dropout,
    hidden_dims=hidden_dims,
    hidden_dropout=hidden_dropout,
    decoder_dims=decoder_dims,
    decoder_dropout=decoder_dropout,
    layers_before_pool=layers_before_pool,
    layers_after_pool=layers_after_pool,
    cond_output_dropout=cond_output_dropout,
    time_freqs=time_freqs,
    match_fn=match_fn,
    optimizer=optimizer,
    flow=flow,
    layer_norm_before_concatenation=False,
    linear_projection_before_concatenation=False,
)

print("Begin training")


Preparing model...
Begin training


In [13]:
# show the devices for each leaf
# scondition_encoder': {'after_pool_modules_0': {'bias

In [14]:
cf.train(
    num_iterations=1000,
    batch_size=1024,
    callbacks=[],
    valid_freq=100,
)

 13%|█▎        | 131/1000 [00:36<04:00,  3.61it/s, loss=1.99]


KeyboardInterrupt: 

In [None]:
cf.solver.vf_state.params['condition_encoder']['after_pool_modules_0']['kernel'].devices()

{CpuDevice(id=0)}

In [None]:
import jax
# check the speed of sample


{'condition': {'cell_type': Array([[[ 1.7382015 , -0.4042862 ,  0.5841826 , -0.06634207,
           -0.04769267, -0.15670413,  0.7732949 ,  0.4783209 ,
           -1.3973482 ,  0.10488927,  1.1956937 , -0.04145573,
           -0.80451876, -1.7779018 ,  2.0243518 ,  1.091451  ,
           -0.5055642 , -0.94782007, -1.0034313 , -0.85201305,
           -1.026559  ,  0.28274524,  0.7853366 ,  0.9337822 ,
            0.0797888 , -0.1512465 , -0.62987643, -0.6104926 ,
           -0.07386218, -1.6913621 ,  1.5826727 , -2.0818222 ,
           -0.65326655,  1.5192457 ,  0.8052003 , -0.566199  ,
            0.17454635,  0.68401533,  0.394697  ,  1.0484004 ,
           -2.017751  , -1.4909996 , -1.868666  , -0.81145805,
           -0.7812861 ,  0.34504193,  0.9844148 , -1.8931816 ,
           -0.8891171 , -1.2514101 ]]], dtype=float32),
  'dosage': Array([[[0.33333334]]], dtype=float32),
  'drug': Array([[[-1.7092513 ,  0.57189363, -0.50048774,  0.6205742 ,
            0.5630994 ,  0.19638546, -0

In [None]:
class IterativeSampler:
    def __init__(self, dataloader, rng, num_iterations):
        self.dataloader = dataloader
        self.rng = rng
        self.num_iterations = num_iterations


    def __iter__(self):
        for _ in range(self.num_iterations):
            self.rng, rng_data = jax.random.split(self.rng, 2)
            batch = self.dataloader.sample(rng_data)
            yield batch


In [None]:
iter_sample = IterativeSampler(dataloader=cf.dataloader, rng=rng, num_iterations=100000)


In [None]:
# check the speed of each iteration
import time
import tqdm


In [None]:
pbar = tqdm.tqdm(iter_sample)

for batch in pbar:
    pass

1704it [00:18, 94.02it/s]


KeyboardInterrupt: 

In [None]:
pbar = tqdm.tqdm(iter_sample)
for batch in pbar:
    batch = jax.device_put(batch)

0it [00:00, ?it/s]

337it [00:03, 96.11it/s]


KeyboardInterrupt: 

True

In [None]:
import dask.dataframe as dd




ddf = dd.from_pandas(df, npartitions=2)




ddf["global_control_comb"] = ddf[split_covariates].apply(lambda x: "_".join(x.astype(str)), axis=1, meta=("x", "object"))
ddf["global_pert_comb"] = ddf[perturbation_covariates_keys + split_covariates].apply(
    lambda x: "_".join(x.astype(str)), axis=1, meta=("x", "object")
)


# df.loc[~df.control, "control_comb"] = np.nan
# df.loc[df.control, "pert_comb"] = np.nan
ddf["control_comb"] = ddf["global_control_comb"].where(ddf["control"], np.nan)
ddf["pert_comb"] = ddf["global_pert_comb"].where(~ddf["control"], np.nan)
# ddf.assign(
#     control_comb=ddf["global_control_comb"].where(ddf["control"], np.nan),
#     pert_comb=ddf["global_pert_comb"].where(~ddf["control"], np.nan),
# )

ddf = ddf.categorize(columns=["global_control_comb", "global_pert_comb", "control_comb", "pert_comb"])

ddf["global_permutation_cov_mask"] = ddf["global_pert_comb"].cat.codes
ddf["global_split_covariates_mask"] = ddf["global_control_comb"].cat.codes
ddf["split_covariates_mask"] = ddf["control_comb"].cat.codes
ddf["perturbation_covariates_mask"] = ddf["pert_comb"].cat.codes


df = ddf.compute()


split_idx_to_covariates = (
    df[["global_split_covariates_mask", *split_covariates]]
    .groupby(["global_split_covariates_mask"])
    .first()
    .to_dict(orient="index")
)
split_idx_to_covariates = {
    k: tuple(v[s] for s in split_covariates) for k, v in split_idx_to_covariates.items()
}

perturbation_idx_to_covariates = (
    df[["global_permutation_cov_mask", *perturbation_covariates_keys, *split_covariates]]
    .groupby(["global_permutation_cov_mask"])
    .first()
    .to_dict(orient="index")
)
perturbation_idx_to_covariates = {
    k: [v[s] for s in [*perturbation_covariates_keys, *split_covariates]]
    for k, v in perturbation_idx_to_covariates.items()
}

control_to_perturbation = (
    df[~df.control].groupby(["global_split_covariates_mask"])["perturbation_covariates_mask"].unique()
)
control_to_perturbation = control_to_perturbation.to_dict()

split_covariates_mask = jnp.asarray(df["split_covariates_mask"].values)
perturbation_covariates_mask = jnp.asarray(df["perturbation_covariates_mask"].values)

NameError: name 'df' is not defined

In [None]:
ddf.compute()

Unnamed: 0,cell_type,batch,drug,dosage,control,global_control_comb,global_pert_comb,control_comb,pert_comb,global_permutation_cov_mask,global_split_covariates_mask,split_covariates_mask,perturbation_covariates_mask
0,cell_line_a,batch_20,control,0.000000,True,cell_line_a_batch_20,control_0.0_cell_line_a_batch_20,cell_line_a_batch_20,,0,0,0,-1
1,cell_line_a,batch_20,control,0.000000,True,cell_line_a_batch_20,control_0.0_cell_line_a_batch_20,cell_line_a_batch_20,,0,0,0,-1
10,cell_line_a,batch_20,control,0.000000,True,cell_line_a_batch_20,control_0.0_cell_line_a_batch_20,cell_line_a_batch_20,,0,0,0,-1
100,cell_line_a,batch_20,drug1,0.666667,False,cell_line_a_batch_20,drug1_0.6666666666666666_cell_line_a_batch_20,,drug1_0.6666666666666666_cell_line_a_batch_20,60,0,-1,30
1000,cell_line_b,batch_19,drug1,0.333333,False,cell_line_b_batch_19,drug1_0.3333333333333333_cell_line_b_batch_19,,drug1_0.3333333333333333_cell_line_b_batch_19,31,1,-1,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,cell_line_k,batch_10,drug4,1.000000,False,cell_line_k_batch_10,drug4_1.0_cell_line_k_batch_10,,drug4_1.0_cell_line_k_batch_10,370,10,-1,340
9996,cell_line_k,batch_10,drug4,1.000000,False,cell_line_k_batch_10,drug4_1.0_cell_line_k_batch_10,,drug4_1.0_cell_line_k_batch_10,370,10,-1,340
9997,cell_line_k,batch_10,drug4,1.000000,False,cell_line_k_batch_10,drug4_1.0_cell_line_k_batch_10,,drug4_1.0_cell_line_k_batch_10,370,10,-1,340
9998,cell_line_k,batch_10,drug4,1.000000,False,cell_line_k_batch_10,drug4_1.0_cell_line_k_batch_10,,drug4_1.0_cell_line_k_batch_10,370,10,-1,340


In [None]:
ddf.persist()


Unnamed: 0_level_0,cell_type,batch,drug,dosage,control,global_control_comb,global_pert_comb,control_comb,pert_comb,global_permutation_cov_mask,global_split_covariates_mask,split_covariates_mask,perturbation_covariates_mask
npartitions=2,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
0,category[known],category[known],category[known],float64,boolean,category[known],category[known],category[known],category[known],int16,int8,int8,int16
22688,...,...,...,...,...,...,...,...,...,...,...,...,...
9999,...,...,...,...,...,...,...,...,...,...,...,...,...


In [None]:
ddf['global_permutation_cov_mask'].compute()

KeyError: 'global_control_comb'

In [None]:
split_idx_to_covariates = (
    ddf[["global_split_covariates_mask", *split_covariates]].compute()
    .groupby(["global_split_covariates_mask"])
    .first()
    .to_dict(orient="index")
)

KeyError: 'global_pert_comb'

In [None]:

split_idx_to_covariates = (
    ddf[["global_split_covariates_mask", *split_covariates]]
    .groupby(["global_split_covariates_mask"])
    .first().compute()
    .to_dict(orient="index")
)
split_idx_to_covariates = {
    k: tuple(v[s] for s in split_covariates) for k, v in split_idx_to_covariates.items()
}

perturbation_idx_to_covariates = (
    df[["global_permutation_cov_mask", *perturbation_covariates_keys, *split_covariates]]
    .groupby(["global_permutation_cov_mask"])
    .first()
    .to_dict(orient="index")
)
perturbation_idx_to_covariates = {
    k: [v[s] for s in [*perturbation_covariates_keys, *split_covariates]]
    for k, v in perturbation_idx_to_covariates.items()
}

control_to_perturbation = (
    df[~df.control].groupby(["global_split_covariates_mask"])["perturbation_covariates_mask"].unique()
)
control_to_perturbation = control_to_perturbation.to_dict()

split_covariates_mask = jnp.asarray(df["split_covariates_mask"].values)
perturbation_covariates_mask = jnp.asarray(df["perturbation_covariates_mask"].values)

NameError: name 'ddf' is not defined

In [None]:
import dask.dataframe as dd




ddf = dd.from_pandas(df, npartitions=2)




ddf["global_control_comb"] = ddf[split_covariates].apply(lambda x: "_".join(x.astype(str)), axis=1)
ddf["global_pert_comb"] = ddf[perturbation_covariates_keys + split_covariates].apply(
    lambda x: "_".join(x.astype(str)), axis=1
)

# df.loc[~df.control, "control_comb"] = np.nan
# df.loc[df.control, "pert_comb"] = np.nan
ddf["control_comb"] = ddf["global_control_comb"].where(ddf["control"], np.nan)
ddf["pert_comb"] = ddf["global_pert_comb"].where(~ddf["control"], np.nan)
# ddf.assign(
#     control_comb=ddf["global_control_comb"].where(ddf["control"], np.nan),
#     pert_comb=ddf["global_pert_comb"].where(~ddf["control"], np.nan),
# )

ddf = ddf.categorize(columns=["global_control_comb", "global_pert_comb", "control_comb", "pert_comb"])


ddf["global_pert_comb"] = ddf["pert_comb"].astype("category")
ddf["global_control_comb"] = ddf["global_control_comb"].astype("category")
df["control_comb"] = df["control_comb"].astype("category")
df["pert_comb"] = df["pert_comb"].astype("category")

df["global_permutation_cov_mask"] = df["global_pert_comb"].cat.codes
df["global_split_covariates_mask"] = df["global_control_comb"].cat.codes
df["split_covariates_mask"] = df["control_comb"].cat.codes
df["perturbation_covariates_mask"] = df["pert_comb"].cat.codes

split_idx_to_covariates = (
    df[["global_split_covariates_mask", *split_covariates]]
    .groupby(["global_split_covariates_mask"])
    .first()
    .to_dict(orient="index")
)
split_idx_to_covariates = {
    k: tuple(v[s] for s in split_covariates) for k, v in split_idx_to_covariates.items()
}

perturbation_idx_to_covariates = (
    df[["global_permutation_cov_mask", *perturbation_covariates_keys, *split_covariates]]
    .groupby(["global_permutation_cov_mask"])
    .first()
    .to_dict(orient="index")
)
perturbation_idx_to_covariates = {
    k: [v[s] for s in [*perturbation_covariates_keys, *split_covariates]]
    for k, v in perturbation_idx_to_covariates.items()
}

control_to_perturbation = (
    df[~df.control].groupby(["global_split_covariates_mask"])["perturbation_covariates_mask"].unique()
)
control_to_perturbation = control_to_perturbation.to_dict()

split_covariates_mask = jnp.asarray(df["split_covariates_mask"].values)
perturbation_covariates_mask = jnp.asarray(df["perturbation_covariates_mask"].values)

You did not provide metadata, so Dask is running your function on a small dataset to guess output types. It is possible that Dask will guess incorrectly.
To provide an explicit output types or to silence this message, please provide the `meta=` keyword, as described in the map or apply function that you are using.
  Before: .apply(func)
  After:  .apply(func, meta=(None, 'object'))

You did not provide metadata, so Dask is running your function on a small dataset to guess output types. It is possible that Dask will guess incorrectly.
To provide an explicit output types or to silence this message, please provide the `meta=` keyword, as described in the map or apply function that you are using.
  Before: .apply(func)
  After:  .apply(func, meta=(None, 'object'))



KeyError: 'pert_comb'

In [None]:
df = orig_df.copy()
split_covariates = dm.split_covariates
perturbation_covariates_keys = list(dm.perturbation_covariates.keys())

df['global_control_comb'] = df['cell_type'].astype(str) + '_' + df['batch'].astype(str)
df['global_control_comb'] = df['global_control_comb'].astype('category')
df['global_split_covariates_mask'] = df['global_control_comb'].cat.codes
df['pert_comb'] = df['cell_type'].astype(str) + '_' + df['batch'].astype(str) + '_' + df['drug'].astype(str) + '_' + df['dosage'].astype(str)
df.loc[~df.control, 'control_comb'] = np.nan
df.loc[df.control, 'pert_comb'] = np.nan
df['control_cov_comb'] = df['control_cov_comb'].astype('category')
df['pert_comb'] = df['pert_comb'].astype('category')
df['split_covariates_mask'] = df['control_cov_comb'].cat.codes
df['perturbation_covariates_mask'] = df['pert_comb'].cat.codes

split_idx_to_covariates = df[['split_covariates_mask','cell_type','batch']].groupby(['split_covariates_mask']).first().to_dict(orient='index')
del split_idx_to_covariates[-1]
split_idx_to_covariates = {k: tuple(v[s] for s in split_covariates) for k,v in split_idx_to_covariates.items()}

perturbation_idx_to_covariates = df[['perturbation_covariates_mask','drug', 'dosage', 'cell_type','batch']].groupby(['perturbation_covariates_mask']).first().to_dict(orient='index')
del perturbation_idx_to_covariates[-1]
perturbation_idx_to_covariates = {k: [v[s] for s in [*perturbation_covariates_keys,*split_covariates]] for k,v in perturbation_idx_to_covariates.items()}

control_to_perturbation = df[~df.control].groupby(['global_split_covariates_mask'])['perturbation_covariates_mask'].unique()
control_to_perturbation = control_to_perturbation.to_dict()

split_covariates_mask = jnp.asarray(df['split_covariates_mask'].values)
perturbation_covariates_mask = jnp.asarray(df['perturbation_covariates_mask'].values)

In [None]:
split_idx_to_covariates = df[['split_covariates_mask','cell_type','batch']].groupby(['split_covariates_mask']).first().to_dict(orient='index')
del split_idx_to_covariates[-1]
split_idx_to_covariates = {k: tuple(v[s] for s in split_covariates) for k,v in split_idx_to_covariates.items()}
split_idx_to_covariates == condition_data.split_idx_to_covariates

False

In [None]:
perturbation_idx_to_covariates = df[['perturbation_covariates_mask','drug', 'dosage', 'cell_type','batch']].groupby(['perturbation_covariates_mask']).first().to_dict(orient='index')
del perturbation_idx_to_covariates[-1]
perturbation_idx_to_covariates = {k: [v[s] for s in [*perturbation_covariates_keys,*split_covariates]] for k,v in perturbation_idx_to_covariates.items()}
# make assertions
for k in perturbation_idx_to_covariates.keys():
    for v1,v2 in zip(perturbation_idx_to_covariates[k],condition_data.perturbation_idx_to_covariates[k], strict=True):
        assert v1 == v2

In [None]:
split_idx_to_covariates

{}

In [None]:
for _,pert_cov in df[[*perturbation_covariates_keys,*split_covariates]].drop_duplicates().reset_index().iterrows():
    print(pert_cov)

index                  0
drug             control
dosage               0.0
cell_type    cell_line_a
batch            batch_9
Name: 0, dtype: object
index                 40
drug               drug1
dosage          0.333333
cell_type    cell_line_a
batch            batch_9
Name: 1, dtype: object
index                 90
drug               drug1
dosage          0.666667
cell_type    cell_line_a
batch            batch_9
Name: 2, dtype: object
index                140
drug               drug1
dosage               1.0
cell_type    cell_line_a
batch            batch_9
Name: 3, dtype: object
index                190
drug               drug2
dosage          0.333333
cell_type    cell_line_a
batch            batch_9
Name: 4, dtype: object
index                240
drug               drug2
dosage          0.666667
cell_type    cell_line_a
batch            batch_9
Name: 5, dtype: object
index                290
drug               drug2
dosage               1.0
cell_type    cell_line_a
batch       

In [None]:
# append the results of groupby to the original dataframe
control_to_perturbation = df[~df.control].groupby(['global_split_covariates_mask'])['perturbation_covariates_mask'].unique()
control_to_perturbation = control_to_perturbation.to_dict()
for k,v in control_to_perturbation.items():
    assert (v == condition_data.control_to_perturbation[k]).all()

NameError: name 'df' is not defined

In [None]:
# create the control to perturbation mapping for each control covariate
# from the dataframe
control_to_perturbation = {}


KeyError: 'cell_line_a_batch_23'

In [None]:
perturbation_idx_to_covariates

{0: ['drug1', 0.3333333333333333, 'cell_line_a', 'batch_9'],
 1: ['drug1', 0.6666666666666666, 'cell_line_a', 'batch_9'],
 2: ['drug1', 1.0, 'cell_line_a', 'batch_9'],
 3: ['drug2', 0.3333333333333333, 'cell_line_a', 'batch_9'],
 4: ['drug2', 0.6666666666666666, 'cell_line_a', 'batch_9'],
 5: ['drug2', 1.0, 'cell_line_a', 'batch_9'],
 6: ['drug3', 0.3333333333333333, 'cell_line_a', 'batch_9'],
 7: ['drug3', 0.6666666666666666, 'cell_line_a', 'batch_9'],
 8: ['drug3', 1.0, 'cell_line_a', 'batch_9'],
 9: ['drug4', 0.3333333333333333, 'cell_line_a', 'batch_9'],
 10: ['drug4', 0.6666666666666666, 'cell_line_a', 'batch_9'],
 11: ['drug4', 1.0, 'cell_line_a', 'batch_9'],
 12: ['drug5', 0.3333333333333333, 'cell_line_a', 'batch_9'],
 13: ['drug5', 0.6666666666666666, 'cell_line_a', 'batch_9'],
 14: ['drug5', 1.0, 'cell_line_a', 'batch_9'],
 15: ['drug6', 0.3333333333333333, 'cell_line_a', 'batch_9'],
 16: ['drug6', 0.6666666666666666, 'cell_line_a', 'batch_9'],
 17: ['drug6', 1.0, 'cell_line_

In [None]:
def _to_list(x):
    """Converts x to a list if it is not already a list or tuple."""
    if isinstance(x, (list | tuple)):
        return x
    return [x]
perturb_covariates={
    k: _to_list(v)
    for k, v in dm._perturbation_covariates.items()
}

In [None]:
perturbation_idx_to_covariates

{0: ['drug1', 0.3333333333333333, 'cell_line_a', 'batch_9'],
 1: ['drug1', 0.6666666666666666, 'cell_line_a', 'batch_9'],
 2: ['drug1', 1.0, 'cell_line_a', 'batch_9'],
 3: ['drug2', 0.3333333333333333, 'cell_line_a', 'batch_9'],
 4: ['drug2', 0.6666666666666666, 'cell_line_a', 'batch_9'],
 5: ['drug2', 1.0, 'cell_line_a', 'batch_9'],
 6: ['drug3', 0.3333333333333333, 'cell_line_a', 'batch_9'],
 7: ['drug3', 0.6666666666666666, 'cell_line_a', 'batch_9'],
 8: ['drug3', 1.0, 'cell_line_a', 'batch_9'],
 9: ['drug4', 0.3333333333333333, 'cell_line_a', 'batch_9'],
 10: ['drug4', 0.6666666666666666, 'cell_line_a', 'batch_9'],
 11: ['drug4', 1.0, 'cell_line_a', 'batch_9'],
 12: ['drug5', 0.3333333333333333, 'cell_line_a', 'batch_9'],
 13: ['drug5', 0.6666666666666666, 'cell_line_a', 'batch_9'],
 14: ['drug5', 1.0, 'cell_line_a', 'batch_9'],
 15: ['drug6', 0.3333333333333333, 'cell_line_a', 'batch_9'],
 16: ['drug6', 0.6666666666666666, 'cell_line_a', 'batch_9'],
 17: ['drug6', 1.0, 'cell_line_

In [None]:
pert_covs = df[['drug', 'dosage', 'cell_type','batch']].drop_duplicates().iloc[0]

dm._get_perturbation_covariates(
    condition_data=pert_covs,
    rep_dict=dm.adata.uns,
    perturb_covariates=perturb_covariates,
)

drug


{'drug': Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0.]], dtype=float32),
 'dosage': Array([[0.]], dtype=float32),
 'cell_type': Array([[ 1.6821451 , -2.0315406 , -0.68855745, -0.54416126,  2.4167037 ,
         -1.9938573 , -0.9390149 ,  2.5959558 ,  1.3020833 ,  0.03520354,
         -0.02320502,  0.92716867,  1.9261419 ,  1.118803  , -0.219788  ,
         -0.7652571 , -0.05078882, -0.52959687,  0.07656004, -1.2925221 ,
          0.30351853,  0.18040532,  0.05841567, -0.38600418, -0.44474316,
         -0.33843622,  2.0356743 ,  0.99912345,  0.5590456 ,  0.21152794,
         -0.29966092,  1.3150667 , -0.0958819 , -0.25833365, -0.4785732 ,
          0.63827723,  0.252799  ,  1.2900095 ,  1.8067259 ,  0.5567474 ,
         -0.7481424 ,  0.9808109 ,  0.11897202, -0.20283894, -0.89784575,
          1.192107

In [None]:
pert_covs

Unnamed: 0,drug,dosage,cell_type,batch
0,control,0.000000,cell_line_a,batch_9
40,drug1,0.333333,cell_line_a,batch_9
90,drug1,0.666667,cell_line_a,batch_9
140,drug1,1.000000,cell_line_a,batch_9
190,drug2,0.333333,cell_line_a,batch_9
...,...,...,...,...
27950,drug5,0.666667,cell_line_~,batch_6
28000,drug5,1.000000,cell_line_~,batch_6
28050,drug6,0.333333,cell_line_~,batch_6
28100,drug6,0.666667,cell_line_~,batch_6


True

In [None]:
import pandas as pd
import numpy as np
import dask.dataframe as dd

# Start with a fresh copy
df = orig_df.copy()

# 1. Create split_covariates_mask for control cells
# Get unique split combinations and assign sequential IDs
split_groups = (df[df['control']]
                [['cell_type', 'batch']]
                .drop_duplicates()
                .reset_index(drop=True)
                .reset_index()
                .rename(columns={'index': 'split_id'}))

# 2. Create perturbation_covariates_mask for non-control cells
# Process one split at a time to maintain original order
pert_list = []
for _, split_comb in split_groups.iterrows():
    # Get perturbations for this split in their original order
    split_perts = (df[
        (df['cell_type'] == split_comb['cell_type']) & 
        (df['batch'] == split_comb['batch']) &
        ~df['control']
    ][['cell_type', 'batch', 'drug', 'dosage']]
     .drop_duplicates())

    if not split_perts.empty:
        pert_list.append(split_perts)

# Combine all perturbations maintaining order
ordered_perts = (pd.concat(pert_list, ignore_index=True)
                .reset_index(drop=True)
                .reset_index()
                .rename(columns={'index': 'pert_id'}))

# 3. Merge the masks back to main dataframe
df = (df.merge(split_groups, on=['cell_type', 'batch'], how='left')
      .assign(split_covariates_mask=lambda x: 
              x['split_id'].where(x['control'], -1))
      .drop(columns=['split_id']))

df = (df.merge(ordered_perts,
               on=['cell_type', 'batch', 'drug', 'dosage'],
               how='left')
      .assign(perturbation_covariates_mask=lambda x: 
              x['pert_id'].where(~x['control'], -1))
      .drop(columns=['pert_id']))

# 4. Convert to correct integer type
df['split_covariates_mask'] = df['split_covariates_mask'].fillna(-1).astype('int32')
df['perturbation_covariates_mask'] = df['perturbation_covariates_mask'].fillna(-1).astype('int32')

# Verify masks
assert (df['perturbation_covariates_mask'].values == condition_data.perturbation_covariates_mask).all()
assert (df['split_covariates_mask'].values == condition_data.split_covariates_mask).all()

In [None]:
import dask.dataframe as dd

# Start with a fresh copy and convert to dask
ddf = dd.from_pandas(orig_df, npartitions=8)  # Adjust partition count based on available cores

# Extract control and non-control data once
control_mask = ddf['control']
control_data = ddf[control_mask].persist()
non_control_data = ddf[~control_mask].persist()

# 1. Process split covariates mask (for control cells)
split_groups = (control_data[['cell_type', 'batch']]
                .drop_duplicates()
                .reset_index(drop=True)
                .compute())  # Compute for sequential IDs
split_groups = split_groups.reset_index().rename(columns={'index': 'split_id'})

# 2. Process perturbation covariates mask more efficiently
# Group by cell_type and batch first to reduce computation
grouped_non_control = non_control_data.groupby(['cell_type', 'batch'])
                     
# Create a function to get perturbations for each group
def get_perturbations(group_df):
    return group_df[['cell_type', 'batch', 'drug', 'dosage']].drop_duplicates()

# Apply to each group and collect results
pert_groups = grouped_non_control.apply(get_perturbations, 
                                       meta=non_control_data[['cell_type', 'batch', 'drug', 'dosage']])
ordered_perts = pert_groups.compute()
ordered_perts = ordered_perts.reset_index(drop=True).reset_index()
ordered_perts = ordered_perts.rename(columns={'index': 'pert_id'})

# 3. Map results back efficiently using map_partitions
def assign_masks(df, splits, perts):
    # Use fast dictionary lookups instead of merges
    split_dict = {(r['cell_type'], r['batch']): r['split_id'] 
                 for _, r in splits.iterrows()}
    pert_dict = {(r['cell_type'], r['batch'], r['drug'], r['dosage']): r['pert_id'] 
                for _, r in perts.iterrows()}
    
    # Create lookup keys
    df['split_key'] = list(zip(df['cell_type'], df['batch']))
    df['pert_key'] = list(zip(df['cell_type'], df['batch'], df['drug'], df['dosage']))
    
    # Apply masks
    df['split_covariates_mask'] = -1
    mask = df['control']
    df.loc[mask, 'split_covariates_mask'] = df.loc[mask, 'split_key'].map(split_dict)
    
    df['perturbation_covariates_mask'] = -1
    mask = ~df['control']
    df.loc[mask, 'perturbation_covariates_mask'] = df.loc[mask, 'pert_key'].map(pert_dict)
    
    # Clean up and convert types
    df = df.drop(['split_key', 'pert_key'], axis=1)
    df['split_covariates_mask'] = df['split_covariates_mask'].fillna(-1).astype('int32')
    df['perturbation_covariates_mask'] = df['perturbation_covariates_mask'].fillna(-1).astype('int32')
    return df

# Apply the function to each partition
final_ddf = ddf.map_partitions(assign_masks, splits=split_groups, perts=ordered_perts)

# Compute the final result
df = final_ddf.compute()

  self._meta = self.obj._meta.groupby(


AssertionError: 

In [None]:
import pandas as pd
import numpy as np
import dask.dataframe as dd

# Start with a fresh copy
df = orig_df.copy()

# 1. Create split_covariates_mask for control cells
# Get unique split combinations and assign sequential IDs
split_groups = (df[df['control']]
                [['cell_type', 'batch']]
                .drop_duplicates()
                .reset_index(drop=True)
                .reset_index()
                .rename(columns={'index': 'split_id'}))

# 2. Create perturbation_covariates_mask for non-control cells
# Process one split at a time to maintain original order
pert_list = []
for _, split_comb in split_groups.iterrows():
    # Get perturbations for this split in their original order
    split_perts = (df[
        (df['cell_type'] == split_comb['cell_type']) & 
        (df['batch'] == split_comb['batch']) &
        ~df['control']
    ][['cell_type', 'batch', 'drug', 'dosage']]
     .drop_duplicates())
    
    if not split_perts.empty:
        pert_list.append(split_perts)

# Combine all perturbations maintaining order
ordered_perts = (pd.concat(pert_list, ignore_index=True)
                .reset_index(drop=True)
                .reset_index()
                .rename(columns={'index': 'pert_id'}))

# 3. Merge the masks back to main dataframe
df = (df.merge(split_groups, on=['cell_type', 'batch'], how='left')
      .assign(split_covariates_mask=lambda x: 
              x['split_id'].where(x['control'], -1))
      .drop(columns=['split_id']))

df = (df.merge(ordered_perts,
               on=['cell_type', 'batch', 'drug', 'dosage'],
               how='left')
      .assign(perturbation_covariates_mask=lambda x: 
              x['pert_id'].where(~x['control'], -1))
      .drop(columns=['pert_id']))

# 4. Convert to correct integer type
df['split_covariates_mask'] = df['split_covariates_mask'].fillna(-1).astype('int32')
df['perturbation_covariates_mask'] = df['perturbation_covariates_mask'].fillna(-1).astype('int32')

# Verify masks
assert (df['perturbation_covariates_mask'].values == condition_data.perturbation_covariates_mask).all()
assert (df['split_covariates_mask'].values == condition_data.split_covariates_mask).all()

In [None]:
# Convert to Dask DataFrame
ddf = dd.from_pandas(orig_df, npartitions=4)  # Adjust npartitions as needed

# Compute unique combinations
split_groups = (ddf[ddf['control']]
                [['cell_type', 'batch']]
                .drop_duplicates()
                .compute()
                .reset_index(drop=True)
                .reset_index()
                .rename(columns={'index': 'split_id'}))

# Process perturbations split by split
pert_list = []
for _, split_comb in split_groups.iterrows():
    split_perts = (ddf[
        (ddf['cell_type'] == split_comb['cell_type']) & 
        (ddf['batch'] == split_comb['batch']) &
        ~ddf['control']
    ][['cell_type', 'batch', 'drug', 'dosage']]
     .drop_duplicates()
     .compute())
    
    if not split_perts.empty:
        pert_list.append(split_perts)

# Create ordered perturbation IDs
ordered_perts = (pd.concat(pert_list, ignore_index=True)
                .reset_index(drop=True)
                .reset_index()
                .rename(columns={'index': 'pert_id'}))

# Convert to Dask DataFrames for merging
split_ddf = dd.from_pandas(split_groups, npartitions=1)
pert_ddf = dd.from_pandas(ordered_perts, npartitions=1)

# Merge and assign values
result = ddf.merge(split_ddf, on=['cell_type', 'batch'], how='left')
result = result.merge(pert_ddf, on=['cell_type', 'batch', 'drug', 'dosage'], how='left')
result = result.assign(
    split_covariates_mask=lambda x: x['split_id'].where(x['control'], -1),
    perturbation_covariates_mask=lambda x: x['pert_id'].where(~x['control'], -1)
)
result = result.drop(['split_id', 'pert_id'], axis=1)

# Convert to integer type and compute
df = result.astype({
    'split_covariates_mask': 'int32',
    'perturbation_covariates_mask': 'int32'
}).compute()

# Final verification
assert (df['perturbation_covariates_mask'].values == condition_data.perturbation_covariates_mask).all()
assert (df['split_covariates_mask'].values == condition_data.split_covariates_mask).all()

AssertionError: 

In [None]:
# Pure Dask version
ddf = dd.from_pandas(orig_df, npartitions=4)

# Create keys
ddf = ddf.assign(
    split_key=lambda x: x['cell_type'].astype(str) + '_' + x['batch'].astype(str),
    pert_key=lambda x: x['cell_type'].astype(str) + '_' + x['batch'].astype(str) + 
              '_' + x['drug'].astype(str) + '_' + x['dosage'].astype(str)
)

# Force compute unique keys
split_keys = ddf[ddf['control']]['split_key'].unique().compute()
split_map = {key: idx for idx, key in enumerate(sorted(split_keys))}

pert_keys = ddf[~ddf['control']]['pert_key'].unique().compute()
pert_map = {key: idx for idx, key in enumerate(sorted(pert_keys))}

# Apply mapping
ddf = ddf.assign(
    split_covariates_mask=lambda x: x.apply(
        lambda row: split_map.get(row['split_key'], -1) if row['control'] else -1,
        axis=1,
        meta=('split_covariates_mask', 'int32')
    ),
    perturbation_covariates_mask=lambda x: x.apply(
        lambda row: pert_map.get(row['pert_key'], -1) if not row['control'] else -1,
        axis=1,
        meta=('perturbation_covariates_mask', 'int32')
    )
)

# Convert to proper types
ddf = ddf.drop(['split_key', 'pert_key'], axis=1)

# Compute final result
df = ddf.compute()

# Verify results
assert (df['split_covariates_mask'].values == condition_data.split_covariates_mask).all()
assert (df['perturbation_covariates_mask'].values == condition_data.perturbation_covariates_mask).all()

AssertionError: 

In [None]:
len(df.loc[df.control, ['cell_type', 'batch']].drop_duplicates()) == len(df.loc[~df.control, 'control_celltype_batch'].drop_duplicates())

True

In [None]:
len(df.loc[df.control, 'control_celltype_batch'].drop_duplicates())

694

In [None]:
import dask.dataframe as dd
import pandas as pd

# Assume orig_df is your large pandas DataFrame.
ddf = dd.from_pandas(orig_df, npartitions=10)

# Create a key for control rows:
ddf = ddf.assign(
    key_control=ddf[['cell_type', 'batch']]
                   .astype(str)
                   .apply(lambda row: '_'.join(row), axis=1, meta=('key_control', 'object'))
)

# Compute unique keys for control cells:
unique_control = ddf[ddf['control']].key_control.unique().compute()
control_map = {key: idx for idx, key in enumerate(unique_control)}

# Function to assign split_covariates_mask using the precomputed map:
def assign_split_mask(df, mapping):
    mask = df['control']
    keys = df[['cell_type', 'batch']].astype(str).agg('_'.join, axis=1)
    df.loc[mask, 'split_covariates_mask'] = keys[mask].map(mapping)
    return df

ddf = ddf.map_partitions(assign_split_mask, mapping=control_map)

# Similarly, for perturbations:
ddf = ddf.assign(
    key_pert=ddf[['cell_type', 'batch', 'drug', 'dosage']]
                   .astype(str)
                   .apply(lambda row: '_'.join(row), axis=1, meta=('key_control', 'object'))
)
# ddf = ddf.assign(key_control = ddf[['cell_type', 'batch', 'drug', 'dosage']].astype(str).agg('_'.join, axis=1))


unique_pert = ddf[~ddf['control']].key_pert.unique().compute()
pert_map = {key: idx for idx, key in enumerate(unique_pert)}

def assign_pert_mask(df, mapping):
    mask = ~df['control']
    keys = df[['cell_type', 'batch', 'drug', 'dosage']].astype(str).agg('_'.join, axis=1)
    df.loc[mask, 'perturbation_covariates_mask'] = keys[mask].map(mapping)
    return df

ddf = ddf.map_partitions(assign_pert_mask, mapping=pert_map)

# Finally, you may want to convert the new columns to int32
# ddf = ddf.astype({'split_covariates_mask': 'int32',
#                   'perturbation_covariates_mask': 'int32'})

ddf['split_covariates_mask'] = ddf['split_covariates_mask'].fillna(-1).astype('int32')
ddf['perturbation_covariates_mask'] = ddf['perturbation_covariates_mask'].fillna(-1).astype('int32')
# To materialize the result (if needed):
result_df = ddf.compute()

assert (result_df['split_covariates_mask'].values == condition_data.split_covariates_mask).all()

AssertionError: 

In [None]:
df = orig_df

Unnamed: 0,cell_type,batch,drug,dosage,control,key_control,split_covariates_mask,key_pert,perturbation_covariates_mask
0,cell_line_a,batch_18,control,0.000000,True,cell_line_a_batch_18,336,cell_line_a_batch_18_control_0.0,-1
1,cell_line_a,batch_21,control,0.000000,True,cell_line_a_batch_21,607,cell_line_a_batch_21_control_0.0,-1
10,cell_line_a,batch_30,control,0.000000,True,cell_line_a_batch_30,268,cell_line_a_batch_30_control_0.0,-1
100,cell_line_a,batch_6,drug1,0.666667,False,cell_line_a_batch_6,-1,cell_line_a_batch_6_drug1_0.6666666666666666,1321
1000,cell_line_b,batch_18,drug1,0.333333,False,cell_line_b_batch_18,-1,cell_line_b_batch_18_drug1_0.3333333333333333,6677
...,...,...,...,...,...,...,...,...,...
9995,cell_line_k,batch_24,drug4,1.000000,False,cell_line_k_batch_24,-1,cell_line_k_batch_24_drug4_1.0,3987
9996,cell_line_k,batch_8,drug4,1.000000,False,cell_line_k_batch_8,-1,cell_line_k_batch_8_drug4_1.0,2664
9997,cell_line_k,batch_18,drug4,1.000000,False,cell_line_k_batch_18,-1,cell_line_k_batch_18_drug4_1.0,10525
9998,cell_line_k,batch_20,drug4,1.000000,False,cell_line_k_batch_20,-1,cell_line_k_batch_20_drug4_1.0,8039


In [None]:
result_df

Unnamed: 0,cell_type,batch,drug,dosage,control,key_control,split_covariates_mask,key_pert,perturbation_covariates_mask
0,cell_line_a,batch_18,control,0.000000,True,cell_line_a_batch_18,336.0,cell_line_a_batch_18_control_0.0,
1,cell_line_a,batch_21,control,0.000000,True,cell_line_a_batch_21,607.0,cell_line_a_batch_21_control_0.0,
10,cell_line_a,batch_30,control,0.000000,True,cell_line_a_batch_30,268.0,cell_line_a_batch_30_control_0.0,
100,cell_line_a,batch_6,drug1,0.666667,False,cell_line_a_batch_6,,cell_line_a_batch_6_drug1_0.6666666666666666,1321.0
1000,cell_line_b,batch_18,drug1,0.333333,False,cell_line_b_batch_18,,cell_line_b_batch_18_drug1_0.3333333333333333,6677.0
...,...,...,...,...,...,...,...,...,...
9995,cell_line_k,batch_24,drug4,1.000000,False,cell_line_k_batch_24,,cell_line_k_batch_24_drug4_1.0,3987.0
9996,cell_line_k,batch_8,drug4,1.000000,False,cell_line_k_batch_8,,cell_line_k_batch_8_drug4_1.0,2664.0
9997,cell_line_k,batch_18,drug4,1.000000,False,cell_line_k_batch_18,,cell_line_k_batch_18_drug4_1.0,10525.0
9998,cell_line_k,batch_20,drug4,1.000000,False,cell_line_k_batch_20,,cell_line_k_batch_20_drug4_1.0,8039.0


In [None]:
import numpy as np
import pandas as pd

# Assume df is a copy of orig_df (and control is a boolean column)
df = orig_df.copy()

# Create structured array for control rows:
control_mask = df['control'].values
# Build a structured array for 'cell_type' and 'batch'
control_keys = np.empty(len(df), dtype=[('cell_type', df['cell_type'].dtype),
                                          ('batch', df['batch'].dtype)])
control_keys['cell_type'] = df['cell_type'].values
control_keys['batch'] = df['batch'].values

# Compute unique keys and group IDs for control rows:
unique_control, split_ids = np.unique(control_keys[control_mask], return_inverse=True)
split_covariates_mask = -1 * np.ones(len(df), dtype=np.int32)
split_covariates_mask[control_mask] = split_ids

# For perturbation rows, create a structured array with 4 fields:
pert_mask = ~control_mask
pert_keys = np.empty(len(df), dtype=[('cell_type', df['cell_type'].dtype),
                                     ('batch', df['batch'].dtype),
                                     ('drug', df['drug'].dtype),
                                     ('dosage', df['dosage'].dtype)])
pert_keys['cell_type'] = df['cell_type'].values
pert_keys['batch'] = df['batch'].values
pert_keys['drug'] = df['drug'].values
pert_keys['dosage'] = df['dosage'].values

# Compute unique keys and group IDs for perturbation rows:
unique_pert, pert_ids = np.unique(pert_keys[pert_mask], return_inverse=True)
perturbation_covariates_mask = -1 * np.ones(len(df), dtype=np.int32)
perturbation_covariates_mask[pert_mask] = pert_ids

# Assign the new columns back to the DataFrame:
df['split_covariates_mask'] = split_covariates_mask
df['perturbation_covariates_mask'] = perturbation_covariates_mask


TypeError: Cannot interpret 'CategoricalDtype(categories=['cell_line_a', 'cell_line_b', 'cell_line_c', 'cell_line_d',
                  'cell_line_e', 'cell_line_f', 'cell_line_g', 'cell_line_h',
                  'cell_line_i', 'cell_line_j', 'cell_line_k', 'cell_line_l',
                  'cell_line_m', 'cell_line_n', 'cell_line_o', 'cell_line_p',
                  'cell_line_q', 'cell_line_r', 'cell_line_s', 'cell_line_t',
                  'cell_line_u', 'cell_line_v', 'cell_line_w', 'cell_line_x',
                  'cell_line_y', 'cell_line_z', 'cell_line_{', 'cell_line_|',
                  'cell_line_}', 'cell_line_~'],
, ordered=False, categories_dtype=object)' as a data type

In [None]:
# works for both
# Start with a fresh copy
df = orig_df.copy()

# Initialize with -1 using assign
df = df.assign(
    split_covariates_mask=-1,
    perturbation_covariates_mask=-1
)

# First get unique split combinations
split_combinations = (df[['cell_type', 'batch']]
                     .drop_duplicates()
                     .reset_index(drop=True))

# Initialize counters
split_counter = 0
pert_counter = 0
split_id_map = {}
pert_id_map = {}

# Process each split combination
for _, split_comb in split_combinations.iterrows():
    # Get mask for this split combination
    split_mask = (
        (df['cell_type'] == split_comb['cell_type']) & 
        (df['batch'] == split_comb['batch'])
    )
    
    # Update split mask for control cells in this split
    control_mask = split_mask & df['control']
    if control_mask.any():
        split_id_map[split_counter] = (split_comb['cell_type'], split_comb['batch'])
        df.loc[control_mask, 'split_covariates_mask'] = split_counter
    
    # Get perturbations for this split
    pert_mask = split_mask & ~df['control']
    if pert_mask.any():
        # Get unique perturbations in this split
        perts = (df[pert_mask][['cell_type', 'batch', 'drug', 'dosage']]
                 .drop_duplicates()
                 .reset_index(drop=True))
        
        # Process each perturbation
        for _, pert in perts.iterrows():
            # Get mask for this perturbation
            full_pert_mask = (
                (df['cell_type'] == pert['cell_type']) & 
                (df['batch'] == pert['batch']) &
                (df['drug'] == pert['drug']) &
                (df['dosage'] == pert['dosage']) &
                ~df['control']
            )
            
            # Update perturbation mask
            pert_id_map[pert_counter] = tuple(pert)
            df.loc[full_pert_mask, 'perturbation_covariates_mask'] = pert_counter
            pert_counter += 1
    split_counter += 1
    

# Ensure integer types
df['split_covariates_mask'] = df['split_covariates_mask'].astype('int32')
df['perturbation_covariates_mask'] = df['perturbation_covariates_mask'].astype('int32')

In [None]:
# Start with a fresh copy
df = orig_df.copy()

# Initialize with -1 using assign
df = df.assign(
    split_covariates_mask=-1,
    perturbation_covariates_mask=-1
)

# Get unique split combinations and assign IDs
split_combinations = (df[['cell_type', 'batch']]
                     .drop_duplicates()
                     .reset_index(drop=True)
                     .reset_index()
                     .rename(columns={'index': 'split_id'}))


# Merge split IDs back to main DataFrame - but only for control cells
df = (df.merge(split_combinations, on=['cell_type', 'batch'], how='left')
      .assign(split_covariates_mask=lambda x: 
              np.where(x['control'], x['split_id'], -1))
      .drop(columns=['split_id']))



# Ensure integer types
df['split_covariates_mask'] = df['split_covariates_mask'].astype('int32')

# Create the mapping dictionaries if needed
split_id_map = (split_combinations
                .set_index('split_id')[['cell_type', 'batch']]
                .apply(tuple, axis=1)
                .to_dict())

assert (df['split_covariates_mask'].values == condition_data.split_covariates_mask).all()

In [None]:
# works for both
# Start with a fresh copy
df = orig_df.copy()

# Initialize with -1 using assign
df = df.assign(
    perturbation_covariates_mask=-1
)

# First get unique split combinations
split_combinations = (df[['cell_type', 'batch']]
                     .drop_duplicates()
                     .reset_index(drop=True))

# Initialize counters
pert_counter = 0
pert_id_map = {}

# Process each split combination
for _, split_comb in split_combinations.iterrows():
    # Get mask for this split combination
    split_mask = (
        (df['cell_type'] == split_comb['cell_type']) & 
        (df['batch'] == split_comb['batch'])
    )

    # Get perturbations for this split
    pert_mask = split_mask & ~df['control']
    if pert_mask.any():
        # Get unique perturbations in this split
        perts = (df[pert_mask][['cell_type', 'batch', 'drug', 'dosage']]
                 .drop_duplicates()
                 .reset_index(drop=True))
        # Process each perturbation
        for _, pert in perts.iterrows():
            # Get mask for this perturbation
            full_pert_mask = (
                (df['cell_type'] == pert['cell_type']) & 
                (df['batch'] == pert['batch']) &
                (df['drug'] == pert['drug']) &
                (df['dosage'] == pert['dosage']) &
                ~df['control']
            )
            # Update perturbation mask
            pert_id_map[pert_counter] = tuple(pert)
            df.loc[full_pert_mask, 'perturbation_covariates_mask'] = pert_counter
            pert_counter += 1


df['perturbation_covariates_mask'] = df['perturbation_covariates_mask'].astype('int32')
assert (df['perturbation_covariates_mask'].values == condition_data.perturbation_covariates_mask).all()

In [None]:
# Start with a fresh copy
df = orig_df.copy()

# Initialize with -1 using assign
df = df.assign(
    perturbation_covariates_mask=-1
)

# First get unique split combinations with their order
split_combinations = (df[['cell_type', 'batch']]
                     .drop_duplicates()
                     .reset_index(drop=True)
                     .reset_index()
                     .rename(columns={'index': 'split_order'}))

# Get all unique perturbation combinations with their split info
all_perts = (df[~df['control']]
             [['cell_type', 'batch', 'drug', 'dosage']]
             .drop_duplicates())

# Merge with split order to maintain the same ordering as original
ordered_perts = (all_perts
                .merge(split_combinations, on=['cell_type', 'batch'])
                .sort_values('split_order')
                .reset_index(drop=True)
                .reset_index()
                .rename(columns={'index': 'pert_id'})
                .drop(columns=['split_order']))

# Assign perturbation masks using a single merge operation
df = (df.merge(ordered_perts, 
               on=['cell_type', 'batch', 'drug', 'dosage'],
               how='left')
      .assign(perturbation_covariates_mask=lambda x: 
              x['pert_id'].where(~x['control'], -1))
      .drop(columns=['pert_id']))

# Ensure integer type
df['perturbation_covariates_mask'] = df['perturbation_covariates_mask'].astype('int32')
if not (df['perturbation_covariates_mask'].values == condition_data.perturbation_covariates_mask).all():
    the_diff = (df['perturbation_covariates_mask'].values == condition_data.perturbation_covariates_mask).astype(int)
    print(the_diff)
    print(np.where(the_diff == 0))
    print(df['perturbation_covariates_mask'].values)
    print(condition_data.perturbation_covariates_mask)


[1 1 1 ... 0 0 1]
(array([   41,    42,    43, ..., 28196, 28197, 28198]),)
[   -1    -1    -1 ... 12871 13200 12901]
[   -1    -1    -1 ... 12873 13201 12901]


In [None]:
# Start with a fresh copy
df = orig_df.copy()

# Initialize with -1 using assign
df = df.assign(
    perturbation_covariates_mask=-1
)

# First get unique split combinations in original order
split_combinations = (df[['cell_type', 'batch']]
                     .drop_duplicates()
                     .reset_index(drop=True))

# Process perturbations split by split but without loops
pert_list = []
for _, split_comb in split_combinations.iterrows():  # This loop is cheap - just over unique splits
    # Get perturbations for this split in their original order
    split_perts = (df[
        (df['cell_type'] == split_comb['cell_type']) & 
        (df['batch'] == split_comb['batch']) &
        ~df['control']
    ][['cell_type', 'batch', 'drug', 'dosage']]
     .drop_duplicates())
    
    if not split_perts.empty:
        pert_list.append(split_perts)

# Combine all perturbations maintaining order
ordered_perts = (pd.concat(pert_list, ignore_index=True)
                .reset_index(drop=True)
                .reset_index()
                .rename(columns={'index': 'pert_id'}))

# Assign perturbation masks using a single merge operation
df = (df.merge(ordered_perts, 
               on=['cell_type', 'batch', 'drug', 'dosage'],
               how='left')
      .assign(perturbation_covariates_mask=lambda x: 
              x['pert_id'].where(~x['control'], -1))
      .drop(columns=['pert_id']))

# Ensure integer type
df['perturbation_covariates_mask'] = df['perturbation_covariates_mask'].astype('int32')
assert (df['perturbation_covariates_mask'].values == condition_data.perturbation_covariates_mask).all()

In [None]:
import dask.dataframe as dd

# Start with a fresh copy and convert to Dask DataFrame
df = dd.from_pandas(orig_df, npartitions=4)  # adjust npartitions based on your data size

# Initialize with -1 using assign
df = df.assign(
    perturbation_covariates_mask=-1
)

# First get unique split combinations in original order
# Compute this since it's small and we need it for the loop
split_combinations = (df[['cell_type', 'batch']]
                     .drop_duplicates()
                     .compute()
                     .reset_index(drop=True))

# Process perturbations split by split
pert_list = []
for _, split_comb in split_combinations.iterrows():  # Still cheap - just over unique splits
    # Get perturbations for this split in their original order
    split_perts = (df[
        (df['cell_type'] == split_comb['cell_type']) & 
        (df['batch'] == split_comb['batch']) &
        ~df['control']
    ][['cell_type', 'batch', 'drug', 'dosage']]
     .drop_duplicates()
     .compute())  # compute here since we need it for the list
    
    if not split_perts.empty:
        pert_list.append(split_perts)

# Combine all perturbations maintaining order
ordered_perts = (dd.from_pandas(
    pd.concat(pert_list, ignore_index=True)
    .reset_index(drop=True)
    .reset_index()
    .rename(columns={'index': 'pert_id'}),
    npartitions=1))  # small DataFrame, one partition is fine

# Assign perturbation masks using a single merge operation
df = (df.merge(ordered_perts, 
               on=['cell_type', 'batch', 'drug', 'dosage'],
               how='left')
      .assign(perturbation_covariates_mask=lambda x: 
              x['pert_id'].where(~x['control'], -1))
      .drop(columns=['pert_id']))

# Ensure integer type
df['perturbation_covariates_mask'] = df['perturbation_covariates_mask'].astype('int32')

# If you need to verify against condition_data
result = df['perturbation_covariates_mask'].compute()
assert (result.values == condition_data.perturbation_covariates_mask).all()

AssertionError: 

In [None]:
import dask

In [None]:
condition_data.perturbation_covariates_mask

Array([   -1,    -1,    -1, ..., 12873, 13201, 12901], dtype=int32)

In [None]:
from dask.distributed import Client
client = Client()  # start distributed scheduler locally.

Perhaps you already have a cluster running?
Hosting the HTTP server on port 50158 instead


In [None]:
# Start with a fresh copy
df = orig_df.copy()

# Initialize with -1 using assign
df = df.assign(
    split_covariates_mask=-1,
    perturbation_covariates_mask=-1
)
df['split_covariates_mask'] = df['split_covariates_mask'].astype('int32')
# Get unique split combinations
split_combinations = (df[['cell_type', 'batch']]
                     .drop_duplicates()
                     .reset_index(drop=True)
                     .reset_index()
                     .rename(columns={'index': 'split_id'}))



df = (df.merge(split_combinations, on=['cell_type', 'batch'], how='left')
      .assign(split_covariates_mask=lambda x: 
              x['split_id'].where(x['control'], -1))
      .drop(columns=['split_id']))


ddf = dd.from_pandas(df, npartitions=1)
pert_list = []
for _, split_comb in split_combinations.iterrows():
    # Get perturbations for this split in their original order
    split_perts = (ddf[
        (ddf['cell_type'] == split_comb['cell_type']) &
        (ddf['batch'] == split_comb['batch']) &
        ~ddf['control']
    ][['cell_type', 'batch', 'drug', 'dosage']]
     .drop_duplicates())
    pert_list.append(split_perts)

pert_list = dask.compute(*pert_list)
# # filter out empty ones
pert_list = [x for x in pert_list if not x.empty]

# Combine all perturbations maintaining order
ordered_perts = (pd.concat(pert_list, ignore_index=True)
                .reset_index(drop=True)
                .reset_index()
                .rename(columns={'index': 'pert_id'}))
# Assign perturbation masks using a single merge operationb
df = (df.merge(ordered_perts,
               on=['cell_type', 'batch', 'drug', 'dosage'],
               how='left')
      .assign(perturbation_covariates_mask=lambda x: 
              x['pert_id'].where(~x['control'], -1).astype('int32'))
      .drop(columns=['pert_id']))


assert (df['perturbation_covariates_mask'].values == condition_data.perturbation_covariates_mask).all()
assert (df['split_covariates_mask'].values == condition_data.split_covariates_mask).all()

In [None]:
df

Unnamed: 0,cell_type,batch,drug,dosage,control,split_covariates_mask,perturbation_covariates_mask
0,cell_line_a,batch_13,control,0.0,True,0,-1
1,cell_line_a,batch_16,control,0.0,True,1,-1
2,cell_line_a,batch_9,control,0.0,True,2,-1
3,cell_line_a,batch_2,control,0.0,True,3,-1
4,cell_line_a,batch_7,control,0.0,True,4,-1
...,...,...,...,...,...,...,...
28195,cell_line_~,batch_25,drug6,1.0,False,-1,12906
28196,cell_line_~,batch_25,drug6,1.0,False,-1,12906
28197,cell_line_~,batch_11,drug6,1.0,False,-1,12846
28198,cell_line_~,batch_8,drug6,1.0,False,-1,12804


In [None]:
# Start with a fresh copy
df = orig_df.copy()

# Initialize with -1 using assign
df = df.assign(
    split_covariates_mask=-1,
    perturbation_covariates_mask=-1
)

# Get unique split combinations
split_combinations = (df[['cell_type', 'batch']]
                     .drop_duplicates()
                     .reset_index(drop=True)
                     .reset_index()
                     .rename(columns={'index': 'split_id'}))

# Create split_idx_to_covariates mapping
split_idx_to_covariates = (split_combinations
                          .set_index('split_id')
                          [['cell_type', 'batch']]
                          .apply(tuple, axis=1)
                          .to_dict())

df = (df.merge(split_combinations, on=['cell_type', 'batch'], how='left')
      .assign(split_covariates_mask=lambda x: 
              x['split_id'].where(x['control'], -1))
      .drop(columns=['split_id']))

ddf = dd.from_pandas(df, npartitions=1)
pert_list = []
for _, split_comb in split_combinations.iterrows():
    split_perts = (ddf[
        (ddf['cell_type'] == split_comb['cell_type']) &
        (ddf['batch'] == split_comb['batch']) &
        ~ddf['control']
    ][['cell_type', 'batch', 'drug', 'dosage']]
     .drop_duplicates())
    pert_list.append(split_perts)

pert_list = dask.compute(*pert_list)
pert_list = [x for x in pert_list if not x.empty]

# Combine all perturbations maintaining order
ordered_perts = (pd.concat(pert_list, ignore_index=True)
                .reset_index(drop=True)
                .reset_index()
                .rename(columns={'index': 'pert_id'}))

# Create perturbation_idx_to_covariates mapping
perturbation_idx_to_covariates = (ordered_perts
                                 .set_index('pert_id')
                                 [['cell_type', 'batch', 'drug', 'dosage']]
                                 .apply(tuple, axis=1)
                                 .to_dict())

# Create control_to_perturbation mapping
control_to_perturbation = {}
for split_id in split_idx_to_covariates.keys():
    cell_type, batch = split_idx_to_covariates[split_id]
    # Get perturbation IDs for this split
    matching_perts = ordered_perts[
        (ordered_perts['cell_type'] == cell_type) &
        (ordered_perts['batch'] == batch)
    ]['pert_id'].tolist()
    if matching_perts:
        control_to_perturbation[split_id] = matching_perts

# Assign perturbation masks using a single merge operation
df = (df.merge(ordered_perts,
               on=['cell_type', 'batch', 'drug', 'dosage'],
               how='left')
      .assign(perturbation_covariates_mask=lambda x: 
              x['pert_id'].where(~x['control'], -1))
      .drop(columns=['pert_id']))

# Verify masks
assert (df['perturbation_covariates_mask'].values == condition_data.perturbation_covariates_mask).all()
assert (df['split_covariates_mask'].values == condition_data.split_covariates_mask).all()

# Return all the mappings along with the DataFrame
a= {
    'df': df,
    'split_idx_to_covariates': split_idx_to_covariates,
    'perturbation_idx_to_covariates': perturbation_idx_to_covariates,
    'control_to_perturbation': control_to_perturbation
}

In [None]:
ddf = dd.from_pandas(orig_df, npartitions=2)

# Get unique combinations for splits (control cells)
split_combinations = (ddf[ddf['control']]
                    [['cell_type', 'batch']]
                    .drop_duplicates()
                    .compute()
                    .reset_index(drop=True))

# Create split mapping
split_idx_to_covariates = dict(enumerate(
    split_combinations[['cell_type', 'batch']].itertuples(index=False, name=None)
))

# Create a mapping DataFrame for splits
split_map_df = pd.DataFrame({
    'cell_type': [x[0] for x in split_idx_to_covariates.values()],
    'batch': [x[1] for x in split_idx_to_covariates.values()],
    'split_id': list(split_idx_to_covariates.keys())
})

# Convert to Dask
split_map_ddf = dd.from_pandas(split_map_df, npartitions=1)

# Merge to assign split IDs
ddf = ddf.merge(split_map_ddf, on=['cell_type', 'batch'], how='left')

# Get unique perturbation combinations
pert_combinations = (ddf[~ddf['control']]
                    [['cell_type', 'batch', 'drug', 'dosage']]
                    .drop_duplicates()
                    .compute()
                    .reset_index(drop=True)
                    .reset_index()
                    .rename(columns={'index': 'pert_id'}))

# Create perturbation mapping
perturbation_idx_to_covariates = dict(enumerate(
    pert_combinations[['cell_type', 'batch', 'drug', 'dosage']].itertuples(index=False, name=None)
))

# Create a mapping DataFrame for perturbations
pert_map_df = pert_combinations.copy()
pert_map_ddf = dd.from_pandas(pert_map_df, npartitions=1)

# Merge to assign perturbation IDs
ddf = ddf.merge(pert_map_ddf, 
                on=['cell_type', 'batch', 'drug', 'dosage'], 
                how='left')

# Assign masks using map_partitions
def assign_masks(df):
    df['split_covariates_mask'] = np.where(df['control'], 
                                            df['split_id'], 
                                            -1)
    df['perturbation_covariates_mask'] = np.where(~df['control'], 
                                                    df['pert_id'], 
                                                    -1)
    return df[['cell_type', 'batch', 'drug', 'dosage', 'control', 
                'split_covariates_mask', 'perturbation_covariates_mask']]

result = ddf.map_partitions(assign_masks)

# Create control_to_perturbation mapping
control_to_perturbation = (pert_combinations.groupby(['cell_type', 'batch'])['pert_id']
                            .agg(list)
                            .to_dict())

# Compute final result
df = result.compute()

# Ensure integer types
df['split_covariates_mask'] = df['split_covariates_mask'].astype('int32')
df['perturbation_covariates_mask'] = df['perturbation_covariates_mask'].astype('int32')


+----------------------------+------------+-------------+
| Merge columns              | left dtype | right dtype |
+----------------------------+------------+-------------+
| ('cell_type', 'cell_type') | category   | string      |
| ('batch', 'batch')         | category   | string      |
+----------------------------+------------+-------------+
Cast dtypes explicitly to avoid unexpected results.
+----------------------------+------------+-------------+
| Merge columns              | left dtype | right dtype |
+----------------------------+------------+-------------+
| ('cell_type', 'cell_type') | object     | string      |
| ('batch', 'batch')         | object     | string      |
+----------------------------+------------+-------------+
Cast dtypes explicitly to avoid unexpected results.


ValueError: Metadata inference failed in `assign_masks`.

You have supplied a custom function and Dask is unable to 
determine the type of output that that function returns. 

To resolve this please provide a meta= keyword.
The docstring of the Dask function you ran should have more information.

Original error is below:
------------------------
TypeError('boolean value of NA is ambiguous')

Traceback:
---------
  File "/Users/selman.ozleyen/mambaforge/envs/moscot/lib/python3.11/site-packages/dask/dataframe/utils.py", line 195, in raise_on_meta_error
    yield
  File "/Users/selman.ozleyen/mambaforge/envs/moscot/lib/python3.11/site-packages/dask_expr/_expr.py", line 3988, in _emulate
    return func(*_extract_meta(args, True), **_extract_meta(kwargs, True))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/folders/w4/rlbyb2md7y50tspf85v1lc440000gn/T/ipykernel_47592/4114690479.py", line 53, in assign_masks
    df['split_covariates_mask'] = np.where(df['control'],
                                 ^^^^^^^^^^^^^^^^^^^^^^^
  File "missing.pyx", line 392, in pandas._libs.missing.NAType.__bool__


In [None]:
# Start with a fresh copy
df = orig_df.copy()

# Initialize with -1 using assign
df = df.assign(
    perturbation_covariates_mask=-1
)

# First get unique split combinations in original order
split_combinations = (df[['cell_type', 'batch']]
                     .drop_duplicates()
                     .reset_index(drop=True))

# Process perturbations split by split but without loops
pert_list = []
for _, split_comb in split_combinations.iterrows():  # This loop is cheap - just over unique splits
    # Get perturbations for this split in their original order
    split_perts = (df[
        (df['cell_type'] == split_comb['cell_type']) & 
        (df['batch'] == split_comb['batch']) &
        ~df['control']
    ][['cell_type', 'batch', 'drug', 'dosage']]
     .drop_duplicates())
    
    if not split_perts.empty:
        pert_list.append(split_perts)

# Combine all perturbations maintaining order
ordered_perts = (pd.concat(pert_list, ignore_index=True)
                .reset_index(drop=True)
                .reset_index()
                .rename(columns={'index': 'pert_id'}))

# Assign perturbation masks using a single merge operation
df = (df.merge(ordered_perts, 
               on=['cell_type', 'batch', 'drug', 'dosage'],
               how='left')
      .assign(perturbation_covariates_mask=lambda x: 
              x['pert_id'].where(~x['control'], -1))
      .drop(columns=['pert_id']))

# Ensure integer type
df['perturbation_covariates_mask'] = df['perturbation_covariates_mask'].astype('int32')
assert (df['perturbation_covariates_mask'].values == condition_data.perturbation_covariates_mask).all()

In [None]:
assert (df['perturbation_covariates_mask'].values == condition_data.perturbation_covariates_mask).all()


AssertionError: 

In [None]:
assert (df['split_covariates_mask'].values == condition_data.split_covariates_mask).all()


In [None]:
print("Sample of original data:")
print(orig_df[['cell_type', 'batch', 'drug', 'dosage', 'control']].head(10))
print("\nUnique values in masks:")
print("split_covariates_mask unique:", df['split_covariates_mask'].unique())
print("perturbation_covariates_mask unique:", df['perturbation_covariates_mask'].unique())
print("\nComparison with condition_data:")
print("Matches in perturbation mask:", 
      (df['perturbation_covariates_mask'].values == condition_data.perturbation_covariates_mask).mean())

Sample of original data:
     cell_type     batch     drug  dosage  control
0  cell_line_a  batch_15  control     0.0     True
1  cell_line_a  batch_23  control     0.0     True
2  cell_line_a  batch_11  control     0.0     True
3  cell_line_a  batch_14  control     0.0     True
4  cell_line_a  batch_13  control     0.0     True
5  cell_line_a  batch_29  control     0.0     True
6  cell_line_a   batch_2  control     0.0     True
7  cell_line_a  batch_17  control     0.0     True
8  cell_line_a  batch_29  control     0.0     True
9  cell_line_a  batch_26  control     0.0     True

Unique values in masks:
split_covariates_mask unique: [  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17
  18  19  20  21  22  23  24  -1  25  26  27  28  29  30  31  32  33  34
  35  36  37  38  39  40  41  42  43  44  45  46  47  48  49  50  51  52
  53  54  55  56  57  58  59  60  61  62  63  64  65  66  67  68  69  70
  71  72  73  74  75  76  77  78  79  80  81  82  83  84  85  86  8

In [None]:
print("First few rows comparison:")
print("Our mask:", df['perturbation_covariates_mask'].head(10))
print("Expected mask:", condition_data.perturbation_covariates_mask[:10])

print("\nShape check:")
print("Our mask shape:", df['perturbation_covariates_mask'].shape)
print("Expected mask shape:", condition_data.perturbation_covariates_mask.shape)

First few rows comparison:
Our mask: 0   -1
1   -1
2   -1
3   -1
4   -1
5   -1
6   -1
7   -1
8   -1
9   -1
Name: perturbation_covariates_mask, dtype: int32
Expected mask: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1]

Shape check:
Our mask shape: (28200,)
Expected mask shape: (28200,)


In [None]:
# Start with a fresh copy
df = orig_df.copy()

# Initialize with -1 using assign
df = df.assign(
    split_covariates_mask=-1,
    perturbation_covariates_mask=-1
)

# Get unique cell_type values to determine the base for IDs
cell_types = sorted(df['cell_type'].unique())
cell_type_bases = {ct: i * 30 for i, ct in enumerate(cell_types)}

# Process each cell type separately
for cell_type, base_id in cell_type_bases.items():
    # Get batches for this cell type
    cell_batches = (df[df['cell_type'] == cell_type]['batch']
                    .drop_duplicates()
                    .sort_values()
                    .values)
    
    # Assign IDs for this cell type's batches
    for i, batch in enumerate(cell_batches):
        mask = ((df['cell_type'] == cell_type) & 
                (df['batch'] == batch) & 
                df['control'])
        if mask.any():
            df.loc[mask, 'split_covariates_mask'] = base_id + i

# Ensure integer types
df['split_covariates_mask'] = df['split_covariates_mask'].astype('int32')

In [None]:
df['split_cov_comb'] = df[dm.split_covariates].astype(str).apply(lambda x: '_'.join(x), axis=1, meta=('split_cov_comb', 'str'))
df['pert_cov_comb'] = df[['split_cov_comb']+list(dm.perturbation_covariates.keys())].astype(str).apply(lambda x: '_'.join(x), axis=1, meta=('pert_cov_comb', 'str'))

In [None]:
# Start with a fresh copy
df = orig_df.copy()

# Initialize with -1 using assign
df = df.assign(
    split_covariates_mask=-1,
    perturbation_covariates_mask=-1
)

# Create mapping for split combinations in order of first appearance
split_id_map = {}
split_counter = 0

# Process rows in order
for idx, row in df.iterrows():
    if row['control']:
        # Create tuple of split values
        split_key = (row['cell_type'], row['batch'])
        
        # If we haven't seen this combination before, assign new ID
        if split_key not in split_id_map:
            split_id_map[split_key] = split_counter
            split_counter += 1
        
        # Assign the ID
        df.loc[idx, 'split_covariates_mask'] = split_id_map[split_key]

# Ensure integer types
df['split_covariates_mask'] = df['split_covariates_mask'].astype('int32')

In [None]:
# Show the first few split combinations and their IDs
print("Split combination to ID mapping:")
for (cell_type, batch), id_ in sorted(split_id_map.items(), key=lambda x: x[1])[:10]:
    print(f"{cell_type}, {batch} -> {id_}")

Split combination to ID mapping:
cell_line_a, batch_15 -> 0
cell_line_a, batch_23 -> 1
cell_line_a, batch_11 -> 2
cell_line_a, batch_14 -> 3
cell_line_a, batch_13 -> 4
cell_line_a, batch_29 -> 5
cell_line_a, batch_2 -> 6
cell_line_a, batch_17 -> 7
cell_line_a, batch_26 -> 8
cell_line_a, batch_5 -> 9


In [None]:
# Start with a fresh copy
df = orig_df.copy()

# Initialize masks with -1
df = df.assign(
    split_covariates_mask=-1,
    perturbation_covariates_mask=-1
)

# First handle split masks - use 30 as increment for each cell type
cell_types = sorted(df['cell_type'].unique())
base_ids = {ct: i * 30 for i, ct in enumerate(cell_types)}
split_counter = {}

# Assign split masks for control cells
for idx, row in df[df['control']].iterrows():
    cell_type = row['cell_type']
    if cell_type not in split_counter:
        split_counter[cell_type] = 0
        
    split_key = (cell_type, row['batch'])
    if split_key not in split_id_map:
        split_id_map[split_key] = base_ids[cell_type] + split_counter[cell_type]
        split_counter[cell_type] += 1
        
    df.loc[idx, 'split_covariates_mask'] = split_id_map[split_key]

# Now handle perturbation masks
# Group by cell_type and batch to match the original function's logic
perturb_groups = df[~df['control']].groupby(['cell_type', 'batch'])
tgt_counter = 0

for (cell_type, batch), group in perturb_groups:
    # Only process if we have control cells for this combination
    control_mask = df['control'] & (df['cell_type'] == cell_type) & (df['batch'] == batch)
    if control_mask.any():
        # Assign sequential IDs to all cells in this group
        df.loc[group.index, 'perturbation_covariates_mask'] = range(tgt_counter, tgt_counter + len(group))
        tgt_counter += len(group)

print("First few rows after assignment:")
print(df[['cell_type', 'batch', 'drug', 'dosage', 'control', 'split_covariates_mask', 'perturbation_covariates_mask']].head(10))

  perturb_groups = df[~df['control']].groupby(['cell_type', 'batch'])


First few rows after assignment:
     cell_type     batch     drug  dosage  control  split_covariates_mask  \
0  cell_line_a  batch_15  control     0.0     True                      0   
1  cell_line_a  batch_23  control     0.0     True                      1   
2  cell_line_a  batch_11  control     0.0     True                      2   
3  cell_line_a  batch_14  control     0.0     True                      3   
4  cell_line_a  batch_13  control     0.0     True                      4   
5  cell_line_a  batch_29  control     0.0     True                      5   
6  cell_line_a   batch_2  control     0.0     True                      6   
7  cell_line_a  batch_17  control     0.0     True                      7   
8  cell_line_a  batch_29  control     0.0     True                      5   
9  cell_line_a  batch_26  control     0.0     True                      8   

   perturbation_covariates_mask  
0                            -1  
1                            -1  
2                

In [None]:
split_covariates_mask.max()

899

In [None]:
split_covariates_mask

array([ -1,  -1,  -1, ..., 302, 327, 315], dtype=int16)

In [None]:
condition_data.split_covariates_mask.max()

Array(891, dtype=int32)

In [None]:
pert_covariates_mask.max()

13192

In [None]:
condition_data.perturbation_covariates_mask.max()

Array(13192, dtype=int32)

In [None]:
import numpy as np
df = df.assign(
    pert_cov_comb=lambda x: np.where(x.control, 'control', x.pert_cov_comb),
    split_cov_comb=lambda x: np.where(x.control, x.split_cov_comb, 'not_control')
)
df = df.categorize(columns=['split_cov_comb', 'pert_cov_comb'])

In [None]:
df.compute()

Unnamed: 0,cell_type,batch,drug,dosage,control,split_cov_comb,pert_cov_comb
0,cell_line_a,batch_18,control,0.000000,True,NotImplemented,NotImplemented
1,cell_line_a,batch_16,control,0.000000,True,NotImplemented,NotImplemented
10,cell_line_a,batch_17,control,0.000000,True,NotImplemented,NotImplemented
100,cell_line_a,batch_16,drug1,0.666667,False,NotImplemented,NotImplemented
1000,cell_line_b,batch_14,drug1,0.333333,False,NotImplemented,NotImplemented
...,...,...,...,...,...,...,...
9995,cell_line_k,batch_26,drug4,1.000000,False,NotImplemented,NotImplemented
9996,cell_line_k,batch_25,drug4,1.000000,False,NotImplemented,NotImplemented
9997,cell_line_k,batch_11,drug4,1.000000,False,NotImplemented,NotImplemented
9998,cell_line_k,batch_7,drug4,1.000000,False,NotImplemented,NotImplemented


In [None]:
df[df.control]['pert_cov_comb'] = 'control'
df[~df.control]['split_cov_comb'] = 'not_control'
df.categorize(columns=['split_cov_comb', 'pert_cov_comb'])

Unnamed: 0_level_0,cell_type,batch,drug,dosage,control,split_cov_comb,pert_cov_comb
npartitions=1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
0,category[known],category[known],category[known],float64,boolean,category[known],category[known]
9999,...,...,...,...,...,...,...


In [None]:
df[df.control].compute()

Unnamed: 0,cell_type,batch,drug,dosage,control,split_cov_comb,pert_cov_comb
0,cell_line_a,batch_18,control,0.0,True,cell_line_a_batch_18,cell_line_a_batch_18_control_0.0
1,cell_line_a,batch_16,control,0.0,True,cell_line_a_batch_16,cell_line_a_batch_16_control_0.0
10,cell_line_a,batch_17,control,0.0,True,cell_line_a_batch_17,cell_line_a_batch_17_control_0.0
10340,cell_line_l,batch_15,control,0.0,True,cell_line_l_batch_15,cell_line_l_batch_15_control_0.0
10341,cell_line_l,batch_23,control,0.0,True,cell_line_l_batch_23,cell_line_l_batch_23_control_0.0
...,...,...,...,...,...,...,...
975,cell_line_b,batch_10,control,0.0,True,cell_line_b_batch_10,cell_line_b_batch_10_control_0.0
976,cell_line_b,batch_26,control,0.0,True,cell_line_b_batch_26,cell_line_b_batch_26_control_0.0
977,cell_line_b,batch_2,control,0.0,True,cell_line_b_batch_2,cell_line_b_batch_2_control_0.0
978,cell_line_b,batch_30,control,0.0,True,cell_line_b_batch_30,cell_line_b_batch_30_control_0.0


In [None]:
list(df['pert_cov_comb'].cat.categories)

['cell_line_a_batch_10_control_0.0',
 'cell_line_a_batch_10_drug1_0.3333333333333333',
 'cell_line_a_batch_10_drug1_0.6666666666666666',
 'cell_line_a_batch_10_drug1_1.0',
 'cell_line_a_batch_10_drug2_0.3333333333333333',
 'cell_line_a_batch_10_drug2_1.0',
 'cell_line_a_batch_10_drug3_0.3333333333333333',
 'cell_line_a_batch_10_drug3_0.6666666666666666',
 'cell_line_a_batch_10_drug3_1.0',
 'cell_line_a_batch_10_drug4_0.6666666666666666',
 'cell_line_a_batch_10_drug4_1.0',
 'cell_line_a_batch_10_drug5_0.3333333333333333',
 'cell_line_a_batch_10_drug5_0.6666666666666666',
 'cell_line_a_batch_10_drug5_1.0',
 'cell_line_a_batch_10_drug6_0.3333333333333333',
 'cell_line_a_batch_10_drug6_0.6666666666666666',
 'cell_line_a_batch_10_drug6_1.0',
 'cell_line_a_batch_11_drug1_0.3333333333333333',
 'cell_line_a_batch_11_drug1_0.6666666666666666',
 'cell_line_a_batch_11_drug1_1.0',
 'cell_line_a_batch_11_drug2_0.3333333333333333',
 'cell_line_a_batch_11_drug2_0.6666666666666666',
 'cell_line_a_batc

In [None]:
# reassing codes so that control is -1
(df['pert_cov_comb'].cat.categories == 'control').sum()


0

In [None]:
df_split_cov_comb = dd.from_pandas(orig_df[dm.split_covariates+["control"]], npartitions=1)
# split_cov_comb.astype('category')

In [None]:
df_split_cov_comb['split_cov_comb'] = df_split_cov_comb[dm.split_covariates].astype(str).apply(lambda x: '_'.join(x), axis=1, meta=('split_cov_comb', 'str'))

In [None]:
df

Unnamed: 0,cell_type,batch,control,split_cov_comb
0,cell_line_a,batch_18,True,cell_line_a_batch_18
1,cell_line_a,batch_16,True,cell_line_a_batch_16
10,cell_line_a,batch_17,True,cell_line_a_batch_17
100,cell_line_a,batch_16,False,cell_line_a_batch_16
1000,cell_line_b,batch_14,False,cell_line_b_batch_14
...,...,...,...,...
9995,cell_line_k,batch_26,False,cell_line_k_batch_26
9996,cell_line_k,batch_25,False,cell_line_k_batch_25
9997,cell_line_k,batch_11,False,cell_line_k_batch_11
9998,cell_line_k,batch_7,False,cell_line_k_batch_7


In [None]:
df_split_cov_comb.compute()

Unnamed: 0,cell_type,batch,drug,dosage,control
0,cell_line_a,batch_18,control,0.0,True
1,cell_line_a,batch_16,control,0.0,True
10,cell_line_a,batch_17,control,0.0,True
10340,cell_line_l,batch_15,control,0.0,True
10341,cell_line_l,batch_23,control,0.0,True
...,...,...,...,...,...
975,cell_line_b,batch_10,control,0.0,True
976,cell_line_b,batch_26,control,0.0,True
977,cell_line_b,batch_2,control,0.0,True
978,cell_line_b,batch_30,control,0.0,True


In [None]:
df_split_cov_comb.compute()

Unnamed: 0,cell_type,batch,drug,dosage,control
0,cell_line_a,batch_18,control,0.0,True
1,cell_line_a,batch_16,control,0.0,True
10,cell_line_a,batch_17,control,0.0,True
10340,cell_line_l,batch_15,control,0.0,True
10341,cell_line_l,batch_23,control,0.0,True
...,...,...,...,...,...
975,cell_line_b,batch_10,control,0.0,True
976,cell_line_b,batch_26,control,0.0,True
977,cell_line_b,batch_2,control,0.0,True
978,cell_line_b,batch_30,control,0.0,True


In [None]:
df.loc[df.control, 'split_cov_comb'] = df.loc[df.control, dm.split_covariates].astype(str).apply(lambda x: '_'.join(x), axis=1)
df.loc[~df.control, 'pert_cov_comb'] = df.loc[~df.control, dm.split_covariates+list(dm.perturbation_covariates.keys())].astype(str).apply(lambda x: '_'.join(x), axis=1)

pert_cov_combs_len = len(df.loc[~df.control, 'pert_cov_comb'].drop_duplicates())
split_cov_combs_len = len(df['split_cov_comb'].drop_duplicates())

df.loc[~df.control, 'split_cov_comb'] = 'not_control'
df.loc[df.control, 'pert_cov_comb'] = 'control'
# cast to categorical
df['split_cov_comb'] = df['split_cov_comb'].astype('category')
df['pert_cov_comb'] = df['pert_cov_comb'].astype('category')

# get the order of the categories
split_cat_order = df['split_cov_comb'].cat.categories
# put control at the end
split_cat_order = np.concatenate([split_cat_order[split_cat_order != 'control'], ['control']])

pert_cat_order = df['pert_cov_comb'].cat.categories
# put control at the end
pert_cat_order = np.concatenate([pert_cat_order[pert_cat_order != 'control'], ['control']])



A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df.loc[df.control, 'split_cov_comb'] = df.loc[df.control, dm.split_covariates].astype(str).apply(lambda x: '_'.join(x), axis=1)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df.loc[~df.control, 'pert_cov_comb'] = df.loc[~df.control, dm.split_covariates+list(dm.perturbation_covariates.keys())].astype(str).apply(lambda x: '_'.join(x), axis=1)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pan

In [None]:
df['pert_cov_comb'].cat.codes[df['pert_cov_comb'] == 'control']

0        13193
1        13193
2        13193
3        13193
4        13193
         ...  
27295    13193
27296    13193
27297    13193
27298    13193
27299    13193
Length: 1200, dtype: int16

Unnamed: 0,cell_type,batch,drug,dosage,control,split_cov_comb,pert_cov_comb
0,cell_line_a,batch_18,control,0.0,True,cell_line_a_batch_18,control
1,cell_line_a,batch_16,control,0.0,True,cell_line_a_batch_16,control
2,cell_line_a,batch_27,control,0.0,True,cell_line_a_batch_27,control
3,cell_line_a,batch_1,control,0.0,True,cell_line_a_batch_1,control
4,cell_line_a,batch_17,control,0.0,True,cell_line_a_batch_17,control
...,...,...,...,...,...,...,...
28195,cell_line_~,batch_7,drug6,1.0,False,not_control,cell_line_~_batch_7_drug6_1.0
28196,cell_line_~,batch_30,drug6,1.0,False,not_control,cell_line_~_batch_30_drug6_1.0
28197,cell_line_~,batch_17,drug6,1.0,False,not_control,cell_line_~_batch_17_drug6_1.0
28198,cell_line_~,batch_18,drug6,1.0,False,not_control,cell_line_~_batch_18_drug6_1.0


In [None]:
df['split_cov_comb'].drop_duplicates()

0        cell_line_a_batch_18
1        cell_line_a_batch_16
2        cell_line_a_batch_27
3         cell_line_a_batch_1
4        cell_line_a_batch_17
                 ...         
27287    cell_line_~_batch_30
27291    cell_line_~_batch_29
27295    cell_line_~_batch_24
27297     cell_line_~_batch_4
27298     cell_line_~_batch_2
Name: split_cov_comb, Length: 669, dtype: object

In [None]:
np.unique(condition_data.split_covariates_mask)

array([ -1,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,
        12,  13,  14,  15,  16,  17,  18,  19,  20,  21,  30,  31,  32,
        33,  34,  35,  36,  37,  38,  39,  40,  41,  42,  43,  44,  45,
        46,  47,  48,  49,  50,  51,  52,  60,  61,  62,  63,  64,  65,
        66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,  78,
        79,  80,  81,  90,  91,  92,  93,  94,  95,  96,  97,  98,  99,
       100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112,
       113, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131,
       132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 150, 151,
       152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164,
       165, 166, 167, 168, 169, 170, 171, 172, 180, 181, 182, 183, 184,
       185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
       198, 199, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220,
       221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 24

In [None]:
condition_data.split_covariates_mask

Array(891, dtype=int32)

In [None]:
# First ensure we have -1 as integers in both mask columns
df = df.assign(
    split_covariates_mask=-1,
    perturbation_covariates_mask=-1
)

# Create mappings
split_groups = (df[df['control']]
                [['cell_type', 'batch']]
                .drop_duplicates()
                .reset_index(drop=True)
                .assign(split_id=lambda x: x.index))

pert_groups = (df[~df['control']]
               [['cell_type', 'batch', 'drug', 'dosage']]
               .drop_duplicates()
               .reset_index(drop=True)
               .assign(pert_id=lambda x: x.index))

# Update the masks using merge
df = (df.merge(split_groups, 
               on=['cell_type', 'batch'],
               how='left')
        .merge(pert_groups,
               on=['cell_type', 'batch', 'drug', 'dosage'],
               how='left'))

# Fill the masks based on control status
df['split_covariates_mask'] = df.apply(
    lambda row: row['split_id'] if row['control'] else -1,
    axis=1
).astype('int32')

df['perturbation_covariates_mask'] = df.apply(
    lambda row: row['pert_id'] if not row['control'] else -1,
    axis=1
).astype('int32')

# Drop the temporary columns
df = df.drop(['split_id', 'pert_id'], axis=1)

In [None]:
(df['perturbation_covariates_mask'].values == condition_data.perturbation_covariates_mask)

Array([ True,  True,  True, ..., False, False, False], dtype=bool)

In [None]:
(df['split_covariates_mask'].values == condition_data.split_covariates_mask)

Array([ True,  True,  True, ...,  True,  True,  True], dtype=bool)

In [None]:
split_covariates = dm.split_covariates
perturb_covar_keys = list(dm.perturbation_covariates.keys())
# First, create unique IDs for split combinations
split_groups = (df[split_covariates]
                 .drop_duplicates()
                 .reset_index(drop=True)
                 .reset_index()
                 .rename(columns={'index': 'split_id'}))

# Create unique IDs for perturbation combinations within each split
pert_groups = (df[~df.control]
                [split_covariates + perturb_covar_keys]
                .drop_duplicates()
                .reset_index(drop=True)
                .reset_index()
                .rename(columns={'index': 'pert_id'}))

# Merge split IDs back to main dataframe (only for control cells)
df_with_splits = (df[df.control]
                   .merge(split_groups, 
                         on=split_covariates, 
                         how='left')
                   .assign(split_covariates_mask=lambda x: x.split_id)
                   .drop(columns=['split_id']))

# Merge perturbation IDs back (only for non-control cells)
df_with_perts = (df[~df.control]
                   .merge(pert_groups,
                         on=split_covariates + perturb_covar_keys,
                         how='left')
                   .assign(perturbation_covariates_mask=lambda x: x.pert_id)
                   .drop(columns=['pert_id']))

# Combine control and non-control rows
df_final = (pd.concat([df_with_splits, df_with_perts])
             .sort_index())
df_final = df_final.fillna(-1)
# Convert mask columns to int32
df_final[['split_covariates_mask', 'perturbation_covariates_mask']] = (
    df_final[['split_covariates_mask', 'perturbation_covariates_mask']]
    .astype('int32')
)

TypeError: Cannot setitem on a Categorical with a new category (-1), set the categories first

In [None]:
len(np.unique(condition_data.split_covariates_mask))

676

In [None]:
len(split_cov_combs)

900

In [None]:
import dask.dataframe as dd
n=1
# Convert to Dask DataFrame
ddf = dd.from_pandas(df, npartitions=n)  # n depends on your data size

# Initialize columns with proper types
ddf = ddf.assign(
    split_covariates_mask=-1,
    perturbation_covariates_mask=-1
).astype({
    'split_covariates_mask': 'int32',
    'perturbation_covariates_mask': 'int32'
})

# Create the groupings (these will be smaller and can be computed)
split_groups = (ddf[split_covariates]
                  .drop_duplicates()
                  .reset_index(drop=True)
                  .reset_index()
                  .rename(columns={'index': 'split_id'})
                  .compute())

pert_groups = (ddf[~ddf.control]
                 [split_covariates + perturb_covar_keys]
                 .drop_duplicates()
                 .reset_index(drop=True)
                 .reset_index()
                 .rename(columns={'index': 'pert_id'})
                 .compute())

# Update masks using map_partitions
def update_masks(df_partition, split_groups, pert_groups):
    # Update split masks for control cells
    control_updates = (df_partition[df_partition.control]
                        .merge(split_groups, 
                              on=split_covariates, 
                              how='left'))
    df_partition.loc[df_partition.control, 'split_covariates_mask'] = control_updates['split_id']
    
    # Update perturbation masks for non-control cells
    pert_updates = (df_partition[~df_partition.control]
                     .merge(pert_groups,
                           on=split_covariates + perturb_covar_keys,
                           how='left'))
    df_partition.loc[~df_partition.control, 'perturbation_covariates_mask'] = pert_updates['pert_id']
    
    return df_partition

ddf = ddf.map_partitions(
    update_masks, 
    split_groups=split_groups, 
    pert_groups=pert_groups
)

  df_partition.loc[~df_partition.control, 'perturbation_covariates_mask'] = pert_updates['pert_id']


In [None]:
ddf.compute()

Unnamed: 0,cell_type,batch,drug,dosage,control,split_covariates_mask,perturbation_covariates_mask
0,cell_line_a,batch_18,control,0.000000,True,,-1.0
1,cell_line_a,batch_16,control,0.000000,True,,-1.0
10,cell_line_a,batch_17,control,0.000000,True,,-1.0
100,cell_line_a,batch_16,drug1,0.666667,False,-1.0,
1000,cell_line_b,batch_14,drug1,0.333333,False,-1.0,
...,...,...,...,...,...,...,...
9995,cell_line_k,batch_26,drug4,1.000000,False,-1.0,
9996,cell_line_k,batch_25,drug4,1.000000,False,-1.0,
9997,cell_line_k,batch_11,drug4,1.000000,False,-1.0,
9998,cell_line_k,batch_7,drug4,1.000000,False,-1.0,


In [None]:
df

Unnamed: 0,cell_type,batch,drug,dosage,control
0,cell_line_a,batch_18,control,0.0,True
1,cell_line_a,batch_16,control,0.0,True
2,cell_line_a,batch_27,control,0.0,True
3,cell_line_a,batch_1,control,0.0,True
4,cell_line_a,batch_17,control,0.0,True
...,...,...,...,...,...
28195,cell_line_~,batch_7,drug6,1.0,False
28196,cell_line_~,batch_30,drug6,1.0,False
28197,cell_line_~,batch_17,drug6,1.0,False
28198,cell_line_~,batch_18,drug6,1.0,False


KeyError: 'split_covariates_mask'

In [None]:
result

Unnamed: 0,cell_type,batch,drug,dosage,control,split_covariates_mask,perturbation_covariates_mask
0,cell_line_a,batch_18,control,0.0,True,,-1.0
1,cell_line_a,batch_16,control,0.0,True,,-1.0
2,cell_line_a,batch_27,control,0.0,True,,-1.0
3,cell_line_a,batch_1,control,0.0,True,,-1.0
4,cell_line_a,batch_17,control,0.0,True,,-1.0
...,...,...,...,...,...,...,...
28195,cell_line_~,batch_7,drug6,1.0,False,-1.0,
28196,cell_line_~,batch_30,drug6,1.0,False,-1.0,
28197,cell_line_~,batch_17,drug6,1.0,False,-1.0,
28198,cell_line_~,batch_18,drug6,1.0,False,-1.0,


In [None]:
df = dm.adata.obs[dm.split_covariates+list(dm.perturbation_covariates.keys())+["control"]]

In [None]:
def create_masks(df):
    """
    Create masks with sequential indices:
    - split_mask: sequential indices (0 to n-1) for unique [cell_type, batch] combinations in control, -1 otherwise
    - pert_mask: sequential indices (0 to n-1) for unique [cell_type, batch, drug, dosage] combinations in non-control, -1 otherwise
    """
    # Initialize masks with -1
    split_mask = np.full(len(df), -1)
    pert_mask = np.full(len(df), -1)
    
    control_mask = df['control'].values
    
    # Handle split mask (control cells)
    if control_mask.any():
        # Get unique combinations for control cells only
        control_df = df[control_mask]
        split_combinations = pd.Categorical(
            control_df['cell_type'].astype(str) + '_' + 
            control_df['batch'].astype(str)
        )
        
        # Create mapping from old codes to sequential indices
        unique_codes = np.unique(split_combinations.codes)
        code_to_seq = {code: idx for idx, code in enumerate(unique_codes)}
        
        # Apply sequential indexing
        split_mask[control_mask] = [code_to_seq[code] for code in split_combinations.codes]
    
    # Handle pert mask (non-control cells)
    if (~control_mask).any():
        # Get unique combinations for non-control cells only
        non_control_df = df[~control_mask]
        pert_combinations = pd.Categorical(
            non_control_df['cell_type'].astype(str) + '_' + 
            non_control_df['batch'].astype(str) + '_' +
            non_control_df['drug'].astype(str) + '_' + 
            non_control_df['dosage'].astype(str)
        )
        
        # Create mapping from old codes to sequential indices
        unique_codes = np.unique(pert_combinations.codes)
        code_to_seq = {code: idx for idx, code in enumerate(unique_codes)}
        
        # Apply sequential indexing
        pert_mask[~control_mask] = [code_to_seq[code] for code in pert_combinations.codes]
    
    return split_mask, pert_mask

In [None]:
# Assuming your dataframe is called 'df'
split_mask, pert_mask = create_masks(df)

# Example to verify the results:
print("Unique split mask values:", np.unique(split_mask))
print("Unique perturbation mask values:", np.unique(pert_mask))

Unique split mask values: [ -1   0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16
  17  18  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  34
  35  36  37  38  39  40  41  42  43  44  45  46  47  48  49  50  51  52
  53  54  55  56  57  58  59  60  61  62  63  64  65  66  67  68  69  70
  71  72  73  74  75  76  77  78  79  80  81  82  83  84  85  86  87  88
  89  90  91  92  93  94  95  96  97  98  99 100 101 102 103 104 105 106
 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214
 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232
 233 234 235 236 237 238 

In [None]:
split_mask

array([ 8,  6,  0, ..., -1, -1, -1])

Array([ 0,  1,  2, ..., -1, -1, -1], dtype=int32)