# On-Disk Concatenation of AnnData Files

**Author:** Selman Özleyen


## Initializing

Let's begin by importing the necessary libraries and modules. This notebook also uses the [memray](https://pypi.org/project/memray/) module. Ensure you've installed it using `pip install memray` before proceeding.

For all dependencies, do `pip install anndata zarr dask[array,distributed] pytest memray`.

In [1]:
%load_ext autoreload
%autoreload 2
import gc
import shutil
import logging
import tempfile
import itertools
from pathlib import Path
from typing import Literal, Callable

import zarr
import numpy as np
import pandas as pd
import memray
from scipy import sparse
from dask.distributed import Client, LocalCluster

import anndata
from anndata.tests.helpers import gen_typed_df
from anndata.experimental import write_elem
from anndata.experimental import concat_on_disk
from dask.distributed.diagnostics.memray import memray_scheduler, memray_workers

## Data Creation and Analysis

In this section, we'll demonstrate the core functionality of the `concat_on_disk` method. We'll create datasets and analyze how this method performs in terms of memory usage. This will help us understand its efficiency and benefits, especially when working with large datasets.

We will define parameters that will influence the structure of our datasets:

- **Shapes**: Defines the shape of array (e.g., "fat", "tall", "square").
- **Sizes**: The size of the array, indicating the number of elements.
- **Densities**: Specifies the data density. 1 means dense numpy array.

These parameters will be utilized in subsequent sections to generate and analyze datasets.

### Ignoring Logs or Not

By default we will ignore logs for the sake of readability. These are mostly reports given from dask distributed. However if one would like to see what is happening behind the dask distributed system, they can change the parameter dedicated to this below. These logs usually also refer to a dashboard link in order to monitor the workers.

In [2]:
# Directory where the data will be stored
TMPDIR = tempfile.TemporaryDirectory()
OUTDIR = Path(TMPDIR.name)

# Parameters that will influence the structure and size of our datasets:

# Shapes of the arrays: "fat", "tall", or "square"
shapes = ["fat", "tall", "square"]

# Sizes of the dataset, indicating the number of elements
sizes = [10_000]

# Densities: Specifies the data density. A higher value means more non-zero elements
densities = [0.1, 1]

# Number of times each array type will be created
num_runs = 3

# Set to False to see the logs and warnings
ignore_logs = True

dask_log = logging.CRITICAL
if not ignore_logs:
    dask_log = logging.DEBUG

### create_adata

This function is designed to create an `AnnData` object, which is a foundational data structure used in bioinformatics to store high-dimensional data such as gene expression matrices. Given a data matrix `X` and its shape, the function constructs the `AnnData` object complete with observation (`obs`) and variable (`var`) metadata.

- `shape`: The shape (dimensions) of the data matrix.
- `X`: The actual data matrix (could be dense or sparse).

Returns: An `AnnData` object constructed from the input data and metadata.


In [3]:
def create_adata(X):
    # Shape of the data matrix
    M, N = X.shape

    # Generating observation and variable names
    obs_names = pd.Index(f"cell{i}" for i in range(M))
    var_names = pd.Index(f"gene{i}" for i in range(N))

    # Creating observation and variable dataframes
    obs = gen_typed_df(M, obs_names)
    var = gen_typed_df(N, var_names)

    # Renaming columns to ensure uniqueness
    obs.rename(columns=dict(cat="obs_cat"), inplace=True)
    var.rename(columns=dict(cat="var_cat"), inplace=True)

    # Constructing the AnnData object
    adata = anndata.AnnData(X, obs=obs, var=var)

    return adata

### array_creators

This function returns a `dict` that takes a string as key and a function to create an array of that type as a value. The type of array format and their corresponding names based on the provided `density` parameter.

- `density`: The density of the dataset. If the density is 1, the dataset is dense; otherwise, it's sparse.

Returns: A dict containing the array creator functions and their corresponding names.


In [4]:
def array_creators(
    density: Literal[1] | float,
) -> dict[str, Callable[[np.ndarray | sparse.spmatrix], np.ndarray | sparse.spmatrix]]:
    """Returns a dictionary of array creators for the given density"""
    array_funcs = {}

    # Check if dataset is dense
    if density == 1:
        array_funcs["np"] = lambda x: x.toarray()
    else:
        # For sparse datasets, consider both csc and csr formats
        array_funcs["csc"] = sparse.csc_matrix
        array_funcs["csr"] = sparse.csr_matrix
    return array_funcs

### generate_dimensions

Given a shape description (like "fat", "tall", or "square") and a base size, this function computes the exact dimensions \(M\) and \(N\) of the dataset. 

- `shape_type`: Description of the desired shape of the dataset. In terms of a string description.
- `size`: Base size for the dataset.

Returns: The dimensions \(M\) and \(N\) of the dataset.


In [5]:
def generate_dimensions(shape_type, size):
    # Default dimensions
    M = size
    N = size

    # If the shape isn't square, adjust the dimensions
    if shape_type != "square":
        other_size = size + int(size * np.random.uniform(0.2, 0.4))
        if shape_type == "fat":
            M = other_size
        elif shape_type == "tall":
            N = other_size

    return M, N

## Writing The Arrays To Disk

We will use the functions defined below to write the anndatas. There is no need to understand them all. However, the functions are also explained below for users who would like to create their own datasets to do the measurements.

### Functions Overview

#### 1. `write_data_to_zarr`

This function is responsible for writing a given dataset `X` to a Zarr format file. Zarr is a format for the storage of chunked, compressed, N-dimensional arrays, which is useful for efficient on-disk storage and retrieval of large datasets.

- **Parameters**:
    - `X`: The dataset to be written.
    - `shape_type`: Descriptive shape type of the dataset.
    - `array_name`: Name representing the type of array (e.g., "np", "csc", "csr").
    - `outdir`: Directory where the Zarr file should be stored.
    - `file_id`: Identifier for the file, used in naming.

- **Returns**: A string report detailing the writing operation.

#### 2. `write_temp_data`

This function is designed to write temporary data based on the specified parameters to the output directory. It iteratively generates data sets based on shapes, sizes, densities, and number of runs, and writes each dataset to a Zarr format file using the `write_data_to_zarr` function.

- **Parameters**:
    - `shapes`: List of dataset shapes (e.g., "fat", "tall", "square").
    - `sizes`: List of dataset sizes.
    - `densities`: List of dataset densities.
    - `num_runs`: Number of iterations for data generation.
    - `outdir`: Directory where the Zarr files should be stored.
    - `rewrite`: Boolean flag; if True, any existing data in the output directory will be overwritten.

This function not only writes the datasets but also maintains a log of the datasets written in a file named "done.txt".




In [6]:
def write_data_to_zarr(X, shape_type, array_name, outdir, file_id):
    outfile = outdir / f"{file_id:02d}_{shape_type}_{array_name}.zarr"
    adata = create_adata(X)
    z = zarr.open_group(outfile, mode="w")
    write_elem(z, "/", adata)
    zarr.consolidate_metadata(z.store)
    return f"wrote {X.shape[0]}x{X.shape[1]}_{array_name} -> {str(outfile)}\n"


def write_temp_data(shapes, sizes, densities, num_runs, outdir, rewrite=False):
    outdir.mkdir(exist_ok=True)
    if rewrite:
        (outdir / "done.txt").unlink(missing_ok=True)
    if (outdir / "done.txt").exists():
        print("already done")
        with open(outdir / "done.txt", "r") as f:
            for line in f.readlines():
                print(line)
        return

    saved = []
    file_id = 1
    for _, shape_type, size, density in itertools.product(
        range(num_runs), shapes, sizes, densities
    ):
        array_funcs = array_creators(density)
        M, N = generate_dimensions(shape_type, size)

        X_base = sparse.random(M, N, density=density, format="csc")

        for array_name, array_func in array_funcs.items():
            X = array_func(X_base)
            report = write_data_to_zarr(X, shape_type, array_name, outdir, file_id)
            del X
            print(report, end="")
            saved.append(report)
            file_id += 1
    with open(outdir / "done.txt", "w") as f:
        f.writelines(saved)

In [7]:
# You can call the function like this:
write_temp_data(shapes, sizes, densities, num_runs, OUTDIR)

wrote 13747x10000_csc -> /var/folders/w4/rlbyb2md7y50tspf85v1lc440000gn/T/tmp75esna2p/01_fat_csc.zarr
wrote 13747x10000_csr -> /var/folders/w4/rlbyb2md7y50tspf85v1lc440000gn/T/tmp75esna2p/02_fat_csr.zarr
wrote 12361x10000_np -> /var/folders/w4/rlbyb2md7y50tspf85v1lc440000gn/T/tmp75esna2p/03_fat_np.zarr
wrote 10000x13069_csc -> /var/folders/w4/rlbyb2md7y50tspf85v1lc440000gn/T/tmp75esna2p/04_tall_csc.zarr
wrote 10000x13069_csr -> /var/folders/w4/rlbyb2md7y50tspf85v1lc440000gn/T/tmp75esna2p/05_tall_csr.zarr
wrote 10000x12903_np -> /var/folders/w4/rlbyb2md7y50tspf85v1lc440000gn/T/tmp75esna2p/06_tall_np.zarr
wrote 10000x10000_csc -> /var/folders/w4/rlbyb2md7y50tspf85v1lc440000gn/T/tmp75esna2p/07_square_csc.zarr
wrote 10000x10000_csr -> /var/folders/w4/rlbyb2md7y50tspf85v1lc440000gn/T/tmp75esna2p/08_square_csr.zarr
wrote 10000x10000_np -> /var/folders/w4/rlbyb2md7y50tspf85v1lc440000gn/T/tmp75esna2p/09_square_np.zarr
wrote 13588x10000_csc -> /var/folders/w4/rlbyb2md7y50tspf85v1lc440000gn/T/tm

### Putting our arrays in categories

The `create_datasets` function constructs a dictionary that maps dataset types (dense or sparse) and their axis (0 or 1) to a set of corresponding file paths. The function processes different file sets and, based on conditions like `requires_reindexing`, refines the set of file paths to be associated with each dataset type and axis combination. If there is reindexing required (i.e., datasets don't have the same size in the axis:`1-axis`) then a more costly concatenation strategy will have to be used compared to the case without reindexing. For this reason we will separate the tests that require reindexing and the ones that do not.


In [8]:
# files by properties
filesets = {
    "nps": set(OUTDIR.glob("*np*")),
    "csrs": set(OUTDIR.glob("*csr*")),
    "cscs": set(OUTDIR.glob("*csc*")),
    "fats": set(OUTDIR.glob("*fat*")),
    "talls": set(OUTDIR.glob("*tall*")),
    "squares": set(OUTDIR.glob("*square*")),
}

In [9]:
def create_datasets(filesets, requires_reindexing=False):
    data = dict()
    for fileset, axis in (("cscs", 1), ("csrs", 0), ("nps", 0), ("nps", 1)):
        filepaths = filesets[fileset].copy()
        if not requires_reindexing:
            tall_or_fat = filesets["talls"] if axis == 1 else filesets["fats"]
            filepaths = filepaths.intersection(tall_or_fat.union(filesets["squares"]))
        fileset_name = "dense" if fileset == "nps" else "sparse"
        data[fileset_name, axis] = filepaths
    return data

Below you can see the both the list of anndatas that would require reindexing when concatenating (i.e, their axis size don't match) and the ones who don't

In [10]:
datasets_aligned = create_datasets(filesets, requires_reindexing=False)
datasets_unaligned = create_datasets(filesets, requires_reindexing=True)

## Measuring Performance

### `get_arr_sizes`

This function calculates the size of the data arrays for a list of given file paths. It can accommodate both sparse and dense formats, adjusting the computation method accordingly.

---

### `get_mem_usage`

The function `get_mem_usage` evaluates the memory usage when performing on-disk concatenation using the `concat_on_disk` method. Depending on whether the dataset is sparse or dense, it either initiates a Dask cluster to handle the data or directly concatenates it. It returns the memory increment, the maximum memory used, the memory usage over time, and the initial memory.

---

### `dataset_max_mem`

The `dataset_max_mem` function profiles and prints the maximum memory usage when concatenating datasets of different types (sparse or dense) and along different axes. For each dataset and axis combination, it determines the files to concatenate, calculates their sizes, and then measures the memory usage during the concatenation process. The results are stored in a dictionary that maps the dataset type and axis to the corresponding memory usage metrics.


In [11]:
def get_arr_sizes(filepaths, is_sparse):
    def get_arr_size(g):
        if is_sparse:
            size = (
                g.store.getsize("X/data")
                + g.store.getsize("X/indices")
                + g.store.getsize("X/indptr")
            )
        else:
            size = g.store.getsize("X")
        return size

    return [get_arr_size(zarr.open_group(filepath)) for filepath in filepaths]


def get_mem_usage(filepaths, writepth, axis, max_arg, is_sparse):
    global dask_log
    concat_kwargs = {
        "in_files": filepaths,
        "out_file": writepth,
        "axis": axis,
    }
    tracer_kwargs = dict(trace_python_allocators=True, native_traces=True, follow_fork=True)
    if not is_sparse:
        cluster = LocalCluster(
            memory_limit=max_arg,
            silence_logs=dask_log,
        )
        client = Client(cluster)
    else:
        concat_kwargs["max_loaded_elems"] = max_arg

    for stat_file in OUTDIR.glob("*.memray"):
        stat_file.unlink()

    if not is_sparse:
        with (
            memray_workers(OUTDIR, report_args=False, **tracer_kwargs),
            memray_scheduler(OUTDIR, report_args=False, **tracer_kwargs),
        ):
            concat_on_disk(**concat_kwargs)
    else:
        with memray.Tracker(OUTDIR / "test-profile.memray", **tracer_kwargs):
            concat_on_disk(**concat_kwargs)

    max_mem = 0
    for stat_file in OUTDIR.glob("*.memray"):
        with memray.FileReader(stat_file) as reader:
            max_mem += reader.metadata.peak_memory
    
    if not is_sparse:
        client.shutdown()
        client.close()
        cluster.close()

    return max_mem


def dataset_max_mem(max_arg, datasets, array_type):
    results = {}
    is_sparse = array_type == "sparse"
    for filepaths, axis in [(datasets[array_type, axis], axis) for axis in [0, 1]]:
        writepth = OUTDIR / f"{array_type}_{axis}.zarr"
        if writepth.exists():
            shutil.rmtree(writepth)

        # print the files we are concatenating
        print("Dataset:", array_type, axis)
        print(f"Concatenating {len(filepaths)} files with sizes:")
        sizes = get_arr_sizes(filepaths, is_sparse)
        print([str(s // (2**20)) + "MiB" for s in sizes])
        print(f"Total size: {sum(sizes)//(2**20)}MiB")

        # force garbage collection
        gc.collect()
        # perform profiling
        mem_increment = get_mem_usage(filepaths, writepth, axis, max_arg, is_sparse)
        # force garbage collection again
        gc.collect()

        print("Concatenation finished")
        print("Peak Memory:", int(mem_increment) // (2**20), "MiB")
        print("--------------------------------------------------")
        results[array_type, axis] = mem_increment
    return results

## Results of concatenation without reindexing

In this section, we evaluate the memory performance of the `concat_on_disk` function when concatenating datasets **without** the need for reindexing. The printed reports provide details about the individual file sizes, the total dataset size, and the maximum memory increment during the concatenation.


### Sparse Datasets

For sparse datasets:

- We can observe that the function has been called multiple times with different memory constraints (`max_arg` values), and each time the datasets were concatenated successfully.
- It's crucial to note that even when the combined size of the files exceeds the allocated memory, the concatenation still proceeds efficiently. This behavior highlights the primary advantage of the `concat_on_disk` function: it performs the concatenation **on disk**, ensuring that memory consumption remains low, even for large datasets.
  
However, it's also worth noting that if one has sufficient memory to fit the files, performing the concatenation in memory would be faster.

### Dense Datasets

The results for dense datasets follow a similar pattern:

- The datasets are concatenated successfully under memory constraints.
- The total size of the dataset is much larger than the memory increment, reinforcing the efficiency of on-disk concatenation.


In [12]:
dataset_max_mem(max_arg=1_000_000_000, datasets=datasets_aligned, array_type="sparse");

Dataset: sparse 0
Concatenating 6 files with sizes:
['78MiB', '78MiB', '106MiB', '78MiB', '107MiB', '106MiB']
Total size: 556MiB
Concatenation finished
Peak Memory: 17 MiB
--------------------------------------------------
Dataset: sparse 1
Concatenating 6 files with sizes:
['101MiB', '106MiB', '78MiB', '78MiB', '78MiB', '102MiB']
Total size: 546MiB
Concatenation finished
Peak Memory: 16 MiB
--------------------------------------------------


In [13]:
dataset_max_mem(max_arg="2000MiB", datasets=datasets_aligned, array_type="dense");

Dataset: dense 0
Concatenating 6 files with sizes:
['668MiB', '827MiB', '668MiB', '920MiB', '875MiB', '668MiB']
Total size: 4630MiB
Concatenation finished
Peak Memory: 2740 MiB
--------------------------------------------------
Dataset: dense 1
Concatenating 6 files with sizes:
['912MiB', '823MiB', '668MiB', '668MiB', '668MiB', '864MiB']
Total size: 4606MiB
Concatenation finished
Peak Memory: 3450 MiB
--------------------------------------------------


## Results of concatenation with reindexing

This section presents the results of the `concat_on_disk` function when concatenating datasets that **require** reindexing.

The observations and interpretations for this section are similar to the ones mentioned for the "without reindexing" section. The primary difference is the datasets used for the concatenation. Once again, the on-disk concatenation allows for efficient memory usage, even when the datasets need reindexing.

One can also see the effect of the memory contrain on the measurements.

In [14]:
dataset_max_mem(max_arg=1_000_000_000, datasets=datasets_unaligned, array_type="sparse");

Dataset: sparse 0
Concatenating 9 files with sizes:
['78MiB', '78MiB', '106MiB', '102MiB', '106MiB', '102MiB', '78MiB', '107MiB', '106MiB']
Total size: 867MiB
Concatenation finished
Peak Memory: 286 MiB
--------------------------------------------------
Dataset: sparse 1
Concatenating 9 files with sizes:
['106MiB', '101MiB', '106MiB', '106MiB', '108MiB', '78MiB', '78MiB', '78MiB', '102MiB']
Total size: 867MiB
Concatenation finished
Peak Memory: 445 MiB
--------------------------------------------------


In [15]:
dataset_max_mem(max_arg="2000MiB", datasets=datasets_unaligned, array_type="dense");

Dataset: dense 0
Concatenating 9 files with sizes:
['912MiB', '823MiB', '827MiB', '668MiB', '668MiB', '920MiB', '875MiB', '668MiB', '864MiB']
Total size: 7230MiB
Concatenation finished
Peak Memory: 2931 MiB
--------------------------------------------------
Dataset: dense 1
Concatenating 9 files with sizes:
['912MiB', '823MiB', '827MiB', '668MiB', '668MiB', '920MiB', '875MiB', '668MiB', '864MiB']
Total size: 7230MiB
Concatenation finished
Peak Memory: 3152 MiB
--------------------------------------------------


## The effect of `max_loaded_elems` on performance
The parameter `max_loaded_elems` is used in very specific cases when the data is sparse and the concatenation requires reindexing. Ideally, for each concatenation element (i.e., file), the function would load the entire file into memory, reindex it, and then write it to disk. However, this is not always possible due to memory constraints. In such cases, the `max_loaded_elems` parameter is used to specify the maximum number of elements that can be loaded into memory at once. The function then iteratively loads the data, reindexes it, and writes it to disk. This process is repeated until all the data has been processed.

Given the dataset we have created, to observe the effect of this parameter, we would need to set the `max_loaded_elems` to a very small number. However, this would result in a very long concatenation process. Therefore, we will use a subset dataset to demonstrate the effect of this parameter.

Ideally, one would see the full benefits of this feature when the dataset has dissimilar sizes (e.g., a list consisting of 100 x 10mb + 2 x 1gb arrays). However, for the sake of simplicity, we will use a dataset with similar sizes. 

In [18]:
subset = {
    ("sparse", 0):  list(datasets_unaligned[("sparse", 0)])[:3],
    ("sparse", 1):  list(datasets_unaligned[("sparse", 1)])[:3],
}

In [19]:
dataset_max_mem(max_arg=10_000_000, datasets=subset, array_type="sparse");

Dataset: sparse 0
Concatenating 3 files with sizes:
['78MiB', '78MiB', '106MiB']
Total size: 263MiB
Concatenation finished
Peak Memory: 11 MiB
--------------------------------------------------
Dataset: sparse 1
Concatenating 3 files with sizes:
['106MiB', '101MiB', '106MiB']
Total size: 315MiB
Concatenation finished
Peak Memory: 38 MiB
--------------------------------------------------


In [20]:
dataset_max_mem(max_arg=1_000_000_000, datasets=subset, array_type="sparse");

Dataset: sparse 0
Concatenating 3 files with sizes:
['78MiB', '78MiB', '106MiB']
Total size: 263MiB
Concatenation finished
Peak Memory: 11 MiB
--------------------------------------------------
Dataset: sparse 1
Concatenating 3 files with sizes:
['106MiB', '101MiB', '106MiB']
Total size: 315MiB
Concatenation finished
Peak Memory: 432 MiB
--------------------------------------------------


## (Optional) Cleaning Up Temporary Files
After all is done with your tests on this notebook you can cleanup the created files.

In [None]:
TMPDIR.cleanup()