In [1]:
%load_ext autoreload
%autoreload 2

import cellflow as cf


In [2]:
from cellflow.model import CellFlow
import anndata as ad
import h5py

from anndata.experimental import read_lazy

print("loading data")
with h5py.File("/lustre/groups/ml01/workspace/100mil/100m_int_indices.h5ad", "r") as f:
    adata_all = ad.AnnData(
        obs=ad.io.read_elem(f["obs"]),
        var=read_lazy(f["var"]),
        uns = read_lazy(f["uns"]),
        obsm = read_lazy(f["obsm"]),
    )

loading data


  return dispatch(args[0].__class__)(*args, **kw)


In [3]:
from cellflow.data import DataManager
dm = DataManager(adata_all,  
    sample_rep="X_pca",
    control_key="control",
    perturbation_covariates={"drugs": ("drug",), "dosage": ("dosage",)},
    perturbation_covariate_reps={"drugs": "drug_embeddings"},
    sample_covariates=["cell_line"],
    sample_covariate_reps={"cell_line": "cell_line_embeddings"},
    split_covariates=["cell_line"],
    max_combination_length=None,
    null_value=0.0
)

cond_data = dm._get_condition_data(adata=adata_all)
cell_data = dm._get_cell_data(adata_all)

[########################################] | 100% Completed | 908.17 ms
[########################################] | 100% Completed | 21.42 s
[########################################] | 100% Completed | 375.38 s


In [27]:
import cupy as cp
import tqdm

n_source_dists = len(cond_data.split_idx_to_covariates)
n_target_dists = len(cond_data.perturbation_idx_to_covariates)

tgt_cell_data = {}
src_cell_data = {}
gpu_per_cov_mask = cp.asarray(cond_data.perturbation_covariates_mask)
gpu_spl_cov_mask = cp.asarray(cond_data.split_covariates_mask)

for src_idx in tqdm.tqdm(range(n_source_dists), desc="Computing source to cell data idcs"):
    mask = gpu_spl_cov_mask == src_idx
    src_cell_data[str(src_idx)] = {
        "cell_data_index": cp.where(mask)[0].get(),
    }

for tgt_idx in tqdm.tqdm(range(n_target_dists), desc="Computing target to cell data idcs"):
    mask = gpu_per_cov_mask == tgt_idx
    tgt_cell_data[str(tgt_idx)] = {
        "cell_data_index": cp.where(mask)[0].get(),
    }


Computing source to cell data idcs: 100%|██████████| 50/50 [00:00<00:00, 75.74it/s]
Computing target to cell data idcs:  14%|█▍        | 8030/56827 [00:27<02:48, 289.15it/s]


KeyboardInterrupt: 

In [None]:
cell_data

Unnamed: 0,Array,Chunk
Bytes,106.87 GiB,1.14 MiB
Shape,"(95624334, 300)","(1000, 300)"
Dask graph,95625 chunks in 1 graph layer,95625 chunks in 1 graph layer
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 106.87 GiB 1.14 MiB Shape (95624334, 300) (1000, 300) Dask graph 95625 chunks in 1 graph layer Data type float32 numpy.ndarray",300  95624334,

Unnamed: 0,Array,Chunk
Bytes,106.87 GiB,1.14 MiB
Shape,"(95624334, 300)","(1000, 300)"
Dask graph,95625 chunks in 1 graph layer,95625 chunks in 1 graph layer
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [None]:
import dask


for src_idx in tqdm.tqdm(range(n_source_dists), desc="Computing source to cell data"):
    indices = src_cell_data[str(src_idx)]["cell_data_index"]
    delayed_obj = dask.delayed(lambda x: cell_data[x])(indices)
    src_cell_data[str(src_idx)]["cell_data"] = dask.array.from_delayed(delayed_obj, shape=(len(indices), cell_data.shape[1]), dtype=cell_data.dtype)

for tgt_idx in tqdm.tqdm(range(n_target_dists), desc="Computing target to cell data"):
    indices = tgt_cell_data[str(tgt_idx)]["cell_data_index"]
    delayed_obj = dask.delayed(lambda x: cell_data[x])(indices)
    tgt_cell_data[str(tgt_idx)]["cell_data"] = dask.array.from_delayed(delayed_obj, shape=(len(indices), cell_data.shape[1]), dtype=cell_data.dtype)


Computing source to cell data: 100%|██████████| 50/50 [00:00<00:00, 246.04it/s]
Computing target to cell data: 100%|██████████| 56827/56827 [00:08<00:00, 6554.24it/s]


In [None]:
import numpy as np

split_covariates_mask = np.asarray(cond_data.split_covariates_mask)
perturbation_covariates_mask = np.asarray(cond_data.perturbation_covariates_mask)
condition_data = {str(k): np.asarray(v) for k, v in (cond_data.condition_data or {}).items()}
control_to_perturbation = {str(k): np.asarray(v) for k, v in (cond_data.control_to_perturbation or {}).items()}
split_idx_to_covariates = {str(k): np.asarray(v) for k, v in (cond_data.split_idx_to_covariates or {}).items()}
perturbation_idx_to_covariates = {
    str(k): np.asarray(v) for k, v in (cond_data.perturbation_idx_to_covariates or {}).items()
}
perturbation_idx_to_id = {str(k): v for k, v in (cond_data.perturbation_idx_to_id or {}).items()}

In [None]:
train_data_dict = {
    "split_covariates_mask": split_covariates_mask,
    "perturbation_covariates_mask": perturbation_covariates_mask,
    "split_idx_to_covariates": split_idx_to_covariates,
    "perturbation_idx_to_covariates": perturbation_idx_to_covariates,
    "perturbation_idx_to_id": perturbation_idx_to_id,
    "condition_data": condition_data,
    "control_to_perturbation": control_to_perturbation,
    "max_combination_length": int(cond_data.max_combination_length),
    "src_cell_data": src_cell_data,
    "tgt_cell_data": tgt_cell_data,
}

In [None]:
import zarr
from cellflow.data._utils import write_sharded

path = "test.zarr"
zgroup = zarr.open_group(path, mode="w")
chunk_size = 65536
shard_size = chunk_size * 16
write_sharded(
    zgroup,
    train_data_dict,
    chunk_size=chunk_size,
    shard_size=shard_size,
    compressors=None,
)

KeyboardInterrupt: 

In [None]:
from cellflow.data import TrainSamplerWithPool, ZarrTrainingData

In [5]:
import numpy as np

def calculate_memory_cost(
    data: ZarrTrainingData,
    src_idx: int,
    include_condition_data: bool = True
) -> dict[str, int | list | dict]:
    """Calculate memory cost in bytes for a given source index and its target distributions.
    
    Parameters
    ----------
    data
        The training data.
    src_idx
        The source distribution index.
    include_condition_data
        Whether to include condition data in memory calculations.
        
    Returns
    -------
    Dictionary with memory statistics in bytes for the source and its targets.
    """
    if src_idx not in data.control_to_perturbation:
        raise ValueError(f"Source index {src_idx} not found in control_to_perturbation mapping")
    
    # Get target indices for this source
    target_indices = data.control_to_perturbation[src_idx]
    
    # Calculate memory for source cells
    source_mask = data.split_covariates_mask == src_idx
    n_source_cells = np.sum(source_mask)
    source_memory = n_source_cells * data.cell_data.shape[1] * data.cell_data.dtype.itemsize
    
    # Calculate memory for target cells
    target_memories = {}
    total_target_memory = 0
    
    for target_idx in target_indices:
        target_mask = data.perturbation_covariates_mask == target_idx
        n_target_cells = np.sum(target_mask)
        target_memory = n_target_cells * data.cell_data.shape[1] * data.cell_data.dtype.itemsize
        target_memories[f"target_{target_idx}"] = target_memory
        total_target_memory += target_memory
    
    # Calculate condition data memory if available and requested
    condition_memory = 0
    condition_details = {}
    if include_condition_data and data.condition_data is not None:
        for cond_name, cond_array in data.condition_data.items():
            # Condition data is indexed by target indices
            relevant_condition_size = len(target_indices) * cond_array.shape[1] * cond_array.dtype.itemsize
            condition_details[f"condition_{cond_name}"] = relevant_condition_size
            condition_memory += relevant_condition_size
    
    # Calculate total memory
    total_memory = source_memory + total_target_memory + condition_memory
    
    # Calculate average target memory
    avg_target_memory = total_target_memory // len(target_indices) if target_indices.size > 0 else 0
    
    result = {
        "source_idx": src_idx,
        "target_indices": target_indices.tolist(),
        "source_memory": source_memory,
        "source_cell_count": int(n_source_cells),
        "total_target_memory": total_target_memory,
        "avg_target_memory": avg_target_memory,
        "condition_memory": condition_memory,
        "total_memory": total_memory,
        "target_details": target_memories,
    }
    
    if condition_details:
        result["condition_details"] = condition_details
        
    return result

def format_memory_stats(memory_stats: dict, unit: str = "auto", summary: bool = False) -> str:
    """Format memory statistics into a human-readable string.
    
    Parameters
    ----------
    memory_stats
        Dictionary with memory statistics from calculate_memory_cost.
    unit
        Memory unit to use for display. Options: 'B', 'KB', 'MB', 'GB', 'auto'.
        If 'auto', the most appropriate unit will be chosen automatically.
    summary
        If True, includes a summary with average, min, and max target memory statistics
        and omits detailed per-target breakdown.
        
    Returns
    -------
    Human-readable string representation of memory statistics.
    """
    def format_bytes(bytes_value, unit="auto"):
        if unit == "auto":
            # Choose appropriate unit
            for unit in ["B", "KB", "MB", "GB"]:
                if bytes_value < 1024 or unit == "GB":
                    break
                bytes_value /= 1024
        elif unit == "KB":
            bytes_value /= 1024
        elif unit == "MB":
            bytes_value /= (1024 * 1024)
        elif unit == "GB":
            bytes_value /= (1024 * 1024 * 1024)
        
        return f"{bytes_value:.2f} {unit}"
    
    src_idx = memory_stats["source_idx"]
    target_indices = memory_stats["target_indices"]
    
    # Base information
    lines = [
        f"Memory statistics for source index {src_idx} with {len(target_indices)} targets:",
        f"- Source cells: {memory_stats['source_cell_count']} cells, {format_bytes(memory_stats['source_memory'], unit)}",
        f"- Total memory: {format_bytes(memory_stats['total_memory'], unit)}",
    ]
    
    # Calculate min and max target memory if summary is requested
    if summary and memory_stats["target_details"]:
        target_memories = list(memory_stats["target_details"].values())
        min_target = min(target_memories)
        max_target = max(target_memories)
        
        lines.extend([
            "\nTarget memory summary:",
            f"- Total: {format_bytes(memory_stats['total_target_memory'], unit)}",
            f"- Average: {format_bytes(memory_stats['avg_target_memory'], unit)}",
            f"- Min: {format_bytes(min_target, unit)}",
            f"- Max: {format_bytes(max_target, unit)}",
            f"- Range: {format_bytes(max_target - min_target, unit)}"
        ])
        
        # Add condition memory summary if available
        if memory_stats["condition_memory"] > 0:
            lines.append(f"\nCondition memory: {format_bytes(memory_stats['condition_memory'], unit)}")
    else:
        # Detailed output (original format)
        lines.extend([
            f"- Target memory: {format_bytes(memory_stats['total_target_memory'], unit)} total, {format_bytes(memory_stats['avg_target_memory'], unit)} average per target",
            f"- Condition memory: {format_bytes(memory_stats['condition_memory'], unit)}",
            "\nTarget details:"
        ])
        
        for target_key, target_memory in memory_stats["target_details"].items():
            target_id = target_key.split("_")[1]
            lines.append(f"  - Target {target_id}: {format_bytes(target_memory, unit)}")
        
        if "condition_details" in memory_stats:
            lines.append("\nCondition details:")
            for cond_key, cond_memory in memory_stats["condition_details"].items():
                cond_name = cond_key.split("_", 1)[1]
                lines.append(f"  - {cond_name}: {format_bytes(cond_memory, unit)}")
    
    return "\n".join(lines)

In [3]:
ztd = ZarrTrainingData.read_zarr(data_paths[0])



In [7]:
stats = calculate_memory_cost(ztd, 0)

In [8]:
print(format_memory_stats(stats, summary=True))

Memory statistics for source index 0 with 194 targets:
- Source cells: 60135 cells, 68.82 MB
- Total memory: 548.11 MB

Target memory summary:
- Total: 479.28 MB
- Average: 2.47 MB
- Min: 44.53 KB
- Max: 6.35 MB
- Range: 6.31 MB

Condition memory: 4.55 KB


In [13]:
ztd_stats = {}
for i in range(ztd.n_controls):
    ztd_stats[i] = calculate_memory_cost(ztd, i)


In [16]:
def print_average_memory_per_source(stats_dict):
    """Print the average total memory per source index.
    
    Parameters
    ----------
    stats_dict
        Optional pre-calculated memory statistics dictionary.
        If None, statistics will be calculated for all source indices.
    """
    
    
    # Extract total memory for each source index
    total_memories = [stats["total_memory"] for stats in stats_dict.values()]
    
    # Calculate statistics
    avg_memory = np.mean(total_memories)
    min_memory = np.min(total_memories)
    max_memory = np.max(total_memories)
    median_memory = np.median(total_memories)
    
    # Format the output
    def format_bytes(bytes_value):
        for unit in ["B", "KB", "MB", "GB"]:
            if bytes_value < 1024 or unit == "GB":
                break
            bytes_value /= 1024
        return f"{bytes_value:.2f} {unit}"
    
    print(f"Memory statistics across {len(stats_dict)} source indices:")
    print(f"- Average total memory per source: {format_bytes(avg_memory)}")
    print(f"- Minimum total memory: {format_bytes(min_memory)}")
    print(f"- Maximum total memory: {format_bytes(max_memory)}")
    print(f"- Median total memory: {format_bytes(median_memory)}")
    print(f"- Range: {format_bytes(max_memory - min_memory)}")
    
    # Identify source indices with min and max memory
    min_idx = min(stats_dict.keys(), key=lambda k: stats_dict[k]["total_memory"])
    max_idx = max(stats_dict.keys(), key=lambda k: stats_dict[k]["total_memory"])
    
    print(f"\nSource index with minimum memory: {min_idx} ({format_bytes(min_memory)})")
    print(f"Source index with maximum memory: {max_idx} ({format_bytes(max_memory)})")

In [59]:
print_average_memory_per_source(ztd_stats)

Memory statistics across 50 source indices:
- Average total memory per source: 423.18 MB
- Minimum total memory: 4.33 MB
- Maximum total memory: 1.29 GB
- Median total memory: 404.51 MB
- Range: 1.28 GB

Source index with minimum memory: 39 (4.33 MB)
Source index with maximum memory: 22 (1.29 GB)


In [11]:
from cellflow.data import TrainSamplerWithPool
import numpy as np
rng = np.random.default_rng(0)

In [13]:
tswp = TrainSamplerWithPool(ztd, batch_size=1024, pool_size=20, replacement_prob=0.01)
tswp.init_pool_n_cache(rng)

Computing target to cell data idcs: 100%|██████████| 9980/9980 [00:11<00:00, 891.95it/s] 
Computing source to cell data idcs: 100%|██████████| 50/50 [00:00<00:00, 1232.06it/s]


In [None]:
import time
iter_times = []
rng = np.random.default_rng(0)
start_time = time.time()
for iter in range(40):
    batch = tswp.sample(rng)
    end_time = time.time()
    iter_times.append(end_time - start_time)
    start_time = end_time

print("average time per iteration: ", np.mean(iter_times))
print("iterations per second: ", 1 / np.mean(iter_times))


replaced 47 with 34
replaced 32 with 30


In [64]:
tswp.get_pool_stats()

{'pool_size': 20,
 'avg_usage': 1.95,
 'unique_sources': 20,
 'pool_elements': array([31, 18, 47, 34, 12, 35, 29, 23, 32, 14,  6, 41, 25,  3,  1, 49, 24,
        10, 46, 33]),
 'usage_counts': array([2, 2, 3, 2, 1, 0, 2, 2, 2, 0, 3, 1, 2, 0, 3, 3, 2, 6, 1, 2])}