In [1]:
%load_ext autoreload
%autoreload 2


In [2]:
import anndata as ad
import h5py
import zarr
from cellflow.data._utils import write_sharded
from anndata.experimental import read_lazy
from cellflow.data import DataManager
import cupy as cp
import tqdm
import dask
import numpy as np

print("loading data")
with h5py.File("/lustre/groups/ml01/workspace/100mil/100m_int_indices.h5ad", "r") as f:
    adata_all = ad.AnnData(
        obs=ad.io.read_elem(f["obs"]),
        var=read_lazy(f["var"]),
        uns = read_lazy(f["uns"]),
        obsm = read_lazy(f["obsm"]),
    )

dm = DataManager(adata_all,  
    sample_rep="X_pca",
    control_key="control",
    perturbation_covariates={"drugs": ("drug",), "dosage": ("dosage",)},
    perturbation_covariate_reps={"drugs": "drug_embeddings"},
    sample_covariates=["cell_line"],
    sample_covariate_reps={"cell_line": "cell_line_embeddings"},
    split_covariates=["cell_line"],
    max_combination_length=None,
    null_value=0.0
)


loading data


  return dispatch(args[0].__class__)(*args, **kw)


In [3]:
cond_data = dm._get_condition_data(adata=adata_all)
cell_data = dm._get_cell_data(adata_all)

[########################################] | 100% Completed | 910.75 ms
[########################################] | 100% Completed | 23.67 s
[########################################] | 100% Completed | 252.54 s


In [4]:
n_source_dists = len(cond_data.split_idx_to_covariates)
n_target_dists = len(cond_data.perturbation_idx_to_covariates)

tgt_cell_data = {}
src_cell_data = {}
gpu_per_cov_mask = cp.asarray(cond_data.perturbation_covariates_mask)
gpu_spl_cov_mask = cp.asarray(cond_data.split_covariates_mask)

for src_idx in tqdm.tqdm(range(n_source_dists), desc="Computing source to cell data idcs"):
    mask = gpu_spl_cov_mask == src_idx
    src_cell_data[str(src_idx)] = {
        "cell_data_index": cp.where(mask)[0].get(),
    }

for tgt_idx in tqdm.tqdm(range(n_target_dists), desc="Computing target to cell data idcs"):
    mask = gpu_per_cov_mask == tgt_idx
    tgt_cell_data[str(tgt_idx)] = {
        "cell_data_index": cp.where(mask)[0].get(),
    }

Computing source to cell data idcs: 100%|██████████| 50/50 [00:00<00:00, 62.77it/s]
Computing target to cell data idcs: 100%|██████████| 56827/56827 [01:05<00:00, 863.54it/s]


In [5]:

import dask.array as da
from dask.diagnostics import ProgressBar

src_delayed_objs = []
for src_idx in tqdm.tqdm(range(n_source_dists), desc="Computing source to cell data"):
    indices = src_cell_data[str(src_idx)]["cell_data_index"]
    delayed_obj = dask.delayed(lambda x: cell_data[x])(indices)
    src_delayed_objs.append((str(src_idx), delayed_obj))

tgt_delayed_objs = []
for tgt_idx in tqdm.tqdm(range(n_target_dists), desc="Computing target to cell data"):
    indices = tgt_cell_data[str(tgt_idx)]["cell_data_index"]
    delayed_obj = dask.delayed(lambda x: cell_data[x])(indices)
    tgt_delayed_objs.append((str(tgt_idx), delayed_obj))

src_results = []
tgt_results = []
with ProgressBar():
    src_results, tgt_results = dask.compute(src_delayed_objs, tgt_delayed_objs)

for k, v in src_results:
    src_cell_data[k]["cell_data"] = v

for k, v in tgt_results:
    tgt_cell_data[k]["cell_data"] = v


Computing source to cell data: 100%|██████████| 50/50 [00:00<00:00, 22329.13it/s]
Computing target to cell data:   6%|▌         | 3184/56827 [00:00<00:01, 31833.81it/s]

Computing target to cell data: 100%|██████████| 56827/56827 [00:02<00:00, 23426.54it/s]


[#####################                   ] | 52% Completed | 36m 17ss

IOStream.flush timed out


[########################################] | 100% Completed | 73m 49s


In [6]:

split_covariates_mask = np.asarray(cond_data.split_covariates_mask)
perturbation_covariates_mask = np.asarray(cond_data.perturbation_covariates_mask)
condition_data = {str(k): np.asarray(v) for k, v in (cond_data.condition_data or {}).items()}
control_to_perturbation = {str(k): np.asarray(v) for k, v in (cond_data.control_to_perturbation or {}).items()}
split_idx_to_covariates = {str(k): np.asarray(v) for k, v in (cond_data.split_idx_to_covariates or {}).items()}
perturbation_idx_to_covariates = {
    str(k): np.asarray(v) for k, v in (cond_data.perturbation_idx_to_covariates or {}).items()
}
perturbation_idx_to_id = {str(k): v for k, v in (cond_data.perturbation_idx_to_id or {}).items()}

train_data_dict = {
    "split_covariates_mask": split_covariates_mask,
    "perturbation_covariates_mask": perturbation_covariates_mask,
    "split_idx_to_covariates": split_idx_to_covariates,
    "perturbation_idx_to_covariates": perturbation_idx_to_covariates,
    "perturbation_idx_to_id": perturbation_idx_to_id,
    "condition_data": condition_data,
    "control_to_perturbation": control_to_perturbation,
    "max_combination_length": int(cond_data.max_combination_length),
    # "src_cell_data": src_cell_data,
    # "tgt_cell_data": tgt_cell_data,
}


In [7]:
path = "/lustre/groups/ml01/workspace/100mil/tahoe.zarr"
zgroup = zarr.open_group(path, mode="w")
chunk_size = 65536
shard_size = chunk_size * 16

ad.settings.zarr_write_format = 3  # Needed to support sharding in Zarr

def get_size(shape: tuple[int, ...], chunk_size: int, shard_size: int) -> tuple[int, int]:
    shard_size_used = shard_size
    chunk_size_used = chunk_size
    if chunk_size > shape[0] or shard_size > shape[0]:
        chunk_size_used = shard_size_used = shape[0]
    return chunk_size_used, shard_size_used




In [8]:

def write_arr(z_arr, arr, k):
    z_arr[:] = arr
    return k

def allocate_cell_data(group, cell_data, chunk_size, shard_size):
    delayed_objs = []

    for k in tqdm.tqdm(cell_data.keys(), desc="Allocating cell data"):
        chunk_size_used, shard_size_used = get_size(cell_data[k]["cell_data"].shape, chunk_size, shard_size)
        arr = cell_data[k]["cell_data"]

        z_arr = group.create_array(
            name=k,
            shape=arr.shape,
            chunks=(chunk_size_used, arr.shape[1]),
            shards=(shard_size_used, arr.shape[1]),
            compressors=None,
            dtype=arr.dtype,
        )

        delayed_objs.append(dask.delayed(write_arr)(z_arr, arr, k))
    
    return delayed_objs


src_group = zgroup.create_group("src_cell_data", overwrite=True)
tgt_group = zgroup.create_group("tgt_cell_data", overwrite=True)


src_delayed_objs = allocate_cell_data(src_group, src_cell_data, chunk_size, shard_size)
tgt_delayed_objs = allocate_cell_data(tgt_group, tgt_cell_data, chunk_size, shard_size)



# for k in tqdm.tqdm(src_cell_data.keys(), desc="Writing src cell data"):
#     chunk_size_used, shard_size_used = get_size(src_cell_data[k]["cell_data"].shape, chunk_size, shard_size)
#     arr = src_cell_data[k]["cell_data"]

#     z_arr = src_group.create_array(
#         name=k,
#         shape=arr.shape,
#         chunks=(chunk_size_used, arr.shape[1]),
#         shards=(shard_size_used, arr.shape[1]),
#         compressors=None,
#         dtype=arr.dtype,
#     )
    
#     delayed_objs.append(dask.delayed(write_arr)(z_arr, arr, k))

# for k in tqdm.tqdm(tgt_cell_data.keys(), desc="Writing tgt cell data"):
#     chunk_size_used, shard_size_used = get_size(tgt_cell_data[k]["cell_data"].shape, chunk_size, shard_size)
#     arr = tgt_cell_data[k]["cell_data"]
#     z_arr = tgt_group.create_array(
#         name=k,
#         shape=arr.shape,
#         chunks=(chunk_size_used, arr.shape[1]),
#         shards=(shard_size_used, arr.shape[1]),
#         compressors=None,
#         dtype=arr.dtype,
#     )
    
    
#     delayed_objs.append(dask.delayed(write_arr)(z_arr, arr, k))

Allocating cell data:  14%|█▍        | 7/50 [00:45<04:47,  6.69s/it]

Allocating cell data: 100%|██████████| 50/50 [06:47<00:00,  8.15s/it]
Allocating cell data: 100%|██████████| 56827/56827 [41:54<00:00, 22.60it/s]    


In [1]:

with ProgressBar():
    res = dask.compute(tgt_delayed_objs)

NameError: name 'ProgressBar' is not defined

In [9]:

mapping_data = zarr.create_group(zgroup, "mapping_data")

write_sharded(
    mapping_data,
    train_data_dict,
    chunk_size=chunk_size,
    shard_size=shard_size,
    compressors=None,
)
print("done")

TypeError: create_group() takes 1 positional argument but 2 were given