In [31]:
import anndata as ad
import pandas as pd
import numpy as np
from scipy import sparse

In [6]:
!ls -lh ../../data/neurips-2023-data

total 4331736
-rw-r--r--@ 1 arturszalata  staff   316M May 18 01:14 de_per_donor.h5ad
-rw-r--r--@ 1 arturszalata  staff   253M May 17 14:34 de_per_donor_old.h5ad
-rw-r--r--@ 1 arturszalata  staff    16M May 17 14:51 de_per_donor_test.h5ad
-rw-r--r--@ 1 arturszalata  staff   104M Oct 18 08:27 de_test.h5ad
-rw-r--r--@ 1 arturszalata  staff   175M Oct 18 08:27 de_train.h5ad
-rw-r--r--@ 1 arturszalata  staff   3.8K Oct 18 08:27 id_map.csv
-rw-r--r--@ 1 arturszalata  staff   3.5M Oct 18 08:27 prediction.h5ad
-rw-r--r--  1 arturszalata  staff    44M May 17 14:20 pseudobulk_cleaned.h5ad
-rw-r--r--@ 1 arturszalata  staff   100M Oct 18 08:27 pseudobulk_filtered_with_uns.h5ad
-rw-r--r--@ 1 arturszalata  staff   180M Oct 18 08:27 sc_test.h5ad
-rw-r--r--@ 1 arturszalata  staff   883M Oct 18 08:27 sc_train.h5ad
-rw-r--r--@ 1 arturszalata  staff    22K Oct 18 08:29 score.h5ad
-rw-r--r--@ 1 arturszalata  staff    40M May 16 20:08 small_pseudobulk.h5ad
-rw-r--r--@ 1 arturszalata  staff  

In [7]:
# load original pseudobulk
pseudobulk = ad.read_h5ad("../../data/neurips-2023-data/pseudobulk_filtered_with_uns.h5ad")

In [9]:
preds = ad.read_h5ad("../../data/neurips-2023-data/prediction.h5ad")

In [2]:
!ls ../../data/perturbench_data

[1m[36mlogs[m[m                        op3_processed_train.h5ad
op3_processed.h5ad          sciplex3_processed.h5ad
op3_processed_test.h5ad     srivatsan20_downloaded.h5ad


In [3]:
!ls ../../data/perturbench_data/logs/predict/runs/2024-10-30_15-47-50/predictions/

prediction_chunk_0.h5ad prediction_chunk_2.h5ad
prediction_chunk_1.h5ad prediction_chunk_3.h5ad


In [17]:
path_to_output = "../../data/perturbench_data/logs/predict/runs/2024-10-30_15-47-50/predictions/"

In [18]:
# Load each chunk and store in a list
chunk_files = [f"{path_to_output}prediction_chunk_{i}.h5ad" for i in range(4)]
chunks = [ad.read_h5ad(chunk_file) for chunk_file in chunk_files]
combined_adata = ad.concat(chunks, axis=0)
del chunks

  utils.warn_names_duplicates("obs")


In [27]:
combined_adata.obs.rename(columns={"condition": "sm_name"}, inplace=True)

In [33]:
combined_adata.obs["sm_cell_type"] = combined_adata.obs["sm_name"].astype(str) + "_" + combined_adata.obs["cell_type"].astype(str)

In [35]:
combined_adata.obs["sm_cell_type"] = combined_adata.obs["sm_cell_type"].astype("category")

In [36]:
def sum_by(adata: ad.AnnData, col: str) -> ad.AnnData:
    """
    Adapted from this forum post:
    https://discourse.scverse.org/t/group-sum-rows-based-on-jobs-feature/371/4
    """

    # assert pd.api.types.is_categorical_dtype(adata.obs[col])
    assert isinstance(adata.obs[col].dtypes, pd.CategoricalDtype)

    # sum `.X` entries for each unique value in `col`
    cat = adata.obs[col].values
    indicator = sparse.coo_matrix(
        (
            np.broadcast_to(True, adata.n_obs),
            (cat.codes, np.arange(adata.n_obs))
        ),
        shape=(len(cat.categories), adata.n_obs),
    )
    sum_adata = ad.AnnData(
        var=adata.var,
        obs=pd.DataFrame(index=cat.categories),
    )
    if adata.X is not None:
        sum_adata.X = indicator @ adata.X
    for layer in adata.layers:
        sum_adata.layers[layer] = indicator @ adata.layers[layer]

    # copy over `.obs` values that have a one-to-one-mapping with `.obs[col]`
    obs_cols = list(set(adata.obs.columns) - set([col]))

    one_to_one_mapped_obs_cols = []
    nunique_in_col = adata.obs[col].nunique()
    for other_col in obs_cols:
        if len(adata.obs[[col, other_col]].drop_duplicates()) == nunique_in_col:
            one_to_one_mapped_obs_cols.append(other_col)

    joining_df = adata.obs[[col] + one_to_one_mapped_obs_cols].drop_duplicates().set_index(col)
    assert (sum_adata.obs.index == sum_adata.obs.join(joining_df).index).all()
    sum_adata.obs = sum_adata.obs.join(joining_df)
    sum_adata.obs.index.name = col
    sum_adata.obs = sum_adata.obs.reset_index()
    sum_adata.obs.index = sum_adata.obs.index.astype('str')

    return sum_adata

print(">> Create pseudobulk dataset", flush=True)
bulk_adata = sum_by(combined_adata, 'sm_cell_type')

>> Create pseudobulk dataset


In [49]:
pseudobulk[pseudobulk.obs.sm_cell_type.isin(bulk_adata.obs.sm_cell_type.unique())].X[:10, :10].A

array([[  1.,  27.,  13.,  27.,  27.,  13.,  18.,  32.,  64.,  18.],
       [119.,  73.,  24.,  76.,  97.,   2.,  23., 194., 187.,  13.],
       [  1.,  34.,   9.,  35.,  40.,  15.,  18.,  37.,  65.,  10.],
       [141., 120.,  31., 100., 135.,   3.,  33., 218., 305.,  30.],
       [  1.,  14.,   6.,  22.,  11.,   7.,   7.,  24.,  30.,   6.],
       [ 45.,  29.,   7.,  39.,  44.,   0.,   3.,  65.,  67.,  10.],
       [  2.,  16.,   9.,  15.,  18.,   3.,   8.,  28.,  39.,  10.],
       [106.,  56.,  27.,  69.,  72.,   0.,  20., 148., 179.,  17.],
       [  3.,  17.,  12.,  21.,  25.,   6.,  11.,  27.,  31.,   9.],
       [ 81.,  58.,  20.,  51.,  72.,   1.,  18., 119., 146.,  13.]])

In [51]:
# Identify the matching cell types
matching_cell_types = set(pseudobulk.obs["sm_cell_type"].unique()).intersection(bulk_adata.obs["sm_cell_type"].unique())
print("Number of matching cell types:", len(matching_cell_types))

# Create a mask for matching sm_cell_types
mask_matching = pseudobulk.obs["sm_cell_type"].isin(matching_cell_types)

# Create a mask for duplicates in pseudobulk based on sm_cell_type
mask_duplicates = pseudobulk.obs.duplicated(subset=["sm_cell_type"], keep="first")

# Create a mask to identify duplicates only in matching sm_cell_types
mask_to_drop = mask_matching & mask_duplicates

# Drop duplicates in pseudobulk for matching sm_cell_types
pseudobulk = pseudobulk[~mask_to_drop].copy()

# Update the index after dropping duplicates
pseudobulk.obs.reset_index(drop=True, inplace=True)

# Replace values in pseudobulk.X for the matching sm_cell_type entries
for cell_type in matching_cell_types:
    # Get boolean masks for the cell type
    pseudobulk_mask = pseudobulk.obs["sm_cell_type"] == cell_type
    bulk_mask = bulk_adata.obs["sm_cell_type"] == cell_type

    # Get the integer positions of the entries
    pseudobulk_pos = np.where(pseudobulk_mask)[0][0]
    bulk_pos = np.where(bulk_mask)[0][0]

    # Replace the pseudobulk entry with the corresponding bulk_adata values
    pseudobulk.X[pseudobulk_pos] = bulk_adata.X[bulk_pos]

Number of matching cell types: 151


  self._set_arrayXarray(i, j, x)


In [57]:
pseudobulk.write_h5ad("../../data/neurips-2023-data/pseudobulk_with_preds.h5ad")

In [60]:
pseudobulk.obs.split.unique()

['private_test', 'train', 'public_test', 'control']
Categories (4, object): ['control', 'private_test', 'public_test', 'train']