In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
from scaleflow.data._datamanager_new import DataManager
from scaleflow.data._anndata_location import AnnDataLocation
from pathlib import Path
import anndata as ad
import h5py

In [4]:
DATA_PATH = Path("/lustre/groups/ml01/workspace/100mil/100m_int_indices.h5ad")

In [None]:
with h5py.File(DATA_PATH, "r") as f:
    adata = ad.AnnData(
        obs=ad.io.read_elem(f["obs"]),
        obsm=ad.experimental.read_lazy(f["obsm"]),
        uns=ad.io.read_elem(f["uns"]),
    )

In [5]:
adl = AnnDataLocation()
dm = DataManager(
    dist_flag_key="control",
    src_dist_keys=["cell_line"],
    tgt_dist_keys=["drug", "dosage"],
    rep_keys={
        'cell_line': 'cell_line_embeddings',
        'drug': 'drug_embeddings',
    },
    data_location=adl.obsm['X_pca'][:,:50],
)

In [None]:
gd = dm.prepare_data(
    adata=adata,
)

In [8]:
import zarr
path = "/lustre/groups/ml01/workspace/100mil/tahoe.zarr"
chunk_size = 131072
shard_size = chunk_size * 8

In [1]:
s.write_zarr(path, chunk_size=chunk_size, shard_size=shard_size) 


NameError: name 's' is not defined

In [None]:
    # adata_all,
    # sample_rep="X_pca",
    # control_key="control",
    # perturbation_covariates={"drugs": ("drug",), "dosage": ("dosage",)},
    # perturbation_covariate_reps={"drugs": "drug_embeddings"},
    # sample_covariates=["cell_line"],
    # sample_covariate_reps={"cell_line": "cell_line_embeddings"},
    # split_covariates=["cell_line"],
    # max_combination_length=None,
    # null_value=0.0,

In [32]:
adl[src_grps[0]].obsm['X_pca']

<AnnDataLocation: AnnDataAccessor()[Index(['54637', '54764', '54794', '54839', '54920', '54940', '54941', '54945',
       '54956', '54968',
       ...
       '95623798', '95623811', '95623828', '95623832', '95623845', '95624002',
       '95624106', '95624182', '95624231', '95624262'],
      dtype='object', length=2567838)].obsm['X_pca']>

In [1]:
# %%
# %load_ext autoreload
# %autoreload 2


# %%
import anndata as ad
import h5py
import zarr
from scaleflow.data._utils import write_sharded
from anndata.experimental import read_lazy
from scaleflow.data import DataManager
import cupy as cp
import tqdm
import dask
import concurrent.futures
from functools import partial
import numpy as np
import dask.array as da
from dask.diagnostics import ProgressBar

print("loading data")
with h5py.File("/lustre/groups/ml01/workspace/100mil/100m_int_indices.h5ad", "r") as f:
    adata_all = ad.AnnData(
        obs=ad.io.read_elem(f["obs"]),
        var=read_lazy(f["var"]),
        uns=read_lazy(f["uns"]),
        obsm=read_lazy(f["obsm"]),
    )

dm = DataManager(
    adata_all,
    sample_rep="X_pca",
    control_key="control",
    perturbation_covariates={"drugs": ("drug",), "dosage": ("dosage",)},
    perturbation_covariate_reps={"drugs": "drug_embeddings"},
    sample_covariates=["cell_line"],
    sample_covariate_reps={"cell_line": "cell_line_embeddings"},
    split_covariates=["cell_line"],
    max_combination_length=None,
    null_value=0.0,
)
print("data loaded")

# %%
cond_data = dm._get_condition_data(adata=adata_all)
cell_data = dm._get_cell_data(adata_all)

# %%
n_source_dists = len(cond_data.split_idx_to_covariates)
n_target_dists = len(cond_data.perturbation_idx_to_covariates)

tgt_cell_data = {}
src_cell_data = {}
gpu_per_cov_mask = cp.asarray(cond_data.perturbation_covariates_mask)
gpu_spl_cov_mask = cp.asarray(cond_data.split_covariates_mask)

for src_idx in tqdm.tqdm(range(n_source_dists), desc="Computing source to cell data idcs"):
    mask = gpu_spl_cov_mask == src_idx
    src_cell_data[str(src_idx)] = {
        "cell_data_index": cp.where(mask)[0].get(),
    }

for tgt_idx in tqdm.tqdm(range(n_target_dists), desc="Computing target to cell data idcs"):
    mask = gpu_per_cov_mask == tgt_idx
    tgt_cell_data[str(tgt_idx)] = {
        "cell_data_index": cp.where(mask)[0].get(),
    }

# %%

print("Computing cell data")
cell_data = cell_data.compute()
print("cell data computed")

for src_idx in tqdm.tqdm(range(n_source_dists), desc="Computing source to cell data"):
    indices = src_cell_data[str(src_idx)]["cell_data_index"]
    src_cell_data[str(src_idx)]["cell_data"] = cell_data[indices]

for tgt_idx in tqdm.tqdm(range(n_target_dists), desc="Computing target to cell data"):
    indices = tgt_cell_data[str(tgt_idx)]["cell_data_index"]
    tgt_cell_data[str(tgt_idx)]["cell_data"] = cell_data[indices]


# %%

split_covariates_mask = np.asarray(cond_data.split_covariates_mask)
perturbation_covariates_mask = np.asarray(cond_data.perturbation_covariates_mask)
condition_data = {str(k): np.asarray(v) for k, v in (cond_data.condition_data or {}).items()}
control_to_perturbation = {str(k): np.asarray(v) for k, v in (cond_data.control_to_perturbation or {}).items()}
split_idx_to_covariates = {str(k): np.asarray(v) for k, v in (cond_data.split_idx_to_covariates or {}).items()}
perturbation_idx_to_covariates = {
    str(k): np.asarray(v) for k, v in (cond_data.perturbation_idx_to_covariates or {}).items()
}
perturbation_idx_to_id = {str(k): v for k, v in (cond_data.perturbation_idx_to_id or {}).items()}

train_data_dict = {
    "split_covariates_mask": split_covariates_mask,
    "perturbation_covariates_mask": perturbation_covariates_mask,
    "split_idx_to_covariates": split_idx_to_covariates,
    "perturbation_idx_to_covariates": perturbation_idx_to_covariates,
    "perturbation_idx_to_id": perturbation_idx_to_id,
    "condition_data": condition_data,
    "control_to_perturbation": control_to_perturbation,
    "max_combination_length": int(cond_data.max_combination_length),
    # "src_cell_data": src_cell_data,
    # "tgt_cell_data": tgt_cell_data,
}

print("prepared train_data_dict")
# %%
path = "/lustre/groups/ml01/workspace/100mil/tahoe.zarr"
zgroup = zarr.open_group(path, mode="w")
chunk_size = 131072
shard_size = chunk_size * 8

ad.settings.zarr_write_format = 3  # Needed to support sharding in Zarr


def get_size(shape: tuple[int, ...], chunk_size: int, shard_size: int) -> tuple[int, int]:
    shard_size_used = shard_size
    chunk_size_used = chunk_size
    if chunk_size > shape[0]:
        chunk_size_used = shard_size_used = shape[0]
    elif chunk_size < shape[0] or shard_size > shape[0]:
        chunk_size_used = shard_size_used = shape[0]
    return chunk_size_used, shard_size_used


def write_single_array(group, key, arr, idxs, chunk_size, shard_size):
    """Write a single array - designed for threading"""
    chunk_size_used, shard_size_used = get_size(arr.shape, chunk_size, shard_size)

    group.create_array(
        name=key,
        data=arr,
        chunks=(chunk_size_used, arr.shape[1]),
        shards=(shard_size_used, arr.shape[1]),
        compressors=None,
    )

    group.create_array(
        name=f"{key}_index",
        data=idxs,
        chunks=(len(idxs),),
        shards=(len(idxs),),
        compressors=None,
    )
    return key


def write_cell_data_threaded(group, cell_data, chunk_size, shard_size, max_workers=8):
    """Write cell data using threading for I/O parallelism"""

    write_func = partial(write_single_array, group, chunk_size=chunk_size, shard_size=shard_size)

    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Submit all write tasks
        future_to_key = {
            executor.submit(
                write_single_array,
                group,
                k,
                cell_data[k]["cell_data"],
                cell_data[k]["cell_data_index"],
                chunk_size,
                shard_size,
            ): k
            for k in cell_data.keys()
        }

        # Process results with progress bar
        for future in tqdm.tqdm(
            concurrent.futures.as_completed(future_to_key), total=len(future_to_key), desc=f"Writing {group.name}"
        ):
            key = future_to_key[future]
            try:
                future.result()  # This will raise any exceptions
            except Exception as exc:
                print(f"Array {key} generated an exception: {exc}")
                raise


# %%


src_group = zgroup.create_group("src_cell_data", overwrite=True)
tgt_group = zgroup.create_group("tgt_cell_data", overwrite=True)


# Use the fast threaded approach you already implemented
write_cell_data_threaded(src_group, src_cell_data, chunk_size, shard_size, max_workers=24)
print("done writing src_cell_data")
write_cell_data_threaded(tgt_group, tgt_cell_data, chunk_size, shard_size, max_workers=24)
print("done writing tgt_cell_data")


# %%

print("Writing mapping data")
mapping_data = zgroup.create_group("mapping_data", overwrite=True)


write_sharded(
    group=mapping_data,
    name="mapping_data",
    data=train_data_dict,
    chunk_size=chunk_size,
    shard_size=shard_size,
    compressors=None,
)
print("done")


IndentationError: expected an indented block after class definition on line 216 (_data.py, line 219)

In [5]:

# %%
import anndata as ad
import h5py
from anndata.experimental import read_lazy
import dask.dataframe as ddf

In [2]:
data_path = "/lustre/groups/ml01/projects/big_perturbation/datasets/nadig_jurkat.h5ad"



In [10]:

print("loading data")
tahoe_path = "/lustre/groups/ml01/workspace/100mil/100m_int_indices.h5ad"
with h5py.File(tahoe_path, "r") as f:
    obs = ddf.read_hdf(
        tahoe_path, "obs"
    ) 

loading data


TypeError: An error occurred while calling the read_hdf method registered to the pandas backend.
Original Message: cannot create a storer if the object is not existing nor a value are passed

In [None]:
import anndata, session_info2; session_info2.session_info(dependencies=True)

  and (v := getattr(pkg, "__version__", None))


Package,Version
Component,Info
xarray,2025.9.0
anndata,0.13.0.dev28+g1ba19458f
h5py,3.14.0
zarr,3.1.2
pandas,2.3.2
Python,"3.12.11 | packaged by conda-forge | (main, Jun 4 2025, 14:45:31) [GCC 13.3.0]"
OS,Linux-5.14.0-570.25.1.el9_6.x86_64-x86_64-with-glibc2.34
Updated,2025-10-16 13:10

Dependency,Version
packaging,25.0
debugpy,1.8.12
msgpack,1.1.1
cupy-cuda12x,13.6.0
urllib3,2.5.0
setuptools,80.9.0
click,8.2.1
Pygments,2.19.2
legacy-api-wrap,1.4.1
cloudpickle,3.1.1


In [16]:

print("loading data")
with h5py.File("/lustre/groups/ml01/workspace/100mil/100m_int_indices.h5ad", "r") as f:
    obs=ad.io.read_elem(f["obs"])

loading data


In [46]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [53]:
from scaleflow.data._datamanager_new import DataManager

ImportError: cannot import name 'ConditionData' from 'scaleflow.data._data' (/ictstr01/home/icb/selman.ozleyen/projects/CellFlow2/src/scaleflow/data/_data.py)

In [20]:
src_keys = ['cell_line']
tgt_keys = ['drug', 'dosage']
control_key = 'control'

In [21]:
obs.sort_values([control_key,*src_keys,*tgt_keys])

Unnamed: 0,cell_indices,drug,cell_line,phase,plate,train_379_pos1,train_379_pos2,train_379_pos3,train_379_pos4,train_379_pos5,...,train_55000_pos4,train_55000_pos5,train_55000_pos6,train_55000_pos7,train_55000_pos8,train_55000_pos9,train_55000_pos10,test,dosage,control
54637,72_005_017-lib_841,4EGI-1,CVCL_0023,G1,plate_1,False,False,False,False,False,...,True,True,True,True,True,True,False,False,0.05,False
54764,72_036_191-lib_841,4EGI-1,CVCL_0023,G2M,plate_1,False,False,False,False,False,...,True,True,True,True,True,True,False,False,0.05,False
54794,72_045_112-lib_841,4EGI-1,CVCL_0023,G1,plate_1,False,False,False,False,False,...,True,True,True,True,True,True,False,False,0.05,False
54839,72_058_077-lib_841,4EGI-1,CVCL_0023,G1,plate_1,False,False,False,False,False,...,True,True,True,True,True,True,False,False,0.05,False
54920,72_075_175-lib_841,4EGI-1,CVCL_0023,G1,plate_1,False,False,False,False,False,...,True,True,True,True,True,True,False,False,0.05,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95623703,95_135_160-lib_2613,DMSO_TF,CVCL_C466,G1,plate_14,False,False,False,False,False,...,False,False,False,False,False,False,False,False,0.00,True
95623742,95_158_048-lib_2613,DMSO_TF,CVCL_C466,G2M,plate_14,False,False,False,False,False,...,False,False,False,False,False,False,False,False,0.00,True
95624003,96_076_165-lib_2613,DMSO_TF,CVCL_C466,G1,plate_14,False,False,False,False,False,...,False,False,False,False,False,False,False,False,0.00,True
95624009,96_080_017-lib_2613,DMSO_TF,CVCL_C466,G1,plate_14,False,False,False,False,False,...,False,False,False,False,False,False,False,False,0.00,True


In [22]:
obs["src_dist_idx"] = obs.groupby(src_keys).ngroup()
obs["tgt_dist_idx"] = obs.groupby([*src_keys, *tgt_keys]).ngroup()

  obs["src_dist_idx"] = obs.groupby(src_keys).ngroup()
  obs["tgt_dist_idx"] = obs.groupby([*src_keys, *tgt_keys]).ngroup()


In [32]:
import anndata
class AnnDataLocation:
    """
    An object that stores a sequence of access operations (attributes and keys)
    and can be called on an AnnData object to execute them.
    """
    def __init__(self, path=None):
        # The path is a list of tuples, e.g., [('getattr', 'obsm'), ('getitem', 's')]
        self._path = path if path is not None else []

    def __getattr__(self, name):
        """
        Handles attribute access, like .obs or .X.
        It returns a new AnnDataLocation with the attribute access added to the path.
        """
        if name.startswith('__') and name.endswith('__'):
            # Avoid interfering with special methods
            raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
            
        new_path = self._path + [('getattr', name)]
        return AnnDataLocation(new_path)

    def __getitem__(self, key):
        """
        Handles item access, like ['my_key'].
        It returns a new AnnDataLocation with the item access added to the path.
        """
        new_path = self._path + [('getitem', key)]
        return AnnDataLocation(new_path)

    def __call__(self, adata: anndata.AnnData):
        """
        Executes the stored path of operations on the provided AnnData object.
        """
        target = adata
        try:
            for op_type, op_arg in self._path:
                if op_type == 'getattr':
                    target = getattr(target, op_arg)
                elif op_type == 'getitem':
                    target = target[op_arg]
            return target
        except (AttributeError, KeyError) as e:
            raise type(e)(f"Failed to resolve location {self!r} on the AnnData object. Reason: {e}") from e

    def __repr__(self):
        """Provides a user-friendly string representation of the stored path."""
        representation = "AnnDataAccessor()"
        for op_type, op_arg in self._path:
            if op_type == 'getattr':
                representation += f'.{op_arg}'
            elif op_type == 'getitem':
                # Use repr() to correctly handle string keys with quotes
                representation += f'[{repr(op_arg)}]'
        return f"<AnnDataLocation: {representation}>"


    

In [33]:
ac = AnnDataLocation()

In [38]:
ac.obsm['X']

<AnnDataLocation: AnnDataAccessor().obsm['X']>

In [39]:
src_dist_map = obs.groupby('src_dist_idx').groups
tgt_dist_map = obs.groupby('tgt_dist_idx').groups

In [40]:
with h5py.File("/lustre/groups/ml01/workspace/100mil/100m_int_indices.h5ad", "r") as f:
    adata_all = ad.AnnData(
        obs=obs,
        obsm=read_lazy(f["obsm"]),
    )

  return dispatch(args[0].__class__)(*args, **kw)
