In [1]:
from dask.distributed import Client
import xarray as xr
import gcsfs
from sgkit_bgen.bgen_reader import unpack_variables
xr.set_options(display_style='text')
client = Client()
client

0,1
Client  Scheduler: tcp://127.0.0.1:8086  Dashboard: http://127.0.0.1:8787/status,Cluster  Workers: 0  Cores: 0  Memory: 0 B


In [2]:
fs = gcsfs.GCSFileSystem()
store = gcsfs.mapping.GCSMap('rs-ukb/prep-data/gt-imputation/ukb_chrXY.zarr', gcs=fs, check=True, create=False)
ds = xr.open_zarr(store)
ds = unpack_variables(ds, dtype='float16')
ds

In [3]:
cr = ds.call_genotype_probability_mask.mean(dim='samples').compute()
cr.max()

In [6]:
import numpy as np
import dask.array as da
from numba import guvectorize
from sgkit.typing import ArrayLike
from xarray import Dataset, DataArray

@guvectorize(
    [
        "void(float32[:], uint8[:], uint8[:])",
    ],
    "(k),(n)->(n)",
    nopython=True,
)
def _hard_calls(gp: ArrayLike, _: ArrayLike, out: ArrayLike) -> None:
    out[:] = 0
    n_genotype = len(gp)
    i = np.argmax(gp)
    if i == 1:
        out[1] = 1
    else:
        out[:] = 1
        
# def hard_calls(ds: Dataset) -> DataArray:
#     ploidy: int = 2
#     n_genotypes = ds.dims["genotypes"]
#     G = da.asarray(ds["call_genotype_probability"])
#     shape = (G.chunks[0], G.chunks[1], ploidy)
#     N = da.empty(ploidy, dtype=np.uint8)
#     return xr.DataArray(
#         da.map_blocks(_hard_calls, G, N, chunks=shape, drop_axis=2, new_axis=2),
#         dims=("variants", "samples", "ploidy"),
#         name="call_genotype",
#     )

def hard_calls(ds: Dataset) -> DataArray:
    I = ds.call_genotype_probability.argmax(dim='genotypes').astype('uint8')
    return xr.concat([
        xr.where(I == np.uint8(0), np.uint8(0), np.uint8(1)),
        xr.where(I == np.uint8(2), np.uint8(1), np.uint8(0)),
    ], dim='ploidy').transpose('variants', 'samples', 'ploidy')

In [7]:
ds['call_genotype'] = hard_calls(ds)
ds

In [6]:
ds['call_genotype'][:10, :10, 0].compute()

In [8]:
ds['call_genotype_mask'] = ds.call_dosage_mask.broadcast_like(ds.call_genotype).astype(bool)
ds

In [10]:
import sgkit
ds['variant_hwe_p_value'] = sgkit.hardy_weinberg_test(ds)['variant_hwe_p_value']
ds

In [None]:
# This requires about 12GiB per core
p = ds['variant_hwe_p_value'].compute()
p