# Sgkit: benchmarking Numba vs JAX

We have used Numba to accelerate CPU computations in sgkit for a long time. This notebook is an experiment to compare the performance of some basic allele counting code that uses Numba with the equivalent in JAX.

In [1]:
import time
import sgkit as sg

## Numba

We'll start by running the Numba function `count_call_alleles` from sgkit, over a variety of different sized datasets.

In [2]:
from sgkit import count_call_alleles as count_call_alleles_numba

In [3]:
matrix = [(10000, 100), (100000, 1000), (20000, 5000), (10000, 10000)]

In [4]:
for n_variant, n_sample in matrix:
    ds = sg.simulate_genotype_call_dataset(
        n_variant=n_variant, n_sample=n_sample, missing_pct=0.01
    )
    ds = ds.chunk({"variants": 10000, "samples": 1000})

    ds = count_call_alleles_numba(ds)
    start = time.time()
    ds = ds.load()
    end = time.time()

    print(f"n_variant: {n_variant}, n_sample: {n_sample}, time: {end - start}")

n_variant: 10000, n_sample: 100, time: 0.0037958621978759766
n_variant: 100000, n_sample: 1000, time: 0.1375119686126709
n_variant: 20000, n_sample: 5000, time: 0.20614981651306152
n_variant: 10000, n_sample: 10000, time: 0.22558307647705078


## JAX

For JAX we have to implement the equivalent of `count_call_alleles`. JAX provides its own version of the NumPy API, so we can use `bincount` to implement the inner function that operates on a single dimension of the `call_genotype` array. Note that this is different to Numba where we can write loops that operate directly on arrays.

In [5]:
import jax
import jax.numpy as jnp

In [6]:
def count_alleles_jax(g):
    # jax bincount will clip to 0, so we add 2 (then truncate the array)
    # so that we drop counts for -1 and -2 (missing or non-allele)
    n_alleles = 2  # we have hardcoded this for the moment, see https://jax.readthedocs.io/en/latest/jit-compilation.html#marking-arguments-as-static 
    counts = jnp.bincount(g + 2, length=n_alleles + 2)
    counts = counts[2:]
    return jnp.astype(counts, jnp.uint8)

The user-level function is very similar to the Numba version. The main difference is that we use JAX's `vmap` and `jit` functions to vectorize and compile the `count_alleles_jax` function.

In [7]:
from typing import Hashable

import dask.array as da
import numpy as np
import xarray as xr
from typing_extensions import Literal
from xarray import Dataset

from sgkit import variables
from sgkit.utils import conditional_merge_datasets, create_dataset

In [8]:
def count_call_alleles_jax(
    ds: Dataset,
    *,
    call_genotype: Hashable = variables.call_genotype,
    merge: bool = True,
) -> Dataset:
    variables.validate(ds, {call_genotype: variables.call_genotype_spec})
    n_alleles = ds.sizes["alleles"]
    G = da.asarray(ds[call_genotype])
    if G.numblocks[2] > 1:
        raise ValueError(
            f"Variable {call_genotype} must have only a single chunk in the ploidy dimension. "
            "Consider rechunking to change the size of chunks."
        )
    shape = (G.chunks[0], G.chunks[1], n_alleles)

    # call vmap twice to vectorize over first two dimensions (variants, samples)
    count_alleles_vectorized = jax.vmap(jax.vmap(count_alleles_jax))

    # jit compile
    count_alleles_vectorized_jit = jax.jit(count_alleles_vectorized)

    # precompile...
    count_alleles_vectorized_jit(np.ones((4, 4, 2), dtype=np.int8)).block_until_ready()

    new_ds = create_dataset(
        {
            variables.call_allele_count: (
                ("variants", "samples", "alleles"),
                da.map_blocks(
                    count_alleles_vectorized_jit,
                    G,
                    chunks=shape,
                    dtype=np.uint8,
                    drop_axis=2,
                    new_axis=2,
                ),
            )
        }
    )
    return conditional_merge_datasets(ds, new_ds, merge)

In [9]:
for n_variant, n_sample in matrix:
    ds = sg.simulate_genotype_call_dataset(
        n_variant=n_variant, n_sample=n_sample, missing_pct=0.01
    )
    ds = ds.chunk({"variants": 10000, "samples": 1000})

    ds = count_call_alleles_jax(ds)
    start = time.time()
    ds = ds.load()
    end = time.time()

    print(f"n_variant: {n_variant}, n_sample: {n_sample}, time: {end - start}")

n_variant: 10000, n_sample: 100, time: 0.028610944747924805
n_variant: 100000, n_sample: 1000, time: 3.8228919506073
n_variant: 20000, n_sample: 5000, time: 3.199921131134033
n_variant: 10000, n_sample: 10000, time: 3.4262049198150635


The JAX version is a lot slower than Numba - over an order of magnitude slower for the last three results.

I also tried running the `bincount` code using regular NumPy and it was around 100 times slower than JAX. This tells us that both Numba and JAX both provide massive performance improvements compared to NumPy.

But it's not clear if there is something wrong with the JAX code or whether it can't do as well as Numba for this problem.

## LAX

JAX has some lower-level primitives in the LAX module that might be suitable for this problem. In particular, `jax.lax.scan` can be used for implementing a `count_alleles` function without using NumPy operations. Would this be more efficient?

In [10]:
from jax import lax

def _count_alleles(res, el):
    res = res.at[el].add(1)
    return res, None

def count_alleles_lax(g, out):
    counts, _ = lax.scan(_count_alleles, out, g)
    return counts

Note that we pass in the output array like Numba does, rather than allocating it in the loop (like `jax.numpy` does above).

In [11]:
def count_call_alleles_lax(
    ds: Dataset,
    *,
    call_genotype: Hashable = variables.call_genotype,
    merge: bool = True,
) -> Dataset:
    variables.validate(ds, {call_genotype: variables.call_genotype_spec})
    n_alleles = ds.sizes["alleles"]
    G = da.asarray(ds[call_genotype])
    if G.numblocks[2] > 1:
        raise ValueError(
            f"Variable {call_genotype} must have only a single chunk in the ploidy dimension. "
            "Consider rechunking to change the size of chunks."
        )
    shape = (G.chunks[0], G.chunks[1], n_alleles)

    # call vmap twice to vectorize over first two dimensions (variants, samples)
    count_alleles_vectorized = jax.vmap(jax.vmap(count_alleles_lax))

    # jit compile
    count_alleles_vectorized_jit = jax.jit(count_alleles_vectorized)

    # precompile...
    count_alleles_vectorized_jit(np.ones((4, 4, 2), dtype=np.int8), np.zeros((4, 4, 2), dtype=np.int8)).block_until_ready()

    N = np.empty((G.chunks[0][0], G.chunks[1][0], n_alleles), dtype=np.uint8)
    new_ds = create_dataset(
        {
            variables.call_allele_count: (
                ("variants", "samples", "alleles"),
                da.map_blocks(
                    count_alleles_vectorized_jit,
                    G,
                    N,
                    chunks=shape,
                    dtype=np.uint8,
                    drop_axis=2,
                    new_axis=2,
                ),
            )
        }
    )
    return conditional_merge_datasets(ds, new_ds, merge)

In [12]:
for n_variant, n_sample in matrix:
    ds = sg.simulate_genotype_call_dataset(
        n_variant=n_variant, n_sample=n_sample, missing_pct=0.0
    )
    ds = ds.chunk({"variants": 10000, "samples": 1000})

    ds = count_call_alleles_lax(ds)
    start = time.time()
    ds = ds.load()
    end = time.time()

    print(f"n_variant: {n_variant}, n_sample: {n_sample}, time: {end - start}")

n_variant: 10000, n_sample: 100, time: 0.02508997917175293
n_variant: 100000, n_sample: 1000, time: 2.639747142791748
n_variant: 20000, n_sample: 5000, time: 2.751932144165039
n_variant: 10000, n_sample: 10000, time: 2.7614219188690186


The LAX version is a bit faster than the JAX NumPy version - but not much, and is still around an order of magnitude slower than Numba.

In [13]:
! pip freeze

  pid, fd = os.forkpty()


aiohappyeyeballs==2.4.0
aiohttp==3.10.5
aiosignal==1.3.1
anyio==4.6.0
appnope==0.1.4
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asciitree==0.3.3
asttokens==2.4.1
asv==0.6.4
asv_runner==0.2.1
async-lru==2.0.4
attrs==24.2.0
babel==2.16.0
beautifulsoup4==4.12.3
bed-reader==1.0.5
bleach==6.1.0
bokeh==3.5.2
build==1.2.2
callee==0.3.1
certifi==2024.8.30
cffi==1.17.1
cfgv==3.4.0
charset-normalizer==3.3.2
click==8.1.7
cloudpickle==3.0.0
coloredlogs==15.0.1
comm==0.2.2
contourpy==1.3.0
coverage==7.6.1
cycler==0.12.1
cyvcf2==0.31.1
dask==2024.8.0
dask-expr==1.1.10
dask-glm==0.3.2
dask-ml==2024.4.4
debugpy==1.8.5
decorator==5.1.1
defusedxml==0.7.1
distlib==0.3.8
distributed==2024.8.0
executing==2.1.0
fasteners==0.19
fastjsonschema==2.20.0
filelock==3.16.1
fonttools==4.53.1
fqdn==1.5.1
frozenlist==1.4.1
fsspec==2024.9.0
graphviz==0.20.3
h11==0.14.0
httpcore==1.0.5
httpx==0.27.2
humanfriendly==10.0
hypothesis==6.112.1
identi