Skip to content

Commit

Permalink
Reference variables instead of using strings
Browse files Browse the repository at this point in the history
  • Loading branch information
ravwojdyla committed Sep 29, 2020
1 parent 8ca9f8b commit 6271e89
Show file tree
Hide file tree
Showing 10 changed files with 360 additions and 270 deletions.
70 changes: 35 additions & 35 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,41 +54,41 @@ Variables
.. autosummary::
:toctree: generated/

variables.base_prediction
variables.call_allele_count
variables.call_dosage
variables.call_dosage_mask
variables.call_genotype
variables.call_genotype_mask
variables.call_genotype_phased
variables.call_genotype_probability
variables.call_genotype_probability_mask
variables.covariates
variables.dosage
variables.genotype_counts
variables.loco_prediction
variables.meta_prediction
variables.pc_relate_phi
variables.sample_id
variables.sample_pcs
variables.traits
variables.variant_allele
variables.variant_allele_count
variables.variant_allele_frequency
variables.variant_allele_total
variables.variant_beta
variables.variant_call_rate
variables.variant_contig
variables.variant_hwe_p_value
variables.variant_id
variables.variant_n_called
variables.variant_n_het
variables.variant_n_hom_alt
variables.variant_n_hom_ref
variables.variant_n_non_ref
variables.variant_p_value
variables.variant_position
variables.variant_t_value
variables.base_prediction_spec
variables.call_allele_count_spec
variables.call_dosage_spec
variables.call_dosage_mask_spec
variables.call_genotype_spec
variables.call_genotype_mask_spec
variables.call_genotype_phased_spec
variables.call_genotype_probability_spec
variables.call_genotype_probability_mask_spec
variables.covariates_spec
variables.dosage_spec
variables.genotype_counts_spec
variables.loco_prediction_spec
variables.meta_prediction_spec
variables.pc_relate_phi_spec
variables.sample_id_spec
variables.sample_pcs_spec
variables.traits_spec
variables.variant_allele_spec
variables.variant_allele_count_spec
variables.variant_allele_frequency_spec
variables.variant_allele_total_spec
variables.variant_beta_spec
variables.variant_call_rate_spec
variables.variant_contig_spec
variables.variant_hwe_p_value_spec
variables.variant_id_spec
variables.variant_n_called_spec
variables.variant_n_het_spec
variables.variant_n_hom_alt_spec
variables.variant_n_hom_ref_spec
variables.variant_n_non_ref_spec
variables.variant_p_value_spec
variables.variant_position_spec
variables.variant_t_value_spec

Utilities
=========
Expand Down
4 changes: 0 additions & 4 deletions sgkit/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from . import variables
from .typing import ArrayLike
from .utils import check_array_like

DIM_VARIANT = "variants"
DIM_SAMPLE = "samples"
Expand Down Expand Up @@ -70,13 +69,11 @@ def create_genotype_call_dataset(
),
}
if call_genotype_phased is not None:
check_array_like(call_genotype_phased, kind="b", ndim=2)
data_vars["call_genotype_phased"] = (
[DIM_VARIANT, DIM_SAMPLE],
call_genotype_phased,
)
if variant_id is not None:
check_array_like(variant_id, kind={"U", "O"}, ndim=1)
data_vars["variant_id"] = ([DIM_VARIANT], variant_id)
attrs: Dict[Hashable, Any] = {"contigs": variant_contig_names}
return variables.validate(xr.Dataset(data_vars=data_vars, attrs=attrs))
Expand Down Expand Up @@ -145,7 +142,6 @@ def create_genotype_dosage_dataset(
),
}
if variant_id is not None:
check_array_like(variant_id, kind={"U", "O"}, ndim=1)
data_vars["variant_id"] = ([DIM_VARIANT], variant_id)
attrs: Dict[Hashable, Any] = {"contigs": variant_contig_names}
return variables.validate(xr.Dataset(data_vars=data_vars, attrs=attrs))
60 changes: 31 additions & 29 deletions sgkit/stats/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def count_alleles(g: ArrayLike, _: ArrayLike, out: ArrayLike) -> None:


def count_call_alleles(
ds: Dataset, *, call_genotype: str = "call_genotype", merge: bool = True
ds: Dataset, *, call_genotype: str = variables.call_genotype, merge: bool = True
) -> Dataset:
"""Compute per sample allele counts from genotype calls.
Expand All @@ -64,7 +64,7 @@ def count_call_alleles(
:func:`sgkit.create_genotype_call_dataset`.
call_genotype
Input variable name holding call_genotype as defined by
:data:`sgkit.variables.call_genotype`
:data:`sgkit.variables.call_genotype_spec`
merge
If True (the default), merge the input dataset and the computed
output variables into a single dataset, otherwise return only
Expand Down Expand Up @@ -104,14 +104,14 @@ def count_call_alleles(
[[2, 0],
[2, 0]]], dtype=uint8)
"""
variables.validate(ds, {call_genotype: variables.call_genotype})
variables.validate(ds, {call_genotype: variables.call_genotype_spec})
n_alleles = ds.dims["alleles"]
G = da.asarray(ds[call_genotype])
shape = (G.chunks[0], G.chunks[1], n_alleles)
N = da.empty(n_alleles, dtype=np.uint8)
new_ds = Dataset(
{
"call_allele_count": (
variables.call_allele_count: (
("variants", "samples", "alleles"),
da.map_blocks(
count_alleles, G, N, chunks=shape, drop_axis=2, new_axis=2
Expand All @@ -123,7 +123,7 @@ def count_call_alleles(


def count_variant_alleles(
ds: Dataset, *, call_genotype: str = "call_genotype", merge: bool = True
ds: Dataset, *, call_genotype: str = variables.call_genotype, merge: bool = True
) -> Dataset:
"""Compute allele count from genotype calls.
Expand All @@ -134,7 +134,7 @@ def count_variant_alleles(
:func:`sgkit.create_genotype_call_dataset`.
call_genotype
Input variable name holding call_genotype as defined by
:data:`sgkit.variables.call_genotype`
:data:`sgkit.variables.call_genotype_spec`
merge
If True (the default), merge the input dataset and the computed
output variables into a single dataset, otherwise return only
Expand Down Expand Up @@ -169,10 +169,10 @@ def count_variant_alleles(
"""
new_ds = Dataset(
{
"variant_allele_count": (
variables.variant_allele_count: (
("variants", "alleles"),
count_call_alleles(ds, call_genotype=call_genotype)[
"call_allele_count"
variables.call_allele_count
].sum(dim="samples"),
)
}
Expand Down Expand Up @@ -222,28 +222,30 @@ def allele_frequency(
data_vars: Dict[Hashable, Any] = {}
# only compute variant allele count if not already in dataset
if variant_allele_count is not None:
variables.validate(ds, {variant_allele_count: variables.variant_allele_count})
variables.validate(
ds, {variant_allele_count: variables.variant_allele_count_spec}
)
AC = ds[variant_allele_count]
else:
AC = count_variant_alleles(ds, merge=False, call_genotype=call_genotype)[
"variant_allele_count"
variables.variant_allele_count
]
data_vars["variant_allele_count"] = AC
data_vars[variables.variant_allele_count] = AC

M = ds[call_genotype_mask].stack(calls=("samples", "ploidy"))
AN = (~M).sum(dim="calls") # type: ignore
assert AN.shape == (ds.dims["variants"],)

data_vars["variant_allele_total"] = AN
data_vars["variant_allele_frequency"] = AC / AN
data_vars[variables.variant_allele_total] = AN
data_vars[variables.variant_allele_frequency] = AC / AN
return Dataset(data_vars)


def variant_stats(
ds: Dataset,
*,
call_genotype_mask: str = "call_genotype_mask",
call_genotype: str = "call_genotype",
call_genotype_mask: str = variables.call_genotype_mask,
call_genotype: str = variables.call_genotype,
variant_allele_count: Optional[str] = None,
merge: bool = True,
) -> Dataset:
Expand All @@ -256,13 +258,13 @@ def variant_stats(
:func:`sgkit.create_genotype_call_dataset`.
call_genotype
Input variable name holding call_genotype.
Defined by :data:`sgkit.variables.call_genotype`.
Defined by :data:`sgkit.variables.call_genotype_spec`.
call_genotype_mask
Input variable name holding call_genotype_mask.
Defined by :data:`sgkit.variables.call_genotype_mask`
Defined by :data:`sgkit.variables.call_genotype_mask_spec`
variant_allele_count
Optional name of the input variable holding variant_allele_count,
as defined by :data:`sgkit.variables.variant_allele_count`.
as defined by :data:`sgkit.variables.variant_allele_count_spec`.
merge
If True (the default), merge the input dataset and the computed
output variables into a single dataset, otherwise return only
Expand All @@ -273,30 +275,30 @@ def variant_stats(
-------
A dataset containing the following variables:
- :data:`sgkit.variables.variant_n_called` (variants):
- :data:`sgkit.variables.variant_n_called_spec` (variants):
The number of samples with called genotypes.
- :data:`sgkit.variables.variant_call_rate` (variants):
- :data:`sgkit.variables.variant_call_rate_spec` (variants):
The fraction of samples with called genotypes.
- :data:`sgkit.variables.variant_n_het` (variants):
- :data:`sgkit.variables.variant_n_het_spec` (variants):
The number of samples with heterozygous calls.
- :data:`sgkit.variables.variant_n_hom_ref` (variants):
- :data:`sgkit.variables.variant_n_hom_ref_spec` (variants):
The number of samples with homozygous reference calls.
- :data:`sgkit.variables.variant_n_hom_alt` (variants):
- :data:`sgkit.variables.variant_n_hom_alt_spec` (variants):
The number of samples with homozygous alternate calls.
- :data:`sgkit.variables.variant_n_non_ref` (variants):
- :data:`sgkit.variables.variant_n_non_ref_spec` (variants):
The number of samples that are not homozygous reference calls.
- :data:`sgkit.variables.variant_allele_count` (variants, alleles):
- :data:`sgkit.variables.variant_allele_count_spec` (variants, alleles):
The number of occurrences of each allele.
- :data:`sgkit.variables.variant_allele_total` (variants):
- :data:`sgkit.variables.variant_allele_total_spec` (variants):
The number of occurrences of all alleles.
- :data:`sgkit.variables.variant_allele_frequency` (variants, alleles):
- :data:`sgkit.variables.variant_allele_frequency_spec` (variants, alleles):
The frequency of occurrence of each allele.
"""
variables.validate(
ds,
{
call_genotype: variables.call_genotype,
call_genotype_mask: variables.call_genotype_mask,
call_genotype: variables.call_genotype_spec,
call_genotype_mask: variables.call_genotype_mask_spec,
},
)
new_ds = xr.merge(
Expand Down
30 changes: 18 additions & 12 deletions sgkit/stats/association.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,13 @@ def linear_regression(
return LinearRegressionResult(beta=B, t_value=T, p_value=P)


def _get_loop_covariates(ds: Dataset, dosage: Optional[str] = None) -> Array:
def _get_loop_covariates(
ds: Dataset, call_genotype: str, dosage: Optional[str] = None
) -> Array:
if dosage is None:
# TODO: This should be (probably gwas-specific) allele
# count with sex chromosome considerations
G = ds["call_genotype"].sum(dim="ploidy") # pragma: no cover
G = ds[call_genotype].sum(dim="ploidy") # pragma: no cover
else:
G = ds[dosage]
return da.asarray(G.data)
Expand All @@ -121,6 +123,7 @@ def gwas_linear_regression(
covariates: Union[str, Sequence[str]],
traits: Union[str, Sequence[str]],
add_intercept: bool = True,
call_genotype: str = variables.call_genotype,
merge: bool = True,
) -> Dataset:
"""Run linear regression to identify continuous trait associations with genetic variants.
Expand All @@ -138,15 +141,18 @@ def gwas_linear_regression(
Dataset containing necessary dependent and independent variables.
dosage
Name of genetic dosage variable.
Defined by :data:`sgkit.variables.dosage`.
Defined by :data:`sgkit.variables.dosage_spec`.
covariates
Names of covariate variables (1D or 2D).
Defined by :data:`sgkit.variables.covariates`.
Defined by :data:`sgkit.variables.covariates_spec`.
traits
Names of trait variables (1D or 2D).
Defined by :data:`sgkit.variables.traits`.
Defined by :data:`sgkit.variables.traits_spec`.
add_intercept
Add intercept term to covariate set, by default True.
call_genotype
Input variable name holding call_genotype.
Defined by :data:`sgkit.variables.call_genotype_spec`.
merge
If True (the default), merge the input dataset and the computed
output variables into a single dataset, otherwise return only
Expand Down Expand Up @@ -193,12 +199,12 @@ def gwas_linear_regression(

variables.validate(
ds,
{dosage: variables.dosage},
{c: variables.covariates for c in covariates},
{t: variables.traits for t in traits},
{dosage: variables.dosage_spec},
{c: variables.covariates_spec for c in covariates},
{t: variables.traits_spec for t in traits},
)

G = _get_loop_covariates(ds, dosage=dosage)
G = _get_loop_covariates(ds, dosage=dosage, call_genotype=call_genotype)

X = da.asarray(concat_2d(ds[list(covariates)], dims=("samples", "covariates")))
if add_intercept:
Expand All @@ -216,9 +222,9 @@ def gwas_linear_regression(
res = linear_regression(G.T, X, Y)
new_ds = xr.Dataset(
{
"variant_beta": (("variants", "traits"), res.beta),
"variant_t_value": (("variants", "traits"), res.t_value),
"variant_p_value": (("variants", "traits"), res.p_value),
variables.variant_beta: (("variants", "traits"), res.beta),
variables.variant_t_value: (("variants", "traits"), res.t_value),
variables.variant_p_value: (("variants", "traits"), res.p_value),
}
)
return conditional_merge_datasets(ds, variables.validate(new_ds), merge)
16 changes: 8 additions & 8 deletions sgkit/stats/hwe.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ def hardy_weinberg_test(
ds: Dataset,
*,
genotype_counts: Optional[Hashable] = None,
call_genotype: str = "call_genotype",
call_genotype_mask: str = "call_genotype_mask",
call_genotype: str = variables.call_genotype,
call_genotype_mask: str = variables.call_genotype_mask,
merge: bool = True,
) -> Dataset:
"""Exact test for HWE as described in Wigginton et al. 2005 [1].
Expand All @@ -146,10 +146,10 @@ def hardy_weinberg_test(
(in that order) across all samples for a variant.
call_genotype
Input variable name holding call_genotype.
Defined by :data:`sgkit.variables.call_genotype`.
Defined by :data:`sgkit.variables.call_genotype_spec`.
call_genotype_mask
Input variable name holding call_genotype_mask.
Defined by :data:`sgkit.variables.call_genotype_mask`
Defined by :data:`sgkit.variables.call_genotype_mask_spec`
merge
If True (the default), merge the input dataset and the computed
output variables into a single dataset, otherwise return only
Expand Down Expand Up @@ -185,15 +185,15 @@ def hardy_weinberg_test(
raise NotImplementedError("HWE test only implemented for biallelic genotypes")
# Use precomputed genotype counts if provided
if genotype_counts is not None:
variables.validate(ds, {genotype_counts: variables.genotype_counts})
variables.validate(ds, {genotype_counts: variables.genotype_counts_spec})
obs = list(da.asarray(ds[genotype_counts]).T)
# Otherwise compute genotype counts from calls
else:
variables.validate(
ds,
{
call_genotype_mask: variables.call_genotype_mask,
call_genotype: variables.call_genotype,
call_genotype_mask: variables.call_genotype_mask_spec,
call_genotype: variables.call_genotype_spec,
},
)
# TODO: Use API genotype counting function instead, e.g.
Expand All @@ -203,5 +203,5 @@ def hardy_weinberg_test(
cts = [1, 0, 2] # arg order: hets, hom1, hom2
obs = [da.asarray((AC == ct).sum(dim="samples")) for ct in cts]
p = da.map_blocks(hardy_weinberg_p_value_vec_jit, *obs)
new_ds = xr.Dataset({"variant_hwe_p_value": ("variants", p)})
new_ds = xr.Dataset({variables.variant_hwe_p_value: ("variants", p)})
return conditional_merge_datasets(ds, variables.validate(new_ds), merge)

0 comments on commit 6271e89

Please sign in to comment.