diff --git a/c/tests/test_stats.c b/c/tests/test_stats.c index 313b0e6fad..6af356b948 100644 --- a/c/tests/test_stats.c +++ b/c/tests/test_stats.c @@ -869,7 +869,8 @@ verify_one_way_stat_func_errors(tsk_treeseq_t *ts, one_way_sample_stat_method *m } static void -verify_two_way_stat_func_errors(tsk_treeseq_t *ts, general_sample_stat_method *method) +verify_two_way_stat_func_errors( + tsk_treeseq_t *ts, general_sample_stat_method *method, tsk_flags_t options) { int ret; tsk_id_t samples[] = { 0, 1, 2, 3 }; @@ -878,30 +879,30 @@ verify_two_way_stat_func_errors(tsk_treeseq_t *ts, general_sample_stat_method *m double result; ret = method(ts, 0, sample_set_sizes, samples, 1, set_indexes, 0, NULL, - TSK_STAT_SITE, &result); + options | TSK_STAT_SITE, &result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_INSUFFICIENT_SAMPLE_SETS); ret = method(ts, 1, sample_set_sizes, samples, 1, set_indexes, 0, NULL, - TSK_STAT_SITE, &result); + options | TSK_STAT_SITE, &result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_INSUFFICIENT_SAMPLE_SETS); ret = method(ts, 2, sample_set_sizes, samples, 0, set_indexes, 0, NULL, - TSK_STAT_SITE, &result); + options | TSK_STAT_SITE, &result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_INSUFFICIENT_INDEX_TUPLES); set_indexes[0] = -1; ret = method(ts, 2, sample_set_sizes, samples, 1, set_indexes, 0, NULL, - TSK_STAT_SITE, &result); + options | TSK_STAT_SITE, &result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_SAMPLE_SET_INDEX); set_indexes[0] = 0; set_indexes[1] = 2; ret = method(ts, 2, sample_set_sizes, samples, 1, set_indexes, 0, NULL, - TSK_STAT_SITE, &result); + options | TSK_STAT_SITE, &result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_SAMPLE_SET_INDEX); } static void verify_two_way_weighted_stat_func_errors( - tsk_treeseq_t *ts, two_way_weighted_method *method) + tsk_treeseq_t *ts, two_way_weighted_method *method, tsk_flags_t options) { int ret; tsk_id_t indexes[] = { 0, 0, 0, 1 }; @@ -911,13 +912,17 @@ verify_two_way_weighted_stat_func_errors( memset(weights, 0, sizeof(weights)); - ret = method(ts, 2, weights, 2, indexes, 0, NULL, result, 0); + ret = method(ts, 2, weights, 2, indexes, 0, NULL, result, options); CU_ASSERT_EQUAL_FATAL(ret, 0); - ret = method(ts, 0, weights, 2, indexes, 0, NULL, result, 0); + ret = method(ts, 2, weights, 2, indexes, 0, NULL, result, + options | TSK_STAT_SITE | TSK_STAT_NODE); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_MULTIPLE_STAT_MODES); + + ret = method(ts, 0, weights, 2, indexes, 0, NULL, result, options); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_INSUFFICIENT_WEIGHTS); - ret = method(ts, 2, weights, 2, indexes, 1, bad_windows, result, 0); + ret = method(ts, 2, weights, 2, indexes, 1, bad_windows, result, options); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); } @@ -1866,7 +1871,7 @@ test_paper_ex_divergence_errors(void) tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, paper_ex_mutations, paper_ex_individuals, NULL, 0); - verify_two_way_stat_func_errors(&ts, tsk_treeseq_divergence); + verify_two_way_stat_func_errors(&ts, tsk_treeseq_divergence, 0); tsk_treeseq_free(&ts); } @@ -1914,6 +1919,11 @@ test_paper_ex_genetic_relatedness(void) ret = tsk_treeseq_genetic_relatedness(&ts, 2, sample_set_sizes, samples, 1, set_indexes, 0, NULL, TSK_STAT_SITE, &result); CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_treeseq_genetic_relatedness(&ts, 2, sample_set_sizes, samples, 1, + set_indexes, 0, NULL, TSK_STAT_SITE | TSK_STAT_NONCENTRED, &result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_treeseq_free(&ts); } @@ -1924,7 +1934,11 @@ test_paper_ex_genetic_relatedness_errors(void) tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, paper_ex_mutations, paper_ex_individuals, NULL, 0); - verify_two_way_stat_func_errors(&ts, tsk_treeseq_genetic_relatedness); + verify_two_way_stat_func_errors(&ts, tsk_treeseq_genetic_relatedness, 0); + verify_two_way_stat_func_errors( + &ts, tsk_treeseq_genetic_relatedness, TSK_STAT_NONCENTRED); + verify_two_way_stat_func_errors( + &ts, tsk_treeseq_genetic_relatedness, TSK_STAT_POLARISED); tsk_treeseq_free(&ts); } @@ -1951,6 +1965,15 @@ test_paper_ex_genetic_relatedness_weighted(void) ret = tsk_treeseq_genetic_relatedness_weighted( &ts, num_weights, weights, 2, indexes, 0, NULL, result, TSK_STAT_NODE); CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_treeseq_genetic_relatedness_weighted(&ts, num_weights, weights, 2, + indexes, 0, NULL, result, TSK_STAT_SITE | TSK_STAT_NONCENTRED); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_treeseq_genetic_relatedness_weighted(&ts, num_weights, weights, 2, + indexes, 0, NULL, result, TSK_STAT_BRANCH | TSK_STAT_NONCENTRED); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_treeseq_genetic_relatedness_weighted(&ts, num_weights, weights, 2, + indexes, 0, NULL, result, TSK_STAT_NODE | TSK_STAT_NONCENTRED); + CU_ASSERT_EQUAL_FATAL(ret, 0); } tsk_treeseq_free(&ts); @@ -1964,7 +1987,11 @@ test_paper_ex_genetic_relatedness_weighted_errors(void) tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, paper_ex_mutations, paper_ex_individuals, NULL, 0); verify_two_way_weighted_stat_func_errors( - &ts, tsk_treeseq_genetic_relatedness_weighted); + &ts, tsk_treeseq_genetic_relatedness_weighted, 0); + verify_two_way_weighted_stat_func_errors( + &ts, tsk_treeseq_genetic_relatedness_weighted, TSK_STAT_NONCENTRED); + verify_two_way_weighted_stat_func_errors( + &ts, tsk_treeseq_genetic_relatedness_weighted, TSK_STAT_POLARISED); tsk_treeseq_free(&ts); } @@ -1975,7 +2002,7 @@ test_paper_ex_Y2_errors(void) tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, paper_ex_mutations, paper_ex_individuals, NULL, 0); - verify_two_way_stat_func_errors(&ts, tsk_treeseq_Y2); + verify_two_way_stat_func_errors(&ts, tsk_treeseq_Y2, 0); tsk_treeseq_free(&ts); } @@ -2013,7 +2040,7 @@ test_paper_ex_f2_errors(void) tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, paper_ex_mutations, paper_ex_individuals, NULL, 0); - verify_two_way_stat_func_errors(&ts, tsk_treeseq_f2); + verify_two_way_stat_func_errors(&ts, tsk_treeseq_f2, 0); tsk_treeseq_free(&ts); } diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 044d4a488d..f989a73146 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -1297,6 +1297,8 @@ tsk_treeseq_branch_general_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, double *state = tsk_calloc(num_nodes * state_dim, sizeof(*state)); double *summary = tsk_calloc(num_nodes * result_dim, sizeof(*summary)); double *running_sum = tsk_calloc(result_dim, sizeof(*running_sum)); + double *zero_state = tsk_calloc(state_dim, sizeof(*zero_state)); + double *zero_summary = tsk_calloc(result_dim, sizeof(*zero_state)); if (self->time_uncalibrated && !(options & TSK_STAT_ALLOW_TIME_UNCALIBRATED)) { ret = TSK_ERR_TIME_UNCALIBRATED; @@ -1304,12 +1306,21 @@ tsk_treeseq_branch_general_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, } if (parent == NULL || branch_length == NULL || state == NULL || running_sum == NULL - || summary == NULL) { + || summary == NULL || zero_state == NULL || zero_summary == NULL) { ret = TSK_ERR_NO_MEMORY; goto out; } tsk_memset(parent, 0xff, num_nodes * sizeof(*parent)); + /* If f is not strict, we may need to set conditions for non-sample nodes as well. */ + ret = f(state_dim, zero_state, result_dim, zero_summary, f_params); + if (ret != 0) { + goto out; + } + for (j = 0; j < num_nodes; j++) { // we could skip this if zero_summary is zero + summary_u = GET_2D_ROW(summary, result_dim, j); + tsk_memcpy(summary_u, zero_summary, result_dim * sizeof(*zero_summary)); + } /* Set the initial conditions */ for (j = 0; j < self->num_samples; j++) { u = self->samples[j]; @@ -1322,6 +1333,7 @@ tsk_treeseq_branch_general_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, goto out; } } + tsk_memset(result, 0, num_windows * result_dim * sizeof(*result)); /* Iterate over the trees */ @@ -1425,6 +1437,8 @@ tsk_treeseq_branch_general_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, tsk_safe_free(state); tsk_safe_free(summary); tsk_safe_free(running_sum); + tsk_safe_free(zero_state); + tsk_safe_free(zero_summary); return ret; } @@ -2072,6 +2086,7 @@ typedef struct { } sample_count_stat_params_t; typedef struct { + tsk_size_t num_samples; double *total_weights; const tsk_id_t *index_tuples; } indexed_weight_stat_params_t; @@ -4542,21 +4557,39 @@ genetic_relatedness_summary_func(tsk_size_t state_dim, const double *state, tsk_id_t i, j; tsk_size_t k; double sumx = 0; - double sumn = 0; double meanx, ni, nj; for (k = 0; k < state_dim; k++) { - sumx += x[k]; - sumn += (double) args.sample_set_sizes[k]; + sumx += x[k] / (double) args.sample_set_sizes[k]; } - meanx = sumx / sumn; + meanx = sumx / (double) state_dim; for (k = 0; k < result_dim; k++) { i = args.set_indexes[2 * k]; j = args.set_indexes[2 * k + 1]; ni = (double) args.sample_set_sizes[i]; nj = (double) args.sample_set_sizes[j]; - result[k] = (x[i] - ni * meanx) * (x[j] - nj * meanx) / 2; + result[k] = (x[i] / ni - meanx) * (x[j] / nj - meanx); + } + return 0; +} + +static int +genetic_relatedness_noncentred_summary_func(tsk_size_t TSK_UNUSED(state_dim), + const double *state, tsk_size_t result_dim, double *result, void *params) +{ + sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + const double *x = state; + tsk_id_t i, j; + tsk_size_t k; + double ni, nj; + + for (k = 0; k < result_dim; k++) { + i = args.set_indexes[2 * k]; + j = args.set_indexes[2 * k + 1]; + ni = (double) args.sample_set_sizes[i]; + nj = (double) args.sample_set_sizes[j]; + result[k] = x[i] * x[j] / (ni * nj); } return 0; } @@ -4572,9 +4605,16 @@ tsk_treeseq_genetic_relatedness(const tsk_treeseq_t *self, tsk_size_t num_sample if (ret != 0) { goto out; } - ret = tsk_treeseq_sample_count_stat(self, num_sample_sets, sample_set_sizes, - sample_sets, num_index_tuples, index_tuples, genetic_relatedness_summary_func, - num_windows, windows, options, result); + if (!(options & TSK_STAT_NONCENTRED)) { + ret = tsk_treeseq_sample_count_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_index_tuples, index_tuples, + genetic_relatedness_summary_func, num_windows, windows, options, result); + } else { + ret = tsk_treeseq_sample_count_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_index_tuples, index_tuples, + genetic_relatedness_noncentred_summary_func, num_windows, windows, options, + result); + } out: return ret; } @@ -4587,15 +4627,32 @@ genetic_relatedness_weighted_summary_func(tsk_size_t state_dim, const double *st const double *x = state; tsk_id_t i, j; tsk_size_t k; - double meanx, ni, nj; + double pn, ni, nj; - meanx = state[state_dim - 1] / args.total_weights[state_dim - 1]; + pn = state[state_dim - 1]; for (k = 0; k < result_dim; k++) { i = args.index_tuples[2 * k]; j = args.index_tuples[2 * k + 1]; ni = args.total_weights[i]; nj = args.total_weights[j]; - result[k] = (x[i] - ni * meanx) * (x[j] - nj * meanx) / 2; + result[k] = (x[i] - ni * pn) * (x[j] - nj * pn); + } + return 0; +} + +static int +genetic_relatedness_weighted_noncentred_summary_func(tsk_size_t TSK_UNUSED(state_dim), + const double *state, tsk_size_t result_dim, double *result, void *params) +{ + indexed_weight_stat_params_t args = *(indexed_weight_stat_params_t *) params; + const double *x = state; + tsk_id_t i, j; + tsk_size_t k; + + for (k = 0; k < result_dim; k++) { + i = args.index_tuples[2 * k]; + j = args.index_tuples[2 * k + 1]; + result[k] = x[i] * x[j]; } return 0; } @@ -4633,17 +4690,26 @@ tsk_treeseq_genetic_relatedness_weighted(const tsk_treeseq_t *self, new_row[k] = row[k]; total_weights[k] += row[k]; } - new_row[num_weights] = 1.0; + new_row[num_weights] = 1.0 / (double) num_samples; } - total_weights[num_weights] = (double) num_samples; + total_weights[num_weights] = 1.0; args.total_weights = total_weights; args.index_tuples = index_tuples; - ret = tsk_treeseq_general_stat(self, num_weights + 1, new_weights, num_index_tuples, - genetic_relatedness_weighted_summary_func, &args, num_windows, windows, options, - result); - if (ret != 0) { - goto out; + if (!(options & TSK_STAT_NONCENTRED)) { + ret = tsk_treeseq_general_stat(self, num_weights + 1, new_weights, + num_index_tuples, genetic_relatedness_weighted_summary_func, &args, + num_windows, windows, options, result); + if (ret != 0) { + goto out; + } + } else { + ret = tsk_treeseq_general_stat(self, num_weights + 1, new_weights, + num_index_tuples, genetic_relatedness_weighted_noncentred_summary_func, + &args, num_windows, windows, options, result); + if (ret != 0) { + goto out; + } } out: diff --git a/c/tskit/trees.h b/c/tskit/trees.h index 4bb2b58ac2..667848415b 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -53,6 +53,7 @@ extern "C" { #define TSK_STAT_SPAN_NORMALISE (1 << 11) #define TSK_STAT_ALLOW_TIME_UNCALIBRATED (1 << 12) #define TSK_STAT_PAIR_NORMALISE (1 << 13) +#define TSK_STAT_NONCENTRED (1 << 14) /* Options for map_mutations */ #define TSK_MM_FIXED_ANCESTRAL_STATE (1 << 0) diff --git a/docs/data-model.md b/docs/data-model.md index 1951a3892e..406177f8ca 100644 --- a/docs/data-model.md +++ b/docs/data-model.md @@ -1095,3 +1095,39 @@ See the {meth}`TreeSequence.variants` method and {class}`Variant` class for more information on how missing data is represented in variant data. +(sec_gotchas)= + +## Possibly surprising consequences of the data model + +This is a section of miscellaneous issues that might trip even an experienced user up, +also known as "gotchas". +The current examples are quite uncommon, so can be ignored for most purposes, +but the list may be expanded in the future. + +### Unrelated material + +Usually, all parts of a tree sequence are ancestral to at least one sample, +since that's essentially the definition of a sample: the genomes that +we're describing the ancestry of. +However, in some cases there will be portions of the tree sequence from which +no samples inherit - notably, the result of a forwards simulation that has +not been simplified. +In fact, if the simulation has not coalesced, +one can have entire portions of some marginal tree that are +unrelated to any of the samples +(for instance, an individual in the initial generation of the simulation +that had no offspring). +This can lead to a gotcha: +the *roots* of a tree are defined to be only those roots *reachable from the samples* +(and, furthermore, reachable from at least `root_threshold` samples; +see {meth}`TreeSequence.trees`). +So, our unlucky ancestor would not appear in the list of `roots`, even though +if we drew all the relationships provided by the tree sequence, +they'd definitely be a root. +Furthermore, only nodes *reachable from a root* are included in the +{meth}`Tree.nodes`. So, if you iterate over all the nodes in each marginal tree, +you won't see those parts of the tree sequence that are unrelated to the samples. +If you need to get those, too, you could either +work with the {meth}`TreeSequence.edge_diffs` directly, +or iterate over all nodes (instead of over {meth}`Tree.nodes`). + diff --git a/docs/stats.md b/docs/stats.md index 064cbb6c39..33b41ae20a 100644 --- a/docs/stats.md +++ b/docs/stats.md @@ -556,6 +556,26 @@ associated with each allele; but if polarised, then the ancestral allele is left For branch or node statistics, summary functions are applied to the total weight or number of samples below, and above each branch or node; if polarised, then only the weight below is used. +(sec_stats_strictness)= + +### Strictness, and which branches count? + +Most statistics are not affected by invariant sites, +and hence do not depend on any part of the tree that is not ancestral to any of the sample sets. +However, some statistics are different: for instance, +given a pair of samples, {meth}`TreeSequence.genetic_relatedness` +with `centre=False` (and `polarised=True`, the default for that method) +adds up the total number of alleles (or total area of branches) that is +either ancestral to both samples *or ancestral to neither*. +So, it depends on what else is in the tree sequence. +(For this reason, we don't recommend actually *using* this combination of options for genetic +relatedness.) + +In terms of the summary function {math}`f(x)`, "not affected by invariant sites" translates to +{math}`f(0) = f(n) = 0`, where {math}`n` is the vector of sample set sizes. +By default, {meth}`TreeSequence.general_stat` checks if the summary function satisfies this condition, +and throws an error if not; this check can be disabled by setting `strict=False`. + (sec_stats_summary_functions)= @@ -585,21 +605,34 @@ and boolean expressions (e.g., {math}`(x > 0)`) are interpreted as 0/1. unless the two indices are the same, when the diversity function is used. - For an unpolarized statistic with biallelic loci, this calculates + For an unpolarised statistic with biallelic loci, this calculates {math}`p_1 (1-p_2) + (1 - p_1) p_2`. -`genetic_relatedness` -: {math}`f(x_i, x_j) = \frac{1}{2}(x_i - m)(x_j - m)`, +`genetic_relatedness, centre=True` +: {math}`f(x_i, x_j) = (x_i / n_i - m)(x_j / n_j - m)`, where {math}`m = \frac{1}{n}\sum_{k=1}^n x_k` with {math}`n` the total number - of samples. + of sample sets. + For a polarised statistic (the default) with biallelic loci, this calculates + {math}`(p_1 - \bar{p}) (p_2 - \bar{p})`, where {math}`\bar{p}` is the average + derived allele frequency across sample sets. + +`genetic_relatedness, centre=False` +: {math}`f(x_i, x_j) = (x_i / n_i) (x_j / n_j)`. + + For an polarised statistic (the default) with biallelic loci, this calculates + {math}`p_1 p_2`. + +`genetic_relatedness_weighted, centre=True` +: {math}`f(w_i, w_j, x_i, x_j) = (x_i - w_i p) (x_j - w_j p)`, + + where {math}`p` is the proportion of all samples below the focal node, + and {math}`w_j = \sum_{k=1}^n W_{kj}` is the sum of the weights in the {math}`j`th column of the weight matrix. + +`genetic_relatedness_weighted, centre=False` +: {math}`f(w_i, w_j, x_i, x_j) = x_i x_j`. -`genetic_relatedness_weighted` -: {math}`f(w_i, w_j, x_i, x_j) = \frac{1}{2}(x_i - w_i m) (x_j - w_j m)`, - where {math}`m = \frac{1}{n}\sum_{k=1}^n x_k` with {math}`n` the total number - of samples, and {math}`w_j = \sum_{k=1}^n W_kj` is the sum of the weights in the {math}`j`th column of the weight matrix. - `Y2` : {math}`f(x_1, x_2) = \frac{x_1 (n_2 - x_2) (n_2 - x_2 - 1)}{n_1 n_2 (n_2 - 1)}` @@ -772,4 +805,3 @@ do not have this property (since both are ratios of statistics that do have this The {meth}`~TreeSequence.genealogical_nearest_neighbours` statistic is not based on branch lengths, but on topologies. therefore it currently has a slightly different interface to the other single site statistics. This may be revised in the future. - diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index e715182163..f3bc0e1629 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -2,6 +2,28 @@ [0.5.8] - 2024-XX-XX -------------------- +**Breaking Changes** + +- The definition of ``TreeSequence.genetic_relatedness`` and + ``TreeSequence.genetic_relatedness_weighted`` are changed + to *average* over sample sets, rather than summing over them. + For computation with diploid sample sets, this will change the result + by a factor of four; for larger sample sets it will now produce + sensible values that are comparable between sample sets of different sizes. + The default for these methods is also changed to ``polarised=True``, + but the output is unchanged for ``centre=True`` (the default). + See the documentation for these methods for more discussion. + (:user:`petrelharp`, :user:`mmosmond`, :pr:`1623`) + +**Bugfixes** + +- Fix to ``TreeSequence.genetic_relatedness`` with ``indexes=None`` and + ``proportion=True``. (:user:`petrelharp`, :issue:`2984`, :pr:`1623`) + +- Fix to ``TreeSequence.general_stat`` when using non-strict summary functions + in the presence of non-ancestral material (very rare). + (:user:`petrelharp`, :issue:`2983`, :pr:`1623`) + **Features** - Add ``TreeSequence.extend_edges`` method that extends ancestral haplotypes @@ -11,6 +33,8 @@ - Add ``Table.drop_metadata`` to make clearing metadata from tables easy. (:user:`jeromekelleher`, :pr:`2944`) +- Add the ``centre`` option to ``TreeSequence.genetic_relatedness`` and + ``TreeSequence.genetic_relatedness_weighted``. -------------------- [0.5.7] - 2024-06-17 diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index f0474e7063..a37d8b0160 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -9560,7 +9560,7 @@ TreeSequence_k_way_stat_method(TreeSequence *self, PyObject *args, PyObject *kwd { PyObject *ret = NULL; static char *kwlist[] = { "sample_set_sizes", "sample_sets", "indexes", "windows", - "mode", "span_normalise", "polarised", NULL }; + "mode", "span_normalise", "polarised", "centre", NULL }; PyObject *sample_set_sizes = NULL; PyObject *sample_sets = NULL; PyObject *indexes = NULL; @@ -9576,13 +9576,15 @@ TreeSequence_k_way_stat_method(TreeSequence *self, PyObject *args, PyObject *kwd char *mode = NULL; int span_normalise = true; int polarised = false; + int centre = true; int err; if (TreeSequence_check_state(self) != 0) { goto out; } - if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOO|sii", kwlist, &sample_set_sizes, - &sample_sets, &indexes, &windows, &mode, &span_normalise, &polarised)) { + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOO|siii", kwlist, &sample_set_sizes, + &sample_sets, &indexes, &windows, &mode, &span_normalise, &polarised, + ¢re)) { goto out; } if (parse_stats_mode(mode, &options) != 0) { @@ -9594,6 +9596,10 @@ TreeSequence_k_way_stat_method(TreeSequence *self, PyObject *args, PyObject *kwd if (polarised) { options |= TSK_STAT_POLARISED; } + if (!centre) { + // only currently used by genetic_relatedness + options |= TSK_STAT_NONCENTRED; + } if (parse_sample_sets(sample_set_sizes, &sample_set_sizes_array, sample_sets, &sample_sets_array, &num_sample_sets) != 0) { @@ -9646,7 +9652,7 @@ TreeSequence_k_way_weighted_stat_method(TreeSequence *self, PyObject *args, { PyObject *ret = NULL; static char *kwlist[] = { "weights", "indexes", "windows", "mode", "span_normalise", - "polarised", NULL }; + "polarised", "centre", NULL }; PyObject *weights = NULL; PyObject *indexes = NULL; PyObject *windows = NULL; @@ -9660,13 +9666,14 @@ TreeSequence_k_way_weighted_stat_method(TreeSequence *self, PyObject *args, char *mode = NULL; int span_normalise = true; int polarised = false; + int centre = true; int err; if (TreeSequence_check_state(self) != 0) { goto out; } - if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|sii", kwlist, &weights, &indexes, - &windows, &mode, &span_normalise, &polarised)) { + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|siii", kwlist, &weights, &indexes, + &windows, &mode, &span_normalise, &polarised, ¢re)) { goto out; } if (parse_stats_mode(mode, &options) != 0) { @@ -9678,6 +9685,10 @@ TreeSequence_k_way_weighted_stat_method(TreeSequence *self, PyObject *args, if (polarised) { options |= TSK_STAT_POLARISED; } + if (!centre) { + // only currently used by genetic_relatedness_weighted + options |= TSK_STAT_NONCENTRED; + } if (parse_windows(windows, &windows_array, &num_windows) != 0) { goto out; } diff --git a/python/tests/test_divmat.py b/python/tests/test_divmat.py index ea83cc560d..ac66f43f1e 100644 --- a/python/tests/test_divmat.py +++ b/python/tests/test_divmat.py @@ -1348,6 +1348,16 @@ def test_bad_arg_types(self, arg): class TestGeneticRelatednessMatrix: def check(self, ts, mode, *, sample_sets=None, windows=None, span_normalise=True): + # These are *only* expected to be the same + # under infinite-sites mutations + if mode == "site" and np.any([len(s.mutations) > 1 for s in ts.sites()]): + ts = msprime.sim_mutations( + ts, + rate=100 / ts.segregating_sites(mode="branch", span_normalise=False), + random_seed=123, + discrete_genome=False, + keep=False, + ) G1 = stats_api_genetic_relatedness_matrix( ts, mode=mode, @@ -1385,8 +1395,7 @@ def test_single_tree_sample_sets(self, mode): # 0 1 ts = tskit.Tree.generate_balanced(4).tree_sequence ts = tsutil.insert_branch_sites(ts) - with pytest.raises(ValueError, match="2888"): - self.check(ts, mode, sample_sets=[[0, 1], [2, 3]]) + self.check(ts, mode, sample_sets=[[0, 1], [2, 3]]) @pytest.mark.parametrize("mode", DIVMAT_MODES) def test_single_tree_single_samples(self, mode): @@ -1425,7 +1434,6 @@ def test_suite_defaults(self, ts, mode): def test_suite_span_normalise(self, ts, mode, span_normalise): self.check(ts, mode=mode, span_normalise=span_normalise) - @pytest.mark.skip("fix sample sets #2888") @pytest.mark.parametrize("ts", get_example_tree_sequences()) @pytest.mark.parametrize("mode", DIVMAT_MODES) @pytest.mark.parametrize("num_sets", [2]) # [[2, 3, 4, 5]) diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index 48eabc133c..908c0dcb1d 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -2602,6 +2602,24 @@ def get_method(self): ts = self.get_example_tree_sequence() return ts, ts.genetic_relatedness + def test_options(self): + ts, _, params = self.get_example() + x = ts.genetic_relatedness(**params) + new_params = params.copy() + new_params["centre"] = False + y = ts.genetic_relatedness(**new_params) + assert x.shape == y.shape + new_params["polarised"] = False + y = ts.genetic_relatedness(**new_params) + assert x.shape == y.shape + del new_params["centre"] + y = ts.genetic_relatedness(**new_params) + assert x.shape == y.shape + + del new_params["indexes"] + with pytest.raises(ValueError, match="object of too small depth"): + ts.genetic_relatedness(**new_params, indexes="foo") + class TestY3(LowLevelTestCase, ThreeWaySampleStatsMixin): def get_method(self): @@ -2626,6 +2644,27 @@ def get_method(self): ts = self.get_example_tree_sequence() return ts, ts.genetic_relatedness_weighted + def test_options(self): + ts, _, params = self.get_example() + x = ts.genetic_relatedness_weighted(**params) + + new_params = params.copy() + new_params["centre"] = False + y = ts.genetic_relatedness_weighted(**new_params) + assert x.shape == y.shape + new_params["polarised"] = False + y = ts.genetic_relatedness_weighted(**new_params) + assert x.shape == y.shape + del new_params["centre"] + y = ts.genetic_relatedness_weighted(**new_params) + assert x.shape == y.shape + + del new_params["weights"] + with pytest.raises(ValueError, match="First dimension"): + ts.genetic_relatedness_weighted( + **new_params, weights=np.ones((ts.get_num_samples() + 2, 1)) + ) + class TestGeneralStatsInterface(LowLevelTestCase, StatsInterfaceMixin): """ diff --git a/python/tests/test_tree_stats.py b/python/tests/test_tree_stats.py index 4b608b5d75..2a3b9041f8 100644 --- a/python/tests/test_tree_stats.py +++ b/python/tests/test_tree_stats.py @@ -67,14 +67,14 @@ def subset_combos(*args, p=0.5, min_tests=3): # of them, using this function, below. If we don't set a seed, a different # random set is run each time. Ensures that at least min_tests are run. # Uncomment this line to run all tests (takes about an hour): - # p = 1.0 + p = 1.0 num_tests = 0 skipped_tests = [] # total_tests = 0 for x in itertools.product(*args): # total_tests = total_tests + 1 if np.random.uniform() < p: - num_tests += num_tests + 1 + num_tests += 1 yield x elif len(skipped_tests) < min_tests: skipped_tests.append(x) @@ -82,7 +82,7 @@ def subset_combos(*args, p=0.5, min_tests=3): skipped_tests[np.random.randint(min_tests)] = x while num_tests < min_tests: yield skipped_tests.pop() - num_tests = num_tests + 1 + num_tests += 1 # print("tests", num_tests) assert num_tests >= min_tests @@ -147,6 +147,8 @@ def windowed_tree_stat(ts, stat, windows, span_normalise=True): def naive_branch_general_stat( ts, w, f, windows=None, polarised=False, span_normalise=True ): + # NOTE: does not behave correctly for unpolarised stats + # with non-ancestral material. if windows is None: windows = [0.0, ts.sequence_length] n, k = w.shape @@ -211,7 +213,7 @@ def polarised_summary(u): s += summary_func(total_weight - state[u]) return s - for u in ts.samples(): + for u in range(ts.num_nodes): summary[u] = polarised_summary(u) window_index = 0 @@ -1794,14 +1796,19 @@ class TestSiteDivergence(TestDivergence, MutatedTopologyExamplesMixin): def site_genetic_relatedness( - ts, sample_sets, indexes, windows=None, span_normalise=True, proportion=True + ts, + sample_sets, + indexes, + windows=None, + span_normalise=True, + polarised=True, + proportion=True, + centre=True, ): + if windows is None: + windows = [0.0, ts.sequence_length] out = np.zeros((len(windows) - 1, len(indexes))) - samples = [u for u in ts.samples()] - all_samples = list({u for s in sample_sets for u in s}) - sample_ind = [samples.index(x) for x in all_samples] - haps = ts.genotype_matrix(isolated_as_missing=False).T - haps = haps[sample_ind] + all_samples = np.array(list({u for s in sample_sets for u in s})) denom = np.ones(len(windows)) if proportion: denom = ts.segregating_sites( @@ -1810,42 +1817,50 @@ def site_genetic_relatedness( mode="site", span_normalise=span_normalise, ) - alleles = np.unique(haps) for j in range(len(windows) - 1): begin = windows[j] end = windows[j + 1] - site_positions = [x.position for x in ts.sites()] - for i, (ix, iy) in enumerate(indexes): - X = sample_sets[ix] - Y = sample_sets[iy] - S = 0 + for vv in zip( + *[ + ts.variants(left=begin, right=end, samples=x, isolated_as_missing=False) + for x in sample_sets + ] + ): + ancestral_state = vv[0].site.ancestral_state + alleles = vv[0].alleles + ff = [v.frequencies() for v in vv] for a in alleles: - this_haps = haps == a - haps_mean = this_haps.mean(axis=0) - haps_centered = this_haps - haps_mean - for k in range(ts.num_sites): - if (site_positions[k] >= begin) and (site_positions[k] < end): - for x in X: - x_index = np.where(all_samples == x)[0][0] - for y in Y: - y_index = np.where(all_samples == y)[0][0] - S += ( - haps_centered[x_index][k] - * haps_centered[y_index][k] - / 2 - ) + mean_f = sum([f[a] for f in ff]) / len(ff) + for i, (ix, iy) in enumerate(indexes): + fx = ff[ix][a] + fy = ff[iy][a] + if not (polarised and a == ancestral_state): + if centre: + out[j][i] += (fx - mean_f) * (fy - mean_f) + else: + out[j][i] += fx * fy + for i in range(len(indexes)): with np.errstate(invalid="ignore", divide="ignore"): - out[j][i] = S / denom[j] + out[j][i] /= denom[j] if span_normalise: out[j][i] /= end - begin return out def branch_genetic_relatedness( - ts, sample_sets, indexes, windows=None, span_normalise=True, proportion=True + ts, + sample_sets, + indexes, + windows=None, + span_normalise=True, + polarised=True, + proportion=True, + centre=True, ): + if windows is None: + windows = [0.0, ts.sequence_length] out = np.zeros((len(windows) - 1, len(indexes))) - all_samples = list({u for s in sample_sets for u in s}) + all_samples = np.array(list({u for s in sample_sets for u in s})) denom = np.ones(len(windows)) if proportion: denom = ts.segregating_sites( @@ -1862,26 +1877,30 @@ def branch_genetic_relatedness( continue if tr.interval.left >= end: break - branches = [(c, tr.parent(c)) for c in tr.nodes()] span = min(end, tr.interval.right) - max(begin, tr.interval.left) - for B in branches: - v = B[0] + # iterating over tr.nodes will miss nodes unreachable from samples + for v in range(ts.num_nodes): area = tr.branch_length(v) * span - haps = np.zeros(len(all_samples)) - for x, u in enumerate(all_samples): - haps[x] = int(tr.is_descendant(u, v)) - haps_mean = haps.mean() - haps_centered = haps - haps_mean + freqs = [ + sum([tr.is_descendant(u, v) for u in x]) / len(x) + for x in sample_sets + ] + mean_freq = sum(freqs) / len(freqs) for i, (ix, iy) in enumerate(indexes): - X = sample_sets[ix] - Y = sample_sets[iy] - for x in X: - x_index = np.where(all_samples == x)[0][0] - for y in Y: - y_index = np.where(all_samples == y)[0][0] + fx = freqs[ix] + fy = freqs[iy] + if centre: + out[j][i] += area * (fx - mean_freq) * (fy - mean_freq) + if not polarised: out[j][i] += ( - area * haps_centered[x_index] * haps_centered[y_index] + area + * (1 - fx - (1 - mean_freq)) + * (1 - fy - (1 - mean_freq)) ) + else: + out[j][i] += area * fx * fy + if not polarised: + out[j][i] += area * (1 - fx) * (1 - fy) for i in range(len(indexes)): with np.errstate(invalid="ignore", divide="ignore"): out[j][i] /= denom[j] @@ -1891,10 +1910,19 @@ def branch_genetic_relatedness( def node_genetic_relatedness( - ts, sample_sets, indexes, windows=None, span_normalise=True, proportion=True + ts, + sample_sets, + indexes, + windows=None, + span_normalise=True, + proportion=True, + centre=True, + polarised=True, ): + if windows is None: + windows = [0.0, ts.sequence_length] out = np.zeros((len(windows) - 1, ts.num_nodes, len(indexes))) - all_samples = list({u for s in sample_sets for u in s}) + all_samples = np.array(list({u for s in sample_sets for u in s})) denom = np.ones((len(windows), ts.num_nodes)) if proportion: denom = ts.segregating_sites( @@ -1907,27 +1935,32 @@ def node_genetic_relatedness( begin = windows[j] end = windows[j + 1] for tr in ts.trees(): - span = min(end, tr.interval.right) - max(begin, tr.interval.left) if tr.interval.right <= begin: continue if tr.interval.left >= end: break - for v in tr.nodes(): - haps = np.zeros(len(all_samples)) - for x, u in enumerate(all_samples): - haps[x] = int(tr.is_descendant(u, v)) - haps_mean = haps.mean() - haps_centered = haps - haps_mean + span = min(end, tr.interval.right) - max(begin, tr.interval.left) + for v in range(ts.num_nodes): + freqs = [ + sum([tr.is_descendant(u, v) for u in x]) / len(x) + for x in sample_sets + ] + mean_freq = sum(freqs) / len(freqs) for i, (ix, iy) in enumerate(indexes): - X = sample_sets[ix] - Y = sample_sets[iy] - for x in X: - x_index = np.where(all_samples == x)[0][0] - for y in Y: - y_index = np.where(all_samples == y)[0][0] + fx = freqs[ix] + fy = freqs[iy] + if centre: + out[j][v][i] += span * (fx - mean_freq) * (fy - mean_freq) + if not polarised: out[j][v][i] += ( - haps_centered[x_index] * haps_centered[y_index] * span + span + * (1 - fx - (1 - mean_freq)) + * (1 - fy - (1 - mean_freq)) ) + else: + out[j][v][i] += span * fx * fy + if not polarised: + out[j][v][i] += span * (1 - fx) * (1 - fy) for i in range(len(indexes)): for v in ts.nodes(): iV = v.id @@ -1946,6 +1979,8 @@ def genetic_relatedness( mode="site", span_normalise=True, proportion=True, + centre=True, + polarised=True, ): """ Computes genetic relatedness between two random choices from x @@ -1965,7 +2000,9 @@ def genetic_relatedness( indexes=indexes, windows=windows, span_normalise=span_normalise, + polarised=polarised, proportion=proportion, + centre=centre, ) @@ -1983,6 +2020,8 @@ def verify_definition( ts_method, definition, proportion, + polarised=True, + centre=True, ): def wrapped_summary_func(x): with suppress_division_by_zero_warning(): @@ -2000,44 +2039,52 @@ def wrapped_summary_func(x): with np.errstate(divide="ignore", invalid="ignore"): sigma1 = ( - ts.general_stat(W, wrapped_summary_func, M, windows, mode=self.mode) - / denom - ) - sigma2 = ( - general_stat(ts, W, wrapped_summary_func, windows, mode=self.mode) + ts.general_stat( + W, + wrapped_summary_func, + M, + windows, + mode=self.mode, + strict=centre, + polarised=polarised, + ) / denom ) - sigma3 = ts_method( + sigma2 = ts_method( sample_sets, indexes=indexes, windows=windows, mode=self.mode, proportion=proportion, + centre=centre, + polarised=polarised, ) - sigma4 = definition( + sigma3 = definition( ts, sample_sets, indexes=indexes, windows=windows, mode=self.mode, proportion=proportion, + centre=centre, + polarised=polarised, ) assert sigma1.shape == sigma2.shape assert sigma1.shape == sigma3.shape - assert sigma1.shape == sigma4.shape self.assertArrayAlmostEqual(sigma1, sigma2) self.assertArrayAlmostEqual(sigma1, sigma3) - self.assertArrayAlmostEqual(sigma1, sigma4) def verify_sample_sets_indexes(self, ts, sample_sets, indexes, windows): n = np.array([len(x) for x in sample_sets]) - n_total = sum(n) - def f(x): - mx = np.sum(x) / n_total - return np.array( - [(x[i] - n[i] * mx) * (x[j] - n[j] * mx) / 2 for i, j in indexes] - ) + def f_noncentred(x): + p = x / n + return np.array([p[i] * p[j] for i, j in indexes]) + + def f_centred(x): + p = x / n + mp = np.mean(p) + return np.array([(p[i] - mp) * (p[j] - mp) for i, j in indexes]) for proportion in [True, False]: self.verify_definition( @@ -2045,16 +2092,178 @@ def f(x): sample_sets, indexes, windows, - f, + f_centred, ts.genetic_relatedness, genetic_relatedness, proportion, ) + for centre, polarised in [ + (True, True), + (False, True), + (True, False), + (False, False), + ]: + f = f_centred if centre else f_noncentred + self.verify_definition( + ts, + sample_sets, + indexes, + windows, + f, + ts.genetic_relatedness, + genetic_relatedness, + proportion=False, + centre=centre, + polarised=polarised, + ) + + @pytest.mark.parametrize("proportion", [None, True, False]) + def test_shapes(self, proportion): + # exclude this test in the parent class + if self.mode is None: + return + ts = msprime.sim_ancestry( + 8, + random_seed=1, + end_time=10, + sequence_length=10, + population_size=10, + recombination_rate=0.02, + ) + ts = msprime.sim_mutations(ts, rate=0.01, random_seed=2) + x = ts.genetic_relatedness( + sample_sets=[[0, 1, 2], [3]], + indexes=None, + windows=None, + mode=self.mode, + proportion=proportion, + ) + if self.mode == "node": + assert x.shape == (ts.num_nodes,) + else: + assert x.shape == () + x = ts.genetic_relatedness( + sample_sets=[[0, 1, 2], [3]], + indexes=[(0, 1)], + windows=None, + mode=self.mode, + proportion=proportion, + ) + if self.mode == "node": + assert x.shape == (ts.num_nodes, 1) + else: + assert x.shape == (1,) + x = ts.genetic_relatedness( + sample_sets=[[0, 1, 2], [3]], + indexes=[(0, 1)], + windows=[0, 10], + mode=self.mode, + proportion=proportion, + ) + if self.mode == "node": + assert x.shape == (1, ts.num_nodes, 1) + else: + assert x.shape == (1, 1) + x = ts.genetic_relatedness( + sample_sets=[[0, 1, 2], [3]], + indexes=[(0, 1)], + windows=[0, 5, 10], + mode=self.mode, + proportion=proportion, + ) + if self.mode == "node": + assert x.shape == (2, ts.num_nodes, 1) + else: + assert x.shape == (2, 1) + x = ts.genetic_relatedness( + sample_sets=[[0, 1, 2], [3]], + indexes=None, + windows=[0, 5, 10], + mode=self.mode, + proportion=proportion, + ) + if self.mode == "node": + assert x.shape == (2, ts.num_nodes) + else: + assert x.shape == (2,) + x = ts.genetic_relatedness( + sample_sets=[[0, 1, 2], [3], [4, 5]], + indexes=[(0, 1), (1, 2)], + windows=[0, 5, 9, 10], + mode=self.mode, + proportion=proportion, + ) + if self.mode == "node": + assert x.shape == (3, ts.num_nodes, 2) + else: + assert x.shape == (3, 2) + x = ts.genetic_relatedness( + sample_sets=[[0, 1, 2], [3], [4, 5]], + indexes=[(0, 1), (1, 2)], + windows=None, + mode=self.mode, + proportion=proportion, + ) + if self.mode == "node": + assert x.shape == (ts.num_nodes, 2) + else: + assert x.shape == (2,) + class TestBranchGeneticRelatedness(TestGeneticRelatedness, TopologyExamplesMixin): mode = "branch" + @pytest.mark.parametrize("polarised", [True, False]) + def test_simple_tree_noncentred(self, polarised): + # 2.00┊ 4 ┊ + # ┊ ┏━┻┓ ┊ + # 1.00┊ ┃ 3 ┊ + # ┊ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(3).tree_sequence + indexes = [(0, 0), (0, 1), (1, 1), (1, 2), (2, 2)] + sample_sets = [[0], [1], [2]] + if polarised: + A = np.array( + [ + 2, # (0, 0) + 0, # (0, 1) + 2, # (1, 1) + 1, # (1, 2), + 2, # (2, 2) + ] + ) + else: + A = np.array( + [ + (2 + 3), # (0, 0) + (0 + 1), # (0, 1) + (2 + 3), # (1, 1) + (1 + 2), # (1, 2), + (2 + 3), # (2, 2) + ] + ) + B = branch_genetic_relatedness( + ts, + sample_sets=sample_sets, + indexes=indexes, + polarised=polarised, + proportion=False, + centre=False, + ).squeeze() + C = ts.genetic_relatedness( + sample_sets=sample_sets, + indexes=indexes, + mode="branch", + polarised=polarised, + proportion=False, + centre=False, + ).squeeze() + self.assertArrayAlmostEqual(A, B) + self.assertArrayAlmostEqual(A, C) + class TestNodeGeneticRelatedness(TestGeneticRelatedness, TopologyExamplesMixin): mode = "node" @@ -2092,7 +2301,7 @@ def test_match_K_c0(self): @ (G_centered[y1] + G_centered[y2]) / ts.segregating_sites(sample_sets=all_samples, span_normalise=False) ) - self.assertArrayAlmostEqual(A, B) + self.assertArrayAlmostEqual(4 * A, B) ############################################ @@ -2100,7 +2309,9 @@ def test_match_K_c0(self): ############################################ -def genetic_relatedness_matrix(ts, sample_sets, windows=None, mode="site"): +def genetic_relatedness_matrix( + ts, sample_sets, windows=None, mode="site", polarised=True, centre=True +): n = len(sample_sets) indexes = [ (n1, n2) for n1, n2 in itertools.combinations_with_replacement(range(n), 2) @@ -2110,7 +2321,13 @@ def genetic_relatedness_matrix(ts, sample_sets, windows=None, mode="site"): n_nodes = ts.num_nodes K = np.zeros((n_nodes, n, n)) out = ts.genetic_relatedness( - sample_sets, indexes, mode=mode, proportion=False, span_normalise=True + sample_sets, + indexes, + mode=mode, + proportion=False, + span_normalise=True, + polarised=polarised, + centre=centre, ) for node in range(n_nodes): this_K = np.zeros((n, n)) @@ -2120,7 +2337,13 @@ def genetic_relatedness_matrix(ts, sample_sets, windows=None, mode="site"): else: K = np.zeros((n, n)) K[np.triu_indices(n)] = ts.genetic_relatedness( - sample_sets, indexes, mode=mode, proportion=False, span_normalise=True + sample_sets, + indexes, + mode=mode, + proportion=False, + span_normalise=True, + centre=centre, + polarised=polarised, ) K = K + np.triu(K, 1).transpose() else: @@ -2133,6 +2356,8 @@ def genetic_relatedness_matrix(ts, sample_sets, windows=None, mode="site"): windows=windows, proportion=False, span_normalise=True, + polarised=polarised, + centre=centre, ) if mode == "node": n_nodes = ts.num_nodes @@ -2153,11 +2378,16 @@ def genetic_relatedness_matrix(ts, sample_sets, windows=None, mode="site"): return K -def genetic_relatedness_weighted(ts, W, indexes, windows=None, mode="site"): - W_mean = W.mean(axis=0) - W = W - W_mean +def genetic_relatedness_weighted( + ts, W, indexes, windows=None, mode="site", polarised=True, centre=True +): + if centre: + W_mean = W.mean(axis=0) + W = W - W_mean sample_sets = [[u] for u in ts.samples()] - K = genetic_relatedness_matrix(ts, sample_sets, windows, mode) + K = genetic_relatedness_matrix( + ts, sample_sets, windows=windows, mode=mode, centre=centre, polarised=polarised + ) n_indexes = len(indexes) n_nodes = ts.num_nodes if windows is None: @@ -2207,16 +2437,38 @@ class TestGeneticRelatednessWeighted(StatsTestCase, WeightStatsMixin): mode = None def verify_definition( - self, ts, W, indexes, windows, summary_func, ts_method, definition + self, + ts, + W, + indexes, + windows, + summary_func, + ts_method, + definition, + polarised=True, + centre=True, ): # Determine output_dim of the function M = len(indexes) sigma1 = ts.general_stat( - W, summary_func, M, windows, mode=self.mode, span_normalise=True + W, + summary_func, + M, + windows, + mode=self.mode, + span_normalise=True, + strict=centre, + polarised=polarised, ) sigma2 = general_stat( - ts, W, summary_func, windows, mode=self.mode, span_normalise=True + ts, + W, + summary_func, + windows, + mode=self.mode, + span_normalise=True, + polarised=polarised, ) sigma3 = ts_method( @@ -2224,6 +2476,8 @@ def verify_definition( indexes=indexes, windows=windows, mode=self.mode, + polarised=polarised, + centre=centre, ) sigma4 = definition( ts, @@ -2231,6 +2485,8 @@ def verify_definition( indexes=indexes, windows=windows, mode=self.mode, + polarised=polarised, + centre=centre, ) assert sigma1.shape == sigma2.shape assert sigma1.shape == sigma3.shape @@ -2247,29 +2503,38 @@ def verify(self, ts): self.verify_weighted_stat(ts, W, indexes, windows) def verify_weighted_stat(self, ts, W, indexes, windows): - W_mean = W.mean(axis=0) - W = W - W_mean - W_sum = W.sum(axis=0) n = W.shape[0] + K = W.shape[1] + WW = np.column_stack([W, np.ones(n) / n]) + W_sum = WW.sum(axis=0) - def f(x): - mx = np.sum(x) / n + def f_noncentred(x): + return np.array([x[i] * x[j] for i, j in indexes]) + + def f_centred(x): + pn = x[K] return np.array( - [ - (x[i] - W_sum[i] * mx) * (x[j] - W_sum[j] * mx) / 2 - for i, j in indexes - ] + [(x[i] - W_sum[i] * pn) * (x[j] - W_sum[j] * pn) for i, j in indexes] ) - self.verify_definition( - ts, - W, - indexes, - windows, - f, - ts.genetic_relatedness_weighted, - genetic_relatedness_weighted, - ) + for centre, polarised in [ + (True, True), + (False, True), + (True, False), + (False, False), + ]: + f = f_centred if centre else f_noncentred + self.verify_definition( + ts, + WW, + indexes, + windows, + f, + ts.genetic_relatedness_weighted, + genetic_relatedness_weighted, + centre=centre, + polarised=polarised, + ) class TestBranchGeneticRelatednessWeighted( @@ -2296,8 +2561,8 @@ class TestSiteGeneticRelatednessWeighted( class TestGeneticRelatednessWeightedSimpleExamples: # Values verified against the simple implementations above - site_value = 11.12 - branch_value = 14.72 + site_value = 22.24 + branch_value = 29.44 def fixture(self): ts = tskit.Tree.generate_balanced(5).tree_sequence @@ -3344,9 +3609,6 @@ def f(x): self.verify_definition(ts, sample_sets, indexes, windows, f, ts.f4, f4) - def verify_interface(self, ts): - self.verify_interface_method(ts.f4) - class TestBranchf4(Testf4, TopologyExamplesMixin): mode = "branch" @@ -4306,6 +4568,64 @@ def test_simple_identity_f_w_zeros_windows(self): assert sigma.shape == (10, W.shape[1]) assert np.all(sigma == 0) + def test_nonstrict_nonancestral_material(self): + # 0 is a sample, 1 is not + # + # 2.00┊ 2 ┊ 2 ┊ 2 ┊ + # ┊ ┏┻┓ ┊ ┃ ┊ ┏┻┓ ┊ + # 1.00┊ ┃ 1 ┊ 1 ┊ ┃ 1 ┊ + # ┊ ┃ ┊ ┃ ┊ ┃ ┊ + # 0.00┊ 0 ┊ 0 ┊ 0 ┊ + # 0 1 2 3 + + tables = tskit.TableCollection(sequence_length=3) + + node_times = [0, 1, 2] + samples = [0] + for n, t in enumerate(node_times): + tables.nodes.add_row( + time=t, flags=tskit.NODE_IS_SAMPLE if n in samples else 0 + ) + + # p, c, l, r + edges = [ + (1, 0, 1, 2), + (2, 0, 0, 1), + (2, 0, 2, 3), + (2, 1, 0, 3), + ] + for p, c, l, r in edges: + tables.edges.add_row(parent=p, child=c, left=l, right=r) + + # this makes it so 'site' mode counts branches + for x in range(int(tables.sequence_length)): + for n in range(tables.nodes.num_rows - 1): + offset = n / tables.nodes.num_rows + s = tables.sites.add_row(position=x + offset, ancestral_state="0") + tables.mutations.add_row(site=s, node=n, derived_state="1") + + ts = tables.tree_sequence() + + def f(x): + return x + + for polarised, mode, answer in [ + (True, "branch", 6), + (True, "site", 4), + (False, "branch", 8), + (False, "site", 6), + ]: + (stat,) = ts.sample_count_stat( + [[0]], + f, + 1, + strict=False, + span_normalise=False, + polarised=polarised, + mode=mode, + ) + assert stat == answer + class TestGeneralSiteStats(StatsTestCase): """ diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 105d0fbc37..73830bb966 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -7675,6 +7675,7 @@ def __k_way_sample_set_stat( mode=None, span_normalise=True, polarised=False, + centre=True, ): sample_set_sizes = np.array( [len(sample_set) for sample_set in sample_sets], dtype=np.uint32 @@ -7710,6 +7711,7 @@ def __k_way_sample_set_stat( mode=mode, span_normalise=span_normalise, polarised=polarised, + centre=centre, ) if drop_dimension: stat = stat.reshape(stat.shape[:-1]) @@ -7727,6 +7729,7 @@ def __k_way_weighted_stat( mode=None, span_normalise=True, polarised=False, + centre=True, ): W = np.asarray(W) if indexes is None: @@ -7754,6 +7757,7 @@ def __k_way_weighted_stat( mode=mode, span_normalise=span_normalise, polarised=polarised, + centre=centre, ) if drop_dimension: stat = stat.reshape(stat.shape[:-1]) @@ -8115,8 +8119,9 @@ def genetic_relatedness( windows=None, mode="site", span_normalise=True, - polarised=False, + polarised=True, proportion=True, + centre=True, ): """ Computes genetic relatedness between (and within) pairs of @@ -8135,44 +8140,70 @@ def genetic_relatedness( What is computed depends on ``mode``: "site" - Number of pairwise allelic matches in the window between two + Frequency of pairwise allelic matches in the window between two sample sets relative to the rest of the sample sets. To be precise, let `m(u,v)` denote the total number of alleles shared between - nodes `u` and `v`, and let `m(I,J)` be the sum of `m(u,v)` over all - nodes `u` in sample set `I` and `v` in sample set `J`. Let `S` and - `T` be independently chosen sample sets. Then, for sample sets `I` - and `J`, this computes `E[m(I,J) - m(I,S) - m(J,T) + m(S,T)]`. + nodes `u` and `v`, and let `m(I,J)` be the average of `m(u,v)` over + all nodes `u` in sample set `I` and `v` in sample set `J`. Let `S` + and `T` be independently chosen sample sets. Then, for sample sets + `I` and `J`, this computes `E[m(I,J) - m(I,S) - m(J,T) + m(S,T)]` + if centre=True (the default), or `E[m(I,J)]` if centre=False. This can also be seen as the covariance of a quantitative trait determined by additive contributions from the genomes in each - sample set. Let each allele be associated with an effect drawn from - a `N(0,1/2)` distribution, and let the trait value of a sample set - be the sum of its allele effects. Then, this computes the covariance - between the trait values of two sample sets. For example, to - compute covariance between the traits of diploid individuals, each - sample set would be the pair of genomes of each individual; if - ``proportion=True``, this then corresponds to :math:`K_{c0}` in - `Speed & Balding (2014) `_. + sample set. Let each derived allele be associated with an effect + drawn from a `N(0,1)` distribution, and let the trait value of a + sample be the sum of its allele effects. Then, this computes + the covariance between the average trait values of two sample sets. + For example, to compute covariance between the traits of diploid + individuals, each sample set would be the pair of genomes of each + individual, with the trait being the average of the two genomes. + If ``proportion=True``, this then corresponds to :math:`K_{c0}` in + `Speed & Balding (2014) `_, + multiplied by four (see below). "branch" - Total area of branches in the window ancestral to pairs of samples + Average area of branches in the window ancestral to pairs of samples in two sample sets relative to the rest of the sample sets. To be precise, let `B(u,v)` denote the total area of all branches - ancestral to nodes `u` and `v`, and let `B(I,J)` be the sum of + ancestral to nodes `u` and `v`, and let `B(I,J)` be the average of `B(u,v)` over all nodes `u` in sample set `I` and `v` in sample set `J`. Let `S` and `T` be two independently chosen sample sets. Then for sample sets `I` and `J`, this computes - `E[B(I,J) - B(I,S) - B(J,T) + B(S,T)]`. + `E[B(I,J) - B(I,S) - B(J,T) + B(S,T)]` if centre=True (the default), + or `E[B(I,J)]` if centre=False. "node" For each node, the proportion of the window over which pairs of samples in two sample sets are descendants, relative to the rest of the sample sets. To be precise, for each node `n`, let `N(u,v)` denote the proportion of the window over which samples `u` and `v` - are descendants of `n`, and let and let `N(I,J)` be the sum of + are descendants of `n`, and let and let `N(I,J)` be the average of `N(u,v)` over all nodes `u` in sample set `I` and `v` in sample set `J`. Let `S` and `T` be two independently chosen sample sets. Then for sample sets `I` and `J`, this computes - `E[N(I,J) - N(I,S) - N(J,T) + N(S,T)]`. + `E[N(I,J) - N(I,S) - N(J,T) + N(S,T)]` if centre=True (the default), + or `E[N(I,J)]` if centre=False. + + *Note:* The default for this statistic - unlike most other statistics - is + ``polarised=True``. Using the default value ``centre=True``, setting + ``polarised=False`` will only multiply the result by a factor of two + for branch-mode, or site-mode if all sites are biallelic. (With + multiallelic sites the difference is more complicated.) The uncentred + and unpolarised value is probably not what you are looking for: for + instance, the unpolarised, uncentred site statistic between two samples + counts the number of alleles inherited by both *and* the number of + alleles inherited by neither of the two samples. + + *Note:* Some authors + (see `Speed & Balding (2014) `_) + compute relatedness between `I` and `J` as the total number of all pairwise + allelic matches between `I` and `J`, rather than the frequency, + which would define `m(I,J)` as the sum of `m(u,v)` rather than the average + in the definition of "site" relatedness above. If every sample set is the + samples of a :math:`k`-ploid individual, this would simply multiply the + result by :math:`k^2`. However, this definition would make the result not + useful as a summary statistic of typical relatedness for larger sample + sets. :param list sample_sets: A list of lists of Node IDs, specifying the groups of nodes to compute the statistic with. @@ -8189,23 +8220,16 @@ def genetic_relatedness( that are segregating between *any* of the samples of *any* of the sample sets (rather than segregating between all of the samples of the tree sequence). + :param bool polarised: Whether to leave the ancestral state out of computations: + see :ref:`sec_stats` for more details. Defaults to True. + :param bool centre: Defaults to True. Whether to 'centre' the result, as + described above (the usual definition is centred). :return: A ndarray with shape equal to (num windows, num statistics). If there is one pair of sample sets and windows=None, a numpy scalar is returned. """ - if proportion: - # TODO this should be done in C also - all_samples = list({u for s in sample_sets for u in s}) - denominator = self.segregating_sites( - sample_sets=[all_samples], - windows=windows, - mode=mode, - span_normalise=span_normalise, - ) - else: - denominator = 1 - numerator = self.__k_way_sample_set_stat( + out = self.__k_way_sample_set_stat( self._ll_tree_sequence.genetic_relatedness, 2, sample_sets, @@ -8214,9 +8238,25 @@ def genetic_relatedness( mode=mode, span_normalise=span_normalise, polarised=polarised, + centre=centre, ) - with np.errstate(divide="ignore", invalid="ignore"): - out = numerator / denominator + if proportion: + # TODO this should be done in C also + all_samples = np.array(list({u for s in sample_sets for u in s})) + denominator = self.segregating_sites( + sample_sets=all_samples, + windows=windows, + mode=mode, + span_normalise=span_normalise, + ) + # the shapes of out and denominator should be the same except that + # out may have an extra dimension if indexes is not None + if indexes is not None and not isinstance(denominator, float): + oshape = list(out.shape) + oshape[-1] = 1 + denominator = denominator.reshape(oshape) + with np.errstate(divide="ignore", invalid="ignore"): + out /= denominator return out @@ -8229,6 +8269,53 @@ def genetic_relatedness_matrix( mode=None, span_normalise=True, ): + """ + Computes the full matrix of pairwise genetic relatedness values + between (and within) pairs of sets of nodes from ``sample_sets``. + *Warning:* this does not compute exactly the same thing as + :meth:`.genetic_relatedness`: see below for more details. + + If `mode="branch"`, then the value obtained is the same as that from + :meth:`.genetic_relatedness`, using the options `centre=True` and + `proportion=False`. The same is true if `mode="site"` and all sites have + at most one mutation. + + However, if some sites have more than one mutation, the value may differ. + The reason is that this function (for efficiency) computes relatedness + using :meth:`.divergence` and the following relationship. + "Relatedness" measures the number of *shared* alleles (or branches), + while "divergence" measures the number of *non-shared* alleles (or branches). + Let :math:`T_i` be the total distance from sample :math:`i` up to the root; + then if :math:`D_{ij}` is the divergence between :math:`i` and :math:`j` + and :math:`R_{ij}` is the relatedness between :math:`i` and :math:`j`, then + :math:`T_i + T_j = D_{ij} + 2 R_{ij}.` + So, for any samples :math:`I`, :math:`J`, :math:`S`, :math:`T` + (that may now be random choices), + :math:`R_{IJ}-R_{IS}-R_{JT}+R_{ST} = (D_{IJ}-D_{IS}-D_{JT}+D_{ST})/ (-2)`. + Note, however, that this relationship only holds for `mode="site"` + if we can treat "number of differing alleles" as distances on the tree; + this is not necessarily the case in the presence of multiple mutations. + + Another caveat in the above relationship between :math:`R` and :math:`D` + is that :meth:`.divergence` of a sample set to itself does not include + the "self" comparisons (so as to provide an unbiased estimator of a + population quantity), while the usual definition of genetic relatedness + *does* include such comparisons (to provide, for instance, an appropriate + value for prospective results beginning with only a given set of + individuals). + + :param list sample_sets: A list of lists of Node IDs, specifying the + groups of nodes to compute the statistic with. + :param list windows: An increasing list of breakpoints between the windows + to compute the statistic in. + :param str mode: A string giving the "type" of the statistic to be computed + (defaults to "site"). + :param bool span_normalise: Whether to divide the result by the span of the + window (defaults to True). Has no effect if ``proportion`` is True. + :return: A ndarray with shape equal to (num windows, num statistics). + If there is one pair of sample sets and windows=None, a numpy scalar is + returned. + """ D = self.divergence_matrix( sample_sets, windows=windows, @@ -8237,24 +8324,20 @@ def genetic_relatedness_matrix( span_normalise=span_normalise, ) - # FIXME remove this when sample sets bug has been fixed. - # https://github.com/tskit-dev/tskit/issues/2888 - if sample_sets is not None: - if any(len(ss) > 1 for ss in sample_sets): - raise ValueError( - "Only single entry sample sets allowed for now." - " See https://github.com/tskit-dev/tskit/issues/2888" - ) + if sample_sets is None: + n = np.ones(self.num_samples) + else: + n = np.array([len(x) for x in sample_sets]) def _normalise(B): if len(B) == 0: return B + # correct for lack of self comparisons in divergence + np.fill_diagonal(B, np.diag(B) * (n - 1) / n) K = B + np.mean(B) y = np.mean(B, axis=0) X = y[:, np.newaxis] + y[np.newaxis, :] K -= X - # FIXME this factor of 2 works for single-sample sample-sets, but not - # otherwise. https://github.com/tskit-dev/tskit/issues/2888 return K / -2 if windows is None: @@ -8272,15 +8355,20 @@ def genetic_relatedness_weighted( mode="site", span_normalise=True, polarised=False, + centre=True, ): r""" - Computes weighted genetic relatedness. If the k-th pair of indices is (i, j) - then the k-th column of output will be + Computes weighted genetic relatedness. If the :math:`k` th pair of indices + is (i, j) then the :math:`k` th column of output will be :math:`\sum_{a,b} W_{ai} W_{bj} C_{ab}`, where :math:`W` is the matrix of weights, and :math:`C_{ab}` is the :meth:`genetic_relatedness <.TreeSequence.genetic_relatedness>` between sample a and sample b, summing over all pairs of samples in the tree sequence. + *Note:* the genetic relatedness matrix :math:`C` here is as returned by + :meth:`.genetic_relatedness`, rather than by :meth:`.genetic_relatedness_matrix` + (see the latter's documentation for the difference). + :param numpy.ndarray W: An array of values with one row for each sample node and one column for each set of weights. :param list indexes: A list of 2-tuples, or None (default). Note that if @@ -8292,6 +8380,10 @@ def genetic_relatedness_weighted( (defaults to "site"). :param bool span_normalise: Whether to divide the result by the span of the window (defaults to True). + :param bool polarised: Whether to leave the ancestral state out of computations: + see :ref:`sec_stats` for more details. Defaults to True. + :param bool centre: Defaults to True. Whether to 'centre' the result, as + described above (the usual definition is centred). :return: A ndarray with shape equal to (num windows, num statistics). """ if len(W) != self.num_samples: @@ -8307,6 +8399,7 @@ def genetic_relatedness_weighted( mode=mode, span_normalise=span_normalise, polarised=polarised, + centre=centre, ) def trait_covariance(self, W, windows=None, mode="site", span_normalise=True): @@ -8806,12 +8899,35 @@ def Fst( # two-way stats and (b) it's a bit more efficient because we're not messing # around with indexes and samples sets twice. - def fst_func(sample_set_sizes, flattened, indexes, **kwargs): - diversities = self._ll_tree_sequence.diversity( - sample_set_sizes, flattened, **kwargs - ) + def fst_func( + sample_set_sizes, + flattened, + indexes, + windows, + mode, + span_normalise, + polarised, + centre, + ): + # note: this is kinda hacky - polarised and centre are not used here - + # but this seems necessary to use our __k_way_sample_set_stat framework divergences = self._ll_tree_sequence.divergence( - sample_set_sizes, flattened, indexes, **kwargs + sample_set_sizes, + flattened, + indexes=indexes, + windows=windows, + mode=mode, + span_normalise=span_normalise, + polarised=polarised, + centre=centre, + ) + diversities = self._ll_tree_sequence.diversity( + sample_set_sizes, + flattened, + windows=windows, + mode=mode, + span_normalise=span_normalise, + polarised=polarised, ) orig_shape = divergences.shape