In [1]:
from scaleflow.data import MappedCellData

In [8]:
import numpy as np

def calculate_memory_cost(
    data: MappedCellData,
    src_idx: int,
    include_condition_data: bool = True
) -> dict[str, int | list | dict]:
    """Calculate memory cost in bytes for a given source index and its target distributions.
    
    Parameters
    ----------
    data
        The training data.
    src_idx
        The source distribution index.
    include_condition_data
        Whether to include condition data in memory calculations.
        
    Returns
    -------
    Dictionary with memory statistics in bytes for the source and its targets.
    """
    if src_idx not in data.control_to_perturbation:
        raise ValueError(f"Source index {src_idx} not found in control_to_perturbation mapping")
    
    # Get target indices for this source
    target_indices = data.control_to_perturbation[src_idx]
    
    # Calculate memory for source cells
    source_mask = data.split_covariates_mask == src_idx
    n_source_cells = data.src_cell_idx[src_idx].shape[0]
    source_memory = data.src_cell_data[src_idx].nbytes
    
    # Calculate memory for target cells
    target_memories = {}
    total_target_memory = 0
    
    for target_idx in target_indices:
        n_target_cells = data.tgt_cell_idx[target_idx].shape[0]
        target_memory = data.tgt_cell_data[target_idx].nbytes
        target_memories[f"target_{target_idx}"] = target_memory
        total_target_memory += target_memory
    
    # Calculate condition data memory if available and requested
    condition_memory = 0
    condition_details = {}
    if include_condition_data and data.condition_data is not None:
        for cond_name, cond_array in data.condition_data.items():
            # Condition data is indexed by target indices
            relevant_condition_size = len(target_indices) * cond_array.shape[1] * cond_array.dtype.itemsize
            condition_details[f"condition_{cond_name}"] = relevant_condition_size
            condition_memory += relevant_condition_size
    
    # Calculate total memory
    total_memory = source_memory + total_target_memory + condition_memory
    
    # Calculate average target memory
    avg_target_memory = total_target_memory // len(target_indices) if target_indices.size > 0 else 0
    
    result = {
        "source_idx": src_idx,
        "target_indices": target_indices.tolist(),
        "source_memory": source_memory,
        "source_cell_count": int(n_source_cells),
        "total_target_memory": total_target_memory,
        "avg_target_memory": avg_target_memory,
        "condition_memory": condition_memory,
        "total_memory": total_memory,
        "target_details": target_memories,
    }
    
    if condition_details:
        result["condition_details"] = condition_details
        
    return result

def format_memory_stats(memory_stats: dict, unit: str = "auto", summary: bool = False) -> str:
    """Format memory statistics into a human-readable string.
    
    Parameters
    ----------
    memory_stats
        Dictionary with memory statistics from calculate_memory_cost.
    unit
        Memory unit to use for display. Options: 'B', 'KB', 'MB', 'GB', 'auto'.
        If 'auto', the most appropriate unit will be chosen automatically.
    summary
        If True, includes a summary with average, min, and max target memory statistics
        and omits detailed per-target breakdown.
        
    Returns
    -------
    Human-readable string representation of memory statistics.
    """
    def format_bytes(bytes_value, unit="auto"):
        if unit == "auto":
            # Choose appropriate unit
            for unit in ["B", "KB", "MB", "GB"]:
                if bytes_value < 1024 or unit == "GB":
                    break
                bytes_value /= 1024
        elif unit == "KB":
            bytes_value /= 1024
        elif unit == "MB":
            bytes_value /= (1024 * 1024)
        elif unit == "GB":
            bytes_value /= (1024 * 1024 * 1024)
        
        return f"{bytes_value:.2f} {unit}"
    
    src_idx = memory_stats["source_idx"]
    target_indices = memory_stats["target_indices"]
    
    # Base information
    lines = [
        f"Memory statistics for source index {src_idx} with {len(target_indices)} targets:",
        f"- Source cells: {memory_stats['source_cell_count']} cells, {format_bytes(memory_stats['source_memory'], unit)}",
        f"- Total memory: {format_bytes(memory_stats['total_memory'], unit)}",
    ]
    
    # Calculate min and max target memory if summary is requested
    if summary and memory_stats["target_details"]:
        target_memories = list(memory_stats["target_details"].values())
        min_target = min(target_memories)
        max_target = max(target_memories)
        
        lines.extend([
            "\nTarget memory summary:",
            f"- Total: {format_bytes(memory_stats['total_target_memory'], unit)}",
            f"- Average: {format_bytes(memory_stats['avg_target_memory'], unit)}",
            f"- Min: {format_bytes(min_target, unit)}",
            f"- Max: {format_bytes(max_target, unit)}",
            f"- Range: {format_bytes(max_target - min_target, unit)}"
        ])
        
        # Add condition memory summary if available
        if memory_stats["condition_memory"] > 0:
            lines.append(f"\nCondition memory: {format_bytes(memory_stats['condition_memory'], unit)}")
    else:
        # Detailed output (original format)
        lines.extend([
            f"- Target memory: {format_bytes(memory_stats['total_target_memory'], unit)} total, {format_bytes(memory_stats['avg_target_memory'], unit)} average per target",
            f"- Condition memory: {format_bytes(memory_stats['condition_memory'], unit)}",
            "\nTarget details:"
        ])
        
        for target_key, target_memory in memory_stats["target_details"].items():
            target_id = target_key.split("_")[1]
            lines.append(f"  - Target {target_id}: {format_bytes(target_memory, unit)}")
        
        if "condition_details" in memory_stats:
            lines.append("\nCondition details:")
            for cond_key, cond_memory in memory_stats["condition_details"].items():
                cond_name = cond_key.split("_", 1)[1]
                lines.append(f"  - {cond_name}: {format_bytes(cond_memory, unit)}")
    
    return "\n".join(lines)

In [5]:
data = MappedCellData.read_zarr(
    "/lustre/groups/ml01/workspace/100mil/tahoe.zarr"
)

In [9]:
stats = calculate_memory_cost(data, 0)

In [None]:
print(format_memory_stats(stats, summary=True))

Memory statistics for source index 0 with 194 targets:
- Source cells: 60135 cells, 68.82 MB
- Total memory: 548.11 MB

Target memory summary:
- Total: 479.28 MB
- Average: 2.47 MB
- Min: 44.53 KB
- Max: 6.35 MB
- Range: 6.31 MB

Condition memory: 4.55 KB


In [10]:
data_stats = {}
for i in range(data.n_controls):
    data_stats[i] = calculate_memory_cost(data, i)


In [11]:
def print_average_memory_per_source(stats_dict):
    """Print the average total memory per source index.
    
    Parameters
    ----------
    stats_dict
        Optional pre-calculated memory statistics dictionary.
        If None, statistics will be calculated for all source indices.
    """
    
    
    # Extract total memory for each source index
    total_memories = [stats["total_memory"] for stats in stats_dict.values()]
    
    # Calculate statistics
    avg_memory = np.mean(total_memories)
    min_memory = np.min(total_memories)
    max_memory = np.max(total_memories)
    median_memory = np.median(total_memories)
    
    # Format the output
    def format_bytes(bytes_value):
        for unit in ["B", "KB", "MB", "GB"]:
            if bytes_value < 1024 or unit == "GB":
                break
            bytes_value /= 1024
        return f"{bytes_value:.2f} {unit}"
    
    print(f"Memory statistics across {len(stats_dict)} source indices:")
    print(f"- Average total memory per source: {format_bytes(avg_memory)}")
    print(f"- Minimum total memory: {format_bytes(min_memory)}")
    print(f"- Maximum total memory: {format_bytes(max_memory)}")
    print(f"- Median total memory: {format_bytes(median_memory)}")
    print(f"- Range: {format_bytes(max_memory - min_memory)}")
    
    # Identify source indices with min and max memory
    min_idx = min(stats_dict.keys(), key=lambda k: stats_dict[k]["total_memory"])
    max_idx = max(stats_dict.keys(), key=lambda k: stats_dict[k]["total_memory"])
    
    print(f"\nSource index with minimum memory: {min_idx} ({format_bytes(min_memory)})")
    print(f"Source index with maximum memory: {max_idx} ({format_bytes(max_memory)})")

In [13]:
print_average_memory_per_source(data_stats)

Memory statistics across 50 source indices:
- Average total memory per source: 2.14 GB
- Minimum total memory: 21.01 MB
- Maximum total memory: 6.75 GB
- Median total memory: 2.05 GB
- Range: 6.73 GB

Source index with minimum memory: 39 (21.01 MB)
Source index with maximum memory: 22 (6.75 GB)
