## Example tasks using vcfzarr

Run on 2024-05-20 by Ben Jeffery

In [1]:
import zarr
import pandas as pd
import numpy as np
import dataclasses
import numba
import zarr
import numcodecs

numcodecs.blosc.set_nthreads(1)

8

In [2]:
VCFZARR = "/home/bjeffery/re/messy_vcf/b2z/chr2_parallel_explode_009_3.zarr"
REGION_START, REGION_END = 58219159, 60650943

### Task 1 - Dump POS

In [3]:
%%time
root = zarr.open(VCFZARR, mode="r")
pos = root['variant_position'][:]
start_index = np.searchsorted(pos, REGION_START, 'left')
end_index = np.searchsorted(pos, REGION_END, 'right')
region_pos = pos[start_index: end_index]
df = pd.DataFrame(region_pos)
df.to_csv("pos_result_vcfzarr.txt")

CPU times: user 846 ms, sys: 1.93 s, total: 2.78 s
Wall time: 1min 44s


In [4]:
len(pos), len(region_pos)

(59880903, 562640)

### Task 1b - Dump POS to RAM only

In [5]:
%%time
root = zarr.open(VCFZARR, mode="r")
pos = root['variant_position'][:]
start_index = np.searchsorted(pos, REGION_START, 'left')
end_index = np.searchsorted(pos, REGION_END, 'right')
region_pos = pos[start_index: end_index]

CPU times: user 560 ms, sys: 572 ms, total: 1.13 s
Wall time: 20.9 s


### Task 2 - afdist

In [14]:
@dataclasses.dataclass
class GenotypeCounts:
    hom_ref: list
    hom_alt: list
    het: list
    ref_count: list

@numba.njit(
    "void(int64, int8[:,:,:], b1[:], b1[:], int32[:], int32[:], int32[:], int32[:])"
)
def count_genotypes_chunk_subset(
    offset, G, variant_mask, sample_mask, hom_ref, hom_alt, het, ref_count
):
    # NB Assuming diploids and no missing data!
    index = offset
    for j in range(G.shape[0]):
        if variant_mask[j]:
            for k in range(G.shape[1]):
                if sample_mask[k]:
                    a = G[j, k, 0]
                    b = G[j, k, 1]
                    if a == b:
                        if a == 0:
                            hom_ref[index] += 1
                        else:
                            hom_alt[index] += 1
                    else:
                        het[index] += 1
                    ref_count[index] += (a == 0) + (b == 0)
            index += 1

def classify_genotypes_subset(call_genotype, variant_mask, sample_mask):
    m = np.sum(variant_mask)

    # Use zarr arrays to get mask chunks aligned with the main data
    # for convenience.
    z_variant_mask = zarr.array(variant_mask, chunks=call_genotype.chunks[0])
    z_sample_mask = zarr.array(sample_mask, chunks=call_genotype.chunks[1])

    het = np.zeros(m, dtype=np.int32)
    hom_alt = np.zeros(m, dtype=np.int32)
    hom_ref = np.zeros(m, dtype=np.int32)
    ref_count = np.zeros(m, dtype=np.int32)
    j = 0
    # We should probably skip to the first non-zero chunk, but there probably
    # isn't much difference unless we have a huge number of chunks, and we're
    # only selecting a tiny subset
    for v_chunk in range(call_genotype.cdata_shape[0]):
        variant_mask_chunk = z_variant_mask.blocks[v_chunk]
        count = np.sum(variant_mask_chunk)
        if count > 0:
            for s_chunk in range(call_genotype.cdata_shape[1]):
                sample_mask_chunk = z_sample_mask.blocks[s_chunk]
                if np.sum(sample_mask_chunk) > 0:
                    G = call_genotype.blocks[v_chunk, s_chunk]
                    count_genotypes_chunk_subset(
                        j,
                        G,
                        variant_mask_chunk,
                        sample_mask_chunk,
                        hom_ref,
                        hom_alt,
                        het,
                        ref_count,
                    )
            j += count
    return GenotypeCounts(hom_ref, hom_alt, het, ref_count)
    
def zarr_afdist(path, num_bins=10, variant_slice=None, sample_slice=None):
    root = zarr.open(path)
    call_genotype = root["call_genotype"]
    m = call_genotype.shape[0]
    n = call_genotype.shape[1]

    variant_mask = np.zeros(m, dtype=bool)
    variant_mask[variant_slice] = 1
    sample_mask = np.zeros(n, dtype=bool)
    sample_mask[sample_slice] = 1
    counts = classify_genotypes_subset(call_genotype, variant_mask, sample_mask)
    n = np.sum(sample_mask)

    alt_count = 2 * n - counts.ref_count
    af = alt_count / (n * 2)
    bins = np.linspace(0, 1.0, num_bins + 1)
    bins[-1] += 0.0125
    pRA = 2 * af * (1 - af)
    pAA = af * af
    a = np.bincount(np.digitize(pRA, bins), weights=counts.het, minlength=num_bins + 1)
    b = np.bincount(
        np.digitize(pAA, bins), weights=counts.hom_alt, minlength=num_bins + 1
    )
    count = (a + b).astype(int)

    return pd.DataFrame({"start": bins[:-1], "stop": bins[1:], "prob_dist": count[1:]})

In [15]:
%%time
df = zarr_afdist(VCFZARR, num_bins=10, variant_slice=slice(start_index, end_index))
df

CPU times: user 2min 18s, sys: 693 ms, total: 2min 19s
Wall time: 4min 15s


Unnamed: 0,start,stop,prob_dist
0,0.0,0.1,286405469
1,0.1,0.2,137172734
2,0.2,0.3,136385315
3,0.3,0.4,158273300
4,0.4,0.5,325497447
5,0.5,0.6,42187173
6,0.6,0.7,44968576
7,0.7,0.8,37326793
8,0.8,0.9,34890232
9,0.9,1.0125,44520767


### Task 3 - filtering on FORMAT fields

In [16]:
%%time
root = zarr.open(VCFZARR)
DP = root['call_DP']
GQ = root['call_GQ']
num_variants = DP.shape[0]
num_samples = DP.shape[1]
variant_mask = np.zeros(num_variants, dtype=bool)
variant_mask[slice(start_index, end_index)] = 1
z_variant_mask = zarr.array(variant_mask, chunks=DP.chunks[0])
output_mask = np.zeros(num_variants, dtype=bool)
z_output_mask = zarr.array(output_mask, chunks=DP.chunks[0])

for v_chunk in range(DP.cdata_shape[0]):
    variant_mask_chunk = z_variant_mask.blocks[v_chunk]
    count = np.sum(variant_mask_chunk)
    if count > 0:
        output_mask_chunk = np.zeros_like(variant_mask_chunk)
        for s_chunk in range(DP.cdata_shape[1]):
            output_mask_chunk = np.logical_or(
                output_mask_chunk,
                np.any(
                    np.logical_and(
                        DP.blocks[v_chunk, s_chunk]>10, 
                        GQ.blocks[v_chunk, s_chunk]>20
                    ),
                    axis=(1)
                )
            )
        z_output_mask.blocks[v_chunk] = np.logical_and(output_mask_chunk, variant_mask_chunk)
root['variant_composite_filter'] = z_output_mask    

CPU times: user 1min 53s, sys: 5.73 s, total: 1min 58s
Wall time: 11min


### Consistency checks

In [9]:
bcftools_pos_df = pd.read_csv('pos_result_bcftools.txt', header=None)

In [10]:
vcfzarr_pos_df = pd.read_csv('pos_result_vcfzarr.txt')

In [11]:
np.array_equal(bcftools_pos_df.iloc[:,0].to_numpy(), vcfzarr_pos_df.iloc[:,1].to_numpy())

True

In [12]:
root = zarr.open(VCFZARR, mode="r")

In [13]:
list(root.keys())

['call_AD',
 'call_ADF',
 'call_ADR',
 'call_DP',
 'call_DPF',
 'call_FT',
 'call_GQ',
 'call_GQX',
 'call_PL',
 'call_PS',
 'call_genotype',
 'call_genotype_mask',
 'call_genotype_phased',
 'contig_id',
 'contig_length',
 'filter_id',
 'sample_id',
 'variant_ABratio',
 'variant_AC',
 'variant_AC_Hemi',
 'variant_AC_Het',
 'variant_AC_Hom',
 'variant_AN',
 'variant_MendelSite',
 'variant_OLD_CLUMPED',
 'variant_OLD_MULTIALLELIC',
 'variant_allele',
 'variant_completeGTRatio',
 'variant_composite_filter',
 'variant_contig',
 'variant_filter',
 'variant_id',
 'variant_id_mask',
 'variant_medianDepthAll',
 'variant_medianDepthNonMiss',
 'variant_medianGQ',
 'variant_missingness',
 'variant_phwe_afr',
 'variant_phwe_amr',
 'variant_phwe_eas',
 'variant_phwe_eur',
 'variant_phwe_sas',
 'variant_position',
 'variant_quality']