In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
from scaleflow.data._datamanager_new import DataManager
from scaleflow.data._anndata_location import AnnDataLocation
from scaleflow.data._data import GroupedDistribution
from pathlib import Path
import anndata as ad
import h5py
from anndata.experimental import read_lazy

In [4]:
DATA_PATH = Path("/lustre/groups/ml01/projects/big_perturbation/datasets/nadig_jurkat.h5ad")
OUTPUT_PATH = Path("/lustre/groups/ml01/workspace/100mil/nadig_jurkat.zarr")

In [5]:
with h5py.File(DATA_PATH, "r") as f:
    adata = ad.AnnData(
        obs=ad.io.read_elem(f["obs"]),
        obsm=read_lazy(f["obsm"]),
        uns=ad.io.read_elem(f["uns"]),
    )

# preparing data to be compatible with DataManager
dum = adata.uns['gene_embeddings']['AAAS']
missing_genes = set(adata.obs['pert_target'].unique()) - set(adata.uns['gene_embeddings'].keys())
for gene in missing_genes:
    adata.uns['gene_embeddings'][gene] = dum
adata.obs['control'] = (adata.obs['nperts'] == 0)


In [6]:
adl = AnnDataLocation()
dm = DataManager(
    dist_flag_key="control",
    src_dist_keys=["cell_line"],
    tgt_dist_keys=["pert_target"],
    rep_keys={
        'cell_line': 'cell_line_embeddings',
        'pert_target': 'gene_embeddings',
    },
    data_location=adl.obsm['X_pca'][:,:50],
)
gd = dm.prepare_data(
    adata=adata,
)

In [7]:
chunk_size = 131072
shard_size = chunk_size * 8

In [8]:
gd.write_zarr(
    path=OUTPUT_PATH,
    chunk_size=chunk_size,
    shard_size=shard_size,
    max_workers=14,
)

Writing /data/src_data: 100%|██████████| 1/1 [00:00<00:00, 76.75it/s]
Writing /data/tgt_data: 100%|██████████| 2394/2394 [00:06<00:00, 374.44it/s]
Writing /data/conditions: 100%|██████████| 2393/2393 [00:06<00:00, 388.11it/s]


In [9]:
gd = GroupedDistribution.read_zarr(OUTPUT_PATH)