From a2c01faded277ece3492d9e21dc7068f34b86427 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Thu, 31 Mar 2022 18:34:59 +0100 Subject: [PATCH] Add copy to lowlevel python variant --- python/_tskitmodule.c | 44 ++++++++++++++++++++++++++++++++++- python/tests/test_lowlevel.py | 37 +++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 1 deletion(-) diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index fd4064a653..f4cbdeb5fa 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -11306,6 +11306,9 @@ static PyTypeObject VariantGeneratorType = { *=================================================================== */ +/* Forward declaration */ +static PyTypeObject VariantType; + static int Variant_check_state(Variant *self) { @@ -11377,7 +11380,7 @@ Variant_init(Variant *self, PyObject *args, PyObject *kwds) goto out; } } - self->variant = PyMem_Malloc(sizeof(tsk_vargen_t)); + self->variant = PyMem_Malloc(sizeof(tsk_variant_t)); if (self->variant == NULL) { PyErr_NoMemory(); goto out; @@ -11424,6 +11427,41 @@ Variant_decode(Variant *self, PyObject *args, PyObject *kwds) return ret; } +static PyObject * +Variant_restricted_copy(Variant *self) +{ + int err; + PyObject *ret = NULL; + Variant *copy = NULL; + + if (Variant_check_state(self) != 0) { + goto out; + } + copy = (Variant *) _PyObject_New((PyTypeObject *) &VariantType); + if (copy == NULL) { + goto out; + } + copy->variant = PyMem_Malloc(sizeof(tsk_variant_t)); + if (copy->variant == NULL) { + PyErr_NoMemory(); + goto out; + } + /* Copies have no ts as a way of indicating they shouldn't be decoded + This is safe as the copy has no reference to the mutation state strings */ + copy->tree_sequence = NULL; + err = tsk_variant_restricted_copy(self->variant, copy->variant); + if (err != 0) { + handle_library_error(err); + goto out; + } + + ret = (PyObject *) copy; + copy = NULL; +out: + Py_XDECREF(copy); + return ret; +} + static PyObject * Variant_get_site_id(Variant *self, void *closure) { @@ -11498,6 +11536,10 @@ static PyMethodDef Variant_methods[] = { .ml_meth = (PyCFunction) Variant_decode, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Sets the variant's genotypes to those of a given tree and site" }, + { .ml_name = "restricted_copy", + .ml_meth = (PyCFunction) Variant_restricted_copy, + .ml_flags = METH_NOARGS, + .ml_doc = "Copies the variant" }, { NULL } /* Sentinel */ }; diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index 01e7d6904d..3c09300058 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -2431,6 +2431,43 @@ def test_variants_lifecycle(self): del variant assert np.array_equal(genotypes, expected) + def test_copy(self): + ts = self.get_example_tree_sequence(random_seed=42) + variant = _tskit.Variant(ts) + variant.decode(0) + # Everything below should work even if the Python ts is free'd + del ts + variant2 = variant.restricted_copy() + # Take a copy for comparison, then move the variant to check the copy + # doesn't move too + genotypes = variant.genotypes + genotypes_copy = np.array(variant.genotypes) + alleles = variant.alleles + site_id = variant.site_id + variant.decode(1) + with pytest.raises( + tskit.LibraryError, match="Can't decode a copy of a variant" + ): + variant2.decode(1) + assert site_id == variant2.site_id + assert alleles == variant2.alleles + + # Variant should be equal to the copy we took earlier + assert np.array_equal(genotypes_copy, variant2.genotypes) + # But not equal to the un-copies genotypes anymore as they + # have decoded a new site as a side effect of reusing the + # array when decoding + assert not np.array_equal(genotypes, variant2.genotypes) + + # Check the lifecycle of copies and copies of copies + del variant + variant3 = variant2.restricted_copy() + del variant2 + assert np.array_equal(genotypes_copy, variant3.genotypes) + genotypes3 = variant3.genotypes + del variant3 + assert np.array_equal(genotypes_copy, genotypes3) + class TestLdCalculator(LowLevelTestCase): """