Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion python/_tskitmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -11306,6 +11306,9 @@ static PyTypeObject VariantGeneratorType = {
*===================================================================
*/

/* Forward declaration */
static PyTypeObject VariantType;

static int
Variant_check_state(Variant *self)
{
Expand Down Expand Up @@ -11377,7 +11380,7 @@ Variant_init(Variant *self, PyObject *args, PyObject *kwds)
goto out;
}
}
self->variant = PyMem_Malloc(sizeof(tsk_vargen_t));
self->variant = PyMem_Malloc(sizeof(tsk_variant_t));
if (self->variant == NULL) {
PyErr_NoMemory();
goto out;
Expand Down Expand Up @@ -11424,6 +11427,41 @@ Variant_decode(Variant *self, PyObject *args, PyObject *kwds)
return ret;
}

static PyObject *
Variant_restricted_copy(Variant *self)
{
int err;
PyObject *ret = NULL;
Variant *copy = NULL;

if (Variant_check_state(self) != 0) {
goto out;
}
copy = (Variant *) _PyObject_New((PyTypeObject *) &VariantType);
if (copy == NULL) {
goto out;
}
copy->variant = PyMem_Malloc(sizeof(tsk_variant_t));
if (copy->variant == NULL) {
PyErr_NoMemory();
goto out;
}
/* Copies have no ts as a way of indicating they shouldn't be decoded
This is safe as the copy has no reference to the mutation state strings */
copy->tree_sequence = NULL;
err = tsk_variant_restricted_copy(self->variant, copy->variant);
if (err != 0) {
handle_library_error(err);
goto out;
}

ret = (PyObject *) copy;
copy = NULL;
out:
Py_XDECREF(copy);
return ret;
}

static PyObject *
Variant_get_site_id(Variant *self, void *closure)
{
Expand Down Expand Up @@ -11498,6 +11536,10 @@ static PyMethodDef Variant_methods[] = {
.ml_meth = (PyCFunction) Variant_decode,
.ml_flags = METH_VARARGS | METH_KEYWORDS,
.ml_doc = "Sets the variant's genotypes to those of a given tree and site" },
{ .ml_name = "restricted_copy",
.ml_meth = (PyCFunction) Variant_restricted_copy,
.ml_flags = METH_NOARGS,
.ml_doc = "Copies the variant" },
{ NULL } /* Sentinel */
};

Expand Down
37 changes: 37 additions & 0 deletions python/tests/test_lowlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2431,6 +2431,43 @@ def test_variants_lifecycle(self):
del variant
assert np.array_equal(genotypes, expected)

def test_copy(self):
ts = self.get_example_tree_sequence(random_seed=42)
variant = _tskit.Variant(ts)
variant.decode(0)
# Everything below should work even if the Python ts is free'd
del ts
variant2 = variant.restricted_copy()
# Take a copy for comparison, then move the variant to check the copy
# doesn't move too
genotypes = variant.genotypes
genotypes_copy = np.array(variant.genotypes)
alleles = variant.alleles
site_id = variant.site_id
variant.decode(1)
with pytest.raises(
tskit.LibraryError, match="Can't decode a copy of a variant"
):
variant2.decode(1)
assert site_id == variant2.site_id
assert alleles == variant2.alleles

# Variant should be equal to the copy we took earlier
assert np.array_equal(genotypes_copy, variant2.genotypes)
# But not equal to the un-copies genotypes anymore as they
# have decoded a new site as a side effect of reusing the
# array when decoding
assert not np.array_equal(genotypes, variant2.genotypes)

# Check the lifecycle of copies and copies of copies
del variant
variant3 = variant2.restricted_copy()
del variant2
assert np.array_equal(genotypes_copy, variant3.genotypes)
genotypes3 = variant3.genotypes
del variant3
assert np.array_equal(genotypes_copy, genotypes3)


class TestLdCalculator(LowLevelTestCase):
"""
Expand Down