From 276bec73b2f401aa20d646ad4dc7aeeae2448372 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Tue, 15 Mar 2022 22:52:55 +0000 Subject: [PATCH] Add initial lowlevel Variant class --- c/tskit/genotypes.c | 3 + python/_tskitmodule.c | 228 ++++++++++++++++++++++++++++++++++ python/tests/test_lowlevel.py | 99 +++++++++++++++ 3 files changed, 330 insertions(+) diff --git a/c/tskit/genotypes.c b/c/tskit/genotypes.c index 8595c5b963..e37642c1ca 100644 --- a/c/tskit/genotypes.c +++ b/c/tskit/genotypes.c @@ -142,6 +142,9 @@ tsk_variant_init(tsk_variant_t *self, const tsk_treeseq_t *tree_sequence, tsk_memset(self, 0, sizeof(tsk_variant_t)); + /* Set site id to NULL to indicate the variant is not decoded */ + self->site.id = TSK_NULL; + self->tree_sequence = tree_sequence; ret = tsk_tree_init( &self->tree, tree_sequence, samples == NULL ? TSK_SAMPLE_LISTS : 0); diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 3483e2491d..fd4064a653 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -154,6 +154,12 @@ typedef struct { tsk_vargen_t *variant_generator; } VariantGenerator; +typedef struct { + PyObject_HEAD + TreeSequence *tree_sequence; + tsk_variant_t *variant; +} Variant; + typedef struct { PyObject_HEAD TreeSequence *tree_sequence; @@ -11295,6 +11301,221 @@ static PyTypeObject VariantGeneratorType = { // clang-format on }; +/*=================================================================== + * Variant + *=================================================================== + */ + +static int +Variant_check_state(Variant *self) +{ + int ret = 0; + if (self->variant == NULL) { + PyErr_SetString(PyExc_SystemError, "variant not initialised"); + ret = -1; + } + return ret; +} + +static void +Variant_dealloc(Variant *self) +{ + if (self->variant != NULL) { + tsk_variant_free(self->variant); + PyMem_Free(self->variant); + self->variant = NULL; + } + Py_XDECREF(self->tree_sequence); + Py_TYPE(self)->tp_free((PyObject *) self); +} + +static int +Variant_init(Variant *self, PyObject *args, PyObject *kwds) +{ + int ret = -1; + int err; + static char *kwlist[] + = { "tree_sequence", "samples", "isolated_as_missing", "alleles", NULL }; + TreeSequence *tree_sequence = NULL; + PyObject *samples_input = Py_None; + PyObject *py_alleles = Py_None; + PyArrayObject *samples_array = NULL; + tsk_id_t *samples = NULL; + tsk_size_t num_samples = 0; + int isolated_as_missing = 1; + const char **alleles = NULL; + npy_intp *shape; + tsk_flags_t options = 0; + + self->variant = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!|OiO", kwlist, &TreeSequenceType, + &tree_sequence, &samples_input, &isolated_as_missing, &py_alleles)) { + goto out; + } + if (!isolated_as_missing) { + options |= TSK_ISOLATED_NOT_MISSING; + } + /* tsk_variant_t holds a reference to the tree sequence so we must too*/ + self->tree_sequence = tree_sequence; + Py_INCREF(self->tree_sequence); + if (TreeSequence_check_state(self->tree_sequence) != 0) { + goto out; + } + if (samples_input != Py_None) { + samples_array = (PyArrayObject *) PyArray_FROMANY( + samples_input, NPY_INT32, 1, 1, NPY_ARRAY_IN_ARRAY); + if (samples_array == NULL) { + goto out; + } + shape = PyArray_DIMS(samples_array); + num_samples = (tsk_size_t) shape[0]; + samples = PyArray_DATA(samples_array); + } + if (py_alleles != Py_None) { + alleles = parse_allele_list(py_alleles); + if (alleles == NULL) { + goto out; + } + } + self->variant = PyMem_Malloc(sizeof(tsk_vargen_t)); + if (self->variant == NULL) { + PyErr_NoMemory(); + goto out; + } + /* Note: the variant currently takes a copy of the samples list. If we wanted + * to avoid this we would INCREF the samples array above and keep a reference + * to in the object struct */ + err = tsk_variant_init(self->variant, self->tree_sequence->tree_sequence, samples, + num_samples, alleles, options); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = 0; +out: + PyMem_Free(alleles); + Py_XDECREF(samples_array); + return ret; +} + +static PyObject * +Variant_decode(Variant *self, PyObject *args, PyObject *kwds) +{ + int err; + PyObject *ret = NULL; + tsk_id_t site_id; + static char *kwlist[] = { "site", NULL }; + + if (Variant_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTupleAndKeywords( + args, kwds, "O&", kwlist, &tsk_id_converter, &site_id)) { + goto out; + } + err = tsk_variant_decode(self->variant, site_id, 0); + if (err != 0) { + handle_library_error(err); + goto out; + } + + ret = Py_BuildValue(""); +out: + return ret; +} + +static PyObject * +Variant_get_site_id(Variant *self, void *closure) +{ + PyObject *ret = NULL; + if (Variant_check_state(self) != 0) { + goto out; + } + ret = Py_BuildValue("n", (Py_ssize_t) self->variant->site.id); +out: + return ret; +} + +static PyObject * +Variant_get_alleles(Variant *self, void *closure) +{ + PyObject *ret = NULL; + + if (Variant_check_state(self) != 0) { + goto out; + } + ret = make_alleles(self->variant); +out: + return ret; +} + +static PyObject * +Variant_get_genotypes(Variant *self, void *closure) +{ + PyObject *ret = NULL; + PyArrayObject *array = NULL; + npy_intp dims; + + if (Variant_check_state(self) != 0) { + goto out; + } + + dims = self->variant->num_samples; + array = (PyArrayObject *) PyArray_SimpleNewFromData( + 1, &dims, NPY_INT32, self->variant->genotypes); + if (array == NULL) { + goto out; + } + PyArray_CLEARFLAGS(array, NPY_ARRAY_WRITEABLE); + if (PyArray_SetBaseObject(array, (PyObject *) self) != 0) { + goto out; + } + /* PyArray_SetBaseObject steals a reference, so we have to incref the variant + * object. This makes sure that the Variant instance will stay alive if there + * are any arrays that refer to its memory. */ + Py_INCREF(self); + ret = (PyObject *) array; + array = NULL; +out: + Py_XDECREF(array); + return ret; +} + +static PyGetSetDef Variant_getsetters[] + = { { .name = "site_id", + .get = (getter) Variant_get_site_id, + .doc = "The site id that the Variant is decoded at" }, + { .name = "alleles", + .get = (getter) Variant_get_alleles, + .doc = "The alleles of the Variant" }, + { .name = "genotypes", + .get = (getter) Variant_get_genotypes, + .doc = "The genotypes of the Variant" }, + { NULL } }; + +static PyMethodDef Variant_methods[] = { + { .ml_name = "decode", + .ml_meth = (PyCFunction) Variant_decode, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Sets the variant's genotypes to those of a given tree and site" }, + { NULL } /* Sentinel */ +}; + +static PyTypeObject VariantType = { + // clang-format off + PyVarObject_HEAD_INIT(NULL, 0) + .tp_name = "_tskit.Variant", + .tp_basicsize = sizeof(Variant), + .tp_dealloc = (destructor) Variant_dealloc, + .tp_flags = Py_TPFLAGS_DEFAULT, + .tp_doc = "Variant objects", + .tp_methods = Variant_methods, + .tp_getset = Variant_getsetters, + .tp_init = (initproc) Variant_init, + .tp_new = PyType_GenericNew, + // clang-format on +}; + /*=================================================================== * LdCalculator *=================================================================== @@ -12276,6 +12497,13 @@ PyInit__tskit(void) Py_INCREF(&VariantGeneratorType); PyModule_AddObject(module, "VariantGenerator", (PyObject *) &VariantGeneratorType); + /* Variant type */ + if (PyType_Ready(&VariantType) < 0) { + return NULL; + } + Py_INCREF(&VariantType); + PyModule_AddObject(module, "Variant", (PyObject *) &VariantType); + /* LdCalculator type */ if (PyType_Ready(&LdCalculatorType) < 0) { return NULL; diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index b84ec22b94..01e7d6904d 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -2333,6 +2333,105 @@ def test_missing_data(self): assert alleles == ("A", None) +class TestVariant(LowLevelTestCase): + """ + Tests for the Variant class. + """ + + def test_uninitialised_tree_sequence(self): + ts = _tskit.TreeSequence() + with pytest.raises(ValueError): + _tskit.Variant(ts) + + def test_constructor(self): + with pytest.raises(TypeError): + _tskit.Variant() + with pytest.raises(TypeError): + _tskit.Variant(None) + ts = self.get_example_tree_sequence() + with pytest.raises(ValueError): + _tskit.Variant(ts, samples={}) + with pytest.raises(TypeError): + _tskit.Variant(ts, isolated_as_missing=None) + with pytest.raises(_tskit.LibraryError): + _tskit.Variant(ts, samples=[-1, 2]) + with pytest.raises(TypeError): + _tskit.Variant(ts, alleles=1234) + + def test_bad_decode(self): + ts = self.get_example_tree_sequence() + variant = _tskit.Variant(ts) + with pytest.raises(tskit.LibraryError, match="Site out of bounds"): + variant.decode(-1) + with pytest.raises(TypeError): + variant.decode("42") + with pytest.raises(TypeError): + variant.decode({}) + with pytest.raises(TypeError): + variant.decode() + + def test_alleles(self): + ts = self.get_example_tree_sequence() + for bad_type in [["a", "b"], "sdf", 234]: + with pytest.raises(TypeError): + _tskit.Variant(ts, samples=[1, 2], alleles=bad_type) + with pytest.raises(ValueError): + _tskit.Variant(ts, samples=[1, 2], alleles=tuple()) + + for bad_allele_type in [None, 0, b"x", []]: + with pytest.raises(TypeError): + _tskit.Variant(ts, samples=[1, 2], alleles=(bad_allele_type,)) + + def test_undecoded(self): + tables = _tskit.TableCollection(1) + tables.build_index() + ts = _tskit.TreeSequence(0) + ts.load_tables(tables) + variant = _tskit.Variant(ts) + assert variant.site_id == tskit.NULL + assert np.array_equal(variant.genotypes, []) + assert variant.alleles == () + + def test_properties_unwritable(self): + ts = self.get_example_tree_sequence() + variant = _tskit.Variant(ts) + with pytest.raises(AttributeError): + variant.site_id = 1 + with pytest.raises(AttributeError): + variant.genotypes = [1] + with pytest.raises(AttributeError): + variant.alleles = "A" + + def test_missing_data(self): + tables = _tskit.TableCollection(1) + tables.nodes.add_row(flags=1, time=0) + tables.nodes.add_row(flags=1, time=0) + tables.sites.add_row(0.1, "A") + tables.build_index() + ts = _tskit.TreeSequence(0) + ts.load_tables(tables) + variant = _tskit.Variant(ts) + variant.decode(0) + assert variant.site_id == 0 + assert np.array_equal(variant.genotypes, [-1, -1]) + assert variant.alleles == ("A", None) + + def test_variants_lifecycle(self): + ts = self.get_example_tree_sequence(random_seed=42) + variant = _tskit.Variant(ts) + variant.decode(0) + genotypes = variant.genotypes + expected = [1, 0, 0, 1, 0, 0, 0, 0, 1, 1] + assert np.array_equal(genotypes, expected) + del variant + assert np.array_equal(genotypes, expected) + variant = _tskit.Variant(ts) + del ts + variant.decode(0) + del variant + assert np.array_equal(genotypes, expected) + + class TestLdCalculator(LowLevelTestCase): """ Tests for the LdCalculator class.