From b061737bbb29e4ca148a11fb6ad40431ce632997 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Fri, 1 Apr 2022 14:12:30 +0100 Subject: [PATCH] Refactor TreeSequence.variants to use new decode method. --- python/CHANGELOG.rst | 17 +++- python/_tskitmodule.c | 8 +- python/tests/test_genotypes.py | 25 +----- python/tests/test_highlevel.py | 23 +++-- python/tests/test_topology.py | 71 +++++++-------- python/tskit/trees.py | 157 +++++++++++++++++++-------------- 6 files changed, 155 insertions(+), 146 deletions(-) diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 9f509e4d97..73c64c8a20 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -1,5 +1,5 @@ ---------------------- -[0.4.2] - 2022-0X-XX +[0.5.0] - 2022-0X-XX ---------------------- **Changes** @@ -10,14 +10,25 @@ - Make dumping of tables and tree seqences to disk a zero-copy operation. (:user:`benjeffery`, :issue:`2111`, :pr:`2124`) +- Add ``return_variant_copies`` argument to ``TreeSequence.variants`` which if False reuses the + returned ``Variant`` object for improved performance. Defaults to True. + (:user:`benjeffery`, :issue:`605`, :pr:`2172`) + +- ``tree.mrca`` now takes 2 or more arguments and gives the common ancestor of them all. + (:user:`savitakartik`, :issue:`1340`, :pr:`2121`) + **Breaking Changes** - The JSON metadata codec now interprets the empty string as an empty object. This means that applying a schema to an existing table will no longer necessitate modifying the existing rows. (:user:`benjeffery`, :issue:`2064`, :pr:`2104`) -- ``tree.mrca`` now takes 2 or more arguments. - (:user:`savitakartik`, :issue:`1340`, :pr:`2121`) +- Remove the previously deprecated ``as_bytes`` argument to ``TreeSequence.variants``. + If you need genotypes in byte form this can be done following the code in the + ``to_macs`` method on line ``5573`` of ``trees.py``. + This argument was initially deprecated more than 3 years ago when the code was part of + ``msprime``. + (:user:`benjeffery`, :issue:`605`, :pr:`2172`) ---------------------- [0.4.1] - 2022-01-11 diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index f4cbdeb5fa..f6a61f9ca9 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -11402,18 +11402,16 @@ Variant_init(Variant *self, PyObject *args, PyObject *kwds) } static PyObject * -Variant_decode(Variant *self, PyObject *args, PyObject *kwds) +Variant_decode(Variant *self, PyObject *args) { int err; PyObject *ret = NULL; tsk_id_t site_id; - static char *kwlist[] = { "site", NULL }; if (Variant_check_state(self) != 0) { goto out; } - if (!PyArg_ParseTupleAndKeywords( - args, kwds, "O&", kwlist, &tsk_id_converter, &site_id)) { + if (!PyArg_ParseTuple(args, "O&", &tsk_id_converter, &site_id)) { goto out; } err = tsk_variant_decode(self->variant, site_id, 0); @@ -11534,7 +11532,7 @@ static PyGetSetDef Variant_getsetters[] static PyMethodDef Variant_methods[] = { { .ml_name = "decode", .ml_meth = (PyCFunction) Variant_decode, - .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_flags = METH_VARARGS, .ml_doc = "Sets the variant's genotypes to those of a given tree and site" }, { .ml_name = "restricted_copy", .ml_meth = (PyCFunction) Variant_restricted_copy, diff --git a/python/tests/test_genotypes.py b/python/tests/test_genotypes.py index 9146fa633b..9db1f64afa 100644 --- a/python/tests/test_genotypes.py +++ b/python/tests/test_genotypes.py @@ -191,29 +191,6 @@ def get_tree_sequence(self): assert ts.get_num_mutations() > 10 return ts - def test_as_bytes(self): - ts = self.get_tree_sequence() - n = ts.get_sample_size() - m = ts.get_num_mutations() - A = np.zeros((m, n), dtype="u1") - B = np.zeros((m, n), dtype="u1") - for variant in ts.variants(): - A[variant.index] = variant.genotypes - for variant in ts.variants(as_bytes=True): - assert isinstance(variant.genotypes, bytes) - B[variant.index] = np.frombuffer(variant.genotypes, np.uint8) - ord("0") - assert np.all(A == B) - bytes_variants = list(ts.variants(as_bytes=True)) - for j, variant in enumerate(bytes_variants): - assert j == variant.index - row = np.frombuffer(variant.genotypes, np.uint8) - ord("0") - assert np.all(A[j] == row) - - def test_as_bytes_fails(self): - ts = tsutil.insert_multichar_mutations(self.get_tree_sequence()) - with pytest.raises(ValueError): - list(ts.variants(as_bytes=True)) - def test_dtype(self): ts = self.get_tree_sequence() for var in ts.variants(): @@ -913,7 +890,7 @@ def test_simple_01_duplicate_alleles(self): ): assert v2.alleles == alleles assert v1.site == v2.site - g = v1.genotypes + g = np.array(v1.genotypes) index = np.where(g == 1) g[index] = 2 assert np.array_equal(g, v2.genotypes) diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index 9f29f19b30..fb9132a806 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -4196,18 +4196,6 @@ def get_instances(self, n): return [tskit.Edgeset(left=j, right=j, parent=j, children=j) for j in range(n)] -class TestVariantContainer(SimpleContainersMixin): - def get_instances(self, n): - return [ - tskit.Variant( - site=TestSiteContainer().get_instances(1)[0], - alleles=["A" * j, "T"], - genotypes=np.zeros(j, dtype=np.int8), - ) - for j in range(n) - ] - - class TestContainersAppend: def test_containers_append(self, ts_fixture): """ @@ -4262,3 +4250,14 @@ def test_macs(self): assert len(col) == n for j in range(n): assert col[j] == haplotypes[j][site_id] + + def test_macs_error(self): + tables = tskit.TableCollection(1) + tables.sites.add_row(position=0.5, ancestral_state="A") + tables.nodes.add_row(time=1, flags=tskit.NODE_IS_SAMPLE) + tables.mutations.add_row(node=0, site=0, derived_state="FOO") + ts = tables.tree_sequence() + with pytest.raises( + ValueError, match="macs output only supports single letter alleles" + ): + ts.to_macs() diff --git a/python/tests/test_topology.py b/python/tests/test_topology.py index a31ac4b077..0be896a3da 100644 --- a/python/tests/test_topology.py +++ b/python/tests/test_topology.py @@ -1686,9 +1686,12 @@ def assert_haplotypes_equal(self, ts1, ts2): assert h1 == h2 def assert_variants_equal(self, ts1, ts2): - v1 = list(ts1.variants(as_bytes=True)) - v2 = list(ts2.variants(as_bytes=True)) - assert v1 == v2 + for v1, v2 in zip( + ts1.variants(copy=False), + ts2.variants(copy=False), + ): + assert v1.alleles == v2.alleles + assert np.array_equal(v1.genotypes, v2.genotypes) def check_num_samples(self, ts, x): """ @@ -2541,9 +2544,12 @@ def verify_permuted_nodes(self, ts): assert ts.sequence_length == permuted.sequence_length assert list(permuted.samples()) == samples assert list(permuted.haplotypes()) == list(ts.haplotypes()) - assert [v.genotypes for v in permuted.variants(as_bytes=True)] == [ - v.genotypes for v in ts.variants(as_bytes=True) - ] + for v1, v2 in zip( + permuted.variants(copy=False), + ts.variants(copy=False), + ): + assert np.array_equal(v1.genotypes, v2.genotypes) + assert ts.num_trees == permuted.num_trees j = 0 for t1, t2 in zip(ts.trees(), permuted.trees()): @@ -3355,9 +3361,10 @@ def test_simplest_degenerate_case(self): assert t.parent_dict == {} assert sorted(t.roots) == [0, 1] assert list(ts.haplotypes(isolated_as_missing=False)) == ["10", "01"] - assert [ - v.genotypes for v in ts.variants(as_bytes=True, isolated_as_missing=False) - ] == [b"10", b"01"] + assert np.array_equal( + np.stack([v.genotypes for v in ts.variants(isolated_as_missing=False)]), + [[1, 0], [0, 1]], + ) simplified = ts.simplify() t1 = ts.dump_tables() t2 = simplified.dump_tables() @@ -3412,12 +3419,10 @@ def test_simplest_non_degenerate_case(self): t = next(ts.trees()) assert t.parent_dict == {0: 4, 1: 4, 2: 5, 3: 5} assert list(ts.haplotypes()) == ["1000", "0100", "0010", "0001"] - assert [v.genotypes for v in ts.variants(as_bytes=True)] == [ - b"1000", - b"0100", - b"0010", - b"0001", - ] + assert np.array_equal( + np.stack([v.genotypes for v in ts.variants()]), + [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], + ) assert t.mrca(0, 1) == 4 assert t.mrca(0, 4) == 4 assert t.mrca(2, 3) == 5 @@ -3489,13 +3494,10 @@ def test_two_reducible_trees(self): t = next(ts.trees()) assert t.parent_dict == {0: 4, 1: 5, 2: 7, 3: 7, 4: 6, 5: 6, 8: 7} assert list(ts.haplotypes()) == ["10000", "01000", "00100", "00010"] - assert [v.genotypes for v in ts.variants(as_bytes=True)] == [ - b"1000", - b"0100", - b"0010", - b"0001", - b"0000", - ] + assert np.array_equal( + np.stack([v.genotypes for v in ts.variants()]), + [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1], [0, 0, 0, 0]], + ) assert t.mrca(0, 1) == 6 assert t.mrca(2, 3) == 7 assert t.mrca(2, 8) == 7 @@ -3508,14 +3510,11 @@ def test_two_reducible_trees(self): assert ts_simplified.num_nodes == 6 assert ts_simplified.num_trees == 1 t = next(ts_simplified.trees()) - # print(ts_simplified.tables) assert list(ts_simplified.haplotypes()) == ["1000", "0100", "0010", "0001"] - assert [v.genotypes for v in ts_simplified.variants(as_bytes=True)] == [ - b"1000", - b"0100", - b"0010", - b"0001", - ] + assert np.array_equal( + np.stack([v.genotypes for v in ts_simplified.variants()]), + [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], + ) # The site over the non-sample external node should have been discarded. sites = list(t.sites()) assert sites[-1].position == 0.4 @@ -3618,15 +3617,17 @@ def test_mutations_over_roots(self): t = next(ts.trees()) assert len(list(t.sites())) == 6 haplotypes = ["101100", "011100", "000011"] - variants = [b"100", b"010", b"110", b"110", b"001", b"001"] + variants = [[1, 0, 0], [0, 1, 0], [1, 1, 0], [1, 1, 0], [0, 0, 1], [0, 0, 1]] assert list(ts.haplotypes()) == haplotypes - assert [v.genotypes for v in ts.variants(as_bytes=True)] == variants + assert np.array_equal(np.stack([v.genotypes for v in ts.variants()]), variants) ts_simplified = ts.simplify(filter_sites=False) assert list(ts_simplified.haplotypes(isolated_as_missing=False)) == haplotypes - assert variants == [ - v.genotypes - for v in ts_simplified.variants(as_bytes=True, isolated_as_missing=False) - ] + assert np.array_equal( + np.stack( + [v.genotypes for v in ts_simplified.variants(isolated_as_missing=False)] + ), + variants, + ) def test_break_single_tree(self): # Take a single largish tree from tskit, and remove the oldest record. diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 2a1f9585c3..71514396cc 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -24,6 +24,8 @@ """ Module responsible for managing trees and tree sequences. """ +from __future__ import annotations + import base64 import collections import concurrent.futures @@ -35,8 +37,6 @@ from dataclasses import dataclass from typing import Any from typing import NamedTuple -from typing import Optional -from typing import Union import numpy as np @@ -131,7 +131,7 @@ class Individual(util.Dataclass): a numpy array (dtype=np.int32). If no nodes are associated with the individual this array will be empty. """ - metadata: Optional[Union[bytes, dict]] + metadata: bytes | dict | None """ The :ref:`metadata ` for this individual, decoded if a schema applies. @@ -183,7 +183,7 @@ class Node(util.Dataclass): """ The integer ID of the individual that this node was a part of. """ - metadata: Optional[Union[bytes, dict]] + metadata: bytes | dict | None """ The :ref:`metadata ` for this node, decoded if a schema applies. @@ -230,7 +230,7 @@ class Edge(util.Dataclass): To obtain further information about a node with a given ID, use :meth:`TreeSequence.node`. """ - metadata: Optional[Union[bytes, dict]] + metadata: bytes | dict | None """ The :ref:`metadata ` for this edge, decoded if a schema applies. @@ -291,7 +291,7 @@ class Site(util.Dataclass): The list of mutations at this site. Mutations within a site are returned in the order they are specified in the underlying :class:`MutationTable`. """ - metadata: Optional[Union[bytes, dict]] + metadata: bytes | dict | None """ The :ref:`metadata ` for this site, decoded if a schema applies. @@ -354,7 +354,7 @@ class Mutation(util.Dataclass): To obtain further information about a mutation with a given ID, use :meth:`TreeSequence.mutation`. """ - metadata: Optional[Union[bytes, dict]] + metadata: bytes | dict | None """ The :ref:`metadata ` for this mutation, decoded if a schema applies. @@ -441,7 +441,7 @@ class Migration(util.Dataclass): """ The time at which this migration occurred at. """ - metadata: Optional[Union[bytes, dict]] + metadata: bytes | dict | None """ The :ref:`metadata ` for this migration, decoded if a schema applies. @@ -469,15 +469,14 @@ class Population(util.Dataclass): The integer ID of this population. Varies from 0 to :attr:`TreeSequence.num_populations` - 1. """ - metadata: Optional[Union[bytes, dict]] + metadata: bytes | dict | None """ The :ref:`metadata ` for this population, decoded if a schema applies. """ -@dataclass -class Variant(util.Dataclass): +class Variant: """ A variant in a tree sequence, describing the observed genetic variation among samples for a given site. A variant consists (a) of a reference to @@ -523,38 +522,51 @@ class Variant(util.Dataclass): As ``tskit.MISSING_DATA`` is equal to -1, code that decodes genotypes into allelic values without taking missing data into account would otherwise incorrectly output the last allele in the list. - - Modifying the attributes in this class will have **no effect** on the - underlying tree sequence data. """ - __slots__ = ["site", "alleles", "genotypes"] - site: Site - """ - The site object for this variant. - """ - alleles: tuple - """ - A tuple of the allelic values that may be observed at the - samples at the current site. The first element of this tuple is always - the site's ancestral state. - """ - genotypes: np.ndarray - """ - An array of indexes into the list ``alleles``, giving the - state of each sample at the current site. - """ + def __init__(self, tree_sequence, samples, isolated_as_missing, alleles): + self.tree_sequence = tree_sequence + self._ll_variant = _tskit.Variant( + tree_sequence._ll_tree_sequence, + samples=samples, + isolated_as_missing=isolated_as_missing, + alleles=alleles, + ) @property - def has_missing_data(self): + def site(self) -> Site: + """ + The site object for this variant. + """ + return self.tree_sequence.site(self._ll_variant.site_id) + + @property + def alleles(self) -> tuple: + """ + A tuple of the allelic values that may be observed at the + samples at the current site. The first element of this tuple is always + the site's ancestral state. + """ + return self._ll_variant.alleles + + @property + def genotypes(self) -> np.ndarray: + """ + An array of indexes into the list ``alleles``, giving the + state of each sample at the current site. + """ + return self._ll_variant.genotypes + + @property + def has_missing_data(self) -> bool: """ True if there is missing data for any of the samples at the current site. """ - return self.alleles[-1] is None + return self._ll_variant.alleles[-1] is None @property - def num_alleles(self): + def num_alleles(self) -> int: """ The number of distinct alleles at this site. Note that this may be greater than the number of distinct values in the genotypes @@ -564,23 +576,33 @@ def num_alleles(self): # Deprecated alias to avoid breaking existing code. @property - def position(self): + def position(self) -> float: return self.site.position # Deprecated alias to avoid breaking existing code. @property - def index(self): - return self.site.id + def index(self) -> int: + return self._ll_variant.site_id # We need a custom eq for the numpy array - def __eq__(self, other): + def __eq__(self, other) -> bool: return ( isinstance(other, Variant) - and self.site == other.site - and self.alleles == other.alleles - and np.array_equal(self.genotypes, other.genotypes) + and self.tree_sequence == other.tree_sequence + and self._ll_variant.site_id == other._ll_variant.site_id + and self._ll_variant.alleles == other._ll_variant.alleles + and np.array_equal(self._ll_variant.genotypes, other._ll_variant.genotypes) ) + def decode(self, site_id) -> None: + self._ll_variant.decode(site_id) + + def copy(self) -> Variant: + variant_copy = Variant.__new__(Variant) + variant_copy.tree_sequence = self.tree_sequence + variant_copy._ll_variant = self._ll_variant.restricted_copy() + return variant_copy + @dataclass class Edgeset(util.Dataclass): @@ -3924,7 +3946,7 @@ def num_samples(self): return self._ll_tree_sequence.get_num_samples() @property - def table_metadata_schemas(self) -> "_TableMetadataSchemas": + def table_metadata_schemas(self) -> _TableMetadataSchemas: """ The set of metadata schemas for the tables in this tree sequence. """ @@ -4652,11 +4674,11 @@ def haplotypes( def variants( self, *, - as_bytes=False, samples=None, isolated_as_missing=None, alleles=None, impute_missing_data=None, + copy=None, ): """ Returns an iterator over the variants (each site with its genotypes @@ -4697,14 +4719,6 @@ def variants( state (this was the default behaviour in versions prior to 0.2.0). Prior to 0.3.0 the `impute_missing_data` argument controlled this behaviour. - .. note:: - The ``as_bytes`` parameter is kept as a compatibility - option for older code. It is not the recommended way of - accessing variant data, and will be deprecated in a later - release. - - :param bool as_bytes: If True, the genotype values will be returned - as a Python bytes object. Legacy use only. :param array_like samples: An array of node IDs for which to generate genotypes, or None for all sample nodes. Default: None. :param bool isolated_as_missing: If True, the genotype value assigned to @@ -4721,6 +4735,10 @@ def variants( :param bool impute_missing_data: *Deprecated in 0.3.0. Use ``isolated_as_missing``, but inverting value. Will be removed in a future version* + :param bool copy: + If False re-use the same Variant object for each site such that any + references held to it are overwritten when the next site is visited. + If True return a fresh :class:`Variant` for each site. Default: True. :return: An iterator over all variants in this tree sequence. :rtype: iter(:class:`Variant`) """ @@ -4734,26 +4752,26 @@ def variants( # Only use impute_missing_data if isolated_as_missing has the default value if isolated_as_missing is None: isolated_as_missing = not impute_missing_data + if copy is None: + copy = True + # See comments for the Variant type for discussion on why the # present form was chosen. - iterator = _tskit.VariantGenerator( - self._ll_tree_sequence, + variant = tskit.Variant( + self, samples=samples, isolated_as_missing=isolated_as_missing, alleles=alleles, ) - for site_id, genotypes, alleles in iterator: - site = self.site(site_id) - if as_bytes: - if any(len(allele) > 1 for allele in alleles): - raise ValueError( - "as_bytes only supported for single-letter alleles" - ) - bytes_genotypes = np.empty(self.num_samples, dtype=np.uint8) - lookup = np.array([ord(a[0]) for a in alleles], dtype=np.uint8) - bytes_genotypes[:] = lookup[genotypes] - genotypes = bytes_genotypes.tobytes() - yield Variant(site, alleles, genotypes) + sites = range(self.num_sites) + if copy: + for site_id in sites: + variant.decode(site_id) + yield variant.copy() + else: + for site_id in sites: + variant.decode(site_id) + yield variant def genotype_matrix( self, *, isolated_as_missing=None, alleles=None, impute_missing_data=None @@ -5552,10 +5570,15 @@ def to_macs(self): m = self.get_sequence_length() output = [f"COMMAND:\tnot_macs {n} {m}"] output.append("SEED:\tASEED") - for variant in self.variants(as_bytes=True): + for variant in self.variants(copy=False): + if any(len(allele) > 1 for allele in variant.alleles): + raise ValueError("macs output only supports single letter alleles") + bytes_genotypes = np.empty(self.num_samples, dtype=np.uint8) + lookup = np.array([ord(a[0]) for a in variant.alleles], dtype=np.uint8) + bytes_genotypes[:] = lookup[variant.genotypes] + genotypes = bytes_genotypes.tobytes().decode() output.append( - f"SITE:\t{variant.index}\t{variant.position / m}\t0.0\t" - f"{variant.genotypes.decode()}" + f"SITE:\t{variant.index}\t{variant.position / m}\t0.0\t" f"{genotypes}" ) return "\n".join(output) + "\n"