Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions c/tskit/genotypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ tsk_variant_init(tsk_variant_t *self, const tsk_treeseq_t *tree_sequence,

tsk_memset(self, 0, sizeof(tsk_variant_t));

/* Set site id to NULL to indicate the variant is not decoded */
self->site.id = TSK_NULL;

self->tree_sequence = tree_sequence;
ret = tsk_tree_init(
&self->tree, tree_sequence, samples == NULL ? TSK_SAMPLE_LISTS : 0);
Expand Down
228 changes: 228 additions & 0 deletions python/_tskitmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,12 @@ typedef struct {
tsk_vargen_t *variant_generator;
} VariantGenerator;

typedef struct {
PyObject_HEAD
TreeSequence *tree_sequence;
tsk_variant_t *variant;
} Variant;

typedef struct {
PyObject_HEAD
TreeSequence *tree_sequence;
Expand Down Expand Up @@ -11295,6 +11301,221 @@ static PyTypeObject VariantGeneratorType = {
// clang-format on
};

/*===================================================================
* Variant
*===================================================================
*/

static int
Variant_check_state(Variant *self)
{
int ret = 0;
if (self->variant == NULL) {
PyErr_SetString(PyExc_SystemError, "variant not initialised");
ret = -1;
}
return ret;
}

static void
Variant_dealloc(Variant *self)
{
if (self->variant != NULL) {
tsk_variant_free(self->variant);
PyMem_Free(self->variant);
self->variant = NULL;
}
Py_XDECREF(self->tree_sequence);
Py_TYPE(self)->tp_free((PyObject *) self);
}

static int
Variant_init(Variant *self, PyObject *args, PyObject *kwds)
{
int ret = -1;
int err;
static char *kwlist[]
= { "tree_sequence", "samples", "isolated_as_missing", "alleles", NULL };
TreeSequence *tree_sequence = NULL;
PyObject *samples_input = Py_None;
PyObject *py_alleles = Py_None;
PyArrayObject *samples_array = NULL;
tsk_id_t *samples = NULL;
tsk_size_t num_samples = 0;
int isolated_as_missing = 1;
const char **alleles = NULL;
npy_intp *shape;
tsk_flags_t options = 0;

self->variant = NULL;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!|OiO", kwlist, &TreeSequenceType,
&tree_sequence, &samples_input, &isolated_as_missing, &py_alleles)) {
goto out;
}
if (!isolated_as_missing) {
options |= TSK_ISOLATED_NOT_MISSING;
}
/* tsk_variant_t holds a reference to the tree sequence so we must too*/
self->tree_sequence = tree_sequence;
Py_INCREF(self->tree_sequence);
if (TreeSequence_check_state(self->tree_sequence) != 0) {
goto out;
}
if (samples_input != Py_None) {
samples_array = (PyArrayObject *) PyArray_FROMANY(
samples_input, NPY_INT32, 1, 1, NPY_ARRAY_IN_ARRAY);
if (samples_array == NULL) {
goto out;
}
shape = PyArray_DIMS(samples_array);
num_samples = (tsk_size_t) shape[0];
samples = PyArray_DATA(samples_array);
}
if (py_alleles != Py_None) {
alleles = parse_allele_list(py_alleles);
if (alleles == NULL) {
goto out;
}
}
self->variant = PyMem_Malloc(sizeof(tsk_vargen_t));
if (self->variant == NULL) {
PyErr_NoMemory();
goto out;
}
/* Note: the variant currently takes a copy of the samples list. If we wanted
* to avoid this we would INCREF the samples array above and keep a reference
* to in the object struct */
err = tsk_variant_init(self->variant, self->tree_sequence->tree_sequence, samples,
num_samples, alleles, options);
if (err != 0) {
handle_library_error(err);
goto out;
}
ret = 0;
out:
PyMem_Free(alleles);
Py_XDECREF(samples_array);
return ret;
}

static PyObject *
Variant_decode(Variant *self, PyObject *args, PyObject *kwds)
{
int err;
PyObject *ret = NULL;
tsk_id_t site_id;
static char *kwlist[] = { "site", NULL };

if (Variant_check_state(self) != 0) {
goto out;
}
if (!PyArg_ParseTupleAndKeywords(
args, kwds, "O&", kwlist, &tsk_id_converter, &site_id)) {
goto out;
}
err = tsk_variant_decode(self->variant, site_id, 0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should get test coverage on this - won't be covered by vargen as you can't specify the ID.

Easy to do, just pass in a bad site ID.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed with a call to site -1

if (err != 0) {
handle_library_error(err);
goto out;
}

ret = Py_BuildValue("");
out:
return ret;
}

static PyObject *
Variant_get_site_id(Variant *self, void *closure)
{
PyObject *ret = NULL;
if (Variant_check_state(self) != 0) {
goto out;
}
ret = Py_BuildValue("n", (Py_ssize_t) self->variant->site.id);
out:
return ret;
}

static PyObject *
Variant_get_alleles(Variant *self, void *closure)
{
PyObject *ret = NULL;

if (Variant_check_state(self) != 0) {
goto out;
}
ret = make_alleles(self->variant);
out:
return ret;
}

static PyObject *
Variant_get_genotypes(Variant *self, void *closure)
{
PyObject *ret = NULL;
PyArrayObject *array = NULL;
npy_intp dims;

if (Variant_check_state(self) != 0) {
goto out;
}

dims = self->variant->num_samples;
array = (PyArrayObject *) PyArray_SimpleNewFromData(
1, &dims, NPY_INT32, self->variant->genotypes);
if (array == NULL) {
goto out;
}
PyArray_CLEARFLAGS(array, NPY_ARRAY_WRITEABLE);
if (PyArray_SetBaseObject(array, (PyObject *) self) != 0) {
goto out;
}
/* PyArray_SetBaseObject steals a reference, so we have to incref the variant
* object. This makes sure that the Variant instance will stay alive if there
* are any arrays that refer to its memory. */
Py_INCREF(self);
ret = (PyObject *) array;
array = NULL;
out:
Py_XDECREF(array);
return ret;
}

static PyGetSetDef Variant_getsetters[]
= { { .name = "site_id",
.get = (getter) Variant_get_site_id,
.doc = "The site id that the Variant is decoded at" },
{ .name = "alleles",
.get = (getter) Variant_get_alleles,
.doc = "The alleles of the Variant" },
{ .name = "genotypes",
.get = (getter) Variant_get_genotypes,
.doc = "The genotypes of the Variant" },
{ NULL } };

static PyMethodDef Variant_methods[] = {
{ .ml_name = "decode",
.ml_meth = (PyCFunction) Variant_decode,
.ml_flags = METH_VARARGS | METH_KEYWORDS,
.ml_doc = "Sets the variant's genotypes to those of a given tree and site" },
{ NULL } /* Sentinel */
};

static PyTypeObject VariantType = {
// clang-format off
PyVarObject_HEAD_INIT(NULL, 0)
.tp_name = "_tskit.Variant",
.tp_basicsize = sizeof(Variant),
.tp_dealloc = (destructor) Variant_dealloc,
.tp_flags = Py_TPFLAGS_DEFAULT,
.tp_doc = "Variant objects",
.tp_methods = Variant_methods,
.tp_getset = Variant_getsetters,
.tp_init = (initproc) Variant_init,
.tp_new = PyType_GenericNew,
// clang-format on
};

/*===================================================================
* LdCalculator
*===================================================================
Expand Down Expand Up @@ -12276,6 +12497,13 @@ PyInit__tskit(void)
Py_INCREF(&VariantGeneratorType);
PyModule_AddObject(module, "VariantGenerator", (PyObject *) &VariantGeneratorType);

/* Variant type */
if (PyType_Ready(&VariantType) < 0) {
return NULL;
}
Py_INCREF(&VariantType);
PyModule_AddObject(module, "Variant", (PyObject *) &VariantType);

/* LdCalculator type */
if (PyType_Ready(&LdCalculatorType) < 0) {
return NULL;
Expand Down
99 changes: 99 additions & 0 deletions python/tests/test_lowlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2333,6 +2333,105 @@ def test_missing_data(self):
assert alleles == ("A", None)


class TestVariant(LowLevelTestCase):
"""
Tests for the Variant class.
"""

def test_uninitialised_tree_sequence(self):
ts = _tskit.TreeSequence()
with pytest.raises(ValueError):
_tskit.Variant(ts)

def test_constructor(self):
with pytest.raises(TypeError):
_tskit.Variant()
with pytest.raises(TypeError):
_tskit.Variant(None)
ts = self.get_example_tree_sequence()
with pytest.raises(ValueError):
_tskit.Variant(ts, samples={})
with pytest.raises(TypeError):
_tskit.Variant(ts, isolated_as_missing=None)
with pytest.raises(_tskit.LibraryError):
_tskit.Variant(ts, samples=[-1, 2])
with pytest.raises(TypeError):
_tskit.Variant(ts, alleles=1234)

def test_bad_decode(self):
ts = self.get_example_tree_sequence()
variant = _tskit.Variant(ts)
with pytest.raises(tskit.LibraryError, match="Site out of bounds"):
variant.decode(-1)
with pytest.raises(TypeError):
variant.decode("42")
with pytest.raises(TypeError):
variant.decode({})
with pytest.raises(TypeError):
variant.decode()

def test_alleles(self):
ts = self.get_example_tree_sequence()
for bad_type in [["a", "b"], "sdf", 234]:
with pytest.raises(TypeError):
_tskit.Variant(ts, samples=[1, 2], alleles=bad_type)
with pytest.raises(ValueError):
_tskit.Variant(ts, samples=[1, 2], alleles=tuple())

for bad_allele_type in [None, 0, b"x", []]:
with pytest.raises(TypeError):
_tskit.Variant(ts, samples=[1, 2], alleles=(bad_allele_type,))

def test_undecoded(self):
tables = _tskit.TableCollection(1)
tables.build_index()
ts = _tskit.TreeSequence(0)
ts.load_tables(tables)
variant = _tskit.Variant(ts)
assert variant.site_id == tskit.NULL
assert np.array_equal(variant.genotypes, [])
assert variant.alleles == ()

def test_properties_unwritable(self):
ts = self.get_example_tree_sequence()
variant = _tskit.Variant(ts)
with pytest.raises(AttributeError):
variant.site_id = 1
with pytest.raises(AttributeError):
variant.genotypes = [1]
with pytest.raises(AttributeError):
variant.alleles = "A"

def test_missing_data(self):
tables = _tskit.TableCollection(1)
tables.nodes.add_row(flags=1, time=0)
tables.nodes.add_row(flags=1, time=0)
tables.sites.add_row(0.1, "A")
tables.build_index()
ts = _tskit.TreeSequence(0)
ts.load_tables(tables)
variant = _tskit.Variant(ts)
variant.decode(0)
assert variant.site_id == 0
assert np.array_equal(variant.genotypes, [-1, -1])
assert variant.alleles == ("A", None)

def test_variants_lifecycle(self):
ts = self.get_example_tree_sequence(random_seed=42)
variant = _tskit.Variant(ts)
variant.decode(0)
genotypes = variant.genotypes
expected = [1, 0, 0, 1, 0, 0, 0, 0, 1, 1]
assert np.array_equal(genotypes, expected)
del variant
assert np.array_equal(genotypes, expected)
variant = _tskit.Variant(ts)
del ts
variant.decode(0)
del variant
assert np.array_equal(genotypes, expected)


class TestLdCalculator(LowLevelTestCase):
"""
Tests for the LdCalculator class.
Expand Down