# On-Disk Concatenation of AnnData Files

**Author:** Selman Özleyen

## Initalizing

First let's do our imports and initalize adata objects with the help of the `adata_with_dask` function defined below.

This notebook uses the [memory-profiler](https://pypi.org/project/memory-profiler/) extension, call `pip install memory-profiler` before running this notebook.

In [1]:
from memory_profiler import memory_usage
import numpy as np
from scipy import sparse
import pandas as pd
import shutil
from anndata.tests.helpers import gen_typed_df
from anndata.experimental import write_elem
import zarr
import anndata
from pathlib import Path
import glob
import dask.distributed as dd

import anndata
import dask.array as da
import zarr
import gc
from anndata.experimental import concat_on_disk
from dask.distributed import Client, LocalCluster


OUTDIR = Path("tmpdata")


shapes = ["fat", "tall", "square"]
sizes = [10_000]
densities = [0.1, 1]
num_runs = 3


In [2]:


def create_adata(shape, X):
    M, N = shape
    obs_names = pd.Index(f"cell{i}" for i in range(shape[0]))
    var_names = pd.Index(f"gene{i}" for i in range(shape[1]))
    obs = gen_typed_df(M, obs_names)
    var = gen_typed_df(N, var_names)
    # For #147
    
    obs.rename(columns=dict(cat="obs_cat"), inplace=True)
    var.rename(columns=dict(cat="var_cat"), inplace=True)
    adata = anndata.AnnData(X, obs=obs, var=var)
    adata.var_names_make_unique()
    adata.obs_names_make_unique()
    
    return adata


In [3]:
def generate_array_funcs_and_names(density):
    array_funcs = []
    array_names = []
    is_dense = density == 1
    if is_dense:
        array_names.append("np")
        array_funcs.append(lambda x: x.toarray())
    else:
        array_names.extend(["csc", "csr"])
        array_funcs.extend([sparse.csc_matrix, sparse.csr_matrix])
    return array_funcs, array_names

def generate_dimensions(shape, size):
    M = size
    N = size
    if shape != "square":
        other_size = size + int(size * np.random.uniform(0.2, 0.4))
        if shape == "fat":
            M = other_size
        elif shape == "tall":
            N = other_size
    return M, N

def write_data_to_zarr(X, shape, array_name, outdir, file_id):
    fname = str(outdir) + f"/{file_id:02d}_{shape}_{array_name}"
    adata = create_adata((X.shape[0], X.shape[1]), X)
    output_zarr_path = f"{str(fname)}.zarr"
    z = zarr.open_group(output_zarr_path)
    write_elem(z, "/", adata)
    zarr.consolidate_metadata(z.store)
    return f"wrote {X.shape[0]}x{X.shape[1]}_{array_name} -> {fname}\n"

def write_temp_data(shapes, sizes, densities, num_runs, outdir, rewrite=False):
    outdir.mkdir(exist_ok=True)
    if rewrite:
        (outdir / "done.txt").unlink(missing_ok=True)
    if (outdir / "done.txt").exists():
        print("already done")
        with open(outdir / "done.txt", "r") as f:
            for line in f.readlines():
                print(line)
        return

    saved = []
    file_id = 1
    for _ in range(num_runs):
        for shape in shapes:
            for size in sizes:
                for density in densities:
                    array_funcs, array_names = generate_array_funcs_and_names(density)
                    M, N = generate_dimensions(shape, size)

                    X_base = sparse.random(M, N, density=density, format="csc")

                    for array_func, array_name in zip(array_funcs, array_names):
                        X = array_func(X_base)
                        report = write_data_to_zarr(X, shape, array_name, outdir, file_id)
                        print(report)
                        saved.append(report)
                        file_id += 1
    with open(outdir / "done.txt", "w") as f:
        f.writelines(saved)



In [4]:

# You can call the function like this:
write_temp_data(shapes, sizes, densities, num_runs, OUTDIR)


wrote 12941x10000_csc -> tmpdata/01_fat_csc

wrote 12941x10000_csr -> tmpdata/02_fat_csr

wrote 12289x10000_np -> tmpdata/03_fat_np

wrote 10000x12366_csc -> tmpdata/04_tall_csc

wrote 10000x12366_csr -> tmpdata/05_tall_csr

wrote 10000x13624_np -> tmpdata/06_tall_np

wrote 10000x10000_csc -> tmpdata/07_square_csc

wrote 10000x10000_csr -> tmpdata/08_square_csr

wrote 10000x10000_np -> tmpdata/09_square_np

wrote 13924x10000_csc -> tmpdata/10_fat_csc

wrote 13924x10000_csr -> tmpdata/11_fat_csr

wrote 13321x10000_np -> tmpdata/12_fat_np

wrote 10000x12377_csc -> tmpdata/13_tall_csc

wrote 10000x12377_csr -> tmpdata/14_tall_csr

wrote 10000x12595_np -> tmpdata/15_tall_np

wrote 10000x10000_csc -> tmpdata/16_square_csc

wrote 10000x10000_csr -> tmpdata/17_square_csr

wrote 10000x10000_np -> tmpdata/18_square_np

wrote 13778x10000_csc -> tmpdata/19_fat_csc

wrote 13778x10000_csr -> tmpdata/20_fat_csr

wrote 12484x10000_np -> tmpdata/21_fat_np

wrote 10000x13293_csc -> tmpdata/22_tall_csc


In [5]:
# files by properties
filesets = {
    'nps' : set(glob.glob(str(OUTDIR) + "/*np*")),
    'csrs' : set(glob.glob(str(OUTDIR) + "/*csr*")),
    'cscs' : set(glob.glob(str(OUTDIR) + "/*csc*")),
    'fats' : set(glob.glob(str(OUTDIR) + "/*fat*")),
    'talls' : set(glob.glob(str(OUTDIR) + "/*tall*")),
    'squares' :set(glob.glob(str(OUTDIR) + "/*square*")),
}



In [6]:

def create_datasets(filesets, requires_reindexing=False):
    data = dict()
    for fileset, axis in (("cscs",1), ("csrs",0), ("nps",0), ("nps",1)):
        filepaths = filesets[fileset].copy()
        if not requires_reindexing:
            tall_or_fat = filesets['talls'] if axis == 1 else filesets['fats']
            filepaths = filepaths.intersection(tall_or_fat.union(filesets['squares']))
        fileset_name = "dense" if fileset == "nps" else "sparse"
        data[fileset_name, axis] = filepaths
    return data    

In [7]:
datasets = create_datasets(filesets, requires_reindexing=True)
datasets


{('sparse', 1): {'tmpdata/01_fat_csc.zarr',
  'tmpdata/04_tall_csc.zarr',
  'tmpdata/07_square_csc.zarr',
  'tmpdata/10_fat_csc.zarr',
  'tmpdata/13_tall_csc.zarr',
  'tmpdata/16_square_csc.zarr',
  'tmpdata/19_fat_csc.zarr',
  'tmpdata/22_tall_csc.zarr',
  'tmpdata/25_square_csc.zarr'},
 ('sparse', 0): {'tmpdata/02_fat_csr.zarr',
  'tmpdata/05_tall_csr.zarr',
  'tmpdata/08_square_csr.zarr',
  'tmpdata/11_fat_csr.zarr',
  'tmpdata/14_tall_csr.zarr',
  'tmpdata/17_square_csr.zarr',
  'tmpdata/20_fat_csr.zarr',
  'tmpdata/23_tall_csr.zarr',
  'tmpdata/26_square_csr.zarr'},
 ('dense', 0): {'tmpdata/03_fat_np.zarr',
  'tmpdata/06_tall_np.zarr',
  'tmpdata/09_square_np.zarr',
  'tmpdata/12_fat_np.zarr',
  'tmpdata/15_tall_np.zarr',
  'tmpdata/18_square_np.zarr',
  'tmpdata/21_fat_np.zarr',
  'tmpdata/24_tall_np.zarr',
  'tmpdata/27_square_np.zarr'},
 ('dense', 1): {'tmpdata/03_fat_np.zarr',
  'tmpdata/06_tall_np.zarr',
  'tmpdata/09_square_np.zarr',
  'tmpdata/12_fat_np.zarr',
  'tmpdata/15

In [8]:
def get_dense_mem_usage(
    filepaths=None,
    writepth=None,
    axis=None,
    max_arg="600MiB",
):

    cluster = LocalCluster(n_workers=1, threads_per_worker=1, memory_limit=max_arg)
    client = Client(cluster)

    # get the current memory usage
    initial_mem = memory_usage(-1, interval=0.001)[0]

    mem_usages = memory_usage(
        (
            concat_on_disk,
            (),
            {
                "in_files": filepaths,
                "out_file": writepth,
                "axis": axis,
                "index_unique": "-",
            },
        ),
        include_children=True,
        interval=0.001,
    )
    max_mem = max(mem_usages)
    mem_increment = max_mem - initial_mem
    
    client.close()
    cluster.close()
    return mem_increment, max_mem, mem_usages, initial_mem



In [9]:


def get_arr_sizes(array_type, filepaths):
    res = []
    for f in filepaths:
        store = zarr.open_group(f).store
        additional_size = 0
        if array_type == 'sparse':
            additional_size = store.getsize('X/data')+store.getsize('X/indices')+store.getsize('X/indptr')
        res.append(store.getsize('X')+additional_size)
    return res

def get_sparse_mem_usage(filepaths, writepth, axis, max_arg):
    # get the current memory usage
    initial_mem = memory_usage(-1, interval=0.001)[0]

    mem_usages = memory_usage(
        (
            concat_on_disk,
            (),
            {
                "in_files": filepaths,
                "out_file": writepth,
                "axis": axis,
                "max_loaded_elems": max_arg,
            },
        ),
        include_children=True,
        interval=0.001,
    )
    max_mem = max(mem_usages)
    mem_increment = max_mem - initial_mem
    
    return mem_increment, max_mem, mem_usages, initial_mem



In [10]:

def dataset_max_mem(max_arg, datasets, array_type):
    results = {}

    for filepaths,axis in [(datasets[array_type,axis],axis) for axis in [0,1]]:
        writepth = OUTDIR / f"{array_type}_{axis}.zarr"
        if writepth.exists():
            shutil.rmtree(writepth)

        # print the files we are concatenating
        print("Dataset:", array_type, axis)
        print(f"Concatenating {len(filepaths)} files with sizes:")
        sizes = get_arr_sizes(array_type, filepaths)
        print([str(s//(2**20))+'MiB' for s in sizes])
        print(f"Total size: {sum(sizes)//(2**20)}MiB")
        
        mem_usage_func = get_sparse_mem_usage if array_type == 'sparse' else get_dense_mem_usage

        # force garbage collection
        gc.collect()
        # perform profiling
        mem_increment, max_mem, mem_usages, initial_mem = mem_usage_func(filepaths, writepth, axis, max_arg)
        # force garbage collection again
        gc.collect()

        print("Concatenation finished")
        print("Max memory increase:", int(mem_increment), "MiB")
        print("--------------------------------------------------")
        results[array_type, axis] = {"max_mem": max_mem, "increment": mem_increment}
    return results


In [11]:
dataset_max_mem(max_arg=1_000_000_000, datasets=datasets, array_type='sparse')

Dataset: sparse 0
Concatenating 9 files with sizes:
['104MiB', '97MiB', '101MiB', '78MiB', '97MiB', '78MiB', '78MiB', '108MiB', '109MiB']
Total size: 853MiB
Concatenation finished
Max memory increase: 426 MiB
--------------------------------------------------
Dataset: sparse 1
Concatenating 9 files with sizes:
['78MiB', '108MiB', '104MiB', '101MiB', '78MiB', '97MiB', '78MiB', '97MiB', '109MiB']
Total size: 853MiB
Concatenation finished
Max memory increase: 590 MiB
--------------------------------------------------


{('sparse', 0): {'max_mem': 577.796875, 'increment': 426.50390625},
 ('sparse', 1): {'max_mem': 756.609375, 'increment': 590.16015625}}

In [12]:
dataset_max_mem(max_arg=100_000_000, datasets=datasets, array_type='sparse')

Dataset: sparse 0
Concatenating 9 files with sizes:
['104MiB', '97MiB', '101MiB', '78MiB', '97MiB', '78MiB', '78MiB', '108MiB', '109MiB']
Total size: 853MiB
Concatenation finished
Max memory increase: 427 MiB
--------------------------------------------------
Dataset: sparse 1
Concatenating 9 files with sizes:
['78MiB', '108MiB', '104MiB', '101MiB', '78MiB', '97MiB', '78MiB', '97MiB', '109MiB']
Total size: 853MiB
Concatenation finished
Max memory increase: 589 MiB
--------------------------------------------------


{('sparse', 0): {'max_mem': 596.15625, 'increment': 427.5},
 ('sparse', 1): {'max_mem': 759.921875, 'increment': 589.625}}

In [13]:
dataset_max_mem(max_arg="1000MiB", datasets=datasets, array_type='dense')

Dataset: dense 0
Concatenating 9 files with sizes:
['910MiB', '668MiB', '923MiB', '891MiB', '668MiB', '822MiB', '835MiB', '843MiB', '668MiB']
Total size: 7233MiB
Concatenation finished
Max memory increase: 1023 MiB
--------------------------------------------------
Dataset: dense 1
Concatenating 9 files with sizes:
['910MiB', '668MiB', '923MiB', '891MiB', '668MiB', '822MiB', '835MiB', '843MiB', '668MiB']
Total size: 7233MiB
Concatenation finished
Max memory increase: 967 MiB
--------------------------------------------------


{('dense', 0): {'max_mem': 1210.96875, 'increment': 1023.703125},
 ('dense', 1): {'max_mem': 1163.109375, 'increment': 967.73046875}}