Skip to content

Commit

Permalink
sgkit var acts like str-like
Browse files Browse the repository at this point in the history
  • Loading branch information
ravwojdyla committed Sep 29, 2020
1 parent 8ca9f8b commit b6197a9
Show file tree
Hide file tree
Showing 9 changed files with 99 additions and 41 deletions.
38 changes: 17 additions & 21 deletions sgkit/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from . import variables
from .typing import ArrayLike
from .utils import check_array_like

DIM_VARIANT = "variants"
DIM_SAMPLE = "samples"
Expand Down Expand Up @@ -59,25 +58,23 @@ def create_genotype_call_dataset(
The dataset of genotype calls.
"""
data_vars: Dict[Hashable, Any] = {
"variant_contig": ([DIM_VARIANT], variant_contig),
"variant_position": ([DIM_VARIANT], variant_position),
"variant_allele": ([DIM_VARIANT, DIM_ALLELE], variant_alleles),
"sample_id": ([DIM_SAMPLE], sample_id),
"call_genotype": ([DIM_VARIANT, DIM_SAMPLE, DIM_PLOIDY], call_genotype),
"call_genotype_mask": (
variables.variant_contig: ([DIM_VARIANT], variant_contig),
variables.variant_position: ([DIM_VARIANT], variant_position),
variables.variant_allele: ([DIM_VARIANT, DIM_ALLELE], variant_alleles),
variables.sample_id: ([DIM_SAMPLE], sample_id),
variables.call_genotype: ([DIM_VARIANT, DIM_SAMPLE, DIM_PLOIDY], call_genotype),
variables.call_genotype_mask: (
[DIM_VARIANT, DIM_SAMPLE, DIM_PLOIDY],
call_genotype < 0,
),
}
if call_genotype_phased is not None:
check_array_like(call_genotype_phased, kind="b", ndim=2)
data_vars["call_genotype_phased"] = (
data_vars[variables.call_genotype_phased] = (
[DIM_VARIANT, DIM_SAMPLE],
call_genotype_phased,
)
if variant_id is not None:
check_array_like(variant_id, kind={"U", "O"}, ndim=1)
data_vars["variant_id"] = ([DIM_VARIANT], variant_id)
data_vars[variables.variant_id] = ([DIM_VARIANT], variant_id)
attrs: Dict[Hashable, Any] = {"contigs": variant_contig_names}
return variables.validate(xr.Dataset(data_vars=data_vars, attrs=attrs))

Expand Down Expand Up @@ -129,23 +126,22 @@ def create_genotype_dosage_dataset(
"""
data_vars: Dict[Hashable, Any] = {
"variant_contig": ([DIM_VARIANT], variant_contig),
"variant_position": ([DIM_VARIANT], variant_position),
"variant_allele": ([DIM_VARIANT, DIM_ALLELE], variant_alleles),
"sample_id": ([DIM_SAMPLE], sample_id),
"call_dosage": ([DIM_VARIANT, DIM_SAMPLE], call_dosage),
"call_dosage_mask": ([DIM_VARIANT, DIM_SAMPLE], np.isnan(call_dosage)),
"call_genotype_probability": (
variables.variant_contig: ([DIM_VARIANT], variant_contig),
variables.variant_position: ([DIM_VARIANT], variant_position),
variables.variant_allele: ([DIM_VARIANT, DIM_ALLELE], variant_alleles),
variables.sample_id: ([DIM_SAMPLE], sample_id),
variables.call_dosage: ([DIM_VARIANT, DIM_SAMPLE], call_dosage),
variables.call_dosage_mask: ([DIM_VARIANT, DIM_SAMPLE], np.isnan(call_dosage)),
variables.call_genotype_probability: (
[DIM_VARIANT, DIM_SAMPLE, DIM_GENOTYPE],
call_genotype_probability,
),
"call_genotype_probability_mask": (
variables.call_genotype_probability_mask: (
[DIM_VARIANT, DIM_SAMPLE, DIM_GENOTYPE],
np.isnan(call_genotype_probability),
),
}
if variant_id is not None:
check_array_like(variant_id, kind={"U", "O"}, ndim=1)
data_vars["variant_id"] = ([DIM_VARIANT], variant_id)
data_vars[variables.variant_id] = ([DIM_VARIANT], variant_id)
attrs: Dict[Hashable, Any] = {"contigs": variant_contig_names}
return variables.validate(xr.Dataset(data_vars=data_vars, attrs=attrs))
6 changes: 3 additions & 3 deletions sgkit/stats/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,14 +228,14 @@ def allele_frequency(
AC = count_variant_alleles(ds, merge=False, call_genotype=call_genotype)[
"variant_allele_count"
]
data_vars["variant_allele_count"] = AC
data_vars[variables.variant_allele_count] = AC

M = ds[call_genotype_mask].stack(calls=("samples", "ploidy"))
AN = (~M).sum(dim="calls") # type: ignore
assert AN.shape == (ds.dims["variants"],)

data_vars["variant_allele_total"] = AN
data_vars["variant_allele_frequency"] = AC / AN
data_vars[variables.variant_allele_total] = AN
data_vars[variables.variant_allele_frequency] = AC / AN
return Dataset(data_vars)


Expand Down
18 changes: 12 additions & 6 deletions sgkit/stats/association.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,13 @@ def linear_regression(
return LinearRegressionResult(beta=B, t_value=T, p_value=P)


def _get_loop_covariates(ds: Dataset, dosage: Optional[str] = None) -> Array:
def _get_loop_covariates(
ds: Dataset, dosage: Optional[str] = None, call_genotype: str = "call_genotype"
) -> Array:
if dosage is None:
# TODO: This should be (probably gwas-specific) allele
# count with sex chromosome considerations
G = ds["call_genotype"].sum(dim="ploidy") # pragma: no cover
G = ds[call_genotype].sum(dim="ploidy") # pragma: no cover
else:
G = ds[dosage]
return da.asarray(G.data)
Expand All @@ -121,6 +123,7 @@ def gwas_linear_regression(
covariates: Union[str, Sequence[str]],
traits: Union[str, Sequence[str]],
add_intercept: bool = True,
call_genotype: str = "call_genotype",
merge: bool = True,
) -> Dataset:
"""Run linear regression to identify continuous trait associations with genetic variants.
Expand All @@ -147,6 +150,9 @@ def gwas_linear_regression(
Defined by :data:`sgkit.variables.traits`.
add_intercept
Add intercept term to covariate set, by default True.
call_genotype
Input variable name holding call_genotype.
Defined by :data:`sgkit.variables.call_genotype`.
merge
If True (the default), merge the input dataset and the computed
output variables into a single dataset, otherwise return only
Expand Down Expand Up @@ -198,7 +204,7 @@ def gwas_linear_regression(
{t: variables.traits for t in traits},
)

G = _get_loop_covariates(ds, dosage=dosage)
G = _get_loop_covariates(ds, dosage=dosage, call_genotype=call_genotype)

X = da.asarray(concat_2d(ds[list(covariates)], dims=("samples", "covariates")))
if add_intercept:
Expand All @@ -216,9 +222,9 @@ def gwas_linear_regression(
res = linear_regression(G.T, X, Y)
new_ds = xr.Dataset(
{
"variant_beta": (("variants", "traits"), res.beta),
"variant_t_value": (("variants", "traits"), res.t_value),
"variant_p_value": (("variants", "traits"), res.p_value),
variables.variant_beta: (("variants", "traits"), res.beta),
variables.variant_t_value: (("variants", "traits"), res.t_value),
variables.variant_p_value: (("variants", "traits"), res.p_value),
}
)
return conditional_merge_datasets(ds, variables.validate(new_ds), merge)
2 changes: 1 addition & 1 deletion sgkit/stats/hwe.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,5 +203,5 @@ def hardy_weinberg_test(
cts = [1, 0, 2] # arg order: hets, hom1, hom2
obs = [da.asarray((AC == ct).sum(dim="samples")) for ct in cts]
p = da.map_blocks(hardy_weinberg_p_value_vec_jit, *obs)
new_ds = xr.Dataset({"variant_hwe_p_value": ("variants", p)})
new_ds = xr.Dataset({variables.variant_hwe_p_value: ("variants", p)})
return conditional_merge_datasets(ds, variables.validate(new_ds), merge)
2 changes: 1 addition & 1 deletion sgkit/stats/pc_relate.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,5 +170,5 @@ def pc_relate(
phi = gramian(centered_af) / gramian(stddev)
# NOTE: phi is of shape (S x S), S = num samples
assert phi.shape == (call_g.shape[1],) * 2
new_ds = xr.Dataset({"pc_relate_phi": (("sample_x", "sample_y"), phi)})
new_ds = xr.Dataset({variables.pc_relate_phi: (("sample_x", "sample_y"), phi)})
return conditional_merge_datasets(ds, variables.validate(new_ds), merge)
6 changes: 3 additions & 3 deletions sgkit/stats/regenie.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,16 +708,16 @@ def regenie_transform(
YP3 = _stage_3(B2, YP1, X, Y, contigs, variant_chunk_start)

data_vars: Dict[Hashable, Any] = {}
data_vars["base_prediction"] = xr.DataArray(
data_vars[variables.base_prediction] = xr.DataArray(
YP1,
dims=("blocks", "alphas", "samples", "outcomes"),
attrs={"description": DESC_BASE_PRED},
)
data_vars["meta_prediction"] = xr.DataArray(
data_vars[variables.meta_prediction] = xr.DataArray(
YP2, dims=("samples", "outcomes"), attrs={"description": DESC_META_PRED}
)
if YP3 is not None:
data_vars["loco_prediction"] = xr.DataArray(
data_vars[variables.loco_prediction] = xr.DataArray(
YP3,
dims=("contigs", "samples", "outcomes"),
attrs={"description": DESC_LOCO_PRED},
Expand Down
7 changes: 6 additions & 1 deletion sgkit/tests/test_association.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from sgkit.stats.association import gwas_linear_regression, linear_regression
from sgkit.typing import ArrayLike
from sgkit.variables import Spec

with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
Expand Down Expand Up @@ -133,7 +134,11 @@ def _get_statistics(
res = _sm_statistics(ds, i, add_intercept)
df_pred.append(
dsr.to_dataframe()
.rename(columns=lambda c: c.replace("variant_", ""))
.rename(
columns=lambda c: c.default_name.replace("variant_", "")
if isinstance(c, Spec)
else c.replace("variant_", "")
)
.iloc[i]
.to_dict()
)
Expand Down
16 changes: 16 additions & 0 deletions sgkit/tests/test_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,19 @@ def test_variables__whole_ds(dummy_ds: xr.Dataset) -> None:
finally:
SgkitVariables.registered_variables.pop("foo", None)
SgkitVariables.registered_variables.pop("bar", None)


def test_variables__eq_hash(dummy_ds: xr.Dataset) -> None:
spec_foo = ArrayLikeSpec("foo", kind="i", ndim=1)
spec_foo_2 = ArrayLikeSpec("foo", kind="i", ndim=2)
assert spec_foo == "foo"
assert spec_foo != "foo2"
assert "foo" == spec_foo
assert "foo2" != spec_foo
assert {spec_foo: 3}[spec_foo] == 3
assert {spec_foo: 3}["foo"] == 3 # type: ignore[index]
assert "foo" in {spec_foo: 3} # type: ignore[comparison-overlap]
assert spec_foo in dummy_ds
assert type(dummy_ds[spec_foo]) == xr.DataArray
assert dummy_ds[spec_foo].name == "foo"
assert spec_foo != spec_foo_2
45 changes: 40 additions & 5 deletions sgkit/variables.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,55 @@
import logging
from dataclasses import dataclass
from typing import Dict, Hashable, Mapping, Set, Union, overload
from dataclasses import dataclass, fields
from typing import Any, Dict, Hashable, Mapping, Set, Union, overload

import xarray as xr

logger = logging.getLogger(__name__)


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False, repr=False)
class Spec:
"""Root type Spec"""

default_name: str

# Note: these eq,hash dunder methods make Spec essentially
# act/look like a string-like object. When
# used in the context of a dictionary (in xr.Dataset
# variables or the sgkit variable registry).
# This essentially means that you can do things like:
#
# foo = ArrayLikeSpec("foo", 3, "u")
# foo == "foo" # True
# {foo: 3}[foo] == {foo: 3}["foo"] # True
# ds = xr.Dataset({foo: (("sample_x", "sample_y"), phi)})
# ds[foo]
#
# This makes sgkit variables API more concise, but
# it does introduce some level of complexity if
# a dev doesn't expect equality like foo == "foo".
# I'm not entirely sure it's a good idea yet.

@dataclass(frozen=True)
def __eq__(self, other: Any) -> bool:
if isinstance(other, self.__class__):
return all(
getattr(self, f.name) == getattr(other, f.name) for f in fields(self)
)
elif isinstance(other, str):
return self.default_name == other
return NotImplemented

def __hash__(self) -> int:
return self.default_name.__hash__()

def __repr__(self) -> str:
return self.default_name.__repr__()

def __str__(self) -> str:
return self.default_name.__str__()


@dataclass(frozen=True, eq=False, repr=False)
class ArrayLikeSpec(Spec):
"""ArrayLike type spec"""

Expand Down Expand Up @@ -161,7 +196,7 @@ class SgkitVariables:
"""Holds registry of Sgkit variables, and can validate a dataset against a spec"""

registered_variables: Dict[Hashable, ArrayLikeSpec] = {
x.default_name: x for x in globals().values() if isinstance(x, ArrayLikeSpec)
x: x for x in globals().values() if isinstance(x, ArrayLikeSpec)
}

@classmethod
Expand Down

0 comments on commit b6197a9

Please sign in to comment.