## Comparison of CPU and GPU accelerated Genotype counting

Firstly here is a CPU baseline for the computation of call_genotype statistics. This is a simplified version of the code in zarr_afdist.py.

In this case, the data (int8) that is being loaded is comprised of 26504 variants across 2504 diploid samples.

The data is loaded using Zarr using an Intel Xeon Silver 4216 x 64 CPU and Samsung NVMe drive (SSD 970 EVO Plus 2TB)

In [7]:
import dataclasses

import numpy as np
import pandas as pd
import numba
import zarr
import numcodecs
import math
import time

dpath = "/datasets/sgkit/vcf-zarr-publication/"

@numba.njit("void(int64, int8[:, :, :], int32[:], int32[:], int32[:], int32[:])")
def count_genotypes_chunk(offset, G, hom_ref, hom_alt, het, ref_count):
    # NB Assuming diploids and no missing data!
    index = offset
    for j in range(G.shape[0]):
        for k in range(G.shape[1]):
            a = G[j, k, 0]
            b = G[j, k, 1]
            if a == b:
                if a == 0:
                    hom_ref[index] += 1
                else:
                    hom_alt[index] += 1
            else:
                het[index] += 1
            ref_count[index] += (a == 0) + (b == 0)
        index += 1


@dataclasses.dataclass
class GenotypeCounts:
    hom_ref: list
    hom_alt: list
    het: list
    ref_count: list

def classify_genotypes(call_genotype):
    m = call_genotype.shape[0]

    het = np.zeros(m, dtype=np.int32)
    hom_alt = np.zeros(m, dtype=np.int32)
    hom_ref = np.zeros(m, dtype=np.int32)
    ref_count = np.zeros(m, dtype=np.int32)
    j = 0
    t1 = 0
    
    for v_chunk in range(call_genotype.cdata_shape[0]):
        for s_chunk in range(call_genotype.cdata_shape[1]):
            G = call_genotype.blocks[v_chunk, s_chunk]
            t = time.time()
            count_genotypes_chunk(j, G, hom_ref, hom_alt, het, ref_count)
            t1 += (time.time()-t)
        j += G.shape[0]
   
    return GenotypeCounts(hom_ref, hom_alt, het, ref_count), t1

def zarr_afdist(root, num_bins=10, variant_slice=None, sample_slice=None):
    call_genotype = root["call_genotype"]
    m = call_genotype.shape[0]
    n = call_genotype.shape[1]

    counts, t = classify_genotypes(call_genotype)

    alt_count = 2 * n - counts.ref_count
    af = alt_count / (n * 2)
    bins = np.linspace(0, 1.0, num_bins + 1)
    bins[-1] += 0.0125
    pRA = 2 * af * (1 - af)
    pAA = af * af

    a = np.bincount(np.digitize(pRA, bins), weights=counts.het, minlength=num_bins + 1)
    b = np.bincount(
        np.digitize(pAA, bins), weights=counts.hom_alt, minlength=num_bins + 1
    )

    count = (a + b).astype(int)

    return pd.DataFrame({"start": bins[:-1], "stop": bins[1:], "prob_dist": count[1:]}), t

z_root = zarr.open(dpath + "real_data/data/WGS/chr22.zarr",mode='r')

# warm up
zarr_afdist(z_root)
gc, t = zarr_afdist(z_root)

print(gc)
print("Summarisation Time (ms) = {:.3f}".format(t*1000))

   start    stop  prob_dist
0    0.0  0.1000    7081263
1    0.1  0.2000    2212177
2    0.2  0.3000    2096409
3    0.3  0.4000    3027269
4    0.4  0.5000   15854472
5    0.5  0.6000    1769666
6    0.6  0.7000     276639
7    0.7  0.8000     223031
8    0.8  0.9000     112195
9    0.9  1.0125     488105
Summarisation Time (ms) = 1457.252


Note that what is measured in "Summarisation Time" is the time to compute the summarisation only (i.e. no loading the data)

Next, the same computation is done using a cupy array and a cupy kernel. cupy arrays are functionally the same as numpy arrays, but reside within GPU memory.

The GPU used is an Nvidia A6000 RTX.

In [8]:
import cupy as cp
import time
from cupyx import jit

start_gpu = cp.cuda.Event()
end_gpu = cp.cuda.Event()

threads_per_block = 256

@jit.rawkernel()
def cu_count_genotypes_chunk(offset, G, hom_ref, hom_alt, het, ref_count, vs_size):
    # NB Assuming diploids and no missing data!
    idx = jit.blockIdx.x * jit.blockDim.x + jit.threadIdx.x
    
    variant_idx = idx + offset[0]

    if idx<vs_size[0]:
        for k in range(vs_size[1]):
            a = G[idx, k, 0]
            b = G[idx, k, 1]
            if a == b:
                if a == 0:
                    hom_ref[variant_idx] += 1
                else:
                    hom_alt[variant_idx] += 1
            else:
                het[variant_idx] += 1

            ref_count[variant_idx] += (a == 0)
            ref_count[variant_idx] += (b == 0)

@dataclasses.dataclass
class cu_GenotypeCounts:
    hom_ref: 'np.ndarray[np.int32]'
    hom_alt: 'np.ndarray[np.int32]'
    het: 'np.ndarray[np.int32]'
    ref_count: 'np.ndarray[np.int32]'
    
def cu_classify_genotypes_chunked(call_genotype):
    m = call_genotype.shape[0]

    cu_het = cp.zeros(m, dtype=np.int32)
    cu_hom_alt = cp.zeros(m, dtype=np.int32)
    cu_hom_ref = cp.zeros(m, dtype=np.int32)
    cu_ref_count = cp.zeros(m, dtype=np.int32)
    t_compute=0
    t_transfer=0
    
    j = cp.zeros(1,dtype=cp.uint32)
    for v_chunk in range(call_genotype.cdata_shape[0]):
        for s_chunk in range(call_genotype.cdata_shape[1]):
            # measure transfer time
            G = call_genotype.blocks[v_chunk, s_chunk]
            t = time.time()
            cu_G = cp.array(G)
            t_transfer += (time.time()-t)
            vs_size = cp.zeros(2,cp.uint32)
            vs_size[0] = cu_G.shape[0]
            vs_size[1] = cu_G.shape[1]
            start_gpu.record()
            cu_count_genotypes_chunk[math.ceil(vs_size[0]/threads_per_block),threads_per_block](j, cu_G, cu_hom_ref, cu_hom_alt, cu_het, cu_ref_count,vs_size)
            end_gpu.record()
            end_gpu.synchronize()
            t_compute += cp.cuda.get_elapsed_time(start_gpu, end_gpu)
        j += cu_G.shape[0]
        
    return cu_GenotypeCounts(cu_hom_ref, cu_hom_alt, cu_het, cu_ref_count), t_compute, t_transfer

def cu_zarr_afdist(root, num_bins=10, variant_slice=None, sample_slice=None):
    call_genotype = root["call_genotype"]

    n = call_genotype.shape[1]
    counts, t1, t2 = cu_classify_genotypes_chunked(call_genotype)
    alt_count = 2 * n - counts.ref_count
    af = alt_count / (n * 2)
    bins = cp.linspace(0, 1.0, num_bins + 1)
    bins[-1] += 0.0125
    pRA = 2 * af * (1 - af)
    pAA = af * af
    x = cp.digitize(cp.array(pRA).astype(np.float32), cp.array(bins).astype(np.float32))
    a = cp.bincount(x, weights=counts.het, minlength=num_bins + 1)
    y = cp.digitize(cp.array(pAA).astype(np.float32), cp.array(bins).astype(np.float32))
    b = cp.bincount(
        y, weights=counts.hom_alt, minlength=num_bins + 1
    )
    count = (a + b).astype(int)

    return pd.DataFrame({"start": bins[:-1].get(), "stop": bins[1:].get(), "prob_dist": count[1:].get()}), t1, t2


z_root = zarr.open(dpath + "real_data/data/WGS/chr22.zarr",mode='r')

# warm up
cu_zarr_afdist(z_root)
gc, t1, t2 = cu_zarr_afdist(z_root)

print(gc)
print("Summarisation Time (ms) = {:.3f}".format(t1))
print("Host to GPU Transfer Time (ms) = {:.3f}".format(t2*1000))

  cupy._util.experimental('cupyx.jit.rawkernel')


   start    stop  prob_dist
0    0.0  0.1000    7081263
1    0.1  0.2000    2212177
2    0.2  0.3000    2096409
3    0.3  0.4000    3027269
4    0.4  0.5000   15854472
5    0.5  0.6000    1769666
6    0.6  0.7000     276639
7    0.7  0.8000     223031
8    0.8  0.9000     112195
9    0.9  1.0125     488105
Summarisation Time (ms) = 24.684
Host to GPU Transfer Time (ms) = 179.458


This simple kernel produces a modest speed-up, but there is more that can be done to optimise the cupy kernel.

Next, the data is passed into the kernel in one large chunk and shared memory is used to cache interim results

In [9]:
import cupy as cp
import time
from cupyx import jit

threads_per_block = 64

@dataclasses.dataclass
class cu_GenotypeCounts:
    hom_ref: 'np.ndarray[np.int32]'
    hom_alt: 'np.ndarray[np.int32]'
    het: 'np.ndarray[np.int32]'
    ref_count: 'np.ndarray[np.int32]'


@jit.rawkernel()
def cu_count_genotypes_chunk(variant_offset, G, hom_ref, hom_alt, het, ref_count,v_size,s_size):
    # NB Assuming diploids and no missing data!
    thread_idx = jit.threadIdx.x
    grid_idx = jit.blockIdx.x * jit.blockDim.x + thread_idx
    sm_offset = thread_idx*4
    
    variant_idx = grid_idx + variant_offset[0]

    if grid_idx<v_size:
        for k in range(s_size):
            a = G[grid_idx, k, 0]
            b = G[grid_idx, k, 1]
            
            if a == b:
                if a == 0:
                    hom_ref[variant_idx] += 1
                else:
                    hom_alt[variant_idx] += 1
            else:
                het[variant_idx] += 1
            ref_count[variant_idx] += (a == 0)
            ref_count[variant_idx] += (b == 0)
    
                    
def cu_classify_genotypes(call_genotype):
    m = call_genotype.shape[0]

    cu_het = cp.zeros(m, dtype=np.int32)
    cu_hom_alt = cp.zeros(m, dtype=np.int32)
    cu_hom_ref = cp.zeros(m, dtype=np.int32)
    cu_ref_count = cp.zeros(m, dtype=np.int32)
    G = call_genotype[:]
    t=time.time()
    cu_arr = cp.array(G)
    t_transfer = time.time() - t
    
    j = cp.zeros(1,dtype=cp.uint32)
    v_size=call_genotype.shape[0]
    s_size=call_genotype.shape[1]

    start_gpu.record()
    cu_count_genotypes_chunk[math.ceil(v_size/threads_per_block),threads_per_block](j, cu_arr, cu_hom_ref, cu_hom_alt, cu_het, cu_ref_count,v_size,s_size)
    end_gpu.record()
    end_gpu.synchronize()
    t_compute = cp.cuda.get_elapsed_time(start_gpu, end_gpu)

    return cu_GenotypeCounts(cu_hom_ref, cu_hom_alt, cu_het, cu_ref_count), t_compute, t_transfer

def cu_zarr_afdist(root, num_bins=10, variant_slice=None, sample_slice=None):
    call_genotype = root["call_genotype"]

    n = call_genotype.shape[1]
   
    counts, tc, tt = cu_classify_genotypes(call_genotype)
    
    alt_count = 2 * n - counts.ref_count
    af = alt_count / (n * 2)
    bins = cp.linspace(0, 1.0, num_bins + 1)
    bins[-1] += 0.0125
    pRA = 2 * af * (1 - af)
    pAA = af * af
    x = cp.digitize(cp.array(pRA).astype(np.float32), cp.array(bins).astype(np.float32))
    a = cp.bincount(x, weights=counts.het, minlength=num_bins + 1)
    y = cp.digitize(cp.array(pAA).astype(np.float32), cp.array(bins).astype(np.float32))
    b = cp.bincount(
        y, weights=counts.hom_alt, minlength=num_bins + 1
    )
    count = (a + b).astype(int)

    return pd.DataFrame({"start": bins[:-1].get(), "stop": bins[1:].get(), "prob_dist": count[1:].get()}), tc, tt


z_root = zarr.open(dpath + "real_data/data/WGS/chr22.zarr",mode='r')

#warm up
cu_zarr_afdist(z_root)
gc, t1, t2 = cu_zarr_afdist(z_root)

print(gc)
print("Summarisation Time (ms) = {:.3f}".format(t1))
print("Host to GPU Transfer Time (ms) = {:.3f}".format(t2*1000))

  cupy._util.experimental('cupyx.jit.rawkernel')


   start    stop  prob_dist
0    0.0  0.1000    7081263
1    0.1  0.2000    2212177
2    0.2  0.3000    2096409
3    0.3  0.4000    3027269
4    0.4  0.5000   15854472
5    0.5  0.6000    1769666
6    0.6  0.7000     276639
7    0.7  0.8000     223031
8    0.8  0.9000     112195
9    0.9  1.0125     488105
Summarisation Time (ms) = 21.120
Host to GPU Transfer Time (ms) = 135.895


Again, an improvement in latency is acheived. However, each instance of the kernel is having to iterate over all samples so, to get the full benefit of the GPUs threads, the kernel laucnh is set to create as many threads as there are combinations of variant and sample. The GPU's scheduler will then find the most efficient way of allocating the computation to free threads.

In this example, an abstraction is used to map the 2 dimensions of the array to be summarised to 2 dimensions of threads. This makes the kernel a little simpler to write, since it can be thought of as variants in the x dimension and samples in the y dimension.

In [10]:
import time
from cupyx.jit import atomic_add
from cupyx import jit
import cupy as cp

# This value can be changed. 
threads_per_block = 8

@jit.rawkernel()
def cu_count_genotypes_chunk(G, hom_ref, hom_alt, het, ref_count,vs_size):
    # Get the index of the current thread within a 2D block
    thread_idx_x = jit.threadIdx.x # The x dimension is along the variant axis
    thread_idx_y = jit.threadIdx.y # The y dimension is along the sample axis
    
    # The combination of block size, block index and thead index provide the variant and sample indices
    variant_idx = jit.blockIdx.x * jit.blockDim.x + thread_idx_x 
    sample_idx = jit.blockIdx.y * jit.blockDim.y + thread_idx_y
    
    # because we round the thread block size up to the problem size, check the bounds
    if variant_idx<vs_size[0]:
        if sample_idx < vs_size[1]:
            a = G[variant_idx, sample_idx, 0]
            b = G[variant_idx, sample_idx, 1]
            
            if a == b:
                if a == 0:
                    atomic_add(hom_ref,variant_idx,1)
                else:
                    atomic_add(hom_alt,variant_idx,1)
            else:
                atomic_add(het,variant_idx,1)

            atomic_add(ref_count,variant_idx,(a == 0))
            atomic_add(ref_count,variant_idx,(b == 0))
                    
def cu_classify_genotypes(call_genotype):
    m = call_genotype.shape[0]

    # allocate cupy arrays for the results
    cu_het = cp.zeros(m, dtype=np.int32)
    cu_hom_alt = cp.zeros(m, dtype=np.int32)
    cu_hom_ref = cp.zeros(m, dtype=np.int32)
    cu_ref_count = cp.zeros(m, dtype=np.int32)

    # use a cupy array to pass in the array sizes to the cuda kernel
    G = call_genotype[:]
    t = time.time()
    cu_G = cp.array(G)
    t_transfer = time.time()-t
    vs_size = cp.zeros(2,cp.uint32)
    vs_size[0]=cu_G.shape[0]
    vs_size[1]=cu_G.shape[1]
    
    # Set the 'grid' size to the number of elements in the two array dimensions: variants and samples
    grid_size = (math.ceil(vs_size[0]/threads_per_block),math.ceil(vs_size[1]/threads_per_block))
    block_size = (threads_per_block,threads_per_block)
    
    start_gpu.record()
    # call the cupyx kernel, with the grid size setting the number of instances of the kernel to use
    cu_count_genotypes_chunk[grid_size,block_size](cu_G, cu_hom_ref, cu_hom_alt, cu_het, cu_ref_count,vs_size)
    end_gpu.record()
    end_gpu.synchronize()
    t_compute = cp.cuda.get_elapsed_time(start_gpu, end_gpu)
    
    return cu_GenotypeCounts(cu_hom_ref, cu_hom_alt, cu_het, cu_ref_count), t_compute, t_transfer

def cu_zarr_afdist(root, num_bins=10, variant_slice=None, sample_slice=None):
    
    call_genotype = root["call_genotype"]

    n = call_genotype.shape[1]
    counts, t1, t2 = cu_classify_genotypes(call_genotype)
    
    alt_count = 2 * n - counts.ref_count
    af = alt_count / (n * 2)
    bins = cp.linspace(0, 1.0, num_bins + 1)
    bins[-1] += 0.0125
    pRA = 2 * af * (1 - af)
    pAA = af * af
    x = cp.digitize(cp.array(pRA).astype(np.float32), cp.array(bins).astype(np.float32))
    a = cp.bincount(x, weights=counts.het, minlength=num_bins + 1)
    y = cp.digitize(cp.array(pAA).astype(np.float32), cp.array(bins).astype(np.float32))
    b = cp.bincount(
        y, weights=counts.hom_alt, minlength=num_bins + 1
    )
    count = (a + b).astype(int)

    # note the use of .get(), which copies data over from GPU memory to CPU memory
    return pd.DataFrame({"start": bins[:-1].get(), "stop": bins[1:].get(), "prob_dist": count[1:].get()}), t1, t2


z_root = zarr.open(dpath + "real_data/data/WGS/chr22.zarr",mode='r')
t = time.time()
gc = cu_zarr_afdist(z_root)

# warm up
cu_zarr_afdist(z_root)
gc, t1, t2 = cu_zarr_afdist(z_root)

print(gc)
print("Summarisation Time (ms) = {:.3f}".format(t1))
print("Host to GPU Transfer Time (ms) = {:.3f}".format(t2*1000))

  cupy._util.experimental('cupyx.jit.rawkernel')


   start    stop  prob_dist
0    0.0  0.1000    7081263
1    0.1  0.2000    2212177
2    0.2  0.3000    2096409
3    0.3  0.4000    3027269
4    0.4  0.5000   15854472
5    0.5  0.6000    1769666
6    0.6  0.7000     276639
7    0.7  0.8000     223031
8    0.8  0.9000     112195
9    0.9  1.0125     488105
Summarisation Time (ms) = 6.342
Host to GPU Transfer Time (ms) = 138.196


As the results show, letting the GPU do the allocation of computation, hugely reduces the execution time.

To deal with larger datasets without overwhelming GPU memory, the chunking approach can be used, so that each chunk is sufficiently large to fully utilise the GPU. These chunk sizes can then be used within the Zarr storage, to maximise efficiency. Where latency needs to be minimised, chunk sizes can be reduced and separate streams created, so that as one batch is being moved into GPU memory, the other is processing the previous batch.

## Zarr Decompression Using GPU on call_AD data

The same, call_AD, data has been saved into several different versions of a Zarr array - each using a different algorithm. 
The GPU arrays are being loaded into GPU memory as a cupy array whereas the standard CPU Zarr data is being loaded into a numpy array.
The data (int16) is the same size as the call_genotype (96514, 2504, 2) but has a 3rd dimension size of 7 rather than 2

N.B. GPUDirect Storage has not been used here, which provides direct memory transfers from NVMe to GOU memory.


In [None]:
import kvikio
import cupy as cp
import zarr
import kvikio.zarr
import time

def compress_nv(data, comp, root):
    
    t = time.time()
    
    z1 = zarr.array(cp_call_ad,
        chunks=(20000,cp_call_ad.shape[1],2),
        store=root,
        meta_array=cp.empty(()),
        compressor=comp,
        overwrite=True)
    
    return time.time()-t

z = zarr.open(dpath + "real_data/data/WGS/chr22.zarr", mode='r')

cp_call_ad=cp.array(z['call_AD'][:])
print(cp_call_ad.shape)

root = kvikio.zarr.GDSStore(dpath + "real_data/data/WGS/chr22_cuda_gt_bitcomp.zarr", normalize_keys=True)
t = compress_nv(cp_call_ad, kvikio.zarr.Bitcomp(), root)
print("GPU Bitcomp compress time = {:.3f} (s)".format(t))

root = kvikio.zarr.GDSStore(dpath + "real_data/data/WGS/chr22_cuda_gt_snappy.zarr", normalize_keys=True)
t = compress_nv(cp_call_ad, kvikio.zarr.Snappy(),root)
print("GPU Snappy compress time = {:.3f} (s)".format(t))

root = kvikio.zarr.GDSStore(dpath + "real_data/data/WGS/chr22_cuda_gt_lz4.zarr", normalize_keys=True)
t = compress_nv(cp_call_ad, kvikio.zarr.LZ4(),root)
print("GPU LZ4 compress time = {:.3f} (s)".format(t))

root = kvikio.zarr.GDSStore(dpath + "real_data/data/WGS/chr22_cuda_gt_casc.zarr", normalize_keys=True)
t = compress_nv(cp_call_ad, kvikio.zarr.Cascaded(),root)
print("GPU Cascaded compress time = {:.3f} (s)".format(t))

root = kvikio.zarr.GDSStore(dpath + "real_data/data/WGS/chr22_cuda_gt_gdeflate.zarr", normalize_keys=True)
t = compress_nv(cp_call_ad, kvikio.zarr.Gdeflate(),root)
print("GPU Gdeflate compress time = {:.3f} (s)".format(t))


(96514, 2504, 7)
GPU Bitcomp compress time = 1.372 (s)
GPU Snappy compress time = 0.996 (s)
GPU LZ4 compress time = 0.937 (s)
GPU Cascaded compress time = 0.703 (s)
GPU Gdeflate compress time = 0.955 (s)


Data has been compressed using the GPU (no comparison is made with CPU compression time here, but GPU saving and compression can be very quick). There are several different compression types, with each suited to a different data type (see [here](https://developer.nvidia.com/nvcomp))

Now the same compressed Zarr files can be loaded and the time measured and compared to the CPU Zarr version

In [None]:
import kvikio
import numpy as np
import cupy as cp
import zarr
import kvikio.zarr
import time

# To prevent initialisation skewing the result, load both a CPU and GPU zarr library
z = kvikio.zarr.open_cupy_array(store="real_data/data/WGS/chr22_cuda_ad_snappy.zarr", mode='r')
a=z[:]
del(a)
z1 = zarr.open_group(store="real_data/data/WGS/chr22.zarr", mode='r')
arr2 = cp.array(z1["call_genotype"][:])
del(arr2)
# Data is deleted and then the actual measurements start

t = time.time()
z = kvikio.zarr.open_cupy_array(store="real_data/data/WGS/chr22_cuda_ad_bitcomp.zarr", mode='r')
a=z[:]
print("GPU Bitcomp deflate time = {:.3f} (s)".format(time.time()-t))
del(a)

t = time.time()
z = kvikio.zarr.open_cupy_array(store="real_data/data/WGS/chr22_cuda_ad_snappy.zarr", mode='r')
a=z[:]
print("GPU Snappy deflate time = {:.3f} (s)".format(time.time()-t))
del(a)

t = time.time()
z = kvikio.zarr.open_cupy_array(store="real_data/data/WGS/chr22_cuda_ad_lz4.zarr", mode='r')
a=z[:]
print("GPU LZ4 deflate time = {:.3f} (s)".format(time.time()-t))
del(a)

t = time.time()    
z = kvikio.zarr.open_cupy_array(store="real_data/data/WGS/chr22_cuda_ad_casc.zarr", mode='r')
a=z[:]
print("GPU Cascaded deflate time = {:.3f} (s)".format(time.time()-t))
del(a)

t = time.time()
z = kvikio.zarr.open_cupy_array(store="real_data/data/WGS/chr22_cuda_ad_gdeflate.zarr", mode='r')
a=z[:]
print("GPU Gdeflate deflate time = {:.3f} (s)".format(time.time()-t))
del(a)

t = time.time()
z1 = zarr.open_group(store="real_data/data/WGS/chr22.zarr", mode='r')
arr2 = z1["call_AD"][:]

print("CPU zarr time = ", time.time()-t)

GPU snappy deflate time =  0.606 (s)
GPU bitcomp deflate time =  0.626 (s)
GPU cascade deflate time =  0.574 (s)
GPU gdeflate deflate time =  0.967 (s)
GPU lz4 deflate time =  0.567 (s)
CPU zarr time =  7.560 (s)



## Appendix

For some kernels, it can be worth using the Shared Memory feature of a GPU. Shared Memory (only visible to threads within the same block) is faster than global memory (visibile to all threads), but is limited.
In this case it makes little or no difference, but for other tasks in which memory is being accessed repeatedly, it can offer noticeable speed ups. For reference, below is an implementation of the kernel that uses Shared Memory.
For more information in cupy kernels, see [here](https://docs.cupy.dev/en/stable/user_guide/kernel.html).

In [None]:
@jit.rawkernel()
def cu_count_genotypes_chunk(G, hom_ref, hom_alt, het, ref_count,vs_size):
    # Get the index of the current thread within a 2D block
    thread_idx_x = jit.threadIdx.x # The x dimension is along the variant axis
    thread_idx_y = jit.threadIdx.y # The y dimension is along the sample axis
    
    # The combination of block size, block index and thead index provide the  variant and sample indexes
    variant_idx = jit.blockIdx.x * jit.blockDim.x + thread_idx_x 
    sample_idx = jit.blockIdx.y * jit.blockDim.y + thread_idx_y
    sm_offset = (thread_idx_x*threads_per_block+thread_idx_y)*4
    
    # Shared memory is shared between threads within the same block
    # It is limited in size but faster than global memory so can be used to cache data
    shmem = jit.shared_memory(cp.int32,threads_per_block*threads_per_block*4)
    
    # Initialise the shared memory for the cached counts
    shmem[sm_offset + 0]=0 # used for hom_ref
    shmem[sm_offset + 1]=0 # used for hom_alt
    shmem[sm_offset + 2]=0 # used for het
    shmem[sm_offset + 3]=0 # used for ref_count
    
    # because we round the thread block size up to the problem size, check the bounds
    if variant_idx<vs_size[0]:
        if sample_idx < vs_size[1]:
            a = G[variant_idx, sample_idx, 0]
            b = G[variant_idx, sample_idx, 1]
            if a == b:
                if a == 0:
                    shmem[sm_offset] += 1
                else:
                    shmem[sm_offset + 1] += 1
            else:
                shmem[sm_offset + 2] += 1
            shmem[sm_offset + 3] += (a==0)
            shmem[sm_offset + 3] += (b==0)
    
    # Now add the final shared memory cache vakues to the
    # global memory arrays that were passed in
    # using atomic adds to avoid data races
    atomic_add(hom_ref,variant_idx,shmem[sm_offset])
    atomic_add(hom_alt,variant_idx, shmem[sm_offset+1])
    atomic_add(het,variant_idx, shmem[sm_offset+2])
    atomic_add(ref_count,variant_idx, shmem[sm_offset+3])