Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add merge parameter to method functions #250

Merged
merged 7 commits into from
Sep 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ sgkit: Statistical genetics toolkit in Python
:caption: Contents:

api
usage
contributing


Expand Down
24 changes: 24 additions & 0 deletions docs/usage.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
.. usage:

*****
Usage
*****

.. contents:: Table of contents:
:local:


.. _dataset_merge:

Dataset merge behavior
======================

Generally, method functions in sgkit compute some new variables based on the
input dataset, then return a new output dataset that consists of the input
dataset plus the new computed variables. The input dataset is unchanged.

This behavior can be controlled using the ``merge`` parameter. If set to ``True``
(the default), then the function will merge the input dataset and the computed
output variables into a single dataset. Output variables will overwrite any
input variables with the same name, and a warning will be issued in this case.
If ``False``, the function will return only the computed output variables.
29 changes: 13 additions & 16 deletions sgkit/stats/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from xarray import Dataset

from sgkit.typing import ArrayLike
from sgkit.utils import merge_datasets
from sgkit.utils import conditional_merge_datasets

Dimension = Literal["samples", "variants"]

Expand Down Expand Up @@ -61,10 +61,9 @@ def count_call_alleles(ds: Dataset, merge: bool = True) -> Dataset:
`sgkit.create_genotype_call_dataset`.
merge : bool, optional
If True (the default), merge the input dataset and the computed
output variables into a single dataset. Output variables will
overwrite any input variables with the same name, and a warning
will be issued in this case.
If False, return only the computed output variables.
output variables into a single dataset, otherwise return only
the computed output variables.
See :ref:`dataset_merge` for more details.

Returns
-------
Expand Down Expand Up @@ -114,7 +113,7 @@ def count_call_alleles(ds: Dataset, merge: bool = True) -> Dataset:
)
}
)
return merge_datasets(ds, new_ds) if merge else new_ds
return conditional_merge_datasets(ds, new_ds, merge)


def count_variant_alleles(ds: Dataset, merge: bool = True) -> Dataset:
Expand All @@ -127,10 +126,9 @@ def count_variant_alleles(ds: Dataset, merge: bool = True) -> Dataset:
`sgkit.create_genotype_call_dataset`.
merge : bool, optional
If True (the default), merge the input dataset and the computed
output variables into a single dataset. Output variables will
overwrite any input variables with the same name, and a warning
will be issued in this case.
If False, return only the computed output variables.
output variables into a single dataset, otherwise return only
the computed output variables.
See :ref:`dataset_merge` for more details.

Returns
-------
Expand Down Expand Up @@ -167,7 +165,7 @@ def count_variant_alleles(ds: Dataset, merge: bool = True) -> Dataset:
)
}
)
return merge_datasets(ds, new_ds) if merge else new_ds
return conditional_merge_datasets(ds, new_ds, merge)


def _swap(dim: Dimension) -> Dimension:
Expand Down Expand Up @@ -229,10 +227,9 @@ def variant_stats(ds: Dataset, merge: bool = True) -> Dataset:
`sgkit.create_genotype_call_dataset`.
merge : bool, optional
If True (the default), merge the input dataset and the computed
output variables into a single dataset. Output variables will
overwrite any input variables with the same name, and a warning
will be issued in this case.
If False, return only the computed output variables.
output variables into a single dataset, otherwise return only
the computed output variables.
See :ref:`dataset_merge` for more details.

Returns
-------
Expand All @@ -255,4 +252,4 @@ def variant_stats(ds: Dataset, merge: bool = True) -> Dataset:
allele_frequency(ds),
]
)
return merge_datasets(ds, new_ds) if merge else new_ds
return conditional_merge_datasets(ds, new_ds, merge)
10 changes: 9 additions & 1 deletion sgkit/stats/association.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from xarray import Dataset

from ..typing import ArrayLike
from ..utils import conditional_merge_datasets
from .utils import concat_2d


Expand Down Expand Up @@ -116,6 +117,7 @@ def gwas_linear_regression(
covariates: Union[str, Sequence[str]],
traits: Union[str, Sequence[str]],
add_intercept: bool = True,
merge: bool = True,
) -> Dataset:
"""Run linear regression to identify continuous trait associations with genetic variants.

Expand Down Expand Up @@ -148,6 +150,11 @@ def gwas_linear_regression(
and concatenated to any 1D traits along the second axis (columns).
add_intercept : bool, optional
Add intercept term to covariate set, by default True.
merge : bool, optional
If True (the default), merge the input dataset and the computed
output variables into a single dataset, otherwise return only
the computed output variables.
See :ref:`dataset_merge` for more details.

Warnings
--------
Expand Down Expand Up @@ -203,10 +210,11 @@ def gwas_linear_regression(
Y = Y.rechunk((None, -1))

res = linear_regression(G.T, X, Y)
return xr.Dataset(
new_ds = xr.Dataset(
{
"variant_beta": (("variants", "traits"), res.beta),
"variant_t_value": (("variants", "traits"), res.t_value),
"variant_p_value": (("variants", "traits"), res.p_value),
}
)
return conditional_merge_datasets(ds, new_ds, merge)
12 changes: 10 additions & 2 deletions sgkit/stats/hwe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from numpy import ndarray
from xarray import Dataset

from sgkit.utils import conditional_merge_datasets


def hardy_weinberg_p_value(obs_hets: int, obs_hom1: int, obs_hom2: int) -> float:
"""Exact test for HWE as described in Wigginton et al. 2005 [1].
Expand Down Expand Up @@ -122,7 +124,7 @@ def hardy_weinberg_p_value_vec(


def hardy_weinberg_test(
ds: Dataset, genotype_counts: Optional[Hashable] = None
ds: Dataset, genotype_counts: Optional[Hashable] = None, merge: bool = True,
) -> Dataset:
"""Exact test for HWE as described in Wigginton et al. 2005 [1].

Expand All @@ -137,6 +139,11 @@ def hardy_weinberg_test(
where `N` is equal to the number of variants and the 3 columns contain
heterozygous, homozygous reference, and homozygous alternate counts
(in that order) across all samples for a variant.
merge : bool, optional
If True (the default), merge the input dataset and the computed
jeromekelleher marked this conversation as resolved.
Show resolved Hide resolved
output variables into a single dataset, otherwise return only
the computed output variables.
See :ref:`dataset_merge` for more details.

Warnings
--------
Expand Down Expand Up @@ -178,4 +185,5 @@ def hardy_weinberg_test(
cts = [1, 0, 2] # arg order: hets, hom1, hom2
obs = [da.asarray((AC == ct).sum(dim="samples")) for ct in cts]
p = da.map_blocks(hardy_weinberg_p_value_vec_jit, *obs)
return xr.Dataset({"variant_hwe_p_value": ("variants", p)})
new_ds = xr.Dataset({"variant_hwe_p_value": ("variants", p)})
return conditional_merge_datasets(ds, new_ds, merge)
11 changes: 5 additions & 6 deletions sgkit/stats/pc_relate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import xarray as xr

from sgkit.typing import ArrayLike
from sgkit.utils import merge_datasets
from sgkit.utils import conditional_merge_datasets


def gramian(a: ArrayLike) -> ArrayLike:
Expand Down Expand Up @@ -69,10 +69,9 @@ def pc_relate(ds: xr.Dataset, maf: float = 0.01, merge: bool = True) -> xr.Datas
The default value is 0.01. Must be between (0.0, 0.1).
merge : bool, optional
If True (the default), merge the input dataset and the computed
output variables into a single dataset. Output variables will
overwrite any input variables with the same name, and a warning
will be issued in this case.
If False, return only the computed output variables.
output variables into a single dataset, otherwise return only
the computed output variables.
See :ref:`dataset_merge` for more details.

Warnings
--------
Expand Down Expand Up @@ -149,4 +148,4 @@ def pc_relate(ds: xr.Dataset, maf: float = 0.01, merge: bool = True) -> xr.Datas
# NOTE: phi is of shape (S x S), S = num samples
assert phi.shape == (call_g.shape[1],) * 2
new_ds = xr.Dataset({"pc_relate_phi": (("sample_x", "sample_y"), phi)})
return merge_datasets(ds, new_ds) if merge else new_ds
return conditional_merge_datasets(ds, new_ds, merge)
13 changes: 10 additions & 3 deletions sgkit/stats/regenie.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from xarray import Dataset

from ..typing import ArrayLike
from ..utils import split_array_chunks
from ..utils import conditional_merge_datasets, split_array_chunks
from .utils import (
assert_array_shape,
assert_block_shape,
Expand Down Expand Up @@ -732,6 +732,7 @@ def regenie(
add_intercept: bool = True,
normalize: bool = False,
orthogonalize: bool = False,
merge: bool = True,
**kwargs: Any,
) -> Dataset:
"""Regenie trait transformation.
Expand Down Expand Up @@ -779,6 +780,11 @@ def regenie(
orthogonalize : bool
**Experimental**: Remove covariates through orthogonalization
of genotypes and traits, by default False.
merge : bool, optional
If True (the default), merge the input dataset and the computed
output variables into a single dataset, otherwise return only
the computed output variables.
See :ref:`dataset_merge` for more details.

Warnings
--------
Expand Down Expand Up @@ -817,7 +823,7 @@ def regenie(
>>> ds["call_dosage"] = (("variants", "samples"), rs.normal(size=(n_variant, n_sample)))
>>> ds["sample_covariate"] = (("samples", "covariates"), rs.normal(size=(n_sample, n_covariate)))
>>> ds["sample_trait"] = (("samples", "traits"), rs.normal(size=(n_sample, n_trait)))
>>> res = regenie(ds, dosage="call_dosage", covariates="sample_covariate", traits="sample_trait")
>>> res = regenie(ds, dosage="call_dosage", covariates="sample_covariate", traits="sample_trait", merge=False)
>>> res.compute() # doctest: +NORMALIZE_WHITESPACE
<xarray.Dataset>
Dimensions: (alphas: 5, blocks: 2, contigs: 2, outcomes: 5, samples: 50)
Expand All @@ -843,7 +849,7 @@ def regenie(
X = da.asarray(concat_2d(ds[list(covariates)], dims=("samples", "covariates")))
Y = da.asarray(concat_2d(ds[list(traits)], dims=("samples", "traits")))
contigs = ds["variant_contig"]
return regenie_transform(
new_ds = regenie_transform(
G.T,
X,
Y,
Expand All @@ -856,3 +862,4 @@ def regenie(
orthogonalize=orthogonalize,
**kwargs,
)
return conditional_merge_datasets(ds, new_ds, merge)
2 changes: 1 addition & 1 deletion sgkit/tests/test_hwe.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def test_hwep_dataset__precomputed_counts(ds_neq: Dataset) -> None:
cts = [1, 0, 2] # arg order: hets, hom1, hom2
gtc = xr.concat([(ac == ct).sum(dim="samples") for ct in cts], dim="counts").T # type: ignore[no-untyped-call]
ds = ds.assign(**{"variant_genotype_counts": gtc})
p = hwep_test(ds, genotype_counts="variant_genotype_counts")
p = hwep_test(ds, genotype_counts="variant_genotype_counts", merge=False)
assert np.all(p < 1e-8)


Expand Down
4 changes: 2 additions & 2 deletions sgkit/tests/test_regenie.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def test_regenie__no_loco_with_one_contig():
ds = simulate_regression_dataset(
n_variant=10, n_sample=5, n_contig=1, n_covariate=1, n_trait=1
)
res = regenie_sim(ds=ds)
res = regenie_sim(ds=ds, merge=False)
assert len(res) == 2
assert "loco_prediction" not in res

Expand All @@ -311,7 +311,7 @@ def test_regenie__32bit_float(ds):
)
# Ensure that a uniform demotion in types for input arrays (aside from contigs)
# results in arrays with the same type
res = regenie_sim(ds=ds)
res = regenie_sim(ds=ds, merge=False)
for v in res:
assert res[v].dtype == np.float32

Expand Down
5 changes: 5 additions & 0 deletions sgkit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,11 @@ def merge_datasets(input: Dataset, output: Dataset) -> Dataset:
return output.merge(input, compat="override")


def conditional_merge_datasets(input: Dataset, output: Dataset, merge: bool) -> Dataset:
"""Merge the input and output datasets only if `merge` is true, otherwise just return the output."""
return merge_datasets(input, output) if merge else output


def split_array_chunks(n: int, blocks: int) -> Tuple[int, ...]:
"""Compute chunk sizes for an array split into blocks.

Expand Down