diff --git a/subprojects/tskit/tskit.h b/subprojects/tskit/tskit.h index e55ffc664..e8f1b69f6 100644 --- a/subprojects/tskit/tskit.h +++ b/subprojects/tskit/tskit.h @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019 Tskit Developers + * Copyright (c) 2019-2024 Tskit Developers * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/subprojects/tskit/tskit/convert.c b/subprojects/tskit/tskit/convert.c index e100b3fea..657b97868 100644 --- a/subprojects/tskit/tskit/convert.c +++ b/subprojects/tskit/tskit/convert.c @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2018-2021 Tskit Developers + * Copyright (c) 2018-2025 Tskit Developers * Copyright (c) 2015-2017 University of Oxford * * Permission is hereby granted, free of charge, to any person obtaining a copy @@ -68,11 +68,11 @@ tsk_newick_converter_run( const char *label_format = ms_labels ? "%d" : "n%d"; if (root < 0 || root >= (tsk_id_t) self->tree->num_nodes) { - ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + ret = tsk_trace_error(TSK_ERR_NODE_OUT_OF_BOUNDS); goto out; } if (buffer == NULL) { - ret = TSK_ERR_BAD_PARAM_VALUE; + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); goto out; } root_parent = tree->parent[root]; @@ -82,7 +82,7 @@ tsk_newick_converter_run( v = stack[stack_top]; if (tree->left_child[v] != TSK_NULL && v != u) { if (s >= buffer_size) { - ret = TSK_ERR_BUFFER_OVERFLOW; + ret = tsk_trace_error(TSK_ERR_BUFFER_OVERFLOW); goto out; } buffer[s] = '('; @@ -104,17 +104,17 @@ tsk_newick_converter_run( } if (label != -1) { if (s >= buffer_size) { - ret = TSK_ERR_BUFFER_OVERFLOW; + ret = tsk_trace_error(TSK_ERR_BUFFER_OVERFLOW); goto out; } r = snprintf(buffer + s, buffer_size - s, label_format, label); if (r < 0) { - ret = TSK_ERR_IO; + ret = tsk_trace_error(TSK_ERR_IO); goto out; } s += (size_t) r; if (s >= buffer_size) { - ret = TSK_ERR_BUFFER_OVERFLOW; + ret = tsk_trace_error(TSK_ERR_BUFFER_OVERFLOW); goto out; } } @@ -123,12 +123,12 @@ tsk_newick_converter_run( r = snprintf(buffer + s, buffer_size - s, ":%.*f", (int) self->precision, branch_length); if (r < 0) { - ret = TSK_ERR_IO; + ret = tsk_trace_error(TSK_ERR_IO); goto out; } s += (size_t) r; if (s >= buffer_size) { - ret = TSK_ERR_BUFFER_OVERFLOW; + ret = tsk_trace_error(TSK_ERR_BUFFER_OVERFLOW); goto out; } if (v == tree->right_child[u]) { @@ -141,7 +141,7 @@ tsk_newick_converter_run( } } if ((s + 1) >= buffer_size) { - ret = TSK_ERR_BUFFER_OVERFLOW; + ret = tsk_trace_error(TSK_ERR_BUFFER_OVERFLOW); goto out; } buffer[s] = ';'; @@ -164,7 +164,7 @@ tsk_newick_converter_init(tsk_newick_converter_t *self, const tsk_tree_t *tree, self->traversal_stack = tsk_malloc(tsk_tree_get_size_bound(tree) * sizeof(*self->traversal_stack)); if (self->traversal_stack == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } out: diff --git a/subprojects/tskit/tskit/core.c b/subprojects/tskit/tskit/core.c index b1ea25bad..53cc0ce67 100644 --- a/subprojects/tskit/tskit/core.c +++ b/subprojects/tskit/tskit/core.c @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2023 Tskit Developers + * Copyright (c) 2019-2025 Tskit Developers * Copyright (c) 2015-2018 University of Oxford * * Permission is hereby granted, free of charge, to any person obtaining a copy @@ -43,22 +43,24 @@ static int TSK_WARN_UNUSED get_random_bytes(uint8_t *buf) { /* Based on CPython's code in bootstrap_hash.c */ - int ret = TSK_ERR_GENERATE_UUID; + int ret = 0; HCRYPTPROV hCryptProv = (HCRYPTPROV) NULL; if (!CryptAcquireContext( &hCryptProv, NULL, NULL, PROV_RSA_FULL, CRYPT_VERIFYCONTEXT)) { + ret = tsk_trace_error(TSK_ERR_GENERATE_UUID); goto out; } if (!CryptGenRandom(hCryptProv, (DWORD) UUID_NUM_BYTES, buf)) { + ret = tsk_trace_error(TSK_ERR_GENERATE_UUID); goto out; } if (!CryptReleaseContext(hCryptProv, 0)) { hCryptProv = (HCRYPTPROV) NULL; + ret = tsk_trace_error(TSK_ERR_GENERATE_UUID); goto out; } hCryptProv = (HCRYPTPROV) NULL; - ret = 0; out: if (hCryptProv != (HCRYPTPROV) NULL) { CryptReleaseContext(hCryptProv, 0); @@ -72,19 +74,21 @@ get_random_bytes(uint8_t *buf) static int TSK_WARN_UNUSED get_random_bytes(uint8_t *buf) { - int ret = TSK_ERR_GENERATE_UUID; + int ret = 0; FILE *f = fopen("/dev/urandom", "r"); if (f == NULL) { + ret = tsk_trace_error(TSK_ERR_GENERATE_UUID); goto out; } if (fread(buf, UUID_NUM_BYTES, 1, f) != 1) { + ret = tsk_trace_error(TSK_ERR_GENERATE_UUID); goto out; } if (fclose(f) != 0) { + ret = tsk_trace_error(TSK_ERR_GENERATE_UUID); goto out; } - ret = 0; out: return ret; } @@ -111,7 +115,7 @@ tsk_generate_uuid(char *dest, int TSK_UNUSED(flags)) buf[4], buf[5], buf[6], buf[7], buf[8], buf[9], buf[10], buf[11], buf[12], buf[13], buf[14], buf[15]) < 0) { - ret = TSK_ERR_GENERATE_UUID; + ret = tsk_trace_error(TSK_ERR_GENERATE_UUID); goto out; } out: @@ -164,7 +168,8 @@ tsk_strerror_internal(int err) break; case TSK_ERR_FILE_VERSION_TOO_OLD: ret = "tskit file version too old. Please upgrade using the " - "'tskit upgrade' command. (TSK_ERR_FILE_VERSION_TOO_OLD)"; + "'tskit upgrade' command from tskit version<0.6.2. " + "(TSK_ERR_FILE_VERSION_TOO_OLD)"; break; case TSK_ERR_FILE_VERSION_TOO_NEW: ret = "tskit file version is too new for this instance. " @@ -226,13 +231,16 @@ tsk_strerror_internal(int err) ret = "One of the kept rows in the table refers to a deleted row. " "(TSK_ERR_KEEP_ROWS_MAP_TO_DELETED)"; break; + case TSK_ERR_POSITION_OUT_OF_BOUNDS: + ret = "Position out of bounds. (TSK_ERR_POSITION_OUT_OF_BOUNDS)"; + break; /* Edge errors */ case TSK_ERR_NULL_PARENT: - ret = "Edge in parent is null. (TSK_ERR_NULL_PARENT)"; + ret = "Edge parent is null. (TSK_ERR_NULL_PARENT)"; break; case TSK_ERR_NULL_CHILD: - ret = "Edge in parent is null. (TSK_ERR_NULL_CHILD)"; + ret = "Edge child is null. (TSK_ERR_NULL_CHILD)"; break; case TSK_ERR_EDGES_NOT_SORTED_PARENT_TIME: ret = "Edges must be listed in (time[parent], child, left) order;" @@ -274,7 +282,10 @@ tsk_strerror_internal(int err) break; case TSK_ERR_CANT_PROCESS_EDGES_WITH_METADATA: ret = "Can't squash, flush, simplify or link ancestors with edges that have " - "non-empty metadata. (TSK_ERR_CANT_PROCESS_EDGES_WITH_METADATA)"; + "non-empty metadata. Removing the metadata from the edges will allow " + "these operations to proceed. For example using " + "tables.edges.drop_metadata() in the tskit Python API. " + "(TSK_ERR_CANT_PROCESS_EDGES_WITH_METADATA)"; break; /* Site errors */ @@ -330,6 +341,17 @@ tsk_strerror_internal(int err) "values for any single site. " "(TSK_ERR_MUTATION_TIME_HAS_BOTH_KNOWN_AND_UNKNOWN)"; break; + case TSK_ERR_DISALLOWED_UNKNOWN_MUTATION_TIME: + ret = "Some mutation times are marked 'unknown' for a method that requires " + "no unknown times. (Use compute_mutation_times to add times?) " + "(TSK_ERR_DISALLOWED_UNKNOWN_MUTATION_TIME)"; + break; + + case TSK_ERR_BAD_MUTATION_PARENT: + ret = "A mutation's parent is not consistent with the topology of the tree. " + "Use compute_mutation_parents to set the parents correctly." + "(TSK_ERR_BAD_MUTATION_PARENT)"; + break; /* Migration errors */ case TSK_ERR_UNSORTED_MIGRATIONS: @@ -466,6 +488,73 @@ tsk_strerror_internal(int err) ret = "Statistics using branch lengths cannot be calculated when time_units " "is 'uncalibrated'. (TSK_ERR_TIME_UNCALIBRATED)"; break; + case TSK_ERR_STAT_POLARISED_UNSUPPORTED: + ret = "The TSK_STAT_POLARISED option is not supported by this statistic. " + "(TSK_ERR_STAT_POLARISED_UNSUPPORTED)"; + break; + case TSK_ERR_STAT_SPAN_NORMALISE_UNSUPPORTED: + ret = "The TSK_STAT_SPAN_NORMALISE option is not supported by this " + "statistic. " + "(TSK_ERR_STAT_SPAN_NORMALISE_UNSUPPORTED)"; + break; + case TSK_ERR_INSUFFICIENT_WEIGHTS: + ret = "Insufficient weights provided (at least 1 required). " + "(TSK_ERR_INSUFFICIENT_WEIGHTS)"; + break; + + /* Pair coalescence errors */ + case TSK_ERR_BAD_NODE_BIN_MAP: + ret = "Node-to-bin map contains values less than TSK_NULL. " + "(TSK_ERR_BAD_NODE_BIN_MAP)"; + break; + case TSK_ERR_BAD_NODE_BIN_MAP_DIM: + ret = "Maximum index in node-to-bin map is greater than the " + "output dimension. (TSK_ERR_BAD_NODE_BIN_MAP_DIM)"; + break; + case TSK_ERR_BAD_QUANTILES: + ret = "Quantiles must be between 0 and 1 (inclusive) " + "and strictly increasing. (TSK_ERR_BAD_QUANTILES)"; + break; + case TSK_ERR_UNSORTED_TIMES: + ret = "Times must be strictly increasing. (TSK_ERR_UNSORTED_TIMES)"; + break; + case TSK_ERR_BAD_TIME_WINDOWS_DIM: + ret = "Must have at least one time window. (TSK_ERR_BAD_TIME_WINDOWS_DIM)"; + break; + case TSK_ERR_BAD_SAMPLE_PAIR_TIMES: + ret = "All sample times must be equal to the start of first time window. " + "(TSK_ERR_BAD_SAMPLE_PAIR_TIMES)"; + break; + case TSK_ERR_BAD_TIME_WINDOWS: + ret = "Time windows must start at zero and be strictly increasing. " + "(TSK_ERR_BAD_TIME_WINDOWS)"; + break; + case TSK_ERR_BAD_TIME_WINDOWS_END: + ret = "Time windows must end at infinity for this method. " + "(TSK_ERR_BAD_TIME_WINDOWS_END)"; + break; + case TSK_ERR_BAD_NODE_TIME_WINDOW: + ret = "Node time does not fall within assigned time window. " + "(TSK_ERR_BAD_NODE_TIME_WINDOW)"; + break; + + /* Two locus errors */ + case TSK_ERR_STAT_UNSORTED_POSITIONS: + ret = "The provided positions are not sorted in strictly increasing " + "order. (TSK_ERR_STAT_UNSORTED_POSITIONS)"; + break; + case TSK_ERR_STAT_DUPLICATE_POSITIONS: + ret = "The provided positions contain duplicates. " + "(TSK_ERR_STAT_DUPLICATE_POSITIONS)"; + break; + case TSK_ERR_STAT_UNSORTED_SITES: + ret = "The provided sites are not sorted in strictly increasing position " + "order. (TSK_ERR_STAT_UNSORTED_SITES)"; + break; + case TSK_ERR_STAT_DUPLICATE_SITES: + ret = "The provided sites contain duplicated entries. " + "(TSK_ERR_STAT_DUPLICATE_SITES)"; + break; /* Mutation mapping errors */ case TSK_ERR_GENOTYPES_ALL_MISSING: @@ -602,6 +691,11 @@ tsk_strerror_internal(int err) "if an individual has nodes from more than one time. " "(TSK_ERR_INDIVIDUAL_TIME_MISMATCH)"; break; + + case TSK_ERR_EXTEND_EDGES_BAD_MAXITER: + ret = "Maximum number of iterations must be positive. " + "(TSK_ERR_EXTEND_EDGES_BAD_MAXITER)"; + break; } return ret; } @@ -685,7 +779,7 @@ tsk_blkalloc_init(tsk_blkalloc_t *self, size_t chunk_size) tsk_memset(self, 0, sizeof(tsk_blkalloc_t)); if (chunk_size < 1) { - ret = TSK_ERR_BAD_PARAM_VALUE; + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); goto out; } self->chunk_size = chunk_size; @@ -696,12 +790,12 @@ tsk_blkalloc_init(tsk_blkalloc_t *self, size_t chunk_size) self->num_chunks = 0; self->mem_chunks = malloc(sizeof(char *)); if (self->mem_chunks == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } self->mem_chunks[0] = malloc(chunk_size); if (self->mem_chunks[0] == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } self->num_chunks = 1; @@ -1164,3 +1258,130 @@ tsk_avl_tree_int_ordered_nodes(const tsk_avl_tree_int_t *self, tsk_avl_node_int_ ordered_nodes_traverse(self->head.rlink, 0, out); return 0; } + +// Bit Array implementation. Allows us to store unsigned integers in a compact manner. +// Currently implemented as an array of 32-bit unsigned integers for ease of counting. + +int +tsk_bit_array_init(tsk_bit_array_t *self, tsk_size_t num_bits, tsk_size_t length) +{ + int ret = 0; + + self->size = (num_bits >> TSK_BIT_ARRAY_CHUNK) + + (num_bits % TSK_BIT_ARRAY_NUM_BITS ? 1 : 0); + self->data = tsk_calloc(self->size * length, sizeof(*self->data)); + if (self->data == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } +out: + return ret; +} + +void +tsk_bit_array_get_row(const tsk_bit_array_t *self, tsk_size_t row, tsk_bit_array_t *out) +{ + out->size = self->size; + out->data = self->data + (row * self->size); +} + +void +tsk_bit_array_intersect( + const tsk_bit_array_t *self, const tsk_bit_array_t *other, tsk_bit_array_t *out) +{ + for (tsk_size_t i = 0; i < self->size; i++) { + out->data[i] = self->data[i] & other->data[i]; + } +} + +void +tsk_bit_array_subtract(tsk_bit_array_t *self, const tsk_bit_array_t *other) +{ + for (tsk_size_t i = 0; i < self->size; i++) { + self->data[i] &= ~(other->data[i]); + } +} + +void +tsk_bit_array_add(tsk_bit_array_t *self, const tsk_bit_array_t *other) +{ + for (tsk_size_t i = 0; i < self->size; i++) { + self->data[i] |= other->data[i]; + } +} + +void +tsk_bit_array_add_bit(tsk_bit_array_t *self, const tsk_bit_array_value_t bit) +{ + tsk_bit_array_value_t i = bit >> TSK_BIT_ARRAY_CHUNK; + self->data[i] |= (tsk_bit_array_value_t) 1 << (bit - (TSK_BIT_ARRAY_NUM_BITS * i)); +} + +bool +tsk_bit_array_contains(const tsk_bit_array_t *self, const tsk_bit_array_value_t bit) +{ + tsk_bit_array_value_t i = bit >> TSK_BIT_ARRAY_CHUNK; + return self->data[i] + & ((tsk_bit_array_value_t) 1 << (bit - (TSK_BIT_ARRAY_NUM_BITS * i))); +} + +tsk_size_t +tsk_bit_array_count(const tsk_bit_array_t *self) +{ + // Utilizes 12 operations per bit array. NB this only works on 32 bit integers. + // Taken from: + // https://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel + // There's a nice breakdown of this algorithm here: + // https://stackoverflow.com/a/109025 + // Could probably do better with explicit SIMD (instead of SWAR), but not as + // portable: https://arxiv.org/pdf/1611.07612.pdf + // + // There is one solution to explore further, which uses __builtin_popcountll. + // This option is relatively simple, but requires a 64 bit bit array and also + // involves some compiler flag plumbing (-mpopcnt) + + tsk_bit_array_value_t tmp; + tsk_size_t i, count = 0; + + for (i = 0; i < self->size; i++) { + tmp = self->data[i] - ((self->data[i] >> 1) & 0x55555555); + tmp = (tmp & 0x33333333) + ((tmp >> 2) & 0x33333333); + count += (((tmp + (tmp >> 4)) & 0xF0F0F0F) * 0x1010101) >> 24; + } + return count; +} + +void +tsk_bit_array_get_items( + const tsk_bit_array_t *self, tsk_id_t *items, tsk_size_t *n_items) +{ + // Get the items stored in the row of a bitset. + // Uses a de Bruijn sequence lookup table to determine the lowest bit set. See the + // wikipedia article for more info: https://w.wiki/BYiF + + tsk_size_t i, n, off; + tsk_bit_array_value_t v, lsb; // least significant bit + static const tsk_id_t lookup[32] = { 0, 1, 28, 2, 29, 14, 24, 3, 30, 22, 20, 15, 25, + 17, 4, 8, 31, 27, 13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9 }; + + n = 0; + for (i = 0; i < self->size; i++) { + v = self->data[i]; + off = i * ((tsk_size_t) TSK_BIT_ARRAY_NUM_BITS); + if (v == 0) { + continue; + } + while ((lsb = v & -v)) { + items[n] = lookup[(lsb * 0x077cb531U) >> 27] + (tsk_id_t) off; + n++; + v ^= lsb; + } + } + *n_items = n; +} + +void +tsk_bit_array_free(tsk_bit_array_t *self) +{ + tsk_safe_free(self->data); +} diff --git a/subprojects/tskit/tskit/core.h b/subprojects/tskit/tskit/core.h index b8b9f354b..7dd24eba5 100644 --- a/subprojects/tskit/tskit/core.h +++ b/subprojects/tskit/tskit/core.h @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2023 Tskit Developers + * Copyright (c) 2019-2025 Tskit Developers * Copyright (c) 2015-2018 University of Oxford * * Permission is hereby granted, free of charge, to any person obtaining a copy @@ -147,12 +147,12 @@ sizes and types of externally visible structs. The library minor version. Incremented when non-breaking backward-compatible changes to the API or ABI are introduced, i.e., the addition of a new function. */ -#define TSK_VERSION_MINOR 1 +#define TSK_VERSION_MINOR 2 /** The library patch version. Incremented when any changes not relevant to the to the API or ABI are introduced, i.e., internal refactors of bugfixes. */ -#define TSK_VERSION_PATCH 2 +#define TSK_VERSION_PATCH 0 /** @} */ /* @@ -283,7 +283,7 @@ A file could not be read because it is in the wrong format /** The file is in tskit format, but the version is too old for the library to read. The file should be upgraded to the latest version -using the ``tskit upgrade`` command line utility. +using the ``tskit upgrade`` command line utility from tskit version<0.6.2. */ #define TSK_ERR_FILE_VERSION_TOO_OLD -101 /** @@ -370,6 +370,11 @@ One of the rows in the retained table refers to a row that has been deleted. */ #define TSK_ERR_KEEP_ROWS_MAP_TO_DELETED -212 +/** +A genomic position was less than zero or greater equal to the sequence +length +*/ +#define TSK_ERR_POSITION_OUT_OF_BOUNDS -213 /** @} */ @@ -500,6 +505,17 @@ the edge on which it occurs, and wasn't TSK_UNKNOWN_TIME. A single site had a mixture of known mutation times and TSK_UNKNOWN_TIME */ #define TSK_ERR_MUTATION_TIME_HAS_BOTH_KNOWN_AND_UNKNOWN -509 +/** +Some mutations have TSK_UNKNOWN_TIME in an algorithm where that's +disallowed (use compute_mutation_times?). +*/ +#define TSK_ERR_DISALLOWED_UNKNOWN_MUTATION_TIME -510 + +/** +A mutation's parent was not consistent with the topology of the tree. + */ +#define TSK_ERR_BAD_MUTATION_PARENT -511 + /** @} */ /** @@ -675,6 +691,72 @@ Statistics based on branch lengths were attempted when the ``time_units`` were ``uncalibrated``. */ #define TSK_ERR_TIME_UNCALIBRATED -910 +/** +The TSK_STAT_POLARISED option was passed to a statistic that does +not support it. +*/ +#define TSK_ERR_STAT_POLARISED_UNSUPPORTED -911 +/** +The TSK_STAT_SPAN_NORMALISE option was passed to a statistic that does +not support it. +*/ +#define TSK_ERR_STAT_SPAN_NORMALISE_UNSUPPORTED -912 +/** +Insufficient weights were provided. +*/ +#define TSK_ERR_INSUFFICIENT_WEIGHTS -913 +/** +The node bin map contains a value less than TSK_NULL. +*/ +#define TSK_ERR_BAD_NODE_BIN_MAP -914 +/** +Maximum index in node bin map is greater than output dimension. +*/ +#define TSK_ERR_BAD_NODE_BIN_MAP_DIM -915 +/** +The vector of quantiles is out of bounds or in nonascending order. +*/ +#define TSK_ERR_BAD_QUANTILES -916 +/** +Times are not in ascending order +*/ +#define TSK_ERR_UNSORTED_TIMES -917 +/* +The provided positions are not provided in strictly increasing order +*/ +#define TSK_ERR_STAT_UNSORTED_POSITIONS -918 +/** +The provided positions are not unique +*/ +#define TSK_ERR_STAT_DUPLICATE_POSITIONS -919 +/** +The provided sites are not provided in strictly increasing position order +*/ +#define TSK_ERR_STAT_UNSORTED_SITES -920 +/** +The provided sites are not unique +*/ +#define TSK_ERR_STAT_DUPLICATE_SITES -921 +/** +The number of time windows is zero +*/ +#define TSK_ERR_BAD_TIME_WINDOWS_DIM -922 +/** +Sample times do not all equal the start of first time window +*/ +#define TSK_ERR_BAD_SAMPLE_PAIR_TIMES -923 +/** +Time windows are not strictly increasing +*/ +#define TSK_ERR_BAD_TIME_WINDOWS -924 +/** +Time windows do not end at infinity +*/ +#define TSK_ERR_BAD_TIME_WINDOWS_END -925 +/** +Node time does not fall within assigned time window +*/ +#define TSK_ERR_BAD_NODE_TIME_WINDOW -926 /** @} */ /** @@ -851,6 +933,16 @@ An individual had nodes from more than one time */ #define TSK_ERR_INDIVIDUAL_TIME_MISMATCH -1704 /** @} */ + +/** +@defgroup EXTEND_EDGES_ERROR_GROUP Extend edges errors. +@{ +*/ +/** +Maximum iteration number (max_iter) must be positive. +*/ +#define TSK_ERR_EXTEND_EDGES_BAD_MAXITER -1800 +/** @} */ // clang-format on /* This bit is 0 for any errors originating from kastore */ @@ -871,6 +963,21 @@ not be freed by client code. */ const char *tsk_strerror(int err); +#ifdef TSK_TRACE_ERRORS + +static inline int +_tsk_trace_error(int err, int line, const char *file) +{ + fprintf(stderr, "tskit-trace-error: %d='%s' at line %d in %s\n", err, + tsk_strerror(err), line, file); + return err; +} + +#define tsk_trace_error(err) (_tsk_trace_error(err, __LINE__, __FILE__)) +#else +#define tsk_trace_error(err) (err) +#endif + #ifndef TSK_BUG_ASSERT_MESSAGE #define TSK_BUG_ASSERT_MESSAGE \ "If you are using tskit directly please open an issue on" \ @@ -995,6 +1102,32 @@ int tsk_memcmp(const void *s1, const void *s2, tsk_size_t size); void tsk_set_debug_stream(FILE *f); FILE *tsk_get_debug_stream(void); +/* Bit Array functionality */ + +typedef uint32_t tsk_bit_array_value_t; +typedef struct { + tsk_size_t size; // Number of chunks per row + tsk_bit_array_value_t *data; // Array data +} tsk_bit_array_t; + +#define TSK_BIT_ARRAY_CHUNK 5U +#define TSK_BIT_ARRAY_NUM_BITS (1U << TSK_BIT_ARRAY_CHUNK) + +int tsk_bit_array_init(tsk_bit_array_t *self, tsk_size_t num_bits, tsk_size_t length); +void tsk_bit_array_free(tsk_bit_array_t *self); +void tsk_bit_array_get_row( + const tsk_bit_array_t *self, tsk_size_t row, tsk_bit_array_t *out); +void tsk_bit_array_intersect( + const tsk_bit_array_t *self, const tsk_bit_array_t *other, tsk_bit_array_t *out); +void tsk_bit_array_subtract(tsk_bit_array_t *self, const tsk_bit_array_t *other); +void tsk_bit_array_add(tsk_bit_array_t *self, const tsk_bit_array_t *other); +void tsk_bit_array_add_bit(tsk_bit_array_t *self, const tsk_bit_array_value_t bit); +bool tsk_bit_array_contains( + const tsk_bit_array_t *self, const tsk_bit_array_value_t bit); +tsk_size_t tsk_bit_array_count(const tsk_bit_array_t *self); +void tsk_bit_array_get_items( + const tsk_bit_array_t *self, tsk_id_t *items, tsk_size_t *n_items); + #ifdef __cplusplus } #endif diff --git a/subprojects/tskit/tskit/genotypes.c b/subprojects/tskit/tskit/genotypes.c index d4d1ecb08..c2385281b 100644 --- a/subprojects/tskit/tskit/genotypes.c +++ b/subprojects/tskit/tskit/genotypes.c @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2022 Tskit Developers + * Copyright (c) 2019-2025 Tskit Developers * Copyright (c) 2016-2018 University of Oxford * * Permission is hereby granted, free of charge, to any person obtaining a copy @@ -74,7 +74,7 @@ tsk_variant_copy_alleles(tsk_variant_t *self, const char **alleles) } self->user_alleles_mem = tsk_malloc(total_len * sizeof(char *)); if (self->user_alleles_mem == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } offset = 0; @@ -103,7 +103,7 @@ variant_init_samples_and_index_map(tsk_variant_t *self, self->alt_sample_index_map = tsk_malloc(num_nodes * sizeof(*self->alt_sample_index_map)); if (self->alt_samples == NULL || self->alt_sample_index_map == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } tsk_memcpy(self->alt_samples, samples, num_samples * sizeof(*samples)); @@ -113,16 +113,16 @@ variant_init_samples_and_index_map(tsk_variant_t *self, for (j = 0; j < num_samples; j++) { u = samples[j]; if (u < 0 || u >= (tsk_id_t) num_nodes) { - ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + ret = tsk_trace_error(TSK_ERR_NODE_OUT_OF_BOUNDS); goto out; } if (self->alt_sample_index_map[u] != TSK_NULL) { - ret = TSK_ERR_DUPLICATE_SAMPLE; + ret = tsk_trace_error(TSK_ERR_DUPLICATE_SAMPLE); goto out; } /* We can only detect missing data for samples */ if (!impute_missing && !(flags[u] & TSK_NODE_IS_SAMPLE)) { - ret = TSK_ERR_MUST_IMPUTE_NON_SAMPLES; + ret = tsk_trace_error(TSK_ERR_MUST_IMPUTE_NON_SAMPLES); goto out; } self->alt_sample_index_map[samples[j]] = (tsk_id_t) j; @@ -156,7 +156,7 @@ tsk_variant_init(tsk_variant_t *self, const tsk_treeseq_t *tree_sequence, /* Take a copy of the samples so we don't have to manage the lifecycle*/ self->samples = tsk_malloc(num_samples * sizeof(*samples)); if (self->samples == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } tsk_memcpy(self->samples, samples, num_samples * sizeof(*samples)); @@ -176,11 +176,11 @@ tsk_variant_init(tsk_variant_t *self, const tsk_treeseq_t *tree_sequence, for (max_alleles = 0; alleles[max_alleles] != NULL; max_alleles++) ; if (max_alleles > max_alleles_limit) { - ret = TSK_ERR_TOO_MANY_ALLELES; + ret = tsk_trace_error(TSK_ERR_TOO_MANY_ALLELES); goto out; } if (max_alleles == 0) { - ret = TSK_ERR_ZERO_ALLELES; + ret = tsk_trace_error(TSK_ERR_ZERO_ALLELES); goto out; } } @@ -188,7 +188,7 @@ tsk_variant_init(tsk_variant_t *self, const tsk_treeseq_t *tree_sequence, self->alleles = tsk_calloc(max_alleles, sizeof(*self->alleles)); self->allele_lengths = tsk_malloc(max_alleles * sizeof(*self->allele_lengths)); if (self->alleles == NULL || self->allele_lengths == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } if (self->user_alleles) { @@ -201,7 +201,7 @@ tsk_variant_init(tsk_variant_t *self, const tsk_treeseq_t *tree_sequence, self->num_samples = tsk_treeseq_get_num_samples(tree_sequence); self->samples = tsk_malloc(self->num_samples * sizeof(*self->samples)); if (self->samples == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } tsk_memcpy(self->samples, tsk_treeseq_get_samples(tree_sequence), @@ -224,7 +224,7 @@ tsk_variant_init(tsk_variant_t *self, const tsk_treeseq_t *tree_sequence, self->traversal_stack = tsk_malloc( tsk_treeseq_get_num_nodes(tree_sequence) * sizeof(*self->traversal_stack)); if (self->traversal_stack == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } } @@ -232,7 +232,7 @@ tsk_variant_init(tsk_variant_t *self, const tsk_treeseq_t *tree_sequence, self->genotypes = tsk_malloc(num_samples_alloc * sizeof(*self->genotypes)); if (self->genotypes == NULL || self->alleles == NULL || self->allele_lengths == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } @@ -293,20 +293,20 @@ tsk_variant_expand_alleles(tsk_variant_t *self) tsk_size_t hard_limit = INT32_MAX; if (self->max_alleles == hard_limit) { - ret = TSK_ERR_TOO_MANY_ALLELES; + ret = tsk_trace_error(TSK_ERR_TOO_MANY_ALLELES); goto out; } self->max_alleles = TSK_MIN(hard_limit, self->max_alleles * 2); p = tsk_realloc(self->alleles, self->max_alleles * sizeof(*self->alleles)); if (p == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } self->alleles = p; p = tsk_realloc( self->allele_lengths, self->max_alleles * sizeof(*self->allele_lengths)); if (p == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } self->allele_lengths = p; @@ -470,7 +470,7 @@ tsk_variant_decode( tsk_size_t (*mark_missing)(tsk_variant_t *); if (self->tree_sequence == NULL) { - ret = TSK_ERR_VARIANT_CANT_DECODE_COPY; + ret = tsk_trace_error(TSK_ERR_VARIANT_CANT_DECODE_COPY); goto out; } @@ -487,7 +487,7 @@ tsk_variant_decode( /* When we have no specified samples we need sample lists to be active * on the tree, as indicated by the presence of left_sample */ if (!by_traversal && self->tree.left_sample == NULL) { - ret = TSK_ERR_NO_SAMPLE_LISTS; + ret = tsk_trace_error(TSK_ERR_NO_SAMPLE_LISTS); goto out; } @@ -508,7 +508,7 @@ tsk_variant_decode( allele_index = tsk_variant_get_allele_index( self, self->site.ancestral_state, self->site.ancestral_state_length); if (allele_index == -1) { - ret = TSK_ERR_ALLELE_NOT_FOUND; + ret = tsk_trace_error(TSK_ERR_ALLELE_NOT_FOUND); goto out; } } else { @@ -545,7 +545,7 @@ tsk_variant_decode( self, mutation.derived_state, mutation.derived_state_length); if (allele_index == -1) { if (self->user_alleles) { - ret = TSK_ERR_ALLELE_NOT_FOUND; + ret = tsk_trace_error(TSK_ERR_ALLELE_NOT_FOUND); goto out; } if (self->num_alleles == self->max_alleles) { @@ -606,7 +606,7 @@ tsk_variant_restricted_copy(const tsk_variant_t *self, tsk_variant_t *other) if (other->samples == NULL || other->genotypes == NULL || other->user_alleles_mem == NULL || other->allele_lengths == NULL || other->alleles == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } tsk_memcpy( diff --git a/subprojects/tskit/tskit/haplotype_matching.c b/subprojects/tskit/tskit/haplotype_matching.c index b942da18d..a8acd204b 100644 --- a/subprojects/tskit/tskit/haplotype_matching.c +++ b/subprojects/tskit/tskit/haplotype_matching.c @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2023 Tskit Developers + * Copyright (c) 2019-2025 Tskit Developers * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -167,7 +167,7 @@ tsk_ls_hmm_init(tsk_ls_hmm_t *self, tsk_treeseq_t *tree_sequence, || self->transition_time_order == NULL || self->values == NULL || self->recombination_rate == NULL || self->mutation_rate == NULL || self->alleles == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } for (l = 0; l < self->num_sites; l++) { @@ -209,7 +209,6 @@ int tsk_ls_hmm_free(tsk_ls_hmm_t *self) { tsk_tree_free(&self->tree); - tsk_diff_iter_free(&self->diffs); tsk_safe_free(self->recombination_rate); tsk_safe_free(self->mutation_rate); tsk_safe_free(self->recombination_rate); @@ -230,10 +229,9 @@ tsk_ls_hmm_free(tsk_ls_hmm_t *self) } static int -tsk_ls_hmm_reset(tsk_ls_hmm_t *self) +tsk_ls_hmm_reset(tsk_ls_hmm_t *self, double value) { int ret = 0; - double n = (double) self->num_samples; tsk_size_t j; tsk_id_t u; const tsk_id_t *samples; @@ -248,21 +246,14 @@ tsk_ls_hmm_reset(tsk_ls_hmm_t *self) tsk_memset(self->transition_parent, 0xff, self->max_transitions * sizeof(*self->transition_parent)); - /* This is safe because we've already zero'd out the memory. */ - tsk_diff_iter_free(&self->diffs); - ret = tsk_diff_iter_init_from_ts(&self->diffs, self->tree_sequence, false); - if (ret != 0) { - goto out; - } samples = tsk_treeseq_get_samples(self->tree_sequence); for (j = 0; j < self->num_samples; j++) { u = samples[j]; self->transitions[j].tree_node = u; - self->transitions[j].value = 1.0 / n; + self->transitions[j].value = value; self->transition_index[u] = (tsk_id_t) j; } self->num_transitions = self->num_samples; -out: return ret; } @@ -301,26 +292,23 @@ tsk_ls_hmm_remove_dead_roots(tsk_ls_hmm_t *self) } static int -tsk_ls_hmm_update_tree(tsk_ls_hmm_t *self) +tsk_ls_hmm_update_tree(tsk_ls_hmm_t *self, int direction) { int ret = 0; tsk_id_t *restrict parent = self->parent; tsk_id_t *restrict T_index = self->transition_index; + const tsk_id_t *restrict edges_child = self->tree_sequence->tables->edges.child; + const tsk_id_t *restrict edges_parent = self->tree_sequence->tables->edges.parent; tsk_value_transition_t *restrict T = self->transitions; - tsk_edge_list_node_t *record; - tsk_edge_list_t records_out, records_in; - tsk_edge_t edge; - double left, right; - tsk_id_t u; + tsk_id_t u, c, p, j, e; tsk_value_transition_t *vt; + tsk_tree_position_t tree_pos; - ret = tsk_diff_iter_next(&self->diffs, &left, &right, &records_out, &records_in); - if (ret < 0) { - goto out; - } - - for (record = records_out.head; record != NULL; record = record->next) { - u = record->edge.child; + tree_pos = self->tree.tree_pos; + for (j = tree_pos.out.start; j != tree_pos.out.stop; j += direction) { + e = tree_pos.out.order[j]; + c = edges_child[e]; + u = c; if (T_index[u] == TSK_NULL) { /* Ensure the subtree we're detaching has a transition at the root */ while (T_index[u] == TSK_NULL) { @@ -328,25 +316,27 @@ tsk_ls_hmm_update_tree(tsk_ls_hmm_t *self) tsk_bug_assert(u != TSK_NULL); } tsk_bug_assert(self->num_transitions < self->max_transitions); - T_index[record->edge.child] = (tsk_id_t) self->num_transitions; - T[self->num_transitions].tree_node = record->edge.child; + T_index[c] = (tsk_id_t) self->num_transitions; + T[self->num_transitions].tree_node = c; T[self->num_transitions].value = T[T_index[u]].value; self->num_transitions++; } - parent[record->edge.child] = TSK_NULL; + parent[c] = TSK_NULL; } - for (record = records_in.head; record != NULL; record = record->next) { - edge = record->edge; - parent[edge.child] = edge.parent; - u = edge.parent; - if (parent[edge.parent] == TSK_NULL) { + for (j = tree_pos.in.start; j != tree_pos.in.stop; j += direction) { + e = tree_pos.in.order[j]; + c = edges_child[e]; + p = edges_parent[e]; + parent[c] = p; + u = p; + if (parent[p] == TSK_NULL) { /* Grafting onto a new root. */ - if (T_index[record->edge.parent] == TSK_NULL) { - T_index[edge.parent] = (tsk_id_t) self->num_transitions; + if (T_index[p] == TSK_NULL) { + T_index[p] = (tsk_id_t) self->num_transitions; tsk_bug_assert(self->num_transitions < self->max_transitions); - T[self->num_transitions].tree_node = edge.parent; - T[self->num_transitions].value = T[T_index[edge.child]].value; + T[self->num_transitions].tree_node = p; + T[self->num_transitions].value = T[T_index[c]].value; self->num_transitions++; } } else { @@ -356,18 +346,17 @@ tsk_ls_hmm_update_tree(tsk_ls_hmm_t *self) } tsk_bug_assert(u != TSK_NULL); } - tsk_bug_assert(T_index[u] != -1 && T_index[edge.child] != -1); - if (T[T_index[u]].value == T[T_index[edge.child]].value) { - vt = &T[T_index[edge.child]]; + tsk_bug_assert(T_index[u] != -1 && T_index[c] != -1); + if (T[T_index[u]].value == T[T_index[c]].value) { + vt = &T[T_index[c]]; /* Mark the value transition as unusued */ vt->value = -1; vt->tree_node = TSK_NULL; - T_index[edge.child] = TSK_NULL; + T_index[c] = TSK_NULL; } } ret = tsk_ls_hmm_remove_dead_roots(self); -out: return ret; } @@ -375,6 +364,8 @@ static int tsk_ls_hmm_get_allele_index(tsk_ls_hmm_t *self, tsk_id_t site, const char *allele_state, const tsk_size_t allele_length) { + /* Note we're not doing tsk_trace_error here because it would require changing + * the logic of the function. Could be done easily enough, though */ int ret = TSK_ERR_ALLELE_NOT_FOUND; const char **alleles = self->alleles[site]; const tsk_id_t num_alleles = (tsk_id_t) self->num_alleles[site]; @@ -721,7 +712,7 @@ tsk_ls_hmm_setup_optimal_value_sets(tsk_ls_hmm_t *self) * worth the bother. */ self->num_optimal_value_set_words = (self->num_values / 64) + 1; if (self->num_optimal_value_set_words > self->max_parsimony_words) { - ret = TSK_ERR_TOO_MANY_VALUES; + ret = tsk_trace_error(TSK_ERR_TOO_MANY_VALUES); goto out; } if (self->num_values >= self->max_values) { @@ -731,7 +722,7 @@ tsk_ls_hmm_setup_optimal_value_sets(tsk_ls_hmm_t *self) = tsk_calloc(self->num_nodes * self->num_optimal_value_set_words, sizeof(*self->optimal_value_sets)); if (self->optimal_value_sets == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } } @@ -921,7 +912,7 @@ tsk_ls_hmm_compress(tsk_ls_hmm_t *self) } static int -tsk_ls_hmm_process_site( +tsk_ls_hmm_process_site_forward( tsk_ls_hmm_t *self, const tsk_site_t *site, int32_t haplotype_state) { int ret = 0; @@ -945,7 +936,7 @@ tsk_ls_hmm_process_site( normalisation_factor = self->compute_normalisation_factor(self); if (normalisation_factor == 0) { - ret = TSK_ERR_MATCH_IMPOSSIBLE; + ret = tsk_trace_error(TSK_ERR_MATCH_IMPOSSIBLE); goto out; } for (j = 0; j < self->num_transitions; j++) { @@ -960,28 +951,23 @@ tsk_ls_hmm_process_site( return ret; } -int -tsk_ls_hmm_run(tsk_ls_hmm_t *self, int32_t *haplotype, - int (*next_probability)(tsk_ls_hmm_t *, tsk_id_t, double, bool, tsk_id_t, double *), - double (*compute_normalisation_factor)(struct _tsk_ls_hmm_t *), void *output) +static int +tsk_ls_hmm_run_forward(tsk_ls_hmm_t *self, int32_t *haplotype) { int ret = 0; int t_ret; const tsk_site_t *sites; tsk_size_t j, num_sites; + const double n = (double) self->num_samples; - self->next_probability = next_probability; - self->compute_normalisation_factor = compute_normalisation_factor; - self->output = output; - - ret = tsk_ls_hmm_reset(self); + ret = tsk_ls_hmm_reset(self, 1 / n); if (ret != 0) { goto out; } for (t_ret = tsk_tree_first(&self->tree); t_ret == TSK_TREE_OK; t_ret = tsk_tree_next(&self->tree)) { - ret = tsk_ls_hmm_update_tree(self); + ret = tsk_ls_hmm_update_tree(self, TSK_DIR_FORWARD); if (ret != 0) { goto out; } @@ -991,7 +977,8 @@ tsk_ls_hmm_run(tsk_ls_hmm_t *self, int32_t *haplotype, goto out; } for (j = 0; j < num_sites; j++) { - ret = tsk_ls_hmm_process_site(self, &sites[j], haplotype[sites[j].id]); + ret = tsk_ls_hmm_process_site_forward( + self, &sites[j], haplotype[sites[j].id]); if (ret != 0) { goto out; } @@ -1073,7 +1060,7 @@ tsk_ls_hmm_forward(tsk_ls_hmm_t *self, int32_t *haplotype, } } else { if (output->tree_sequence != self->tree_sequence) { - ret = TSK_ERR_BAD_PARAM_VALUE; + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); goto out; } ret = tsk_compressed_matrix_clear(output); @@ -1081,11 +1068,170 @@ tsk_ls_hmm_forward(tsk_ls_hmm_t *self, int32_t *haplotype, goto out; } } - ret = tsk_ls_hmm_run(self, haplotype, tsk_ls_hmm_next_probability_forward, - tsk_ls_hmm_compute_normalisation_factor_forward, output); + + self->next_probability = tsk_ls_hmm_next_probability_forward; + self->compute_normalisation_factor = tsk_ls_hmm_compute_normalisation_factor_forward; + self->output = output; + + ret = tsk_ls_hmm_run_forward(self, haplotype); +out: + return ret; +} + +/**************************************************************** + * Backward Algorithm + ****************************************************************/ + +static int +tsk_ls_hmm_next_probability_backward(tsk_ls_hmm_t *self, tsk_id_t site_id, double p_last, + bool is_match, tsk_id_t TSK_UNUSED(node), double *result) +{ + const double mu = self->mutation_rate[site_id]; + const double num_alleles = self->num_alleles[site_id]; + double p_e; + + p_e = mu; + if (is_match) { + p_e = 1 - (num_alleles - 1) * mu; + } + *result = p_last * p_e; + return 0; +} + +static int +tsk_ls_hmm_process_site_backward(tsk_ls_hmm_t *self, const tsk_site_t *site, + const int32_t haplotype_state, const double normalisation_factor) +{ + int ret = 0; + double x, b_last_sum; + tsk_compressed_matrix_t *output = (tsk_compressed_matrix_t *) self->output; + tsk_value_transition_t *restrict T = self->transitions; + const unsigned int precision = (unsigned int) self->precision; + const double rho = self->recombination_rate[site->id]; + const double n = (double) self->num_samples; + tsk_size_t j; + + /* FIXME!!! We are calling compress twice here because we need to compress + * immediately before calling store_site in order to filter out -1 nodes, + * and also (crucially) to ensure that the value transitions are listed + * in preorder, which we rely on later for decoding. + * + * https://github.com/tskit-dev/tskit/issues/2803 + */ + ret = tsk_ls_hmm_compress(self); + if (ret != 0) { + goto out; + } + ret = tsk_compressed_matrix_store_site( + output, site->id, normalisation_factor, (tsk_size_t) self->num_transitions, T); if (ret != 0) { goto out; } + + ret = tsk_ls_hmm_update_probabilities(self, site, haplotype_state); + if (ret != 0) { + goto out; + } + /* DO WE NEED THIS compress?? See above */ + ret = tsk_ls_hmm_compress(self); + if (ret != 0) { + goto out; + } + tsk_bug_assert(self->num_transitions <= self->num_samples); + b_last_sum = self->compute_normalisation_factor(self); + for (j = 0; j < self->num_transitions; j++) { + tsk_bug_assert(T[j].tree_node != TSK_NULL); + x = rho * b_last_sum / n + (1 - rho) * T[j].value; + x /= normalisation_factor; + T[j].value = tsk_round(x, precision); + } +out: + return ret; +} + +static int +tsk_ls_hmm_run_backward( + tsk_ls_hmm_t *self, int32_t *haplotype, const double *forward_norm) +{ + int ret = 0; + int t_ret; + const tsk_site_t *sites; + double s; + tsk_size_t num_sites; + tsk_id_t j; + + ret = tsk_ls_hmm_reset(self, 1); + if (ret != 0) { + goto out; + } + + for (t_ret = tsk_tree_last(&self->tree); t_ret == TSK_TREE_OK; + t_ret = tsk_tree_prev(&self->tree)) { + ret = tsk_ls_hmm_update_tree(self, TSK_DIR_REVERSE); + if (ret != 0) { + goto out; + } + /* tsk_ls_hmm_check_state(self); */ + ret = tsk_tree_get_sites(&self->tree, &sites, &num_sites); + if (ret != 0) { + goto out; + } + for (j = (tsk_id_t) num_sites - 1; j >= 0; j--) { + s = forward_norm[sites[j].id]; + if (s <= 0) { + /* NOTE: I'm not sure if this is the correct interpretation, + * but norm values of 0 do lead to problems, and this seems + * like a simple way of guarding against it. We do seem to + * get norm values of 0 with impossible matches from the fwd + * matrix. + */ + ret = tsk_trace_error(TSK_ERR_MATCH_IMPOSSIBLE); + goto out; + } + ret = tsk_ls_hmm_process_site_backward( + self, &sites[j], haplotype[sites[j].id], s); + if (ret != 0) { + goto out; + } + } + } + /* Set to zero so we can print and check the state OK. */ + self->num_transitions = 0; + if (t_ret != 0) { + ret = t_ret; + goto out; + } +out: + return ret; +} + +int +tsk_ls_hmm_backward(tsk_ls_hmm_t *self, int32_t *haplotype, const double *forward_norm, + tsk_compressed_matrix_t *output, tsk_flags_t options) +{ + int ret = 0; + + if (!(options & TSK_NO_INIT)) { + ret = tsk_compressed_matrix_init(output, self->tree_sequence, 0, 0); + if (ret != 0) { + goto out; + } + } else { + if (output->tree_sequence != self->tree_sequence) { + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); + goto out; + } + ret = tsk_compressed_matrix_clear(output); + if (ret != 0) { + goto out; + } + } + + self->next_probability = tsk_ls_hmm_next_probability_backward; + self->compute_normalisation_factor = tsk_ls_hmm_compute_normalisation_factor_forward; + self->output = output; + + ret = tsk_ls_hmm_run_backward(self, haplotype, forward_norm); out: return ret; } @@ -1155,7 +1301,7 @@ tsk_ls_hmm_viterbi(tsk_ls_hmm_t *self, int32_t *haplotype, tsk_viterbi_matrix_t } } else { if (output->matrix.tree_sequence != self->tree_sequence) { - ret = TSK_ERR_BAD_PARAM_VALUE; + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); goto out; } ret = tsk_viterbi_matrix_clear(output); @@ -1163,11 +1309,12 @@ tsk_ls_hmm_viterbi(tsk_ls_hmm_t *self, int32_t *haplotype, tsk_viterbi_matrix_t goto out; } } - ret = tsk_ls_hmm_run(self, haplotype, tsk_ls_hmm_next_probability_viterbi, - tsk_ls_hmm_compute_normalisation_factor_viterbi, output); - if (ret != 0) { - goto out; - } + + self->next_probability = tsk_ls_hmm_next_probability_viterbi; + self->compute_normalisation_factor = tsk_ls_hmm_compute_normalisation_factor_viterbi; + self->output = output; + + ret = tsk_ls_hmm_run_forward(self, haplotype); out: return ret; } @@ -1193,7 +1340,7 @@ tsk_compressed_matrix_init(tsk_compressed_matrix_t *self, tsk_treeseq_t *tree_se self->values = tsk_malloc(self->num_sites * sizeof(*self->values)); self->nodes = tsk_malloc(self->num_sites * sizeof(*self->nodes)); if (self->num_transitions == NULL || self->values == NULL || self->nodes == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } if (block_size == 0) { @@ -1264,7 +1411,7 @@ tsk_compressed_matrix_store_site(tsk_compressed_matrix_t *self, tsk_id_t site, tsk_size_t j; if (site < 0 || site >= (tsk_id_t) self->num_sites) { - ret = TSK_ERR_SITE_OUT_OF_BOUNDS; + ret = tsk_trace_error(TSK_ERR_SITE_OUT_OF_BOUNDS); goto out; } @@ -1275,14 +1422,16 @@ tsk_compressed_matrix_store_site(tsk_compressed_matrix_t *self, tsk_id_t site, self->values[site] = tsk_blkalloc_get(&self->memory, (size_t) num_transitions * sizeof(double)); if (self->nodes[site] == NULL || self->values[site] == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } for (j = 0; j < num_transitions; j++) { + tsk_bug_assert(transitions[j].tree_node >= 0); self->values[site][j] = transitions[j].value; self->nodes[site][j] = transitions[j].tree_node; } + out: return ret; } @@ -1303,14 +1452,14 @@ tsk_compressed_matrix_decode_site(tsk_compressed_matrix_t *self, const tsk_tree_ for (j = 0; j < self->num_transitions[site]; j++) { node = self->nodes[site][j]; if (node < 0 || node >= num_nodes) { - ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + ret = tsk_trace_error(TSK_ERR_NODE_OUT_OF_BOUNDS); goto out; } value = self->values[site][j]; index = list_left[node]; if (index == TSK_NULL) { /* It's an error if there are nodes that don't subtend any samples */ - ret = TSK_ERR_BAD_COMPRESSED_MATRIX_NODE; + ret = tsk_trace_error(TSK_ERR_BAD_COMPRESSED_MATRIX_NODE); goto out; } stop = list_right[node]; @@ -1383,7 +1532,7 @@ tsk_viterbi_matrix_expand_recomb_records(tsk_viterbi_matrix_t *self) self->recombination_required, self->max_recomb_records * sizeof(*tmp)); if (tmp == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } self->recombination_required = tmp; @@ -1498,7 +1647,7 @@ tsk_viterbi_matrix_choose_sample( bool found; if (num_transitions == 0) { - ret = TSK_ERR_NULL_VITERBI_MATRIX; + ret = tsk_trace_error(TSK_ERR_NULL_VITERBI_MATRIX); goto out; } for (j = 0; j < num_transitions; j++) { @@ -1552,7 +1701,7 @@ tsk_viterbi_matrix_traceback( goto out; } if (recombination_tree == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } /* Initialise the path an recombination_tree to contain TSK_NULL */ diff --git a/subprojects/tskit/tskit/haplotype_matching.h b/subprojects/tskit/tskit/haplotype_matching.h index 46631fb08..151eb321b 100644 --- a/subprojects/tskit/tskit/haplotype_matching.h +++ b/subprojects/tskit/tskit/haplotype_matching.h @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2022 Tskit Developers + * Copyright (c) 2019-2024 Tskit Developers * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -98,7 +98,6 @@ typedef struct _tsk_ls_hmm_t { tsk_size_t num_nodes; /* state */ tsk_tree_t tree; - tsk_diff_iter_t diffs; tsk_id_t *parent; /* The probability value transitions on the tree */ tsk_value_transition_t *transitions; @@ -131,6 +130,7 @@ typedef struct _tsk_ls_hmm_t { void *output; } tsk_ls_hmm_t; +/* TODO constify these APIs */ int tsk_ls_hmm_init(tsk_ls_hmm_t *self, tsk_treeseq_t *tree_sequence, double *recombination_rate, double *mutation_rate, tsk_flags_t options); int tsk_ls_hmm_set_precision(tsk_ls_hmm_t *self, unsigned int precision); @@ -138,11 +138,10 @@ int tsk_ls_hmm_free(tsk_ls_hmm_t *self); void tsk_ls_hmm_print_state(tsk_ls_hmm_t *self, FILE *out); int tsk_ls_hmm_forward(tsk_ls_hmm_t *self, int32_t *haplotype, tsk_compressed_matrix_t *output, tsk_flags_t options); +int tsk_ls_hmm_backward(tsk_ls_hmm_t *self, int32_t *haplotype, + const double *forward_norm, tsk_compressed_matrix_t *output, tsk_flags_t options); int tsk_ls_hmm_viterbi(tsk_ls_hmm_t *self, int32_t *haplotype, tsk_viterbi_matrix_t *output, tsk_flags_t options); -int tsk_ls_hmm_run(tsk_ls_hmm_t *self, int32_t *haplotype, - int (*next_probability)(tsk_ls_hmm_t *, tsk_id_t, double, bool, tsk_id_t, double *), - double (*compute_normalisation_factor)(struct _tsk_ls_hmm_t *), void *output); int tsk_compressed_matrix_init(tsk_compressed_matrix_t *self, tsk_treeseq_t *tree_sequence, tsk_size_t block_size, tsk_flags_t options); diff --git a/subprojects/tskit/tskit/stats.c b/subprojects/tskit/tskit/stats.c index 1c1aeea68..e89ce0382 100644 --- a/subprojects/tskit/tskit/stats.c +++ b/subprojects/tskit/tskit/stats.c @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2018-2022 Tskit Developers + * Copyright (c) 2018-2025 Tskit Developers * Copyright (c) 2016-2017 University of Oxford * * Permission is hereby granted, free of charge, to any person obtaining a copy @@ -53,6 +53,7 @@ tsk_ld_calc_init(tsk_ld_calc_t *self, const tsk_treeseq_t *tree_sequence) self->sample_buffer = tsk_malloc(self->total_samples * sizeof(*self->sample_buffer)); if (self->sample_buffer == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } out: @@ -75,14 +76,14 @@ tsk_ld_calc_check_site(tsk_ld_calc_t *TSK_UNUSED(self), const tsk_site_t *site) /* These are both limitations in the current implementation, there's no * fundamental reason why we can't support them */ if (site->mutations_length != 1) { - ret = TSK_ERR_ONLY_INFINITE_SITES; + ret = tsk_trace_error(TSK_ERR_ONLY_INFINITE_SITES); goto out; } if (site->ancestral_state_length == site->mutations[0].derived_state_length && tsk_memcmp(site->ancestral_state, site->mutations[0].derived_state, site->ancestral_state_length) == 0) { - ret = TSK_ERR_SILENT_MUTATIONS_NOT_SUPPORTED; + ret = tsk_trace_error(TSK_ERR_SILENT_MUTATIONS_NOT_SUPPORTED); goto out; } out: @@ -294,7 +295,7 @@ tsk_ld_calc_get_r2_array(tsk_ld_calc_t *self, tsk_id_t a, int direction, } else if (direction == TSK_DIR_REVERSE) { ret = tsk_ld_calc_run_reverse(self); } else { - ret = TSK_ERR_BAD_PARAM_VALUE; + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); } if (ret != 0) { goto out; diff --git a/subprojects/tskit/tskit/tables.c b/subprojects/tskit/tskit/tables.c index 8eea85f5a..9805d669a 100644 --- a/subprojects/tskit/tskit/tables.c +++ b/subprojects/tskit/tskit/tables.c @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2023 Tskit Developers + * Copyright (c) 2019-2025 Tskit Developers * Copyright (c) 2017-2018 University of Oxford * * Permission is hereby granted, free of charge, to any person obtaining a copy @@ -128,16 +128,16 @@ read_table_cols(kastore_t *store, tsk_size_t *num_rows, read_table_col_t *cols, *num_rows = (tsk_size_t) len; } else { if (*num_rows != (tsk_size_t) len) { - ret = TSK_ERR_FILE_FORMAT; + ret = tsk_trace_error(TSK_ERR_FILE_FORMAT); goto out; } } if (type != col->type) { - ret = TSK_ERR_BAD_COLUMN_TYPE; + ret = tsk_trace_error(TSK_ERR_BAD_COLUMN_TYPE); goto out; } } else if (!(col->options & TSK_COL_OPTIONAL)) { - ret = TSK_ERR_REQUIRED_COL_NOT_FOUND; + ret = tsk_trace_error(TSK_ERR_REQUIRED_COL_NOT_FOUND); goto out; } } @@ -154,7 +154,7 @@ cast_offset_array(read_table_ragged_col_t *col, uint32_t *source, tsk_size_t num uint64_t *dest = tsk_malloc(len * sizeof(*dest)); if (dest == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } *col->offset_array_dest = dest; @@ -193,13 +193,13 @@ read_table_ragged_cols(kastore_t *store, tsk_size_t *num_rows, goto out; } if (type != col->data_type) { - ret = TSK_ERR_BAD_COLUMN_TYPE; + ret = tsk_trace_error(TSK_ERR_BAD_COLUMN_TYPE); goto out; } *col->data_len_dest = (tsk_size_t) data_len; data_col_present = true; } else if (!(col->options & TSK_COL_OPTIONAL)) { - ret = TSK_ERR_REQUIRED_COL_NOT_FOUND; + ret = tsk_trace_error(TSK_ERR_REQUIRED_COL_NOT_FOUND); goto out; } @@ -214,7 +214,7 @@ read_table_ragged_cols(kastore_t *store, tsk_size_t *num_rows, } offset_col_present = ret == 1; if (offset_col_present != data_col_present) { - ret = TSK_ERR_BOTH_COLUMNS_REQUIRED; + ret = tsk_trace_error(TSK_ERR_BOTH_COLUMNS_REQUIRED); goto out; } if (offset_col_present) { @@ -227,7 +227,7 @@ read_table_ragged_cols(kastore_t *store, tsk_size_t *num_rows, /* A table with zero rows will still have an offset length of 1; * catching this here prevents underflows in the logic below */ if (offset_len == 0) { - ret = TSK_ERR_FILE_FORMAT; + ret = tsk_trace_error(TSK_ERR_FILE_FORMAT); goto out; } /* Some tables have only ragged columns */ @@ -235,7 +235,7 @@ read_table_ragged_cols(kastore_t *store, tsk_size_t *num_rows, *num_rows = (tsk_size_t) offset_len - 1; } else { if (*num_rows != (tsk_size_t) offset_len - 1) { - ret = TSK_ERR_FILE_FORMAT; + ret = tsk_trace_error(TSK_ERR_FILE_FORMAT); goto out; } } @@ -250,12 +250,12 @@ read_table_ragged_cols(kastore_t *store, tsk_size_t *num_rows, tsk_safe_free(store_offset_array); store_offset_array = NULL; } else { - ret = TSK_ERR_BAD_COLUMN_TYPE; + ret = tsk_trace_error(TSK_ERR_BAD_COLUMN_TYPE); goto out; } offset_array = *col->offset_array_dest; if (offset_array[*num_rows] != (tsk_size_t) data_len) { - ret = TSK_ERR_BAD_OFFSET; + ret = tsk_trace_error(TSK_ERR_BAD_OFFSET); goto out; } } @@ -288,7 +288,7 @@ read_table_properties( goto out; } if (type != property->type) { - ret = TSK_ERR_BAD_COLUMN_TYPE; + ret = tsk_trace_error(TSK_ERR_BAD_COLUMN_TYPE); goto out; } *property->len_dest = (tsk_size_t) len; @@ -320,7 +320,7 @@ read_table(kastore_t *store, tsk_size_t *num_rows, read_table_col_t *cols, } } if (*num_rows == TSK_NUM_ROWS_UNSET) { - ret = TSK_ERR_FILE_FORMAT; + ret = tsk_trace_error(TSK_ERR_FILE_FORMAT); goto out; } if (properties != NULL) { @@ -384,7 +384,7 @@ write_offset_col( } else { offset32 = tsk_malloc(len * sizeof(*offset32)); if (offset32 == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } for (j = 0; j < len; j++) { @@ -468,17 +468,20 @@ static int check_offsets( tsk_size_t num_rows, const tsk_size_t *offsets, tsk_size_t length, bool check_length) { - int ret = TSK_ERR_BAD_OFFSET; + int ret = 0; tsk_size_t j; if (offsets[0] != 0) { + ret = tsk_trace_error(TSK_ERR_BAD_OFFSET); goto out; } if (check_length && offsets[num_rows] != length) { + ret = tsk_trace_error(TSK_ERR_BAD_OFFSET); goto out; } for (j = 0; j < num_rows; j++) { if (offsets[j] > offsets[j + 1]) { + ret = tsk_trace_error(TSK_ERR_BAD_OFFSET); goto out; } } @@ -496,7 +499,7 @@ calculate_max_rows(tsk_size_t num_rows, tsk_size_t max_rows, int ret = 0; if (check_table_overflow(num_rows, additional_rows)) { - ret = TSK_ERR_TABLE_OVERFLOW; + ret = tsk_trace_error(TSK_ERR_TABLE_OVERFLOW); goto out; } @@ -517,7 +520,7 @@ calculate_max_rows(tsk_size_t num_rows, tsk_size_t max_rows, } else { /* Use user increment value */ if (check_table_overflow(max_rows, max_rows_increment)) { - ret = TSK_ERR_TABLE_OVERFLOW; + ret = tsk_trace_error(TSK_ERR_TABLE_OVERFLOW); goto out; } new_max_rows = max_rows + max_rows_increment; @@ -538,7 +541,7 @@ calculate_max_length(tsk_size_t current_length, tsk_size_t max_length, int ret = 0; if (check_offset_overflow(current_length, additional_length)) { - ret = TSK_ERR_COLUMN_OVERFLOW; + ret = tsk_trace_error(TSK_ERR_COLUMN_OVERFLOW); goto out; } @@ -564,7 +567,7 @@ calculate_max_length(tsk_size_t current_length, tsk_size_t max_length, * Instead we are erroring out as this is much easier to test. * The cost is that (at most) the last "max_length_increment"-1 * bytes of the possible array space can't be used. */ - ret = TSK_ERR_COLUMN_OVERFLOW; + ret = tsk_trace_error(TSK_ERR_COLUMN_OVERFLOW); goto out; } new_max_length = max_length + max_length_increment; @@ -584,7 +587,7 @@ expand_column(void **column, tsk_size_t new_max_rows, size_t element_size) tmp = tsk_realloc((void **) *column, new_max_rows * element_size); if (tmp == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } *column = tmp; @@ -629,7 +632,7 @@ replace_string( if (new_len > 0) { *str = tsk_malloc(new_len * sizeof(char)); if (*str == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } tsk_memcpy(*str, new_str, new_len * sizeof(char)); @@ -655,7 +658,7 @@ alloc_empty_ragged_column(tsk_size_t num_rows, void **data_col, tsk_size_t **off *data_col = tsk_malloc(1); *offset_col = tsk_calloc(num_rows + 1, sizeof(tsk_size_t)); if (*data_col == NULL || *offset_col == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } out: @@ -667,7 +670,7 @@ check_ragged_column(tsk_size_t num_rows, void *data, tsk_size_t *offset) { int ret = 0; if ((data == NULL) != (offset == NULL)) { - ret = TSK_ERR_BAD_PARAM_VALUE; + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); goto out; } if (data != NULL) { @@ -710,7 +713,7 @@ takeset_optional_id_column(tsk_size_t num_rows, tsk_id_t *input, tsk_id_t **dest buffsize = num_rows * sizeof(*buff); buff = tsk_malloc(buffsize); if (buff == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } *dest = buff; @@ -1284,7 +1287,7 @@ tsk_individual_table_takeset_columns(tsk_individual_table_t *self, tsk_size_t nu * unused so this is a worthwhile optimisation. */ self->flags = tsk_calloc(num_rows, sizeof(*self->flags)); if (self->flags == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } } else { @@ -1320,19 +1323,19 @@ tsk_individual_table_append_columns(tsk_individual_table_t *self, tsk_size_t num tsk_size_t j, metadata_length, location_length, parents_length; if (flags == NULL) { - ret = TSK_ERR_BAD_PARAM_VALUE; + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); goto out; } if ((location == NULL) != (location_offset == NULL)) { - ret = TSK_ERR_BAD_PARAM_VALUE; + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); goto out; } if ((parents == NULL) != (parents_offset == NULL)) { - ret = TSK_ERR_BAD_PARAM_VALUE; + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); goto out; } if ((metadata == NULL) != (metadata_offset == NULL)) { - ret = TSK_ERR_BAD_PARAM_VALUE; + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); goto out; } ret = tsk_individual_table_expand_main_columns(self, (tsk_size_t) num_rows); @@ -1490,7 +1493,7 @@ tsk_individual_table_update_row_rewrite(tsk_individual_table_t *self, tsk_id_t i } rows = tsk_malloc(self->num_rows * sizeof(*rows)); if (rows == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } @@ -1565,7 +1568,7 @@ tsk_individual_table_truncate(tsk_individual_table_t *self, tsk_size_t num_rows) int ret = 0; if (num_rows > self->num_rows) { - ret = TSK_ERR_BAD_TABLE_POSITION; + ret = tsk_trace_error(TSK_ERR_BAD_TABLE_POSITION); goto out; } self->num_rows = num_rows; @@ -1587,7 +1590,7 @@ tsk_individual_table_extend(tsk_individual_table_t *self, tsk_individual_t individual; if (self == other) { - ret = TSK_ERR_CANNOT_EXTEND_FROM_SELF; + ret = tsk_trace_error(TSK_ERR_CANNOT_EXTEND_FROM_SELF); goto out; } @@ -1689,7 +1692,7 @@ tsk_individual_table_get_row( int ret = 0; if (index < 0 || index >= (tsk_id_t) self->num_rows) { - ret = TSK_ERR_INDIVIDUAL_OUT_OF_BOUNDS; + ret = tsk_trace_error(TSK_ERR_INDIVIDUAL_OUT_OF_BOUNDS); goto out; } tsk_individual_table_get_row_unsafe(self, index, row); @@ -1819,7 +1822,7 @@ tsk_individual_table_keep_rows(tsk_individual_table_t *self, const tsk_bool_t *k if (ret_id_map == NULL) { id_map = tsk_malloc(current_num_rows * sizeof(*id_map)); if (id_map == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } } @@ -1834,12 +1837,12 @@ tsk_individual_table_keep_rows(tsk_individual_table_t *self, const tsk_bool_t *k pk = parents[k]; if (pk != TSK_NULL) { if (pk < 0 || pk >= (tsk_id_t) current_num_rows) { - ret = TSK_ERR_INDIVIDUAL_OUT_OF_BOUNDS; + ret = tsk_trace_error(TSK_ERR_INDIVIDUAL_OUT_OF_BOUNDS); ; goto out; } if (id_map[pk] == TSK_NULL) { - ret = TSK_ERR_KEEP_ROWS_MAP_TO_DELETED; + ret = tsk_trace_error(TSK_ERR_KEEP_ROWS_MAP_TO_DELETED); goto out; } } @@ -2118,7 +2121,7 @@ tsk_node_table_takeset_columns(tsk_node_table_t *self, tsk_size_t num_rows, /* We need to check all the inputs before we start freeing or taking memory */ if (flags == NULL || time == NULL) { - ret = TSK_ERR_BAD_PARAM_VALUE; + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); goto out; } ret = check_ragged_column(num_rows, metadata, metadata_offset); @@ -2159,11 +2162,11 @@ tsk_node_table_append_columns(tsk_node_table_t *self, tsk_size_t num_rows, tsk_size_t j, metadata_length; if (flags == NULL || time == NULL) { - ret = TSK_ERR_BAD_PARAM_VALUE; + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); goto out; } if ((metadata == NULL) != (metadata_offset == NULL)) { - ret = TSK_ERR_BAD_PARAM_VALUE; + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); goto out; } ret = tsk_node_table_expand_main_columns(self, num_rows); @@ -2270,7 +2273,7 @@ tsk_node_table_update_row_rewrite(tsk_node_table_t *self, tsk_id_t index, } rows = tsk_malloc(self->num_rows * sizeof(*rows)); if (rows == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } @@ -2341,7 +2344,7 @@ tsk_node_table_truncate(tsk_node_table_t *self, tsk_size_t num_rows) int ret = 0; if (num_rows > self->num_rows) { - ret = TSK_ERR_BAD_TABLE_POSITION; + ret = tsk_trace_error(TSK_ERR_BAD_TABLE_POSITION); goto out; } self->num_rows = num_rows; @@ -2360,7 +2363,7 @@ tsk_node_table_extend(tsk_node_table_t *self, const tsk_node_table_t *other, tsk_node_t node; if (self == other) { - ret = TSK_ERR_CANNOT_EXTEND_FROM_SELF; + ret = tsk_trace_error(TSK_ERR_CANNOT_EXTEND_FROM_SELF); goto out; } @@ -2510,7 +2513,7 @@ tsk_node_table_get_row(const tsk_node_table_t *self, tsk_id_t index, tsk_node_t int ret = 0; if (index < 0 || index >= (tsk_id_t) self->num_rows) { - ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + ret = tsk_trace_error(TSK_ERR_NODE_OUT_OF_BOUNDS); goto out; } tsk_node_table_get_row_unsafe(self, index, row); @@ -2752,7 +2755,7 @@ tsk_edge_table_add_row(tsk_edge_table_t *self, double left, double right, tsk_id_t ret = 0; if (metadata_length > 0 && !tsk_edge_table_has_metadata(self)) { - ret = TSK_ERR_METADATA_DISABLED; + ret = tsk_trace_error(TSK_ERR_METADATA_DISABLED); goto out; } @@ -2802,7 +2805,7 @@ tsk_edge_table_update_row_rewrite(tsk_edge_table_t *self, tsk_id_t index, double } rows = tsk_malloc(self->num_rows * sizeof(*rows)); if (rows == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } @@ -2882,7 +2885,7 @@ tsk_edge_table_copy( * This also captures the case where TSK_TABLE_NO_METADATA is set on this table. */ if (self->metadata_length > 0 && !tsk_edge_table_has_metadata(dest)) { - ret = TSK_ERR_METADATA_DISABLED; + ret = tsk_trace_error(TSK_ERR_METADATA_DISABLED); goto out; } if (tsk_edge_table_has_metadata(dest)) { @@ -2926,11 +2929,11 @@ tsk_edge_table_takeset_columns(tsk_edge_table_t *self, tsk_size_t num_rows, doub /* We need to check all the inputs before we start freeing or taking memory */ if (left == NULL || right == NULL || parent == NULL || child == NULL) { - ret = TSK_ERR_BAD_PARAM_VALUE; + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); goto out; } if (metadata != NULL && !tsk_edge_table_has_metadata(self)) { - ret = TSK_ERR_METADATA_DISABLED; + ret = tsk_trace_error(TSK_ERR_METADATA_DISABLED); goto out; } ret = check_ragged_column(num_rows, metadata, metadata_offset); @@ -2964,15 +2967,15 @@ tsk_edge_table_append_columns(tsk_edge_table_t *self, tsk_size_t num_rows, tsk_size_t j, metadata_length; if (left == NULL || right == NULL || parent == NULL || child == NULL) { - ret = TSK_ERR_BAD_PARAM_VALUE; + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); goto out; } if ((metadata == NULL) != (metadata_offset == NULL)) { - ret = TSK_ERR_BAD_PARAM_VALUE; + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); goto out; } if (metadata != NULL && !tsk_edge_table_has_metadata(self)) { - ret = TSK_ERR_METADATA_DISABLED; + ret = tsk_trace_error(TSK_ERR_METADATA_DISABLED); goto out; } @@ -3028,7 +3031,7 @@ tsk_edge_table_truncate(tsk_edge_table_t *self, tsk_size_t num_rows) int ret = 0; if (num_rows > self->num_rows) { - ret = TSK_ERR_BAD_TABLE_POSITION; + ret = tsk_trace_error(TSK_ERR_BAD_TABLE_POSITION); goto out; } self->num_rows = num_rows; @@ -3049,7 +3052,7 @@ tsk_edge_table_extend(tsk_edge_table_t *self, const tsk_edge_table_t *other, tsk_edge_t edge; if (self == other) { - ret = TSK_ERR_CANNOT_EXTEND_FROM_SELF; + ret = tsk_trace_error(TSK_ERR_CANNOT_EXTEND_FROM_SELF); goto out; } @@ -3101,7 +3104,7 @@ tsk_edge_table_get_row(const tsk_edge_table_t *self, tsk_id_t index, tsk_edge_t int ret = 0; if (index < 0 || index >= (tsk_id_t) self->num_rows) { - ret = TSK_ERR_EDGE_OUT_OF_BOUNDS; + ret = tsk_trace_error(TSK_ERR_EDGE_OUT_OF_BOUNDS); goto out; } tsk_edge_table_get_row_unsafe(self, index, row); @@ -3336,13 +3339,13 @@ tsk_edge_table_squash(tsk_edge_table_t *self) tsk_size_t num_output_edges; if (self->metadata_length > 0) { - ret = TSK_ERR_CANT_PROCESS_EDGES_WITH_METADATA; + ret = tsk_trace_error(TSK_ERR_CANT_PROCESS_EDGES_WITH_METADATA); goto out; } edges = tsk_malloc(self->num_rows * sizeof(tsk_edge_t)); if (edges == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } @@ -3560,7 +3563,7 @@ tsk_site_table_update_row_rewrite(tsk_site_table_t *self, tsk_id_t index, } rows = tsk_malloc(self->num_rows * sizeof(*rows)); if (rows == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } @@ -3629,11 +3632,11 @@ tsk_site_table_append_columns(tsk_site_table_t *self, tsk_size_t num_rows, tsk_size_t j, ancestral_state_length, metadata_length; if (position == NULL || ancestral_state == NULL || ancestral_state_offset == NULL) { - ret = TSK_ERR_BAD_PARAM_VALUE; + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); goto out; } if ((metadata == NULL) != (metadata_offset == NULL)) { - ret = TSK_ERR_BAD_PARAM_VALUE; + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); goto out; } @@ -3744,7 +3747,7 @@ tsk_site_table_takeset_columns(tsk_site_table_t *self, tsk_size_t num_rows, /* We need to check all the inputs before we start freeing or taking memory */ if (position == NULL || ancestral_state == NULL || ancestral_state_offset == NULL) { - ret = TSK_ERR_BAD_PARAM_VALUE; + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); goto out; } ret = check_ragged_column(num_rows, ancestral_state, ancestral_state_offset); @@ -3819,7 +3822,7 @@ tsk_site_table_truncate(tsk_site_table_t *self, tsk_size_t num_rows) int ret = 0; if (num_rows > self->num_rows) { - ret = TSK_ERR_BAD_TABLE_POSITION; + ret = tsk_trace_error(TSK_ERR_BAD_TABLE_POSITION); goto out; } self->num_rows = num_rows; @@ -3839,7 +3842,7 @@ tsk_site_table_extend(tsk_site_table_t *self, const tsk_site_table_t *other, tsk_site_t site; if (self == other) { - ret = TSK_ERR_CANNOT_EXTEND_FROM_SELF; + ret = tsk_trace_error(TSK_ERR_CANNOT_EXTEND_FROM_SELF); goto out; } @@ -3918,7 +3921,7 @@ tsk_site_table_get_row(const tsk_site_table_t *self, tsk_id_t index, tsk_site_t int ret = 0; if (index < 0 || index >= (tsk_id_t) self->num_rows) { - ret = TSK_ERR_SITE_OUT_OF_BOUNDS; + ret = tsk_trace_error(TSK_ERR_SITE_OUT_OF_BOUNDS); goto out; } tsk_site_table_get_row_unsafe(self, index, row); @@ -4275,7 +4278,7 @@ tsk_mutation_table_update_row_rewrite(tsk_mutation_table_t *self, tsk_id_t index } rows = tsk_malloc(self->num_rows * sizeof(*rows)); if (rows == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } @@ -4349,11 +4352,11 @@ tsk_mutation_table_append_columns(tsk_mutation_table_t *self, tsk_size_t num_row if (site == NULL || node == NULL || derived_state == NULL || derived_state_offset == NULL) { - ret = TSK_ERR_BAD_PARAM_VALUE; + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); goto out; } if ((metadata == NULL) != (metadata_offset == NULL)) { - ret = TSK_ERR_BAD_PARAM_VALUE; + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); goto out; } @@ -4438,7 +4441,7 @@ tsk_mutation_table_takeset_columns(tsk_mutation_table_t *self, tsk_size_t num_ro if (site == NULL || node == NULL || derived_state == NULL || derived_state_offset == NULL) { - ret = TSK_ERR_BAD_PARAM_VALUE; + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); goto out; } /* We need to check all the inputs before we start freeing or taking memory */ @@ -4465,7 +4468,7 @@ tsk_mutation_table_takeset_columns(tsk_mutation_table_t *self, tsk_size_t num_ro /* Time defaults to unknown time if not specified. */ self->time = tsk_malloc(num_rows * sizeof(*self->time)); if (self->time == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } for (j = 0; j < num_rows; j++) { @@ -4583,7 +4586,7 @@ tsk_mutation_table_truncate(tsk_mutation_table_t *mutations, tsk_size_t num_rows int ret = 0; if (num_rows > mutations->num_rows) { - ret = TSK_ERR_BAD_TABLE_POSITION; + ret = tsk_trace_error(TSK_ERR_BAD_TABLE_POSITION); goto out; } mutations->num_rows = num_rows; @@ -4603,7 +4606,7 @@ tsk_mutation_table_extend(tsk_mutation_table_t *self, const tsk_mutation_table_t tsk_mutation_t mutation; if (self == other) { - ret = TSK_ERR_CANNOT_EXTEND_FROM_SELF; + ret = tsk_trace_error(TSK_ERR_CANNOT_EXTEND_FROM_SELF); goto out; } @@ -4683,7 +4686,7 @@ tsk_mutation_table_get_row( int ret = 0; if (index < 0 || index >= (tsk_id_t) self->num_rows) { - ret = TSK_ERR_MUTATION_OUT_OF_BOUNDS; + ret = tsk_trace_error(TSK_ERR_MUTATION_OUT_OF_BOUNDS); goto out; } tsk_mutation_table_get_row_unsafe(self, index, row); @@ -4747,7 +4750,7 @@ tsk_mutation_table_keep_rows(tsk_mutation_table_t *self, const tsk_bool_t *keep, if (ret_id_map == NULL) { id_map = tsk_malloc(current_num_rows * sizeof(*id_map)); if (id_map == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } } @@ -4763,11 +4766,11 @@ tsk_mutation_table_keep_rows(tsk_mutation_table_t *self, const tsk_bool_t *keep, pj = parent[j]; if (pj != TSK_NULL) { if (pj < 0 || pj >= (tsk_id_t) current_num_rows) { - ret = TSK_ERR_MUTATION_OUT_OF_BOUNDS; + ret = tsk_trace_error(TSK_ERR_MUTATION_OUT_OF_BOUNDS); goto out; } if (id_map[pj] == TSK_NULL) { - ret = TSK_ERR_KEEP_ROWS_MAP_TO_DELETED; + ret = tsk_trace_error(TSK_ERR_KEEP_ROWS_MAP_TO_DELETED); goto out; } } @@ -5020,11 +5023,11 @@ tsk_migration_table_append_columns(tsk_migration_table_t *self, tsk_size_t num_r if (left == NULL || right == NULL || node == NULL || source == NULL || dest == NULL || time == NULL) { - ret = TSK_ERR_BAD_PARAM_VALUE; + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); goto out; } if ((metadata == NULL) != (metadata_offset == NULL)) { - ret = TSK_ERR_BAD_PARAM_VALUE; + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); goto out; } @@ -5076,7 +5079,7 @@ tsk_migration_table_takeset_columns(tsk_migration_table_t *self, tsk_size_t num_ if (left == NULL || right == NULL || node == NULL || source == NULL || dest == NULL || time == NULL) { - ret = TSK_ERR_BAD_PARAM_VALUE; + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); goto out; } @@ -5198,7 +5201,7 @@ tsk_migration_table_update_row_rewrite(tsk_migration_table_t *self, tsk_id_t ind } rows = tsk_malloc(self->num_rows * sizeof(*rows)); if (rows == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } @@ -5271,7 +5274,7 @@ tsk_migration_table_truncate(tsk_migration_table_t *self, tsk_size_t num_rows) int ret = 0; if (num_rows > self->num_rows) { - ret = TSK_ERR_BAD_TABLE_POSITION; + ret = tsk_trace_error(TSK_ERR_BAD_TABLE_POSITION); goto out; } self->num_rows = num_rows; @@ -5291,7 +5294,7 @@ tsk_migration_table_extend(tsk_migration_table_t *self, tsk_migration_t migration; if (self == other) { - ret = TSK_ERR_CANNOT_EXTEND_FROM_SELF; + ret = tsk_trace_error(TSK_ERR_CANNOT_EXTEND_FROM_SELF); goto out; } @@ -5360,7 +5363,7 @@ tsk_migration_table_get_row( int ret = 0; if (index < 0 || index >= (tsk_id_t) self->num_rows) { - ret = TSK_ERR_MIGRATION_OUT_OF_BOUNDS; + ret = tsk_trace_error(TSK_ERR_MIGRATION_OUT_OF_BOUNDS); goto out; } tsk_migration_table_get_row_unsafe(self, index, row); @@ -5692,7 +5695,7 @@ tsk_population_table_append_columns(tsk_population_table_t *self, tsk_size_t num tsk_size_t j, metadata_length; if (metadata == NULL || metadata_offset == NULL) { - ret = TSK_ERR_BAD_PARAM_VALUE; + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); goto out; } ret = tsk_population_table_expand_main_columns(self, num_rows); @@ -5731,7 +5734,7 @@ tsk_population_table_takeset_columns(tsk_population_table_t *self, tsk_size_t nu /* We need to check all the inputs before we start freeing or taking memory */ if (metadata == NULL || metadata_offset == NULL) { - ret = TSK_ERR_BAD_PARAM_VALUE; + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); goto out; } ret = check_ragged_column(num_rows, metadata, metadata_offset); @@ -5803,7 +5806,7 @@ tsk_population_table_update_row_rewrite(tsk_population_table_t *self, tsk_id_t i } rows = tsk_malloc(self->num_rows * sizeof(*rows)); if (rows == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } @@ -5868,7 +5871,7 @@ tsk_population_table_truncate(tsk_population_table_t *self, tsk_size_t num_rows) int ret = 0; if (num_rows > self->num_rows) { - ret = TSK_ERR_BAD_TABLE_POSITION; + ret = tsk_trace_error(TSK_ERR_BAD_TABLE_POSITION); goto out; } self->num_rows = num_rows; @@ -5888,7 +5891,7 @@ tsk_population_table_extend(tsk_population_table_t *self, tsk_population_t population; if (self == other) { - ret = TSK_ERR_CANNOT_EXTEND_FROM_SELF; + ret = tsk_trace_error(TSK_ERR_CANNOT_EXTEND_FROM_SELF); goto out; } @@ -5961,7 +5964,7 @@ tsk_population_table_get_row( int ret = 0; if (index < 0 || index >= (tsk_id_t) self->num_rows) { - ret = TSK_ERR_POPULATION_OUT_OF_BOUNDS; + ret = tsk_trace_error(TSK_ERR_POPULATION_OUT_OF_BOUNDS); goto out; } tsk_population_table_get_row_unsafe(self, index, row); @@ -6279,7 +6282,7 @@ tsk_provenance_table_append_columns(tsk_provenance_table_t *self, tsk_size_t num if (timestamp == NULL || timestamp_offset == NULL || record == NULL || record_offset == NULL) { - ret = TSK_ERR_BAD_PARAM_VALUE; + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); goto out; } ret = tsk_provenance_table_expand_main_columns(self, num_rows); @@ -6336,7 +6339,7 @@ tsk_provenance_table_takeset_columns(tsk_provenance_table_t *self, tsk_size_t nu /* We need to check all the inputs before we start freeing or taking memory */ if (timestamp == NULL || timestamp_offset == NULL || record == NULL || record_offset == NULL) { - ret = TSK_ERR_BAD_PARAM_VALUE; + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); goto out; } ret = check_ragged_column(num_rows, timestamp, timestamp_offset); @@ -6430,7 +6433,7 @@ tsk_provenance_table_update_row_rewrite(tsk_provenance_table_t *self, tsk_id_t i } rows = tsk_malloc(self->num_rows * sizeof(*rows)); if (rows == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } @@ -6500,7 +6503,7 @@ tsk_provenance_table_truncate(tsk_provenance_table_t *self, tsk_size_t num_rows) int ret = 0; if (num_rows > self->num_rows) { - ret = TSK_ERR_BAD_TABLE_POSITION; + ret = tsk_trace_error(TSK_ERR_BAD_TABLE_POSITION); goto out; } self->num_rows = num_rows; @@ -6521,7 +6524,7 @@ tsk_provenance_table_extend(tsk_provenance_table_t *self, tsk_provenance_t provenance; if (self == other) { - ret = TSK_ERR_CANNOT_EXTEND_FROM_SELF; + ret = tsk_trace_error(TSK_ERR_CANNOT_EXTEND_FROM_SELF); goto out; } @@ -6603,7 +6606,7 @@ tsk_provenance_table_get_row( int ret = 0; if (index < 0 || index >= (tsk_id_t) self->num_rows) { - ret = TSK_ERR_PROVENANCE_OUT_OF_BOUNDS; + ret = tsk_trace_error(TSK_ERR_PROVENANCE_OUT_OF_BOUNDS); goto out; } tsk_provenance_table_get_row_unsafe(self, index, row); @@ -6753,7 +6756,8 @@ typedef struct { typedef struct { tsk_mutation_t mut; int num_descendants; -} mutation_canonical_sort_t; + double node_time; +} mutation_sort_t; typedef struct { tsk_individual_t ind; @@ -6794,39 +6798,30 @@ cmp_site(const void *a, const void *b) static int cmp_mutation(const void *a, const void *b) { - const tsk_mutation_t *ia = (const tsk_mutation_t *) a; - const tsk_mutation_t *ib = (const tsk_mutation_t *) b; - /* Compare mutations by site */ - int ret = (ia->site > ib->site) - (ia->site < ib->site); - /* Within a particular site sort by time if known, then ID. This ensures that - * relative ordering within a site is maintained */ - if (ret == 0 && !tsk_is_unknown_time(ia->time) && !tsk_is_unknown_time(ib->time)) { - ret = (ia->time < ib->time) - (ia->time > ib->time); - } - if (ret == 0) { - ret = (ia->id > ib->id) - (ia->id < ib->id); - } - return ret; -} - -static int -cmp_mutation_canonical(const void *a, const void *b) -{ - const mutation_canonical_sort_t *ia = (const mutation_canonical_sort_t *) a; - const mutation_canonical_sort_t *ib = (const mutation_canonical_sort_t *) b; + const mutation_sort_t *ia = (const mutation_sort_t *) a; + const mutation_sort_t *ib = (const mutation_sort_t *) b; /* Compare mutations by site */ int ret = (ia->mut.site > ib->mut.site) - (ia->mut.site < ib->mut.site); + + /* Within a particular site sort by time if known */ if (ret == 0 && !tsk_is_unknown_time(ia->mut.time) && !tsk_is_unknown_time(ib->mut.time)) { ret = (ia->mut.time < ib->mut.time) - (ia->mut.time > ib->mut.time); } + /* Or node times when mutation times are unknown or equal */ + if (ret == 0) { + ret = (ia->node_time < ib->node_time) - (ia->node_time > ib->node_time); + } + /* If node times are equal, sort by number of descendants */ if (ret == 0) { ret = (ia->num_descendants < ib->num_descendants) - (ia->num_descendants > ib->num_descendants); } + /* If number of descendants are equal, sort by node */ if (ret == 0) { ret = (ia->mut.node > ib->mut.node) - (ia->mut.node < ib->mut.node); } + /* Final tiebreaker: ID */ if (ret == 0) { ret = (ia->mut.id > ib->mut.id) - (ia->mut.id < ib->mut.id); } @@ -6911,7 +6906,7 @@ tsk_table_sorter_sort_edges(tsk_table_sorter_t *self, tsk_size_t start) bool has_metadata = tsk_edge_table_has_metadata(edges); if (sorted_edges == NULL || old_metadata == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } tsk_memcpy(old_metadata, edges->metadata, edges->metadata_length); @@ -6964,7 +6959,7 @@ tsk_table_sorter_sort_migrations(tsk_table_sorter_t *self, tsk_size_t start) char *old_metadata = tsk_malloc(migrations->metadata_length); if (sorted_migrations == NULL || old_metadata == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } tsk_memcpy(old_metadata, migrations->metadata, migrations->metadata_length); @@ -7020,7 +7015,7 @@ tsk_table_sorter_sort_sites(tsk_table_sorter_t *self) goto out; } if (sorted_sites == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } for (j = 0; j < num_sites; j++) { @@ -7053,77 +7048,15 @@ tsk_table_sorter_sort_sites(tsk_table_sorter_t *self) static int tsk_table_sorter_sort_mutations(tsk_table_sorter_t *self) -{ - int ret = 0; - tsk_size_t j; - tsk_id_t ret_id, parent, mapped_parent; - tsk_mutation_table_t *mutations = &self->tables->mutations; - tsk_size_t num_mutations = mutations->num_rows; - tsk_mutation_table_t copy; - tsk_mutation_t *sorted_mutations - = tsk_malloc(num_mutations * sizeof(*sorted_mutations)); - tsk_id_t *mutation_id_map = tsk_malloc(num_mutations * sizeof(*mutation_id_map)); - - ret = tsk_mutation_table_copy(mutations, ©, 0); - if (ret != 0) { - goto out; - } - if (mutation_id_map == NULL || sorted_mutations == NULL) { - ret = TSK_ERR_NO_MEMORY; - goto out; - } - - for (j = 0; j < num_mutations; j++) { - tsk_mutation_table_get_row_unsafe(©, (tsk_id_t) j, sorted_mutations + j); - sorted_mutations[j].site = self->site_id_map[sorted_mutations[j].site]; - } - ret = tsk_mutation_table_clear(mutations); - if (ret != 0) { - goto out; - } - - qsort(sorted_mutations, (size_t) num_mutations, sizeof(*sorted_mutations), - cmp_mutation); - - /* Make a first pass through the sorted mutations to build the ID map. */ - for (j = 0; j < num_mutations; j++) { - mutation_id_map[sorted_mutations[j].id] = (tsk_id_t) j; - } - - for (j = 0; j < num_mutations; j++) { - mapped_parent = TSK_NULL; - parent = sorted_mutations[j].parent; - if (parent != TSK_NULL) { - mapped_parent = mutation_id_map[parent]; - } - ret_id = tsk_mutation_table_add_row(mutations, sorted_mutations[j].site, - sorted_mutations[j].node, mapped_parent, sorted_mutations[j].time, - sorted_mutations[j].derived_state, sorted_mutations[j].derived_state_length, - sorted_mutations[j].metadata, sorted_mutations[j].metadata_length); - if (ret_id < 0) { - ret = (int) ret_id; - goto out; - } - } - ret = 0; - -out: - tsk_safe_free(mutation_id_map); - tsk_safe_free(sorted_mutations); - tsk_mutation_table_free(©); - return ret; -} - -static int -tsk_table_sorter_sort_mutations_canonical(tsk_table_sorter_t *self) { int ret = 0; tsk_size_t j; tsk_id_t ret_id, parent, mapped_parent, p; tsk_mutation_table_t *mutations = &self->tables->mutations; + tsk_node_table_t *nodes = &self->tables->nodes; tsk_size_t num_mutations = mutations->num_rows; tsk_mutation_table_t copy; - mutation_canonical_sort_t *sorted_mutations + mutation_sort_t *sorted_mutations = tsk_malloc(num_mutations * sizeof(*sorted_mutations)); tsk_id_t *mutation_id_map = tsk_malloc(num_mutations * sizeof(*mutation_id_map)); @@ -7132,7 +7065,7 @@ tsk_table_sorter_sort_mutations_canonical(tsk_table_sorter_t *self) goto out; } if (mutation_id_map == NULL || sorted_mutations == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } @@ -7145,7 +7078,7 @@ tsk_table_sorter_sort_mutations_canonical(tsk_table_sorter_t *self) while (p != TSK_NULL) { sorted_mutations[p].num_descendants += 1; if (sorted_mutations[p].num_descendants > (int) num_mutations) { - ret = TSK_ERR_MUTATION_PARENT_INCONSISTENT; + ret = tsk_trace_error(TSK_ERR_MUTATION_PARENT_INCONSISTENT); goto out; } p = mutations->parent[p]; @@ -7155,6 +7088,7 @@ tsk_table_sorter_sort_mutations_canonical(tsk_table_sorter_t *self) for (j = 0; j < num_mutations; j++) { tsk_mutation_table_get_row_unsafe(©, (tsk_id_t) j, &sorted_mutations[j].mut); sorted_mutations[j].mut.site = self->site_id_map[sorted_mutations[j].mut.site]; + sorted_mutations[j].node_time = nodes->time[sorted_mutations[j].mut.node]; } ret = tsk_mutation_table_clear(mutations); if (ret != 0) { @@ -7162,7 +7096,7 @@ tsk_table_sorter_sort_mutations_canonical(tsk_table_sorter_t *self) } qsort(sorted_mutations, (size_t) num_mutations, sizeof(*sorted_mutations), - cmp_mutation_canonical); + cmp_mutation); /* Make a first pass through the sorted mutations to build the ID map. */ for (j = 0; j < num_mutations; j++) { @@ -7209,7 +7143,7 @@ tsk_individual_table_topological_sort( bool count_descendants = (num_descendants != NULL); if (incoming_edge_count == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } @@ -7262,7 +7196,7 @@ tsk_individual_table_topological_sort( /* Any edges left are parts of cycles */ for (i = 0; i < (tsk_id_t) num_individuals; i++) { if (incoming_edge_count[i] > 0) { - ret = TSK_ERR_INDIVIDUAL_PARENT_CYCLE; + ret = tsk_trace_error(TSK_ERR_INDIVIDUAL_PARENT_CYCLE); goto out; } } @@ -7287,7 +7221,7 @@ tsk_table_collection_individual_topological_sort( tsk_id_t *new_id_map = tsk_malloc(num_individuals * sizeof(*new_id_map)); if (new_id_map == NULL || traversal_order == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } tsk_memset(new_id_map, 0xff, num_individuals * sizeof(*new_id_map)); @@ -7365,7 +7299,7 @@ tsk_table_sorter_sort_individuals_canonical(tsk_table_sorter_t *self) if (individual_id_map == NULL || sorted_individuals == NULL || traversal_order == NULL || num_descendants == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } @@ -7457,12 +7391,12 @@ tsk_table_sorter_run(tsk_table_sorter_t *self, const tsk_bookmark_t *start) if (start != NULL) { if (start->edges > self->tables->edges.num_rows) { - ret = TSK_ERR_EDGE_OUT_OF_BOUNDS; + ret = tsk_trace_error(TSK_ERR_EDGE_OUT_OF_BOUNDS); goto out; } edge_start = start->edges; if (start->migrations > self->tables->migrations.num_rows) { - ret = TSK_ERR_MIGRATION_OUT_OF_BOUNDS; + ret = tsk_trace_error(TSK_ERR_MIGRATION_OUT_OF_BOUNDS); goto out; } migration_start = start->migrations; @@ -7474,7 +7408,7 @@ tsk_table_sorter_run(tsk_table_sorter_t *self, const tsk_bookmark_t *start) && start->mutations == self->tables->mutations.num_rows) { skip_sites = true; } else if (start->sites != 0 || start->mutations != 0) { - ret = TSK_ERR_SORT_OFFSET_NOT_SUPPORTED; + ret = tsk_trace_error(TSK_ERR_SORT_OFFSET_NOT_SUPPORTED); goto out; } } @@ -7536,7 +7470,7 @@ tsk_table_sorter_init( self->site_id_map = tsk_malloc(self->tables->sites.num_rows * sizeof(tsk_id_t)); if (self->site_id_map == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } @@ -7652,7 +7586,7 @@ segment_overlapper_alloc(segment_overlapper_t *self) self->max_overlapping = 8; /* Making sure we call tsk_realloc in tests */ self->overlapping = tsk_malloc(self->max_overlapping * sizeof(*self->overlapping)); if (self->overlapping == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } out: @@ -7682,7 +7616,7 @@ segment_overlapper_start( p = tsk_realloc( self->overlapping, self->max_overlapping * sizeof(*self->overlapping)); if (p == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } self->overlapping = p; @@ -7890,7 +7824,7 @@ ancestor_mapper_record_edge( self->num_buffered_children++; x = ancestor_mapper_alloc_interval_list(self, left, right); if (x == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } self->child_edge_map_head[child] = x; @@ -7901,7 +7835,7 @@ ancestor_mapper_record_edge( } else { x = ancestor_mapper_alloc_interval_list(self, left, right); if (x == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } tail->next = x; @@ -7924,7 +7858,7 @@ ancestor_mapper_add_ancestry(ancestor_mapper_t *self, tsk_id_t input_id, double if (tail == NULL) { x = ancestor_mapper_alloc_segment(self, left, right, output_id); if (x == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } self->ancestor_map_head[input_id] = x; @@ -7935,7 +7869,7 @@ ancestor_mapper_add_ancestry(ancestor_mapper_t *self, tsk_id_t input_id, double } else { x = ancestor_mapper_alloc_segment(self, left, right, output_id); if (x == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } tail->next = x; @@ -7972,11 +7906,11 @@ ancestor_mapper_init_samples(ancestor_mapper_t *self, tsk_id_t *samples) /* Go through the samples to check for errors. */ for (j = 0; j < self->num_samples; j++) { if (samples[j] < 0 || samples[j] > (tsk_id_t) self->tables->nodes.num_rows) { - ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + ret = tsk_trace_error(TSK_ERR_NODE_OUT_OF_BOUNDS); goto out; } if (self->is_sample[samples[j]]) { - ret = TSK_ERR_DUPLICATE_SAMPLE; + ret = tsk_trace_error(TSK_ERR_DUPLICATE_SAMPLE); goto out; } self->is_sample[samples[j]] = true; @@ -7999,11 +7933,11 @@ ancestor_mapper_init_ancestors(ancestor_mapper_t *self, tsk_id_t *ancestors) /* Go through the samples to check for errors. */ for (j = 0; j < self->num_ancestors; j++) { if (ancestors[j] < 0 || ancestors[j] > (tsk_id_t) self->tables->nodes.num_rows) { - ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + ret = tsk_trace_error(TSK_ERR_NODE_OUT_OF_BOUNDS); goto out; } if (self->is_ancestor[ancestors[j]]) { - ret = TSK_ERR_DUPLICATE_SAMPLE; + ret = tsk_trace_error(TSK_ERR_DUPLICATE_SAMPLE); goto out; } self->is_ancestor[ancestors[j]] = true; @@ -8030,7 +7964,7 @@ ancestor_mapper_init(ancestor_mapper_t *self, tsk_id_t *samples, tsk_size_t num_ self->sequence_length = self->tables->sequence_length; if (samples == NULL || num_samples == 0 || ancestors == NULL || num_ancestors == 0) { - ret = TSK_ERR_BAD_PARAM_VALUE; + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); goto out; } @@ -8065,7 +7999,7 @@ ancestor_mapper_init(ancestor_mapper_t *self, tsk_id_t *samples, tsk_size_t num_ || self->child_edge_map_head == NULL || self->child_edge_map_tail == NULL || self->is_sample == NULL || self->is_ancestor == NULL || self->segment_queue == NULL || self->buffered_children == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } // Clear memory. @@ -8119,7 +8053,7 @@ ancestor_mapper_enqueue_segment( p = tsk_realloc(self->segment_queue, self->max_segment_queue_size * sizeof(*self->segment_queue)); if (p == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } self->segment_queue = p; @@ -8333,11 +8267,11 @@ tsk_identity_segments_get_key( tsk_id_t N = (tsk_id_t) self->num_nodes; if (a < 0 || b < 0 || a >= N || b >= N) { - ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + ret = tsk_trace_error(TSK_ERR_NODE_OUT_OF_BOUNDS); goto out; } if (a == b) { - ret = TSK_ERR_SAME_NODES_IN_PAIR; + ret = tsk_trace_error(TSK_ERR_SAME_NODES_IN_PAIR); goto out; } ret = pair_to_integer(a, b, self->num_nodes); @@ -8553,7 +8487,7 @@ tsk_identity_segments_update_pair(tsk_identity_segments_t *self, tsk_id_t a, tsk /* We haven't seen this pair before */ avl_node = tsk_identity_segments_alloc_new_pair(self, key); if (avl_node == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } ret = tsk_avl_tree_int_insert(&self->pair_map, avl_node); @@ -8610,7 +8544,7 @@ tsk_identity_segments_get(const tsk_identity_segments_t *self, tsk_id_t sample_a goto out; } if (!self->store_pairs) { - ret = TSK_ERR_IBD_PAIRS_NOT_STORED; + ret = tsk_trace_error(TSK_ERR_IBD_PAIRS_NOT_STORED); goto out; } avl_node = tsk_avl_tree_int_search(&self->pair_map, key); @@ -8673,7 +8607,7 @@ tsk_ibd_finder_add_ancestry(tsk_ibd_finder_t *self, tsk_id_t input_id, double le tsk_bug_assert(left < right); x = tsk_ibd_finder_alloc_segment(self, left, right, output_id); if (x == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } if (tail == NULL) { @@ -8699,11 +8633,11 @@ tsk_ibd_finder_init_samples_from_set( u = samples[j]; if (u < 0 || u > (tsk_id_t) self->tables->nodes.num_rows) { - ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + ret = tsk_trace_error(TSK_ERR_NODE_OUT_OF_BOUNDS); goto out; } if (self->sample_set_id[u] != TSK_NULL) { - ret = TSK_ERR_DUPLICATE_SAMPLE; + ret = tsk_trace_error(TSK_ERR_DUPLICATE_SAMPLE); goto out; } self->sample_set_id[u] = 0; @@ -8757,11 +8691,11 @@ tsk_ibd_finder_init(tsk_ibd_finder_t *self, const tsk_table_collection_t *tables tsk_memset(self, 0, sizeof(tsk_ibd_finder_t)); if (min_span < 0) { - ret = TSK_ERR_BAD_PARAM_VALUE; + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); goto out; } if (max_time < 0) { - ret = TSK_ERR_BAD_PARAM_VALUE; + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); goto out; } @@ -8785,7 +8719,7 @@ tsk_ibd_finder_init(tsk_ibd_finder_t *self, const tsk_table_collection_t *tables = tsk_malloc(self->max_segment_queue_size * sizeof(*self->segment_queue)); if (self->ancestor_map_head == NULL || self->ancestor_map_tail == NULL || self->sample_set_id == NULL || self->segment_queue == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } tsk_memset(self->sample_set_id, TSK_NULL, num_nodes * sizeof(*self->sample_set_id)); @@ -8809,7 +8743,7 @@ tsk_ibd_finder_enqueue_segment( p = tsk_realloc(self->segment_queue, self->max_segment_queue_size * sizeof(*self->segment_queue)); if (p == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } self->segment_queue = p; @@ -8953,11 +8887,11 @@ tsk_ibd_finder_init_between(tsk_ibd_finder_t *self, tsk_size_t num_sample_sets, for (k = 0; k < sample_set_sizes[j]; k++) { u = sample_sets[index]; if (u < 0 || u > (tsk_id_t) self->tables->nodes.num_rows) { - ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + ret = tsk_trace_error(TSK_ERR_NODE_OUT_OF_BOUNDS); goto out; } if (self->sample_set_id[u] != TSK_NULL) { - ret = TSK_ERR_DUPLICATE_SAMPLE; + ret = tsk_trace_error(TSK_ERR_DUPLICATE_SAMPLE); goto out; } self->sample_set_id[u] = (tsk_id_t) j; @@ -9373,7 +9307,7 @@ simplifier_record_edge(simplifier_t *self, double left, double right, tsk_id_t c self->num_buffered_children++; x = simplifier_alloc_interval_list(self, left, right); if (x == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } self->child_edge_map_head[child] = x; @@ -9384,7 +9318,7 @@ simplifier_record_edge(simplifier_t *self, double left, double right, tsk_id_t c } else { x = simplifier_alloc_interval_list(self, left, right); if (x == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } tail->next = x; @@ -9414,7 +9348,7 @@ simplifier_init_sites(simplifier_t *self) if (self->mutation_node_map == NULL || self->node_mutation_list_mem == NULL || self->node_mutation_list_map_head == NULL || self->node_mutation_list_map_tail == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } tsk_memset(self->mutation_node_map, 0xff, @@ -9467,7 +9401,7 @@ simplifier_add_ancestry( if (tail == NULL) { x = simplifier_alloc_segment(self, left, right, output_id); if (x == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } self->ancestor_map_head[input_id] = x; @@ -9478,7 +9412,7 @@ simplifier_add_ancestry( } else { x = simplifier_alloc_segment(self, left, right, output_id); if (x == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } tail->next = x; @@ -9645,18 +9579,18 @@ simplifier_init(simplifier_t *self, const tsk_id_t *samples, tsk_size_t num_samp || self->child_edge_map_head == NULL || self->child_edge_map_tail == NULL || self->node_id_map == NULL || self->is_sample == NULL || self->segment_queue == NULL || self->buffered_children == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } /* Go through the samples to check for errors before we clear the tables. */ for (j = 0; j < self->num_samples; j++) { if (samples[j] < 0 || samples[j] >= (tsk_id_t) num_nodes) { - ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + ret = tsk_trace_error(TSK_ERR_NODE_OUT_OF_BOUNDS); goto out; } if (self->is_sample[samples[j]]) { - ret = TSK_ERR_DUPLICATE_SAMPLE; + ret = tsk_trace_error(TSK_ERR_DUPLICATE_SAMPLE); goto out; } self->is_sample[samples[j]] = true; @@ -9725,7 +9659,7 @@ simplifier_enqueue_segment(simplifier_t *self, double left, double right, tsk_id p = tsk_realloc(self->segment_queue, self->max_segment_queue_size * sizeof(*self->segment_queue)); if (p == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } self->segment_queue = p; @@ -9878,7 +9812,7 @@ simplifier_extract_ancestry( if (x->left != y.left) { seg_left = simplifier_alloc_segment(self, x->left, y.left, x->node); if (seg_left == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } if (x_prev == NULL) { @@ -10001,7 +9935,7 @@ simplifier_finalise_population_references(simplifier_t *self) tsk_bug_assert(self->options & TSK_SIMPLIFY_FILTER_POPULATIONS); if (population_referenced == NULL || population_id_map == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } @@ -10059,7 +9993,7 @@ simplifier_finalise_individual_references(simplifier_t *self) tsk_bug_assert(self->options & TSK_SIMPLIFY_FILTER_INDIVIDUALS); if (individual_referenced == NULL || individual_id_map == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } @@ -10128,7 +10062,7 @@ simplifier_output_sites(simplifier_t *self) const tsk_id_t *mutation_site = self->input_tables.mutations.site; if (site_referenced == NULL || site_id_map == NULL || mutation_id_map == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } @@ -10187,7 +10121,7 @@ simplifier_flush_output(simplifier_t *self) * or whether to add these nodes back in (probably the former is the correct * approach).*/ if (self->input_tables.migrations.num_rows != 0) { - ret = TSK_ERR_SIMPLIFY_MIGRATIONS_NOT_SUPPORTED; + ret = tsk_trace_error(TSK_ERR_SIMPLIFY_MIGRATIONS_NOT_SUPPORTED); goto out; } @@ -10439,19 +10373,19 @@ tsk_table_collection_check_node_integrity( for (j = 0; j < self->nodes.num_rows; j++) { node_time = self->nodes.time[j]; if (!tsk_isfinite(node_time)) { - ret = TSK_ERR_TIME_NONFINITE; + ret = tsk_trace_error(TSK_ERR_TIME_NONFINITE); goto out; } if (check_population_refs) { population = self->nodes.population[j]; if (population < TSK_NULL || population >= num_populations) { - ret = TSK_ERR_POPULATION_OUT_OF_BOUNDS; + ret = tsk_trace_error(TSK_ERR_POPULATION_OUT_OF_BOUNDS); goto out; } } individual = self->nodes.individual[j]; if (individual < TSK_NULL || individual >= num_individuals) { - ret = TSK_ERR_INDIVIDUAL_OUT_OF_BOUNDS; + ret = tsk_trace_error(TSK_ERR_INDIVIDUAL_OUT_OF_BOUNDS); goto out; } } @@ -10477,7 +10411,7 @@ tsk_table_collection_check_edge_integrity( if (check_ordering) { parent_seen = tsk_calloc((tsk_size_t) num_nodes, sizeof(*parent_seen)); if (parent_seen == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } } @@ -10493,67 +10427,67 @@ tsk_table_collection_check_edge_integrity( right = edges.right[j]; /* Node ID integrity */ if (parent == TSK_NULL) { - ret = TSK_ERR_NULL_PARENT; + ret = tsk_trace_error(TSK_ERR_NULL_PARENT); goto out; } if (parent < 0 || parent >= num_nodes) { - ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + ret = tsk_trace_error(TSK_ERR_NODE_OUT_OF_BOUNDS); goto out; } if (child == TSK_NULL) { - ret = TSK_ERR_NULL_CHILD; + ret = tsk_trace_error(TSK_ERR_NULL_CHILD); goto out; } if (child < 0 || child >= num_nodes) { - ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + ret = tsk_trace_error(TSK_ERR_NODE_OUT_OF_BOUNDS); goto out; } /* Spatial requirements for edges */ if (!(tsk_isfinite(left) && tsk_isfinite(right))) { - ret = TSK_ERR_GENOME_COORDS_NONFINITE; + ret = tsk_trace_error(TSK_ERR_GENOME_COORDS_NONFINITE); goto out; } if (left < 0) { - ret = TSK_ERR_LEFT_LESS_ZERO; + ret = tsk_trace_error(TSK_ERR_LEFT_LESS_ZERO); goto out; } if (right > L) { - ret = TSK_ERR_RIGHT_GREATER_SEQ_LENGTH; + ret = tsk_trace_error(TSK_ERR_RIGHT_GREATER_SEQ_LENGTH); goto out; } if (left >= right) { - ret = TSK_ERR_BAD_EDGE_INTERVAL; + ret = tsk_trace_error(TSK_ERR_BAD_EDGE_INTERVAL); goto out; } /* time[child] must be < time[parent] */ if (time[child] >= time[parent]) { - ret = TSK_ERR_BAD_NODE_TIME_ORDERING; + ret = tsk_trace_error(TSK_ERR_BAD_NODE_TIME_ORDERING); goto out; } if (check_ordering) { if (parent_seen[parent]) { - ret = TSK_ERR_EDGES_NONCONTIGUOUS_PARENTS; + ret = tsk_trace_error(TSK_ERR_EDGES_NONCONTIGUOUS_PARENTS); goto out; } if (j > 0) { /* Input data must sorted by (time[parent], parent, child, left). */ if (time[parent] < time[last_parent]) { - ret = TSK_ERR_EDGES_NOT_SORTED_PARENT_TIME; + ret = tsk_trace_error(TSK_ERR_EDGES_NOT_SORTED_PARENT_TIME); goto out; } if (time[parent] == time[last_parent]) { if (parent == last_parent) { if (child < last_child) { - ret = TSK_ERR_EDGES_NOT_SORTED_CHILD; + ret = tsk_trace_error(TSK_ERR_EDGES_NOT_SORTED_CHILD); goto out; } if (child == last_child) { if (left == last_left) { - ret = TSK_ERR_DUPLICATE_EDGES; + ret = tsk_trace_error(TSK_ERR_DUPLICATE_EDGES); goto out; } else if (left < last_left) { - ret = TSK_ERR_EDGES_NOT_SORTED_LEFT; + ret = tsk_trace_error(TSK_ERR_EDGES_NOT_SORTED_LEFT); goto out; } } @@ -10588,20 +10522,20 @@ tsk_table_collection_check_site_integrity( position = sites.position[j]; /* Spatial requirements */ if (!tsk_isfinite(position)) { - ret = TSK_ERR_BAD_SITE_POSITION; + ret = tsk_trace_error(TSK_ERR_BAD_SITE_POSITION); goto out; } if (position < 0 || position >= L) { - ret = TSK_ERR_BAD_SITE_POSITION; + ret = tsk_trace_error(TSK_ERR_BAD_SITE_POSITION); goto out; } if (j > 0) { if (check_site_duplicates && sites.position[j - 1] == position) { - ret = TSK_ERR_DUPLICATE_SITE_POSITION; + ret = tsk_trace_error(TSK_ERR_DUPLICATE_SITE_POSITION); goto out; } if (check_site_ordering && sites.position[j - 1] > position) { - ret = TSK_ERR_UNSORTED_SITES; + ret = tsk_trace_error(TSK_ERR_UNSORTED_SITES); goto out; } } @@ -10632,21 +10566,21 @@ tsk_table_collection_check_mutation_integrity( for (j = 0; j < mutations.num_rows; j++) { /* Basic reference integrity */ if (mutations.site[j] < 0 || mutations.site[j] >= num_sites) { - ret = TSK_ERR_SITE_OUT_OF_BOUNDS; + ret = tsk_trace_error(TSK_ERR_SITE_OUT_OF_BOUNDS); goto out; } if (mutations.node[j] < 0 || mutations.node[j] >= num_nodes) { - ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + ret = tsk_trace_error(TSK_ERR_NODE_OUT_OF_BOUNDS); goto out; } /* Integrity check for mutation parent */ parent_mut = mutations.parent[j]; if (parent_mut < TSK_NULL || parent_mut >= num_mutations) { - ret = TSK_ERR_MUTATION_OUT_OF_BOUNDS; + ret = tsk_trace_error(TSK_ERR_MUTATION_OUT_OF_BOUNDS); goto out; } if (parent_mut == (tsk_id_t) j) { - ret = TSK_ERR_MUTATION_PARENT_EQUAL; + ret = tsk_trace_error(TSK_ERR_MUTATION_PARENT_EQUAL); goto out; } /* Check that time is finite and not more recent than node time */ @@ -10654,11 +10588,11 @@ tsk_table_collection_check_mutation_integrity( unknown_time = tsk_is_unknown_time(mutation_time); if (!unknown_time) { if (!tsk_isfinite(mutation_time)) { - ret = TSK_ERR_TIME_NONFINITE; + ret = tsk_trace_error(TSK_ERR_TIME_NONFINITE); goto out; } if (mutation_time < node_time[mutations.node[j]]) { - ret = TSK_ERR_MUTATION_TIME_YOUNGER_THAN_NODE; + ret = tsk_trace_error(TSK_ERR_MUTATION_TIME_YOUNGER_THAN_NODE); goto out; } } @@ -10677,14 +10611,14 @@ tsk_table_collection_check_mutation_integrity( num_known_times++; } if ((num_unknown_times > 0) && (num_known_times > 0)) { - ret = TSK_ERR_MUTATION_TIME_HAS_BOTH_KNOWN_AND_UNKNOWN; + ret = tsk_trace_error(TSK_ERR_MUTATION_TIME_HAS_BOTH_KNOWN_AND_UNKNOWN); goto out; } /* check parent site agrees */ if (parent_mut != TSK_NULL) { if (mutations.site[parent_mut] != mutations.site[j]) { - ret = TSK_ERR_MUTATION_PARENT_DIFFERENT_SITE; + ret = tsk_trace_error(TSK_ERR_MUTATION_PARENT_DIFFERENT_SITE); goto out; } /* If this mutation time is known, then the parent time @@ -10692,7 +10626,7 @@ tsk_table_collection_check_mutation_integrity( * TSK_ERR_MUTATION_TIME_HAS_BOTH_KNOWN_AND_UNKNOWN check * above will fail. */ if (!unknown_time && mutation_time > mutations.time[parent_mut]) { - ret = TSK_ERR_MUTATION_TIME_OLDER_THAN_PARENT_MUTATION; + ret = tsk_trace_error(TSK_ERR_MUTATION_TIME_OLDER_THAN_PARENT_MUTATION); goto out; } } @@ -10700,13 +10634,13 @@ tsk_table_collection_check_mutation_integrity( if (check_mutation_ordering) { /* Check site ordering */ if (j > 0 && mutations.site[j - 1] > mutations.site[j]) { - ret = TSK_ERR_UNSORTED_MUTATIONS; + ret = tsk_trace_error(TSK_ERR_UNSORTED_MUTATIONS); goto out; } /* Check if parents are listed before their children */ if (parent_mut != TSK_NULL && parent_mut > (tsk_id_t) j) { - ret = TSK_ERR_MUTATION_PARENT_AFTER_CHILD; + ret = tsk_trace_error(TSK_ERR_MUTATION_PARENT_AFTER_CHILD); goto out; } @@ -10714,7 +10648,7 @@ tsk_table_collection_check_mutation_integrity( * so that more specific errors trigger first */ if (!unknown_time) { if (mutation_time > last_known_time) { - ret = TSK_ERR_UNSORTED_MUTATIONS; + ret = tsk_trace_error(TSK_ERR_UNSORTED_MUTATIONS); goto out; } last_known_time = mutation_time; @@ -10741,27 +10675,27 @@ tsk_table_collection_check_migration_integrity( for (j = 0; j < migrations.num_rows; j++) { if (migrations.node[j] < 0 || migrations.node[j] >= num_nodes) { - ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + ret = tsk_trace_error(TSK_ERR_NODE_OUT_OF_BOUNDS); goto out; } if (check_population_refs) { if (migrations.source[j] < 0 || migrations.source[j] >= num_populations) { - ret = TSK_ERR_POPULATION_OUT_OF_BOUNDS; + ret = tsk_trace_error(TSK_ERR_POPULATION_OUT_OF_BOUNDS); goto out; } if (migrations.dest[j] < 0 || migrations.dest[j] >= num_populations) { - ret = TSK_ERR_POPULATION_OUT_OF_BOUNDS; + ret = tsk_trace_error(TSK_ERR_POPULATION_OUT_OF_BOUNDS); goto out; } } time = migrations.time[j]; if (!tsk_isfinite(time)) { - ret = TSK_ERR_TIME_NONFINITE; + ret = tsk_trace_error(TSK_ERR_TIME_NONFINITE); goto out; } if (j > 0) { if (check_migration_ordering && migrations.time[j - 1] > time) { - ret = TSK_ERR_UNSORTED_MIGRATIONS; + ret = tsk_trace_error(TSK_ERR_UNSORTED_MIGRATIONS); goto out; } } @@ -10770,19 +10704,19 @@ tsk_table_collection_check_migration_integrity( /* Spatial requirements */ /* TODO it's a bit misleading to use the edge-specific errors here. */ if (!(tsk_isfinite(left) && tsk_isfinite(right))) { - ret = TSK_ERR_GENOME_COORDS_NONFINITE; + ret = tsk_trace_error(TSK_ERR_GENOME_COORDS_NONFINITE); goto out; } if (left < 0) { - ret = TSK_ERR_LEFT_LESS_ZERO; + ret = tsk_trace_error(TSK_ERR_LEFT_LESS_ZERO); goto out; } if (right > L) { - ret = TSK_ERR_RIGHT_GREATER_SEQ_LENGTH; + ret = tsk_trace_error(TSK_ERR_RIGHT_GREATER_SEQ_LENGTH); goto out; } if (left >= right) { - ret = TSK_ERR_BAD_EDGE_INTERVAL; + ret = tsk_trace_error(TSK_ERR_BAD_EDGE_INTERVAL); goto out; } } @@ -10807,18 +10741,18 @@ tsk_table_collection_check_individual_integrity( if (individuals.parents[k] != TSK_NULL && (individuals.parents[k] < 0 || individuals.parents[k] >= num_individuals)) { - ret = TSK_ERR_INDIVIDUAL_OUT_OF_BOUNDS; + ret = tsk_trace_error(TSK_ERR_INDIVIDUAL_OUT_OF_BOUNDS); goto out; } /* Check no-one is their own parent */ if (individuals.parents[k] == (tsk_id_t) j) { - ret = TSK_ERR_INDIVIDUAL_SELF_PARENT; + ret = tsk_trace_error(TSK_ERR_INDIVIDUAL_SELF_PARENT); goto out; } /* Check parents are ordered */ if (check_individual_ordering && individuals.parents[k] != TSK_NULL && individuals.parents[k] >= (tsk_id_t) j) { - ret = TSK_ERR_UNSORTED_INDIVIDUALS; + ret = tsk_trace_error(TSK_ERR_UNSORTED_INDIVIDUALS); goto out; } } @@ -10856,7 +10790,7 @@ tsk_table_collection_check_tree_integrity(const tsk_table_collection_t *self) parent = tsk_malloc(self->nodes.num_rows * sizeof(*parent)); used_edges = tsk_malloc(num_edges * sizeof(*used_edges)); if (parent == NULL || used_edges == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } tsk_memset(parent, 0xff, self->nodes.num_rows * sizeof(*parent)); @@ -10875,7 +10809,7 @@ tsk_table_collection_check_tree_integrity(const tsk_table_collection_t *self) while (k < num_edges && edge_right[O[k]] == tree_left) { e = O[k]; if (used_edges[e] != 1) { - ret = TSK_ERR_TABLES_BAD_INDEXES; + ret = tsk_trace_error(TSK_ERR_TABLES_BAD_INDEXES); goto out; } parent[edge_child[e]] = TSK_NULL; @@ -10885,13 +10819,13 @@ tsk_table_collection_check_tree_integrity(const tsk_table_collection_t *self) while (j < num_edges && edge_left[I[j]] == tree_left) { e = I[j]; if (used_edges[e] != 0) { - ret = TSK_ERR_TABLES_BAD_INDEXES; + ret = tsk_trace_error(TSK_ERR_TABLES_BAD_INDEXES); goto out; } used_edges[e]++; u = edge_child[e]; if (parent[u] != TSK_NULL) { - ret = TSK_ERR_BAD_EDGES_CONTRADICTORY_CHILDREN; + ret = tsk_trace_error(TSK_ERR_BAD_EDGES_CONTRADICTORY_CHILDREN); goto out; } parent[u] = edge_parent[e]; @@ -10910,7 +10844,7 @@ tsk_table_collection_check_tree_integrity(const tsk_table_collection_t *self) && parent[mutation_node[mutation]] != TSK_NULL && node_time[parent[mutation_node[mutation]]] <= mutation_time[mutation]) { - ret = TSK_ERR_MUTATION_TIME_OLDER_THAN_PARENT_NODE; + ret = tsk_trace_error(TSK_ERR_MUTATION_TIME_OLDER_THAN_PARENT_NODE); goto out; } mutation++; @@ -10918,7 +10852,7 @@ tsk_table_collection_check_tree_integrity(const tsk_table_collection_t *self) site++; } if (tree_right <= tree_left) { - ret = TSK_ERR_TABLES_BAD_INDEXES; + ret = tsk_trace_error(TSK_ERR_TABLES_BAD_INDEXES); goto out; } tree_left = tree_right; @@ -10926,7 +10860,7 @@ tsk_table_collection_check_tree_integrity(const tsk_table_collection_t *self) * a single tree, and there's a gap between each of these edges we * would overflow this counter. */ if (num_trees == TSK_MAX_ID) { - ret = TSK_ERR_TREE_OVERFLOW; + ret = tsk_trace_error(TSK_ERR_TREE_OVERFLOW); goto out; } num_trees++; @@ -10938,7 +10872,7 @@ tsk_table_collection_check_tree_integrity(const tsk_table_collection_t *self) * and so hit the error above. */ e = O[k]; if (edge_right[e] != sequence_length) { - ret = TSK_ERR_TABLES_BAD_INDEXES; + ret = tsk_trace_error(TSK_ERR_TABLES_BAD_INDEXES); goto out; } used_edges[e]++; @@ -10966,16 +10900,16 @@ tsk_table_collection_check_index_integrity(const tsk_table_collection_t *self) const tsk_id_t *edge_removal_order = self->indexes.edge_removal_order; if (!tsk_table_collection_has_index(self, 0)) { - ret = TSK_ERR_TABLES_NOT_INDEXED; + ret = tsk_trace_error(TSK_ERR_TABLES_NOT_INDEXED); goto out; } for (j = 0; j < num_edges; j++) { if (edge_insertion_order[j] < 0 || edge_insertion_order[j] >= num_edges) { - ret = TSK_ERR_EDGE_OUT_OF_BOUNDS; + ret = tsk_trace_error(TSK_ERR_EDGE_OUT_OF_BOUNDS); goto out; } if (edge_removal_order[j] < 0 || edge_removal_order[j] >= num_edges) { - ret = TSK_ERR_EDGE_OUT_OF_BOUNDS; + ret = tsk_trace_error(TSK_ERR_EDGE_OUT_OF_BOUNDS); goto out; } } @@ -10983,11 +10917,159 @@ tsk_table_collection_check_index_integrity(const tsk_table_collection_t *self) return ret; } +static int TSK_WARN_UNUSED +tsk_table_collection_compute_mutation_parents_to_array( + const tsk_table_collection_t *self, tsk_id_t *mutation_parent) +{ + int ret = 0; + const tsk_id_t *I, *O; + const tsk_edge_table_t edges = self->edges; + const tsk_node_table_t nodes = self->nodes; + const tsk_site_table_t sites = self->sites; + const tsk_mutation_table_t mutations = self->mutations; + const tsk_id_t M = (tsk_id_t) edges.num_rows; + tsk_id_t tj, tk; + tsk_id_t *parent = NULL; + tsk_id_t *bottom_mutation = NULL; + tsk_id_t u; + double left, right; + tsk_id_t site; + /* Using unsigned values here avoids potentially undefined behaviour */ + tsk_size_t j, mutation, first_mutation; + + parent = tsk_malloc(nodes.num_rows * sizeof(*parent)); + bottom_mutation = tsk_malloc(nodes.num_rows * sizeof(*bottom_mutation)); + if (parent == NULL || bottom_mutation == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + tsk_memset(parent, 0xff, nodes.num_rows * sizeof(*parent)); + tsk_memset(bottom_mutation, 0xff, nodes.num_rows * sizeof(*bottom_mutation)); + tsk_memset(mutation_parent, 0xff, self->mutations.num_rows * sizeof(tsk_id_t)); + + I = self->indexes.edge_insertion_order; + O = self->indexes.edge_removal_order; + tj = 0; + tk = 0; + site = 0; + mutation = 0; + left = 0; + while (tj < M || left < self->sequence_length) { + while (tk < M && edges.right[O[tk]] == left) { + parent[edges.child[O[tk]]] = TSK_NULL; + tk++; + } + while (tj < M && edges.left[I[tj]] == left) { + parent[edges.child[I[tj]]] = edges.parent[I[tj]]; + tj++; + } + right = self->sequence_length; + if (tj < M) { + right = TSK_MIN(right, edges.left[I[tj]]); + } + if (tk < M) { + right = TSK_MIN(right, edges.right[O[tk]]); + } + + /* Tree is now ready. We look at each site on this tree in turn */ + while (site < (tsk_id_t) sites.num_rows && sites.position[site] < right) { + /* Create a mapping from mutations to nodes. If we see more than one + * mutation at a node, the previously seen one must be the parent + * of the current since we assume they are in order. */ + first_mutation = mutation; + while (mutation < mutations.num_rows && mutations.site[mutation] == site) { + u = mutations.node[mutation]; + if (bottom_mutation[u] != TSK_NULL) { + mutation_parent[mutation] = bottom_mutation[u]; + } + bottom_mutation[u] = (tsk_id_t) mutation; + mutation++; + } + /* Make the common case of 1 mutation fast */ + if (mutation > first_mutation + 1) { + /* If we have more than one mutation, compute the parent for each + * one by traversing up the tree until we find a node that has a + * mutation. */ + for (j = first_mutation; j < mutation; j++) { + if (mutation_parent[j] == TSK_NULL) { + u = parent[mutations.node[j]]; + while (u != TSK_NULL && bottom_mutation[u] == TSK_NULL) { + u = parent[u]; + } + if (u != TSK_NULL) { + mutation_parent[j] = bottom_mutation[u]; + } + } + } + } + /* Reset the mapping for the next site */ + for (j = first_mutation; j < mutation; j++) { + u = mutations.node[j]; + bottom_mutation[u] = TSK_NULL; + /* Check that we haven't violated the sortedness property */ + if (mutation_parent[j] > (tsk_id_t) j) { + ret = tsk_trace_error(TSK_ERR_MUTATION_PARENT_AFTER_CHILD); + goto out; + } + } + site++; + } + /* Move on to the next tree */ + left = right; + } + +out: + tsk_safe_free(parent); + tsk_safe_free(bottom_mutation); + return ret; +} + +static int TSK_WARN_UNUSED +tsk_table_collection_check_mutation_parents(const tsk_table_collection_t *self) +{ + int ret = 0; + tsk_mutation_table_t mutations = self->mutations; + tsk_id_t *new_parents = NULL; + tsk_size_t j; + + if (mutations.num_rows == 0) { + return ret; + } + + new_parents = tsk_malloc(mutations.num_rows * sizeof(*new_parents)); + if (new_parents == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + + ret = tsk_table_collection_compute_mutation_parents_to_array(self, new_parents); + if (ret != 0) { + goto out; + } + + for (j = 0; j < mutations.num_rows; j++) { + if (mutations.parent[j] != new_parents[j]) { + ret = tsk_trace_error(TSK_ERR_BAD_MUTATION_PARENT); + goto out; + } + } + +out: + tsk_safe_free(new_parents); + return ret; +} + tsk_id_t TSK_WARN_UNUSED tsk_table_collection_check_integrity( const tsk_table_collection_t *self, tsk_flags_t options) { tsk_id_t ret = 0; + int mut_ret = 0; + + if (options & TSK_CHECK_MUTATION_PARENTS) { + /* If we're checking mutation parents, we need to check the trees first */ + options |= TSK_CHECK_TREES; + } if (options & TSK_CHECK_TREES) { /* Checking the trees implies these checks */ @@ -10997,7 +11079,7 @@ tsk_table_collection_check_integrity( } if (self->sequence_length <= 0) { - ret = TSK_ERR_BAD_SEQUENCE_LENGTH; + ret = tsk_trace_error(TSK_ERR_BAD_SEQUENCE_LENGTH); goto out; } ret = tsk_table_collection_check_offsets(self); @@ -11040,6 +11122,14 @@ tsk_table_collection_check_integrity( if (ret < 0) { goto out; } + /* This check requires tree integrity so do it last */ + if (options & TSK_CHECK_MUTATION_PARENTS) { + mut_ret = tsk_table_collection_check_mutation_parents(self); + if (mut_ret != 0) { + ret = mut_ret; + goto out; + } + } } out: return ret; @@ -11244,7 +11334,7 @@ tsk_table_collection_set_indexes(tsk_table_collection_t *self, self->indexes.edge_removal_order = tsk_malloc(index_size); if (self->indexes.edge_insertion_order == NULL || self->indexes.edge_removal_order == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } tsk_memcpy(self->indexes.edge_insertion_order, edge_insertion_order, index_size); @@ -11261,7 +11351,7 @@ tsk_table_collection_takeset_indexes(tsk_table_collection_t *self, int ret = 0; if (edge_insertion_order == NULL || edge_removal_order == NULL) { - ret = TSK_ERR_BAD_PARAM_VALUE; + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); goto out; } tsk_table_collection_drop_index(self, 0); @@ -11326,7 +11416,7 @@ tsk_table_collection_build_index( sort_buff = tsk_malloc(self->edges.num_rows * sizeof(index_sort_t)); if (self->indexes.edge_insertion_order == NULL || self->indexes.edge_removal_order == NULL || sort_buff == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } @@ -11379,7 +11469,7 @@ tsk_table_collection_set_file_uuid(tsk_table_collection_t *self, const char *uui /* Allow space for \0 so we can print it as a string */ self->file_uuid = tsk_malloc(TSK_UUID_SIZE + 1); if (self->file_uuid == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } tsk_memcpy(self->file_uuid, uuid, TSK_UUID_SIZE); @@ -11501,12 +11591,12 @@ tsk_table_collection_read_format_data(tsk_table_collection_t *self, kastore_t *s goto out; } if (len != TSK_FILE_FORMAT_NAME_LENGTH) { - ret = TSK_ERR_FILE_FORMAT; + ret = tsk_trace_error(TSK_ERR_FILE_FORMAT); goto out; } if (tsk_memcmp(TSK_FILE_FORMAT_NAME, format_name, TSK_FILE_FORMAT_NAME_LENGTH) != 0) { - ret = TSK_ERR_FILE_FORMAT; + ret = tsk_trace_error(TSK_ERR_FILE_FORMAT); goto out; } @@ -11516,15 +11606,15 @@ tsk_table_collection_read_format_data(tsk_table_collection_t *self, kastore_t *s goto out; } if (len != 2) { - ret = TSK_ERR_FILE_FORMAT; + ret = tsk_trace_error(TSK_ERR_FILE_FORMAT); goto out; } if (version[0] < TSK_FILE_FORMAT_VERSION_MAJOR) { - ret = TSK_ERR_FILE_VERSION_TOO_OLD; + ret = tsk_trace_error(TSK_ERR_FILE_VERSION_TOO_OLD); goto out; } if (version[0] > TSK_FILE_FORMAT_VERSION_MAJOR) { - ret = TSK_ERR_FILE_VERSION_TOO_NEW; + ret = tsk_trace_error(TSK_ERR_FILE_VERSION_TOO_NEW); goto out; } @@ -11534,11 +11624,11 @@ tsk_table_collection_read_format_data(tsk_table_collection_t *self, kastore_t *s goto out; } if (len != 1) { - ret = TSK_ERR_FILE_FORMAT; + ret = tsk_trace_error(TSK_ERR_FILE_FORMAT); goto out; } if (L[0] <= 0.0) { - ret = TSK_ERR_BAD_SEQUENCE_LENGTH; + ret = tsk_trace_error(TSK_ERR_BAD_SEQUENCE_LENGTH); goto out; } self->sequence_length = L[0]; @@ -11549,7 +11639,7 @@ tsk_table_collection_read_format_data(tsk_table_collection_t *self, kastore_t *s goto out; } if (len != TSK_UUID_SIZE) { - ret = TSK_ERR_FILE_FORMAT; + ret = tsk_trace_error(TSK_ERR_FILE_FORMAT); goto out; } ret = tsk_table_collection_set_file_uuid(self, (const char *) uuid); @@ -11616,7 +11706,7 @@ tsk_table_collection_read_format_data(tsk_table_collection_t *self, kastore_t *s out: if ((ret ^ (1 << TSK_KAS_ERR_BIT)) == KAS_ERR_KEY_NOT_FOUND) { - ret = TSK_ERR_REQUIRED_COL_NOT_FOUND; + ret = tsk_trace_error(TSK_ERR_REQUIRED_COL_NOT_FOUND); } tsk_safe_free(version); tsk_safe_free(format_name); @@ -11672,12 +11762,12 @@ tsk_table_collection_load_indexes(tsk_table_collection_t *self, kastore_t *store } if ((edge_insertion_order == NULL) != (edge_removal_order == NULL)) { - ret = TSK_ERR_BOTH_COLUMNS_REQUIRED; + ret = tsk_trace_error(TSK_ERR_BOTH_COLUMNS_REQUIRED); goto out; } if (edge_insertion_order != NULL) { if (num_rows != self->edges.num_rows) { - ret = TSK_ERR_FILE_FORMAT; + ret = tsk_trace_error(TSK_ERR_FILE_FORMAT); goto out; } ret = tsk_table_collection_takeset_indexes( @@ -11778,7 +11868,7 @@ tsk_table_collection_loadf_inited( * and we hit EOF immediately without reading any bytes. We signal * this back to the client, which allows it to read an indefinite * number of stores from a stream */ - ret = TSK_ERR_EOF; + ret = tsk_trace_error(TSK_ERR_EOF); } else { ret = tsk_set_kas_error(ret); } @@ -11883,7 +11973,7 @@ tsk_table_collection_load( } file = fopen(filename, "rb"); if (file == NULL) { - ret = TSK_ERR_IO; + ret = tsk_trace_error(TSK_ERR_IO); goto out; } ret = tsk_table_collection_loadf_inited(self, file, options); @@ -11891,7 +11981,7 @@ tsk_table_collection_load( goto out; } if (fclose(file) != 0) { - ret = TSK_ERR_IO; + ret = tsk_trace_error(TSK_ERR_IO); goto out; } file = NULL; @@ -11933,7 +12023,7 @@ tsk_table_collection_dump( FILE *file = fopen(filename, "wb"); if (file == NULL) { - ret = TSK_ERR_IO; + ret = tsk_trace_error(TSK_ERR_IO); goto out; } ret = tsk_table_collection_dumpf(self, file, options); @@ -11941,7 +12031,7 @@ tsk_table_collection_dump( goto out; } if (fclose(file) != 0) { - ret = TSK_ERR_IO; + ret = tsk_trace_error(TSK_ERR_IO); goto out; } file = NULL; @@ -12068,21 +12158,21 @@ tsk_table_collection_simplify(tsk_table_collection_t *self, const tsk_id_t *samp if ((options & TSK_SIMPLIFY_KEEP_UNARY) && (options & TSK_SIMPLIFY_KEEP_UNARY_IN_INDIVIDUALS)) { - ret = TSK_ERR_KEEP_UNARY_MUTUALLY_EXCLUSIVE; + ret = tsk_trace_error(TSK_ERR_KEEP_UNARY_MUTUALLY_EXCLUSIVE); goto out; } /* For now we don't bother with edge metadata, but it can easily be * implemented. */ if (self->edges.metadata_length > 0) { - ret = TSK_ERR_CANT_PROCESS_EDGES_WITH_METADATA; + ret = tsk_trace_error(TSK_ERR_CANT_PROCESS_EDGES_WITH_METADATA); goto out; } if (samples == NULL) { local_samples = tsk_malloc(self->nodes.num_rows * sizeof(*local_samples)); if (local_samples == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } num_samples = 0; @@ -12125,7 +12215,7 @@ tsk_table_collection_link_ancestors(tsk_table_collection_t *self, tsk_id_t *samp tsk_memset(&ancestor_mapper, 0, sizeof(ancestor_mapper_t)); if (self->edges.metadata_length > 0) { - ret = TSK_ERR_CANT_PROCESS_EDGES_WITH_METADATA; + ret = tsk_trace_error(TSK_ERR_CANT_PROCESS_EDGES_WITH_METADATA); goto out; } @@ -12242,12 +12332,12 @@ tsk_table_collection_canonicalise(tsk_table_collection_t *self, tsk_flags_t opti if (ret != 0) { goto out; } - sorter.sort_mutations = tsk_table_sorter_sort_mutations_canonical; + sorter.sort_mutations = tsk_table_sorter_sort_mutations; sorter.sort_individuals = tsk_table_sorter_sort_individuals_canonical; nodes = tsk_malloc(self->nodes.num_rows * sizeof(*nodes)); if (nodes == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } for (k = 0; k < (tsk_id_t) self->nodes.num_rows; k++) { @@ -12302,7 +12392,7 @@ tsk_table_collection_deduplicate_sites( site_id_map = tsk_malloc(copy.num_rows * sizeof(*site_id_map)); if (site_id_map == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } ret = tsk_site_table_clear(&self->sites); @@ -12343,117 +12433,25 @@ tsk_table_collection_deduplicate_sites( int TSK_WARN_UNUSED tsk_table_collection_compute_mutation_parents( - tsk_table_collection_t *self, tsk_flags_t TSK_UNUSED(options)) + tsk_table_collection_t *self, tsk_flags_t options) { int ret = 0; - tsk_id_t num_trees; - const tsk_id_t *I, *O; - const tsk_edge_table_t edges = self->edges; - const tsk_node_table_t nodes = self->nodes; - const tsk_site_table_t sites = self->sites; - const tsk_mutation_table_t mutations = self->mutations; - const tsk_id_t M = (tsk_id_t) edges.num_rows; - tsk_id_t tj, tk; - tsk_id_t *parent = NULL; - tsk_id_t *bottom_mutation = NULL; - tsk_id_t u; - double left, right; - tsk_id_t site; - /* Using unsigned values here avoids potentially undefined behaviour */ - tsk_size_t j, mutation, first_mutation; - /* Set the mutation parent to TSK_NULL so that we don't check the - * parent values we are about to write over. */ - tsk_memset(mutations.parent, 0xff, mutations.num_rows * sizeof(*mutations.parent)); - num_trees = tsk_table_collection_check_integrity(self, TSK_CHECK_TREES); - if (num_trees < 0) { - ret = (int) num_trees; - goto out; - } - parent = tsk_malloc(nodes.num_rows * sizeof(*parent)); - bottom_mutation = tsk_malloc(nodes.num_rows * sizeof(*bottom_mutation)); - if (parent == NULL || bottom_mutation == NULL) { - ret = TSK_ERR_NO_MEMORY; - goto out; - } - tsk_memset(parent, 0xff, nodes.num_rows * sizeof(*parent)); - tsk_memset(bottom_mutation, 0xff, nodes.num_rows * sizeof(*bottom_mutation)); - tsk_memset(mutations.parent, 0xff, self->mutations.num_rows * sizeof(tsk_id_t)); - - I = self->indexes.edge_insertion_order; - O = self->indexes.edge_removal_order; - tj = 0; - tk = 0; - site = 0; - mutation = 0; - left = 0; - while (tj < M || left < self->sequence_length) { - while (tk < M && edges.right[O[tk]] == left) { - parent[edges.child[O[tk]]] = TSK_NULL; - tk++; - } - while (tj < M && edges.left[I[tj]] == left) { - parent[edges.child[I[tj]]] = edges.parent[I[tj]]; - tj++; - } - right = self->sequence_length; - if (tj < M) { - right = TSK_MIN(right, edges.left[I[tj]]); - } - if (tk < M) { - right = TSK_MIN(right, edges.right[O[tk]]); + if (!(options & TSK_NO_CHECK_INTEGRITY)) { + /* Safe to cast here as we're not counting trees */ + ret = (int) tsk_table_collection_check_integrity(self, TSK_CHECK_TREES); + if (ret < 0) { + goto out; } + } - /* Tree is now ready. We look at each site on this tree in turn */ - while (site < (tsk_id_t) sites.num_rows && sites.position[site] < right) { - /* Create a mapping from mutations to nodes. If we see more than one - * mutation at a node, the previously seen one must be the parent - * of the current since we assume they are in order. */ - first_mutation = mutation; - while (mutation < mutations.num_rows && mutations.site[mutation] == site) { - u = mutations.node[mutation]; - if (bottom_mutation[u] != TSK_NULL) { - mutations.parent[mutation] = bottom_mutation[u]; - } - bottom_mutation[u] = (tsk_id_t) mutation; - mutation++; - } - /* Make the common case of 1 mutation fast */ - if (mutation > first_mutation + 1) { - /* If we have more than one mutation, compute the parent for each - * one by traversing up the tree until we find a node that has a - * mutation. */ - for (j = first_mutation; j < mutation; j++) { - if (mutations.parent[j] == TSK_NULL) { - u = parent[mutations.node[j]]; - while (u != TSK_NULL && bottom_mutation[u] == TSK_NULL) { - u = parent[u]; - } - if (u != TSK_NULL) { - mutations.parent[j] = bottom_mutation[u]; - } - } - } - } - /* Reset the mapping for the next site */ - for (j = first_mutation; j < mutation; j++) { - u = mutations.node[j]; - bottom_mutation[u] = TSK_NULL; - /* Check that we haven't violated the sortedness property */ - if (mutations.parent[j] > (tsk_id_t) j) { - ret = TSK_ERR_MUTATION_PARENT_AFTER_CHILD; - goto out; - } - } - site++; - } - /* Move on to the next tree */ - left = right; + ret = tsk_table_collection_compute_mutation_parents_to_array( + self, self->mutations.parent); + if (ret != 0) { + goto out; } out: - tsk_safe_free(parent); - tsk_safe_free(bottom_mutation); return ret; } @@ -12483,7 +12481,7 @@ tsk_table_collection_compute_mutation_times( /* The random param is for future usage */ if (random != NULL) { - ret = TSK_ERR_BAD_PARAM_VALUE; + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); goto out; } @@ -12491,6 +12489,7 @@ tsk_table_collection_compute_mutation_times( for (j = 0; j < mutations.num_rows; j++) { mutations.time[j] = TSK_UNKNOWN_TIME; } + /* TSK_CHECK_MUTATION_PARENTS isn't needed here as we're not using the parents */ num_trees = tsk_table_collection_check_integrity(self, TSK_CHECK_TREES); if (num_trees < 0) { ret = (int) num_trees; @@ -12500,7 +12499,7 @@ tsk_table_collection_compute_mutation_times( numerator = tsk_malloc(nodes.num_rows * sizeof(*numerator)); denominator = tsk_malloc(nodes.num_rows * sizeof(*denominator)); if (parent == NULL || numerator == NULL || denominator == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } tsk_memset(parent, 0xff, nodes.num_rows * sizeof(*parent)); @@ -12631,7 +12630,7 @@ tsk_table_collection_delete_older( mutation_map = tsk_malloc(self->mutations.num_rows * sizeof(*mutation_map)); if (mutation_map == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } ret = tsk_mutation_table_copy(&self->mutations, &mutations, 0); @@ -12927,7 +12926,7 @@ tsk_table_collection_subset(tsk_table_collection_t *self, const tsk_id_t *nodes, mutation_map = tsk_malloc(tables.mutations.num_rows * sizeof(*mutation_map)); if (node_map == NULL || individual_map == NULL || population_map == NULL || site_map == NULL || mutation_map == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } tsk_memset(node_map, 0xff, tables.nodes.num_rows * sizeof(*node_map)); @@ -12961,7 +12960,7 @@ tsk_table_collection_subset(tsk_table_collection_t *self, const tsk_id_t *nodes, } else { for (k = 0; k < (tsk_id_t) num_nodes; k++) { if (nodes[k] < 0 || nodes[k] >= (tsk_id_t) tables.nodes.num_rows) { - ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + ret = tsk_trace_error(TSK_ERR_NODE_OUT_OF_BOUNDS); goto out; } j = tables.nodes.individual[nodes[k]]; @@ -13018,7 +13017,7 @@ tsk_table_collection_subset(tsk_table_collection_t *self, const tsk_id_t *nodes, * that we don't remove populations that are referenced, so it would * need to be done before the next code block. */ if (tables.migrations.num_rows != 0) { - ret = TSK_ERR_MIGRATIONS_NOT_SUPPORTED; + ret = tsk_trace_error(TSK_ERR_MIGRATIONS_NOT_SUPPORTED); goto out; } @@ -13133,7 +13132,7 @@ tsk_check_subset_equality(tsk_table_collection_t *self, self_nodes = tsk_malloc(num_shared_nodes * sizeof(*self_nodes)); other_nodes = tsk_malloc(num_shared_nodes * sizeof(*other_nodes)); if (self_nodes == NULL || other_nodes == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } @@ -13173,7 +13172,7 @@ tsk_check_subset_equality(tsk_table_collection_t *self, if (!tsk_table_collection_equals(&self_copy, &other_copy, TSK_CMP_IGNORE_TS_METADATA | TSK_CMP_IGNORE_PROVENANCE | TSK_CMP_IGNORE_REFERENCE_SEQUENCE)) { - ret = TSK_ERR_UNION_DIFF_HISTORIES; + ret = tsk_trace_error(TSK_ERR_UNION_DIFF_HISTORIES); goto out; } @@ -13216,7 +13215,7 @@ tsk_table_collection_union(tsk_table_collection_t *self, for (k = 0; k < (tsk_id_t) other->nodes.num_rows; k++) { if (other_node_mapping[k] >= (tsk_id_t) self->nodes.num_rows || other_node_mapping[k] < TSK_NULL) { - ret = TSK_ERR_UNION_BAD_MAP; + ret = tsk_trace_error(TSK_ERR_UNION_BAD_MAP); goto out; } if (other_node_mapping[k] != TSK_NULL) { @@ -13239,7 +13238,7 @@ tsk_table_collection_union(tsk_table_collection_t *self, site_map = tsk_malloc(other->sites.num_rows * sizeof(*site_map)); if (node_map == NULL || individual_map == NULL || population_map == NULL || site_map == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } tsk_memset(node_map, 0xff, other->nodes.num_rows * sizeof(*node_map)); @@ -13335,7 +13334,7 @@ tsk_table_collection_union(tsk_table_collection_t *self, * union operation on Migrations Tables is that tsk_table_collection_sort * does not sort migrations by time, and instead throws an error. */ if (self->migrations.num_rows != 0 || other->migrations.num_rows != 0) { - ret = TSK_ERR_MIGRATIONS_NOT_SUPPORTED; + ret = tsk_trace_error(TSK_ERR_MIGRATIONS_NOT_SUPPORTED); goto out; } @@ -13410,7 +13409,7 @@ tsk_squash_edges(tsk_edge_t *edges, tsk_size_t num_edges, tsk_size_t *num_output l = 0; for (k = 1; k < num_edges; k++) { if (edges[k - 1].metadata_length > 0) { - ret = TSK_ERR_CANT_PROCESS_EDGES_WITH_METADATA; + ret = tsk_trace_error(TSK_ERR_CANT_PROCESS_EDGES_WITH_METADATA); goto out; } @@ -13418,7 +13417,7 @@ tsk_squash_edges(tsk_edge_t *edges, tsk_size_t num_edges, tsk_size_t *num_output if (edges[k - 1].parent == edges[k].parent && edges[k - 1].child == edges[k].child && edges[k - 1].right > edges[k].left) { - ret = TSK_ERR_BAD_EDGES_CONTRADICTORY_CHILDREN; + ret = tsk_trace_error(TSK_ERR_BAD_EDGES_CONTRADICTORY_CHILDREN); goto out; } @@ -13445,165 +13444,3 @@ tsk_squash_edges(tsk_edge_t *edges, tsk_size_t num_edges, tsk_size_t *num_output out: return ret; } - -/* ======================================================== * - * Tree diff iterator. - * ======================================================== */ - -int TSK_WARN_UNUSED -tsk_diff_iter_init(tsk_diff_iter_t *self, const tsk_table_collection_t *tables, - tsk_id_t num_trees, tsk_flags_t options) -{ - int ret = 0; - - tsk_bug_assert(tables != NULL); - tsk_memset(self, 0, sizeof(tsk_diff_iter_t)); - self->num_nodes = tables->nodes.num_rows; - self->num_edges = tables->edges.num_rows; - self->tables = tables; - self->insertion_index = 0; - self->removal_index = 0; - self->tree_left = 0; - self->tree_index = -1; - if (num_trees < 0) { - num_trees = tsk_table_collection_check_integrity(self->tables, TSK_CHECK_TREES); - if (num_trees < 0) { - ret = (int) num_trees; - goto out; - } - } - self->last_index = num_trees; - - if (options & TSK_INCLUDE_TERMINAL) { - self->last_index = self->last_index + 1; - } - self->edge_list_nodes = tsk_malloc(self->num_edges * sizeof(*self->edge_list_nodes)); - if (self->edge_list_nodes == NULL) { - ret = TSK_ERR_NO_MEMORY; - goto out; - } -out: - return ret; -} - -int -tsk_diff_iter_free(tsk_diff_iter_t *self) -{ - tsk_safe_free(self->edge_list_nodes); - return 0; -} - -void -tsk_diff_iter_print_state(const tsk_diff_iter_t *self, FILE *out) -{ - fprintf(out, "tree_diff_iterator state\n"); - fprintf(out, "num_edges = %lld\n", (long long) self->num_edges); - fprintf(out, "insertion_index = %lld\n", (long long) self->insertion_index); - fprintf(out, "removal_index = %lld\n", (long long) self->removal_index); - fprintf(out, "tree_left = %f\n", self->tree_left); - fprintf(out, "tree_index = %lld\n", (long long) self->tree_index); -} - -int TSK_WARN_UNUSED -tsk_diff_iter_next(tsk_diff_iter_t *self, double *ret_left, double *ret_right, - tsk_edge_list_t *edges_out_ret, tsk_edge_list_t *edges_in_ret) -{ - int ret = 0; - tsk_id_t k; - const double sequence_length = self->tables->sequence_length; - double left = self->tree_left; - double right = sequence_length; - tsk_size_t next_edge_list_node = 0; - tsk_edge_list_node_t *out_head = NULL; - tsk_edge_list_node_t *out_tail = NULL; - tsk_edge_list_node_t *in_head = NULL; - tsk_edge_list_node_t *in_tail = NULL; - tsk_edge_list_node_t *w = NULL; - tsk_edge_list_t edges_out; - tsk_edge_list_t edges_in; - const tsk_edge_table_t *edges = &self->tables->edges; - const tsk_id_t *insertion_order = self->tables->indexes.edge_insertion_order; - const tsk_id_t *removal_order = self->tables->indexes.edge_removal_order; - - tsk_memset(&edges_out, 0, sizeof(edges_out)); - tsk_memset(&edges_in, 0, sizeof(edges_in)); - - if (self->tree_index + 1 < self->last_index) { - /* First we remove the stale records */ - while (self->removal_index < (tsk_id_t) self->num_edges - && left == edges->right[removal_order[self->removal_index]]) { - k = removal_order[self->removal_index]; - tsk_bug_assert(next_edge_list_node < self->num_edges); - w = &self->edge_list_nodes[next_edge_list_node]; - next_edge_list_node++; - w->edge.id = k; - w->edge.left = edges->left[k]; - w->edge.right = edges->right[k]; - w->edge.parent = edges->parent[k]; - w->edge.child = edges->child[k]; - w->edge.metadata = edges->metadata + edges->metadata_offset[k]; - w->edge.metadata_length - = edges->metadata_offset[k + 1] - edges->metadata_offset[k]; - w->next = NULL; - w->prev = NULL; - if (out_head == NULL) { - out_head = w; - out_tail = w; - } else { - out_tail->next = w; - w->prev = out_tail; - out_tail = w; - } - self->removal_index++; - } - edges_out.head = out_head; - edges_out.tail = out_tail; - - /* Now insert the new records */ - while (self->insertion_index < (tsk_id_t) self->num_edges - && left == edges->left[insertion_order[self->insertion_index]]) { - k = insertion_order[self->insertion_index]; - tsk_bug_assert(next_edge_list_node < self->num_edges); - w = &self->edge_list_nodes[next_edge_list_node]; - next_edge_list_node++; - w->edge.id = k; - w->edge.left = edges->left[k]; - w->edge.right = edges->right[k]; - w->edge.parent = edges->parent[k]; - w->edge.child = edges->child[k]; - w->edge.metadata = edges->metadata + edges->metadata_offset[k]; - w->edge.metadata_length - = edges->metadata_offset[k + 1] - edges->metadata_offset[k]; - w->next = NULL; - w->prev = NULL; - if (in_head == NULL) { - in_head = w; - in_tail = w; - } else { - in_tail->next = w; - w->prev = in_tail; - in_tail = w; - } - self->insertion_index++; - } - edges_in.head = in_head; - edges_in.tail = in_tail; - - right = sequence_length; - if (self->insertion_index < (tsk_id_t) self->num_edges) { - right = TSK_MIN(right, edges->left[insertion_order[self->insertion_index]]); - } - if (self->removal_index < (tsk_id_t) self->num_edges) { - right = TSK_MIN(right, edges->right[removal_order[self->removal_index]]); - } - self->tree_index++; - ret = TSK_TREE_OK; - } - *edges_out_ret = edges_out; - *edges_in_ret = edges_in; - *ret_left = left; - *ret_right = right; - /* Set the left coordinate for the next tree */ - self->tree_left = right; - return ret; -} diff --git a/subprojects/tskit/tskit/tables.h b/subprojects/tskit/tskit/tables.h index 38f3096c9..85ed29d58 100644 --- a/subprojects/tskit/tskit/tables.h +++ b/subprojects/tskit/tskit/tables.h @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2023 Tskit Developers + * Copyright (c) 2019-2024 Tskit Developers * Copyright (c) 2017-2018 University of Oxford * * Permission is hereby granted, free of charge, to any person obtaining a copy @@ -158,6 +158,10 @@ typedef struct { /** @brief The ID of the edge that this mutation lies on, or TSK_NULL if there is no corresponding edge.*/ tsk_id_t edge; + /** @brief Inherited state. */ + const char *inherited_state; + /** @brief Size of the inherited state in bytes. */ + tsk_size_t inherited_state_length; } tsk_mutation_t; /** @@ -682,18 +686,6 @@ typedef struct { tsk_edge_list_node_t *tail; } tsk_edge_list_t; -typedef struct { - tsk_size_t num_nodes; - tsk_size_t num_edges; - double tree_left; - const tsk_table_collection_t *tables; - tsk_id_t insertion_index; - tsk_id_t removal_index; - tsk_id_t tree_index; - tsk_id_t last_index; - tsk_edge_list_node_t *edge_list_nodes; -} tsk_diff_iter_t; - /****************************************************************************/ /* Common function options */ /****************************************************************************/ @@ -797,6 +789,11 @@ All checks needed to define a valid tree sequence. Note that this implies all of the above checks. */ #define TSK_CHECK_TREES (1 << 7) +/** +Check mutation parents are consistent with topology. +Implies TSK_CHECK_TREES. +*/ +#define TSK_CHECK_MUTATION_PARENTS (1 << 8) /* Leave room for more positive check flags */ /** @@ -4771,19 +4768,6 @@ int tsk_identity_segments_get(const tsk_identity_segments_t *self, tsk_id_t a, void tsk_identity_segments_print_state(tsk_identity_segments_t *self, FILE *out); int tsk_identity_segments_free(tsk_identity_segments_t *self); -/* Edge differences */ - -/* Internal API - currently used in a few places, but a better API is envisaged - * at some point. - * IMPORTANT: tskit-rust uses this API, so don't break without discussing! - */ -int tsk_diff_iter_init(tsk_diff_iter_t *self, const tsk_table_collection_t *tables, - tsk_id_t num_trees, tsk_flags_t options); -int tsk_diff_iter_free(tsk_diff_iter_t *self); -int tsk_diff_iter_next(tsk_diff_iter_t *self, double *left, double *right, - tsk_edge_list_t *edges_out, tsk_edge_list_t *edges_in); -void tsk_diff_iter_print_state(const tsk_diff_iter_t *self, FILE *out); - #ifdef __cplusplus } #endif diff --git a/subprojects/tskit/tskit/trees.c b/subprojects/tskit/tskit/trees.c index 4604579e0..7a159a7fe 100644 --- a/subprojects/tskit/tskit/trees.c +++ b/subprojects/tskit/tskit/trees.c @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2023 Tskit Developers + * Copyright (c) 2019-2025 Tskit Developers * Copyright (c) 2015-2018 University of Oxford * * Permission is hereby granted, free of charge, to any person obtaining a copy @@ -145,7 +145,7 @@ tsk_treeseq_init_sites(tsk_treeseq_t *self) self->tree_sites_mem = tsk_malloc(num_sites * sizeof(*self->tree_sites_mem)); if (self->site_mutations_mem == NULL || self->site_mutations_length == NULL || self->site_mutations == NULL || self->tree_sites_mem == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } @@ -197,7 +197,7 @@ tsk_treeseq_init_individuals(tsk_treeseq_t *self) node_count = tsk_calloc(TSK_MAX(1, num_inds), sizeof(*node_count)); if (self->individual_nodes_length == NULL || node_count == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } @@ -213,7 +213,7 @@ tsk_treeseq_init_individuals(tsk_treeseq_t *self) = tsk_malloc(TSK_MAX(1, total_node_refs) * sizeof(tsk_node_t)); self->individual_nodes = tsk_malloc(TSK_MAX(1, num_inds) * sizeof(tsk_node_t *)); if (self->individual_nodes_mem == NULL || self->individual_nodes == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } @@ -253,6 +253,13 @@ tsk_treeseq_init_trees(tsk_treeseq_t *self) const tsk_size_t num_nodes = self->tables->nodes.num_rows; const double *restrict site_position = self->tables->sites.position; const tsk_id_t *restrict mutation_site = self->tables->mutations.site; + const tsk_id_t *restrict mutation_parent = self->tables->mutations.parent; + const char *restrict sites_ancestral_state = self->tables->sites.ancestral_state; + const tsk_size_t *restrict sites_ancestral_state_offset + = self->tables->sites.ancestral_state_offset; + const char *restrict mutations_derived_state = self->tables->mutations.derived_state; + const tsk_size_t *restrict mutations_derived_state_offset + = self->tables->mutations.derived_state_offset; const tsk_id_t *restrict I = self->tables->indexes.edge_insertion_order; const tsk_id_t *restrict O = self->tables->indexes.edge_removal_order; const double *restrict edge_right = self->tables->edges.right; @@ -262,6 +269,7 @@ tsk_treeseq_init_trees(tsk_treeseq_t *self) bool discrete_breakpoints = true; tsk_id_t *node_edge_map = tsk_malloc(num_nodes * sizeof(*node_edge_map)); tsk_mutation_t *mutation; + tsk_id_t parent_id; self->tree_sites_length = tsk_malloc(num_trees_alloc * sizeof(*self->tree_sites_length)); @@ -269,7 +277,7 @@ tsk_treeseq_init_trees(tsk_treeseq_t *self) self->breakpoints = tsk_malloc(num_trees_alloc * sizeof(*self->breakpoints)); if (node_edge_map == NULL || self->tree_sites == NULL || self->tree_sites_length == NULL || self->breakpoints == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } tsk_memset( @@ -311,6 +319,26 @@ tsk_treeseq_init_trees(tsk_treeseq_t *self) mutation_id < num_mutations && mutation_site[mutation_id] == site_id) { mutation = self->site_mutations_mem + mutation_id; mutation->edge = node_edge_map[mutation->node]; + + /* Compute inherited state */ + if (mutation_parent[mutation_id] == TSK_NULL) { + /* No parent: inherited state is the site's ancestral state */ + mutation->inherited_state + = sites_ancestral_state + sites_ancestral_state_offset[site_id]; + mutation->inherited_state_length + = sites_ancestral_state_offset[site_id + 1] + - sites_ancestral_state_offset[site_id]; + } else { + /* Has parent: inherited state is parent's derived state */ + parent_id = mutation_parent[mutation_id]; + mutation->inherited_state + = mutations_derived_state + + mutations_derived_state_offset[parent_id]; + mutation->inherited_state_length + = mutations_derived_state_offset[parent_id + 1] + - mutations_derived_state_offset[parent_id]; + } + mutation_id++; } site_id++; @@ -393,7 +421,7 @@ tsk_treeseq_init_nodes(tsk_treeseq_t *self) self->samples = tsk_malloc(self->num_samples * sizeof(tsk_id_t)); self->sample_index_map = tsk_malloc(num_nodes * sizeof(tsk_id_t)); if (self->samples == NULL || self->sample_index_map == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } k = 0; @@ -434,13 +462,13 @@ tsk_treeseq_init( if (options & TSK_TAKE_OWNERSHIP) { self->tables = tables; if (tables->edges.options & TSK_TABLE_NO_METADATA) { - ret = TSK_ERR_CANT_TAKE_OWNERSHIP_NO_EDGE_METADATA; + ret = tsk_trace_error(TSK_ERR_CANT_TAKE_OWNERSHIP_NO_EDGE_METADATA); goto out; } } else { self->tables = tsk_malloc(sizeof(*self->tables)); if (self->tables == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } @@ -458,10 +486,29 @@ tsk_treeseq_init( goto out; } } - num_trees = tsk_table_collection_check_integrity(self->tables, TSK_CHECK_TREES); - if (num_trees < 0) { - ret = (int) num_trees; - goto out; + + if (options & TSK_TS_INIT_COMPUTE_MUTATION_PARENTS) { + /* As tsk_table_collection_compute_mutation_parents performs an + integrity check, and we don't wish to do that twice we perform + our own check here */ + num_trees = tsk_table_collection_check_integrity(self->tables, TSK_CHECK_TREES); + if (num_trees < 0) { + ret = (int) num_trees; + goto out; + } + + ret = tsk_table_collection_compute_mutation_parents( + self->tables, TSK_NO_CHECK_INTEGRITY); + if (ret != 0) { + goto out; + } + } else { + num_trees = tsk_table_collection_check_integrity( + self->tables, TSK_CHECK_TREES | TSK_CHECK_MUTATION_PARENTS); + if (num_trees < 0) { + ret = (int) num_trees; + goto out; + } } self->num_trees = (tsk_size_t) num_trees; self->discrete_genome = true; @@ -513,7 +560,7 @@ tsk_treeseq_load(tsk_treeseq_t *self, const char *filename, tsk_flags_t options) tsk_memset(self, 0, sizeof(*self)); if (tables == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } @@ -543,7 +590,7 @@ tsk_treeseq_loadf(tsk_treeseq_t *self, FILE *file, tsk_flags_t options) tsk_memset(self, 0, sizeof(*self)); if (tables == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } @@ -765,7 +812,7 @@ tsk_treeseq_get_individuals_population(const tsk_treeseq_t *self, tsk_id_t *outp if (ind_pop == -2) { ind_pop = node_population[ind.nodes[j]]; } else if (ind_pop != node_population[ind.nodes[j]]) { - ret = TSK_ERR_INDIVIDUAL_POPULATION_MISMATCH; + ret = tsk_trace_error(TSK_ERR_INDIVIDUAL_POPULATION_MISMATCH); goto out; } } @@ -796,7 +843,7 @@ tsk_treeseq_get_individuals_time(const tsk_treeseq_t *self, double *output) if (j == 0) { ind_time = node_time[ind.nodes[j]]; } else if (ind_time != node_time[ind.nodes[j]]) { - ret = TSK_ERR_INDIVIDUAL_TIME_MISMATCH; + ret = tsk_trace_error(TSK_ERR_INDIVIDUAL_TIME_MISMATCH); goto out; } } @@ -808,7 +855,7 @@ tsk_treeseq_get_individuals_time(const tsk_treeseq_t *self, double *output) /* Stats functions */ -#define GET_2D_ROW(array, row_len, row) (array + (((size_t)(row_len)) * (size_t) row)) +#define GET_2D_ROW(array, row_len, row) (array + (((size_t)(row_len)) * (size_t)(row))) static inline double * GET_3D_ROW(double *base, tsk_size_t num_nodes, tsk_size_t output_dim, @@ -876,12 +923,12 @@ tsk_treeseq_genealogical_nearest_neighbours(const tsk_treeseq_t *self, /* We support a max of 8K focal sets */ if (num_reference_sets == 0 || num_reference_sets > (INT16_MAX - 1)) { /* TODO: more specific error */ - ret = TSK_ERR_BAD_PARAM_VALUE; + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); goto out; } if (parent == NULL || ref_count == NULL || reference_set_map == NULL || length == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } @@ -896,12 +943,12 @@ tsk_treeseq_genealogical_nearest_neighbours(const tsk_treeseq_t *self, for (j = 0; j < reference_set_size[k]; j++) { u = reference_sets[k][j]; if (u < 0 || u >= (tsk_id_t) num_nodes) { - ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + ret = tsk_trace_error(TSK_ERR_NODE_OUT_OF_BOUNDS); goto out; } if (reference_set_map[u] != TSK_NULL) { /* FIXME Technically inaccurate here: duplicate focal not sample */ - ret = TSK_ERR_DUPLICATE_SAMPLE; + ret = tsk_trace_error(TSK_ERR_DUPLICATE_SAMPLE); goto out; } reference_set_map[u] = k; @@ -914,7 +961,7 @@ tsk_treeseq_genealogical_nearest_neighbours(const tsk_treeseq_t *self, for (j = 0; j < num_focal; j++) { u = focal[j]; if (u < 0 || u >= (tsk_id_t) num_nodes) { - ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + ret = tsk_trace_error(TSK_ERR_NODE_OUT_OF_BOUNDS); goto out; } } @@ -1055,12 +1102,12 @@ tsk_treeseq_mean_descendants(const tsk_treeseq_t *self, if (num_reference_sets == 0 || num_reference_sets > (INT32_MAX - 1)) { /* TODO: more specific error */ - ret = TSK_ERR_BAD_PARAM_VALUE; + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); goto out; } if (parent == NULL || ref_count == NULL || last_update == NULL || total_length == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } /* TODO add check for duplicate values in the reference sets */ @@ -1073,7 +1120,7 @@ tsk_treeseq_mean_descendants(const tsk_treeseq_t *self, for (j = 0; j < reference_set_size[k]; j++) { u = reference_sets[k][j]; if (u < 0 || u >= (tsk_id_t) num_nodes) { - ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + ret = tsk_trace_error(TSK_ERR_NODE_OUT_OF_BOUNDS); goto out; } row = GET_2D_ROW(ref_count, K, u); @@ -1191,24 +1238,69 @@ tsk_treeseq_mean_descendants(const tsk_treeseq_t *self, * General stats framework ***********************************/ +#define TSK_REQUIRE_FULL_SPAN 1 + static int -tsk_treeseq_check_windows( - const tsk_treeseq_t *self, tsk_size_t num_windows, const double *windows) +tsk_treeseq_check_windows(const tsk_treeseq_t *self, tsk_size_t num_windows, + const double *windows, tsk_flags_t options) { - int ret = TSK_ERR_BAD_WINDOWS; + int ret = 0; tsk_size_t j; if (num_windows < 1) { - ret = TSK_ERR_BAD_NUM_WINDOWS; + ret = tsk_trace_error(TSK_ERR_BAD_NUM_WINDOWS); goto out; } - /* TODO these restrictions can be lifted later if we want a specific interval. */ - if (windows[0] != 0) { + if (options & TSK_REQUIRE_FULL_SPAN) { + /* TODO the general stat code currently requires that we include the + * entire tree sequence span. This should be relaxed, so hopefully + * this branch (and the option) can be removed at some point */ + if (windows[0] != 0) { + ret = tsk_trace_error(TSK_ERR_BAD_WINDOWS); + goto out; + } + if (windows[num_windows] != self->tables->sequence_length) { + ret = tsk_trace_error(TSK_ERR_BAD_WINDOWS); + goto out; + } + } else { + if (windows[0] < 0) { + ret = tsk_trace_error(TSK_ERR_BAD_WINDOWS); + goto out; + } + if (windows[num_windows] > self->tables->sequence_length) { + ret = tsk_trace_error(TSK_ERR_BAD_WINDOWS); + goto out; + } + } + for (j = 0; j < num_windows; j++) { + if (windows[j] >= windows[j + 1]) { + ret = tsk_trace_error(TSK_ERR_BAD_WINDOWS); + goto out; + } + } + ret = 0; +out: + return ret; +} + +static int +tsk_treeseq_check_time_windows(tsk_size_t num_windows, const double *windows) +{ + // This does not check the last window ends at infinity, + // which is required for some time window functions. + int ret = TSK_ERR_BAD_TIME_WINDOWS; + tsk_size_t j; + + if (num_windows < 1) { + ret = TSK_ERR_BAD_TIME_WINDOWS_DIM; goto out; } - if (windows[num_windows] != self->tables->sequence_length) { + + if (windows[0] != 0.0) { goto out; } + for (j = 0; j < num_windows; j++) { if (windows[j] >= windows[j + 1]) { goto out; @@ -1284,19 +1376,30 @@ tsk_treeseq_branch_general_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, double *state = tsk_calloc(num_nodes * state_dim, sizeof(*state)); double *summary = tsk_calloc(num_nodes * result_dim, sizeof(*summary)); double *running_sum = tsk_calloc(result_dim, sizeof(*running_sum)); + double *zero_state = tsk_calloc(state_dim, sizeof(*zero_state)); + double *zero_summary = tsk_calloc(result_dim, sizeof(*zero_state)); if (self->time_uncalibrated && !(options & TSK_STAT_ALLOW_TIME_UNCALIBRATED)) { - ret = TSK_ERR_TIME_UNCALIBRATED; + ret = tsk_trace_error(TSK_ERR_TIME_UNCALIBRATED); goto out; } if (parent == NULL || branch_length == NULL || state == NULL || running_sum == NULL - || summary == NULL) { - ret = TSK_ERR_NO_MEMORY; + || summary == NULL || zero_state == NULL || zero_summary == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } tsk_memset(parent, 0xff, num_nodes * sizeof(*parent)); + /* If f is not strict, we may need to set conditions for non-sample nodes as well. */ + ret = f(state_dim, zero_state, result_dim, zero_summary, f_params); + if (ret != 0) { + goto out; + } + for (j = 0; j < num_nodes; j++) { // we could skip this if zero_summary is zero + summary_u = GET_2D_ROW(summary, result_dim, j); + tsk_memcpy(summary_u, zero_summary, result_dim * sizeof(*zero_summary)); + } /* Set the initial conditions */ for (j = 0; j < self->num_samples; j++) { u = self->samples[j]; @@ -1309,6 +1412,7 @@ tsk_treeseq_branch_general_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, goto out; } } + tsk_memset(result, 0, num_windows * result_dim * sizeof(*result)); /* Iterate over the trees */ @@ -1412,6 +1516,8 @@ tsk_treeseq_branch_general_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, tsk_safe_free(state); tsk_safe_free(summary); tsk_safe_free(running_sum); + tsk_safe_free(zero_state); + tsk_safe_free(zero_summary); return ret; } @@ -1433,7 +1539,7 @@ get_allele_weights(const tsk_site_t *site, const double *state, tsk_size_t state const char *alt_allele; if (alleles == NULL || allele_lengths == NULL || allele_states == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } @@ -1470,7 +1576,7 @@ get_allele_weights(const tsk_site_t *site, const double *state, tsk_size_t state allele_row[k] += state_row[k]; } - /* Get the index for the alternate allele that we must substract from */ + /* Get the index for the alternate allele that we must subtract from */ alt_allele = site->ancestral_state; alt_allele_length = site->ancestral_state_length; if (mutation.parent != TSK_NULL) { @@ -1516,7 +1622,7 @@ compute_general_stat_site_result(tsk_site_t *site, double *state, tsk_size_t sta double *result_tmp = tsk_calloc(result_dim, sizeof(*result_tmp)); if (result_tmp == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } tsk_memset(result, 0, result_dim * sizeof(*result)); @@ -1574,7 +1680,7 @@ tsk_treeseq_site_general_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, bool polarised = false; if (parent == NULL || state == NULL || total_weight == NULL || site_result == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } tsk_memset(parent, 0xff, num_nodes * sizeof(*parent)); @@ -1706,7 +1812,7 @@ tsk_treeseq_node_general_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, double t_left, t_right, w_right; if (parent == NULL || state == NULL || node_summary == NULL || last_update == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } tsk_memset(parent, 0xff, num_nodes * sizeof(*parent)); @@ -1894,7 +2000,7 @@ tsk_polarisable_func_general_stat(const tsk_treeseq_t *self, tsk_size_t state_di if (upargs.total_weight == NULL || upargs.total_minus_state == NULL || upargs.result_tmp == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } @@ -1944,23 +2050,24 @@ tsk_treeseq_general_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, } /* It's an error to specify more than one mode */ if (stat_site + stat_branch + stat_node > 1) { - ret = TSK_ERR_MULTIPLE_STAT_MODES; + ret = tsk_trace_error(TSK_ERR_MULTIPLE_STAT_MODES); goto out; } if (state_dim < 1) { - ret = TSK_ERR_BAD_STATE_DIMS; + ret = tsk_trace_error(TSK_ERR_BAD_STATE_DIMS); goto out; } if (result_dim < 1) { - ret = TSK_ERR_BAD_RESULT_DIMS; + ret = tsk_trace_error(TSK_ERR_BAD_RESULT_DIMS); goto out; } if (windows == NULL) { num_windows = 1; windows = default_windows; } else { - ret = tsk_treeseq_check_windows(self, num_windows, windows); + ret = tsk_treeseq_check_windows( + self, num_windows, windows, TSK_REQUIRE_FULL_SPAN); if (ret != 0) { goto out; } @@ -1995,7 +2102,7 @@ check_set_indexes( for (j = 0; j < num_set_indexes; j++) { if (set_indexes[j] < 0 || set_indexes[j] >= (tsk_id_t) num_sets) { - ret = TSK_ERR_BAD_SAMPLE_SET_INDEX; + ret = tsk_trace_error(TSK_ERR_BAD_SAMPLE_SET_INDEX); goto out; } } @@ -2013,24 +2120,24 @@ tsk_treeseq_check_sample_sets(const tsk_treeseq_t *self, tsk_size_t num_sample_s tsk_id_t u, sample_index; if (num_sample_sets == 0) { - ret = TSK_ERR_INSUFFICIENT_SAMPLE_SETS; + ret = tsk_trace_error(TSK_ERR_INSUFFICIENT_SAMPLE_SETS); goto out; } j = 0; for (k = 0; k < num_sample_sets; k++) { if (sample_set_sizes[k] == 0) { - ret = TSK_ERR_EMPTY_SAMPLE_SET; + ret = tsk_trace_error(TSK_ERR_EMPTY_SAMPLE_SET); goto out; } for (l = 0; l < sample_set_sizes[k]; l++) { u = sample_sets[j]; if (u < 0 || u >= num_nodes) { - ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + ret = tsk_trace_error(TSK_ERR_NODE_OUT_OF_BOUNDS); goto out; } sample_index = self->sample_index_map[u]; if (sample_index == TSK_NULL) { - ret = TSK_ERR_BAD_SAMPLES; + ret = tsk_trace_error(TSK_ERR_BAD_SAMPLES); goto out; } j++; @@ -2057,6 +2164,12 @@ typedef struct { const tsk_id_t *set_indexes; } sample_count_stat_params_t; +typedef struct { + tsk_size_t num_samples; + double *total_weights; + const tsk_id_t *index_tuples; +} indexed_weight_stat_params_t; + static int tsk_treeseq_sample_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, @@ -2081,7 +2194,7 @@ tsk_treeseq_sample_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_s } weights = tsk_calloc(num_samples * num_sample_sets, sizeof(*weights)); if (weights == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } j = 0; @@ -2091,7 +2204,7 @@ tsk_treeseq_sample_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_s sample_index = self->sample_index_map[u]; weight_row = GET_2D_ROW(weights, num_sample_sets, sample_index); if (weight_row[k] != 0) { - ret = TSK_ERR_DUPLICATE_SAMPLE; + ret = tsk_trace_error(TSK_ERR_DUPLICATE_SAMPLE); goto out; } weight_row[k] = 1; @@ -2106,3906 +2219,8511 @@ tsk_treeseq_sample_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_s } /*********************************** - * Allele frequency spectrum + * Two Locus Statistics ***********************************/ -static inline void -fold(tsk_size_t *restrict coordinate, const tsk_size_t *restrict dims, - tsk_size_t num_dims) +static int +get_allele_samples(const tsk_site_t *site, const tsk_bit_array_t *state, + tsk_bit_array_t *out_allele_samples, tsk_size_t *out_num_alleles) +{ + int ret = 0; + tsk_mutation_t mutation, parent_mut; + tsk_size_t mutation_index, allele, alt_allele_length; + /* The allele table */ + tsk_size_t max_alleles = site->mutations_length + 1; + const char **alleles = tsk_malloc(max_alleles * sizeof(*alleles)); + tsk_size_t *allele_lengths = tsk_calloc(max_alleles, sizeof(*allele_lengths)); + const char *alt_allele; + tsk_bit_array_t state_row; + tsk_bit_array_t allele_samples_row; + tsk_bit_array_t alt_allele_samples_row; + tsk_size_t num_alleles = 1; + + if (alleles == NULL || allele_lengths == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + + tsk_bug_assert(state != NULL); + alleles[0] = site->ancestral_state; + allele_lengths[0] = site->ancestral_state_length; + + for (mutation_index = 0; mutation_index < site->mutations_length; mutation_index++) { + mutation = site->mutations[mutation_index]; + /* Compute the allele index for this derived state value. */ + for (allele = 0; allele < num_alleles; allele++) { + if (mutation.derived_state_length == allele_lengths[allele] + && tsk_memcmp( + mutation.derived_state, alleles[allele], allele_lengths[allele]) + == 0) { + break; + } + } + if (allele == num_alleles) { + tsk_bug_assert(allele < max_alleles); + alleles[allele] = mutation.derived_state; + allele_lengths[allele] = mutation.derived_state_length; + num_alleles++; + } + + /* Add the mutation's samples to this allele */ + tsk_bit_array_get_row(out_allele_samples, allele, &allele_samples_row); + tsk_bit_array_get_row(state, mutation_index, &state_row); + tsk_bit_array_add(&allele_samples_row, &state_row); + + /* Get the index for the alternate allele that we must subtract from */ + alt_allele = site->ancestral_state; + alt_allele_length = site->ancestral_state_length; + if (mutation.parent != TSK_NULL) { + parent_mut = site->mutations[mutation.parent - site->mutations[0].id]; + alt_allele = parent_mut.derived_state; + alt_allele_length = parent_mut.derived_state_length; + } + for (allele = 0; allele < num_alleles; allele++) { + if (alt_allele_length == allele_lengths[allele] + && tsk_memcmp(alt_allele, alleles[allele], allele_lengths[allele]) + == 0) { + break; + } + } + tsk_bug_assert(allele < num_alleles); + + tsk_bit_array_get_row(out_allele_samples, allele, &alt_allele_samples_row); + tsk_bit_array_subtract(&alt_allele_samples_row, &allele_samples_row); + } + *out_num_alleles = num_alleles; +out: + tsk_safe_free(alleles); + tsk_safe_free(allele_lengths); + return ret; +} + +static int +norm_hap_weighted(tsk_size_t result_dim, const double *hap_weights, + tsk_size_t TSK_UNUSED(n_a), tsk_size_t TSK_UNUSED(n_b), double *result, void *params) { + sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + const double *weight_row; + double n; tsk_size_t k; - double n = 0; - int s = 0; - for (k = 0; k < num_dims; k++) { - tsk_bug_assert(coordinate[k] < dims[k]); - n += (double) dims[k] - 1; - s += (int) coordinate[k]; + for (k = 0; k < result_dim; k++) { + weight_row = GET_2D_ROW(hap_weights, 3, k); + n = (double) args.sample_set_sizes[k]; + // TODO: what to do when n = 0 + result[k] = weight_row[0] / n; } - n /= 2; - k = num_dims; - while (s == n && k > 0) { - k--; - n -= ((double) (dims[k] - 1)) / 2; - s -= (int) coordinate[k]; + return 0; +} + +static int +norm_hap_weighted_ij(tsk_size_t result_dim, const double *hap_weights, + tsk_size_t TSK_UNUSED(n_a), tsk_size_t TSK_UNUSED(n_b), double *result, void *params) +{ + sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + const double *weight_row; + double ni, nj, wAB_i, wAB_j; + tsk_id_t i, j; + tsk_size_t k; + + for (k = 0; k < result_dim; k++) { + i = args.set_indexes[2 * k]; + j = args.set_indexes[2 * k + 1]; + ni = (double) args.sample_set_sizes[i]; + nj = (double) args.sample_set_sizes[j]; + weight_row = GET_2D_ROW(hap_weights, 3, i); + wAB_i = weight_row[0]; + weight_row = GET_2D_ROW(hap_weights, 3, j); + wAB_j = weight_row[0]; + + result[k] = (wAB_i + wAB_j) / (ni + nj); } - if (s > n) { - for (k = 0; k < num_dims; k++) { - s = (int) (dims[k] - 1 - coordinate[k]); - tsk_bug_assert(s >= 0); - coordinate[k] = (tsk_size_t) s; - } + + return 0; +} + +static int +norm_total_weighted(tsk_size_t result_dim, const double *TSK_UNUSED(hap_weights), + tsk_size_t n_a, tsk_size_t n_b, double *result, void *TSK_UNUSED(params)) +{ + tsk_size_t k; + + for (k = 0; k < result_dim; k++) { + result[k] = 1 / (double) (n_a * n_b); + } + return 0; +} + +static void +get_all_samples_bits(tsk_bit_array_t *all_samples, tsk_size_t n) +{ + tsk_size_t i; + const tsk_bit_array_value_t all = ~((tsk_bit_array_value_t) 0); + const tsk_bit_array_value_t remainder_samples = n % TSK_BIT_ARRAY_NUM_BITS; + + all_samples->data[all_samples->size - 1] + = remainder_samples ? ~(all << remainder_samples) : all; + for (i = 0; i < all_samples->size - 1; i++) { + all_samples->data[i] = all; } } static int -tsk_treeseq_update_site_afs(const tsk_treeseq_t *self, const tsk_site_t *site, - const double *total_counts, const double *counts, tsk_size_t num_sample_sets, - tsk_size_t window_index, tsk_size_t *result_dims, tsk_flags_t options, +compute_general_two_site_stat_result(const tsk_bit_array_t *site_a_state, + const tsk_bit_array_t *site_b_state, tsk_size_t num_a_alleles, + tsk_size_t num_b_alleles, tsk_size_t num_samples, tsk_size_t state_dim, + const tsk_bit_array_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, + sample_count_stat_params_t *f_params, norm_func_t *norm_f, bool polarised, double *result) { int ret = 0; - tsk_size_t afs_size; - tsk_size_t k, allele, num_alleles, all_samples; - double increment, *afs, *allele_counts, *allele_count; - tsk_size_t *coordinate = tsk_malloc(num_sample_sets * sizeof(*coordinate)); - bool polarised = !!(options & TSK_STAT_POLARISED); - const tsk_size_t K = num_sample_sets + 1; - - if (coordinate == NULL) { - ret = TSK_ERR_NO_MEMORY; + tsk_bit_array_t A_samples, B_samples; + // ss_ prefix refers to a sample set + tsk_bit_array_t ss_row; + tsk_bit_array_t ss_A_samples, ss_B_samples, ss_AB_samples, AB_samples; + // Sample sets and b sites are rows, a sites are columns + // b1 b2 b3 + // a1 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3] + // a2 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3] + // a3 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3] + tsk_size_t k, mut_a, mut_b; + tsk_size_t result_row_len = num_b_alleles * result_dim; + tsk_size_t w_A = 0, w_B = 0, w_AB = 0; + uint8_t polarised_val = polarised ? 1 : 0; + double *hap_weight_row; + double *result_tmp_row; + double *weights = tsk_malloc(3 * state_dim * sizeof(*weights)); + double *norm = tsk_malloc(result_dim * sizeof(*norm)); + double *result_tmp + = tsk_malloc(result_row_len * num_a_alleles * sizeof(*result_tmp)); + + tsk_memset(&ss_A_samples, 0, sizeof(ss_A_samples)); + tsk_memset(&ss_B_samples, 0, sizeof(ss_B_samples)); + tsk_memset(&ss_AB_samples, 0, sizeof(ss_AB_samples)); + tsk_memset(&AB_samples, 0, sizeof(AB_samples)); + + if (weights == NULL || norm == NULL || result_tmp == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + + ret = tsk_bit_array_init(&ss_A_samples, num_samples, 1); + if (ret != 0) { goto out; } - ret = get_allele_weights( - site, counts, K, total_counts, &num_alleles, &allele_counts); + ret = tsk_bit_array_init(&ss_B_samples, num_samples, 1); + if (ret != 0) { + goto out; + } + ret = tsk_bit_array_init(&ss_AB_samples, num_samples, 1); + if (ret != 0) { + goto out; + } + ret = tsk_bit_array_init(&AB_samples, num_samples, 1); if (ret != 0) { goto out; } - afs_size = result_dims[num_sample_sets]; - afs = result + afs_size * window_index; + for (mut_a = polarised_val; mut_a < num_a_alleles; mut_a++) { + result_tmp_row = GET_2D_ROW(result_tmp, result_row_len, mut_a); + for (mut_b = polarised_val; mut_b < num_b_alleles; mut_b++) { + tsk_bit_array_get_row(site_a_state, mut_a, &A_samples); + tsk_bit_array_get_row(site_b_state, mut_b, &B_samples); + tsk_bit_array_intersect(&A_samples, &B_samples, &AB_samples); + for (k = 0; k < state_dim; k++) { + tsk_bit_array_get_row(sample_sets, k, &ss_row); + hap_weight_row = GET_2D_ROW(weights, 3, k); - increment = polarised ? 1 : 0.5; - /* Sum over the allele weights. Skip the ancestral state if polarised. */ - for (allele = polarised ? 1 : 0; allele < num_alleles; allele++) { - allele_count = GET_2D_ROW(allele_counts, K, allele); - all_samples = (tsk_size_t) allele_count[num_sample_sets]; - if (all_samples > 0 && all_samples < self->num_samples) { - for (k = 0; k < num_sample_sets; k++) { - coordinate[k] = (tsk_size_t) allele_count[k]; + tsk_bit_array_intersect(&A_samples, &ss_row, &ss_A_samples); + tsk_bit_array_intersect(&B_samples, &ss_row, &ss_B_samples); + tsk_bit_array_intersect(&AB_samples, &ss_row, &ss_AB_samples); + + w_AB = tsk_bit_array_count(&ss_AB_samples); + w_A = tsk_bit_array_count(&ss_A_samples); + w_B = tsk_bit_array_count(&ss_B_samples); + + hap_weight_row[0] = (double) w_AB; + hap_weight_row[1] = (double) (w_A - w_AB); // w_Ab + hap_weight_row[2] = (double) (w_B - w_AB); // w_aB } - if (!polarised) { - fold(coordinate, result_dims, num_sample_sets); + ret = f(state_dim, weights, result_dim, result_tmp_row, f_params); + if (ret != 0) { + goto out; } - increment_nd_array_value( - afs, num_sample_sets, result_dims, coordinate, increment); + ret = norm_f(result_dim, weights, num_a_alleles - polarised_val, + num_b_alleles - polarised_val, norm, f_params); + if (ret != 0) { + goto out; + } + for (k = 0; k < result_dim; k++) { + result[k] += result_tmp_row[k] * norm[k]; + } + result_tmp_row += result_dim; // Advance to the next column } } + out: - tsk_safe_free(coordinate); - tsk_safe_free(allele_counts); + tsk_safe_free(weights); + tsk_safe_free(norm); + tsk_safe_free(result_tmp); + tsk_bit_array_free(&ss_A_samples); + tsk_bit_array_free(&ss_B_samples); + tsk_bit_array_free(&ss_AB_samples); + tsk_bit_array_free(&AB_samples); return ret; } +static void +get_site_row_col_indices(tsk_size_t n_rows, const tsk_id_t *row_sites, tsk_size_t n_cols, + const tsk_id_t *col_sites, tsk_id_t *sites, tsk_size_t *n_sites, tsk_size_t *row_idx, + tsk_size_t *col_idx) +{ + tsk_size_t r = 0, c = 0, s = 0; + + // Iterate rows and columns until we've exhaused one of the lists + while ((r < n_rows) && (c < n_cols)) { + if (row_sites[r] < col_sites[c]) { + sites[s] = row_sites[r]; + row_idx[r] = s; + s++; + r++; + } else if (col_sites[c] < row_sites[r]) { + sites[s] = col_sites[c]; + col_idx[c] = s; + s++; + c++; + } else { // row == col + sites[s] = row_sites[r]; + col_idx[c] = s; + row_idx[r] = s; + s++; + r++; + c++; + } + } + + // If there are any items remaining in the other list, drain it + while (r < n_rows) { + sites[s] = row_sites[r]; + row_idx[r] = s; + s++; + r++; + } + while (c < n_cols) { + sites[s] = col_sites[c]; + col_idx[c] = s; + s++; + c++; + } + *n_sites = s; +} + static int -tsk_treeseq_site_allele_frequency_spectrum(const tsk_treeseq_t *self, - tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, double *counts, - tsk_size_t num_windows, const double *windows, tsk_size_t *result_dims, - tsk_flags_t options, double *result) +get_mutation_samples(const tsk_treeseq_t *ts, const tsk_id_t *sites, tsk_size_t n_sites, + tsk_size_t *num_alleles, tsk_bit_array_t *allele_samples) { int ret = 0; - tsk_id_t u, v; - tsk_size_t tree_site, tree_index, window_index; - tsk_size_t num_nodes = self->tables->nodes.num_rows; - const tsk_id_t num_edges = (tsk_id_t) self->tables->edges.num_rows; - const tsk_id_t *restrict I = self->tables->indexes.edge_insertion_order; - const tsk_id_t *restrict O = self->tables->indexes.edge_removal_order; - const double *restrict edge_left = self->tables->edges.left; - const double *restrict edge_right = self->tables->edges.right; - const tsk_id_t *restrict edge_parent = self->tables->edges.parent; - const tsk_id_t *restrict edge_child = self->tables->edges.child; - const double sequence_length = self->tables->sequence_length; - tsk_id_t *restrict parent = tsk_malloc(num_nodes * sizeof(*parent)); - tsk_site_t *site; - tsk_id_t tj, tk, h; - tsk_size_t j; - const tsk_size_t K = num_sample_sets + 1; - double t_left, t_right; - double *total_counts = tsk_malloc((1 + num_sample_sets) * sizeof(*total_counts)); + const tsk_flags_t *restrict flags = ts->tables->nodes.flags; + const tsk_size_t num_samples = tsk_treeseq_get_num_samples(ts); + const tsk_size_t *restrict site_muts_len = ts->site_mutations_length; + tsk_site_t site; + tsk_tree_t tree; + tsk_bit_array_t all_samples_bits, mut_samples, mut_samples_row, out_row; + tsk_size_t max_muts_len, site_offset, num_nodes, site_idx, s, m, n; + tsk_id_t node, *nodes = NULL; + void *tmp_nodes; - if (parent == NULL || total_counts == NULL) { - ret = TSK_ERR_NO_MEMORY; + tsk_memset(&mut_samples, 0, sizeof(mut_samples)); + tsk_memset(&all_samples_bits, 0, sizeof(all_samples_bits)); + + max_muts_len = 0; + for (s = 0; s < n_sites; s++) { + if (site_muts_len[sites[s]] > max_muts_len) { + max_muts_len = site_muts_len[sites[s]]; + } + } + // Allocate a bit array of size max alleles for all sites + ret = tsk_bit_array_init(&mut_samples, num_samples, max_muts_len); + if (ret != 0) { goto out; } - tsk_memset(parent, 0xff, num_nodes * sizeof(*parent)); - - for (j = 0; j < num_sample_sets; j++) { - total_counts[j] = (double) sample_set_sizes[j]; + ret = tsk_bit_array_init(&all_samples_bits, num_samples, 1); + if (ret != 0) { + goto out; + } + get_all_samples_bits(&all_samples_bits, num_samples); + ret = tsk_tree_init(&tree, ts, TSK_NO_SAMPLE_COUNTS); + if (ret != 0) { + goto out; } - total_counts[num_sample_sets] = (double) self->num_samples; - - /* Iterate over the trees */ - tj = 0; - tk = 0; - t_left = 0; - tree_index = 0; - window_index = 0; - while (tj < num_edges || t_left < sequence_length) { - while (tk < num_edges && edge_right[O[tk]] == t_left) { - h = O[tk]; - tk++; - u = edge_child[h]; - v = edge_parent[h]; - while (v != TSK_NULL) { - update_state(counts, K, v, u, -1); - v = parent[v]; - } - parent[u] = TSK_NULL; - } - while (tj < num_edges && edge_left[I[tj]] == t_left) { - h = I[tj]; - tj++; - u = edge_child[h]; - v = edge_parent[h]; - parent[u] = v; - while (v != TSK_NULL) { - update_state(counts, K, v, u, +1); - v = parent[v]; - } - } - t_right = sequence_length; - if (tj < num_edges) { - t_right = TSK_MIN(t_right, edge_left[I[tj]]); + // For each mutation within each site, perform one preorder traversal to gather + // the samples under each mutation's node. + site_offset = 0; + for (site_idx = 0; site_idx < n_sites; site_idx++) { + tsk_treeseq_get_site(ts, sites[site_idx], &site); + ret = tsk_tree_seek(&tree, site.position, 0); + if (ret != 0) { + goto out; } - if (tk < num_edges) { - t_right = TSK_MIN(t_right, edge_right[O[tk]]); + tmp_nodes = tsk_realloc(nodes, tsk_tree_get_size_bound(&tree) * sizeof(*nodes)); + if (tmp_nodes == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; } + nodes = tmp_nodes; - /* Update the sites */ - for (tree_site = 0; tree_site < self->tree_sites_length[tree_index]; - tree_site++) { - site = self->tree_sites[tree_index] + tree_site; - while (windows[window_index + 1] <= site->position) { - window_index++; - tsk_bug_assert(window_index < num_windows); - } - ret = tsk_treeseq_update_site_afs(self, site, total_counts, counts, - num_sample_sets, window_index, result_dims, options, result); + tsk_bit_array_get_row(allele_samples, site_offset, &out_row); + tsk_bit_array_add(&out_row, &all_samples_bits); + + // Zero out results before the start of each iteration + tsk_memset(mut_samples.data, 0, + mut_samples.size * max_muts_len * sizeof(tsk_bit_array_value_t)); + for (m = 0; m < site.mutations_length; m++) { + tsk_bit_array_get_row(&mut_samples, m, &mut_samples_row); + node = site.mutations[m].node; + ret = tsk_tree_preorder_from(&tree, node, nodes, &num_nodes); if (ret != 0) { goto out; } - tsk_bug_assert(windows[window_index] <= site->position); - tsk_bug_assert(site->position < windows[window_index + 1]); + for (n = 0; n < num_nodes; n++) { + node = nodes[n]; + if (flags[node] & TSK_NODE_IS_SAMPLE) { + tsk_bit_array_add_bit(&mut_samples_row, + (tsk_bit_array_value_t) ts->sample_index_map[node]); + } + } } - tree_index++; - t_left = t_right; + site_offset += site.mutations_length + 1; + get_allele_samples(&site, &mut_samples, &out_row, &(num_alleles[site_idx])); } +// if adding code below, check ret before continuing out: - /* Can't use msp_safe_free here because of restrict */ - if (parent != NULL) { - free(parent); - } - tsk_safe_free(total_counts); - return ret; + tsk_safe_free(nodes); + tsk_tree_free(&tree); + tsk_bit_array_free(&mut_samples); + tsk_bit_array_free(&all_samples_bits); + return ret == TSK_TREE_OK ? 0 : ret; } -static int TSK_WARN_UNUSED -tsk_treeseq_update_branch_afs(const tsk_treeseq_t *self, tsk_id_t u, double right, - const double *restrict branch_length, double *restrict last_update, - const double *counts, tsk_size_t num_sample_sets, tsk_size_t window_index, - const tsk_size_t *result_dims, tsk_flags_t options, double *result) +static int +tsk_treeseq_two_site_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, + const tsk_bit_array_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, + sample_count_stat_params_t *f_params, norm_func_t *norm_f, tsk_size_t n_rows, + const tsk_id_t *row_sites, tsk_size_t n_cols, const tsk_id_t *col_sites, + tsk_flags_t options, double *result) { + int ret = 0; - tsk_size_t afs_size; - tsk_size_t k; - double *afs; - tsk_size_t *coordinate = tsk_malloc(num_sample_sets * sizeof(*coordinate)); - bool polarised = !!(options & TSK_STAT_POLARISED); - const double *count_row = GET_2D_ROW(counts, num_sample_sets + 1, u); - double x = (right - last_update[u]) * branch_length[u]; - const tsk_size_t all_samples = (tsk_size_t) count_row[num_sample_sets]; + tsk_bit_array_t allele_samples, c_state, r_state; + bool polarised = false; + tsk_id_t *sites; + tsk_size_t r, c, s, n_alleles, n_sites, *row_idx, *col_idx; + double *result_row; + const tsk_size_t num_samples = self->num_samples; + tsk_size_t *num_alleles = NULL, *site_offsets = NULL; + tsk_size_t result_row_len = n_cols * result_dim; - if (coordinate == NULL) { - ret = TSK_ERR_NO_MEMORY; + tsk_memset(&allele_samples, 0, sizeof(allele_samples)); + + sites = tsk_malloc(self->tables->sites.num_rows * sizeof(*sites)); + row_idx = tsk_malloc(self->tables->sites.num_rows * sizeof(*row_idx)); + col_idx = tsk_malloc(self->tables->sites.num_rows * sizeof(*col_idx)); + if (sites == NULL || row_idx == NULL || col_idx == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } + get_site_row_col_indices( + n_rows, row_sites, n_cols, col_sites, sites, &n_sites, row_idx, col_idx); - if (0 < all_samples && all_samples < self->num_samples) { - if (!polarised) { - x *= 0.5; - } - afs_size = result_dims[num_sample_sets]; - afs = result + afs_size * window_index; - for (k = 0; k < num_sample_sets; k++) { - coordinate[k] = (tsk_size_t) count_row[k]; - } - if (!polarised) { - fold(coordinate, result_dims, num_sample_sets); - } - increment_nd_array_value(afs, num_sample_sets, result_dims, coordinate, x); + // We rely on n_sites to allocate these arrays, which are initialized + // to NULL for safe deallocation if the previous allocation fails + num_alleles = tsk_malloc(n_sites * sizeof(*num_alleles)); + site_offsets = tsk_malloc(n_sites * sizeof(*site_offsets)); + if (num_alleles == NULL || site_offsets == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; } - last_update[u] = right; -out: - tsk_safe_free(coordinate); - return ret; -} - -static int -tsk_treeseq_branch_allele_frequency_spectrum(const tsk_treeseq_t *self, - tsk_size_t num_sample_sets, double *counts, tsk_size_t num_windows, - const double *windows, const tsk_size_t *result_dims, tsk_flags_t options, - double *result) -{ - int ret = 0; - tsk_id_t u, v; - tsk_size_t window_index; - tsk_size_t num_nodes = self->tables->nodes.num_rows; - const tsk_id_t num_edges = (tsk_id_t) self->tables->edges.num_rows; - const tsk_id_t *restrict I = self->tables->indexes.edge_insertion_order; - const tsk_id_t *restrict O = self->tables->indexes.edge_removal_order; - const double *restrict edge_left = self->tables->edges.left; - const double *restrict edge_right = self->tables->edges.right; - const tsk_id_t *restrict edge_parent = self->tables->edges.parent; - const tsk_id_t *restrict edge_child = self->tables->edges.child; - const double *restrict node_time = self->tables->nodes.time; - const double sequence_length = self->tables->sequence_length; - tsk_id_t *restrict parent = tsk_malloc(num_nodes * sizeof(*parent)); - double *restrict last_update = tsk_calloc(num_nodes, sizeof(*last_update)); - double *restrict branch_length = tsk_calloc(num_nodes, sizeof(*branch_length)); - tsk_id_t tj, tk, h; - double t_left, t_right, w_right; - const tsk_size_t K = num_sample_sets + 1; - if (self->time_uncalibrated && !(options & TSK_STAT_ALLOW_TIME_UNCALIBRATED)) { - ret = TSK_ERR_TIME_UNCALIBRATED; + n_alleles = 0; + for (s = 0; s < n_sites; s++) { + site_offsets[s] = n_alleles; + n_alleles += self->site_mutations_length[sites[s]] + 1; + } + ret = tsk_bit_array_init(&allele_samples, num_samples, n_alleles); + if (ret != 0) { goto out; } - - if (parent == NULL || last_update == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = get_mutation_samples(self, sites, n_sites, num_alleles, &allele_samples); + if (ret != 0) { goto out; } - tsk_memset(parent, 0xff, num_nodes * sizeof(*parent)); - /* Iterate over the trees */ - tj = 0; - tk = 0; - t_left = 0; - window_index = 0; - while (tj < num_edges || t_left < sequence_length) { - tsk_bug_assert(window_index < num_windows); - while (tk < num_edges && edge_right[O[tk]] == t_left) { - h = O[tk]; - tk++; - u = edge_child[h]; - v = edge_parent[h]; - ret = tsk_treeseq_update_branch_afs(self, u, t_left, branch_length, - last_update, counts, num_sample_sets, window_index, result_dims, options, - result); + if (options & TSK_STAT_POLARISED) { + polarised = true; + } + + // For each row/column pair, fill in the sample set in the result matrix. + for (r = 0; r < n_rows; r++) { + result_row = GET_2D_ROW(result, result_row_len, r); + for (c = 0; c < n_cols; c++) { + tsk_bit_array_get_row(&allele_samples, site_offsets[row_idx[r]], &r_state); + tsk_bit_array_get_row(&allele_samples, site_offsets[col_idx[c]], &c_state); + ret = compute_general_two_site_stat_result(&r_state, &c_state, + num_alleles[row_idx[r]], num_alleles[col_idx[c]], num_samples, state_dim, + sample_sets, result_dim, f, f_params, norm_f, polarised, + &(result_row[c * result_dim])); if (ret != 0) { goto out; } - while (v != TSK_NULL) { - ret = tsk_treeseq_update_branch_afs(self, v, t_left, branch_length, - last_update, counts, num_sample_sets, window_index, result_dims, - options, result); - if (ret != 0) { - goto out; - } - update_state(counts, K, v, u, -1); - v = parent[v]; - } - parent[u] = TSK_NULL; - branch_length[u] = 0; } + } - while (tj < num_edges && edge_left[I[tj]] == t_left) { - h = I[tj]; - tj++; - u = edge_child[h]; - v = edge_parent[h]; - parent[u] = v; - branch_length[u] = node_time[v] - node_time[u]; - while (v != TSK_NULL) { - ret = tsk_treeseq_update_branch_afs(self, v, t_left, branch_length, - last_update, counts, num_sample_sets, window_index, result_dims, - options, result); - if (ret != 0) { - goto out; - } - update_state(counts, K, v, u, +1); - v = parent[v]; - } - } +out: + tsk_safe_free(sites); + tsk_safe_free(row_idx); + tsk_safe_free(col_idx); + tsk_safe_free(num_alleles); + tsk_safe_free(site_offsets); + tsk_bit_array_free(&allele_samples); + return ret; +} - t_right = sequence_length; - if (tj < num_edges) { - t_right = TSK_MIN(t_right, edge_left[I[tj]]); - } - if (tk < num_edges) { - t_right = TSK_MIN(t_right, edge_right[O[tk]]); - } +static int +sample_sets_to_bit_array(const tsk_treeseq_t *self, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t num_sample_sets, + tsk_bit_array_t *sample_sets_bits) +{ + int ret; + tsk_bit_array_t bits_row; + tsk_size_t j, k, l; + tsk_id_t u, sample_index; - while (window_index < num_windows && windows[window_index + 1] <= t_right) { - w_right = windows[window_index + 1]; - /* Flush the contributions of all nodes to the current window */ - for (u = 0; u < (tsk_id_t) num_nodes; u++) { - tsk_bug_assert(last_update[u] < w_right); - ret = tsk_treeseq_update_branch_afs(self, u, w_right, branch_length, - last_update, counts, num_sample_sets, window_index, result_dims, - options, result); - if (ret != 0) { - goto out; - } + ret = tsk_bit_array_init(sample_sets_bits, self->num_samples, num_sample_sets); + if (ret != 0) { + return ret; + } + + j = 0; + for (k = 0; k < num_sample_sets; k++) { + tsk_bit_array_get_row(sample_sets_bits, k, &bits_row); + for (l = 0; l < sample_set_sizes[k]; l++) { + u = sample_sets[j]; + sample_index = self->sample_index_map[u]; + if (tsk_bit_array_contains( + &bits_row, (tsk_bit_array_value_t) sample_index)) { + ret = tsk_trace_error(TSK_ERR_DUPLICATE_SAMPLE); + goto out; } - window_index++; + tsk_bit_array_add_bit(&bits_row, (tsk_bit_array_value_t) sample_index); + j++; } - - t_left = t_right; } + out: - /* Can't use msp_safe_free here because of restrict */ - if (parent != NULL) { - free(parent); + return ret; +} + +static int +check_sites(const tsk_id_t *sites, tsk_size_t num_sites, tsk_size_t num_site_rows) +{ + int ret = 0; + tsk_size_t i; + + if (num_sites == 0) { + return ret; // No need to verify sites if there aren't any } - if (last_update != NULL) { - free(last_update); + + for (i = 0; i < num_sites - 1; i++) { + if (sites[i] < 0 || sites[i] >= (tsk_id_t) num_site_rows) { + ret = tsk_trace_error(TSK_ERR_SITE_OUT_OF_BOUNDS); + goto out; + } + if (sites[i] > sites[i + 1]) { + ret = tsk_trace_error(TSK_ERR_STAT_UNSORTED_SITES); + goto out; + } + if (sites[i] == sites[i + 1]) { + ret = tsk_trace_error(TSK_ERR_STAT_DUPLICATE_SITES); + goto out; + } } - if (branch_length != NULL) { - free(branch_length); + // check the last value + if (sites[i] < 0 || sites[i] >= (tsk_id_t) num_site_rows) { + ret = tsk_trace_error(TSK_ERR_SITE_OUT_OF_BOUNDS); + goto out; } +out: return ret; } -int -tsk_treeseq_allele_frequency_spectrum(const tsk_treeseq_t *self, - tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, - const tsk_id_t *sample_sets, tsk_size_t num_windows, const double *windows, - tsk_flags_t options, double *result) +static int +check_positions( + const double *positions, tsk_size_t num_positions, double sequence_length) { int ret = 0; - bool stat_site = !!(options & TSK_STAT_SITE); - bool stat_branch = !!(options & TSK_STAT_BRANCH); - bool stat_node = !!(options & TSK_STAT_NODE); - double default_windows[] = { 0, self->tables->sequence_length }; - const tsk_size_t num_nodes = self->tables->nodes.num_rows; - const tsk_size_t K = num_sample_sets + 1; - tsk_size_t j, k, l, afs_size; - tsk_id_t u; - tsk_size_t *result_dims = NULL; - /* These counts should really be ints, but we use doubles so that we can - * reuse code from the general_stats code paths. */ - double *counts = NULL; - double *count_row; + tsk_size_t i; - if (stat_node) { - ret = TSK_ERR_UNSUPPORTED_STAT_MODE; - goto out; - } - /* If no mode is specified, we default to site mode */ - if (!(stat_site || stat_branch)) { - stat_site = true; - } - /* It's an error to specify more than one mode */ - if (stat_site + stat_branch > 1) { - ret = TSK_ERR_MULTIPLE_STAT_MODES; - goto out; + if (num_positions == 0) { + return ret; // No need to verify positions if there aren't any } - if (windows == NULL) { - num_windows = 1; - windows = default_windows; - } else { - ret = tsk_treeseq_check_windows(self, num_windows, windows); - if (ret != 0) { + + for (i = 0; i < num_positions - 1; i++) { + if (positions[i] < 0 || positions[i] >= sequence_length) { + ret = tsk_trace_error(TSK_ERR_POSITION_OUT_OF_BOUNDS); + goto out; + } + if (positions[i] > positions[i + 1]) { + ret = tsk_trace_error(TSK_ERR_STAT_UNSORTED_POSITIONS); + goto out; + } + if (positions[i] == positions[i + 1]) { + ret = tsk_trace_error(TSK_ERR_STAT_DUPLICATE_POSITIONS); goto out; } } - ret = tsk_treeseq_check_sample_sets( - self, num_sample_sets, sample_set_sizes, sample_sets); - if (ret != 0) { + // check bounds of last value + if (positions[i] < 0 || positions[i] >= sequence_length) { + ret = tsk_trace_error(TSK_ERR_POSITION_OUT_OF_BOUNDS); goto out; } +out: + return ret; +} - /* the last element of result_dims stores the total size of the dimenensions */ - result_dims = tsk_malloc((num_sample_sets + 1) * sizeof(*result_dims)); - counts = tsk_calloc(num_nodes * K, sizeof(*counts)); - if (counts == NULL || result_dims == NULL) { - ret = TSK_ERR_NO_MEMORY; +static int +positions_to_tree_indexes(const tsk_treeseq_t *ts, const double *positions, + tsk_size_t num_positions, tsk_id_t **tree_indexes) +{ + int ret = 0; + tsk_id_t tree_index = 0; + tsk_size_t i, num_trees = ts->num_trees; + + // This is tricky. If there are 0 positions, we calloc a size of 1 + // we must calloc, because memset will have no effect when called with size 0 + *tree_indexes = tsk_calloc(num_positions, sizeof(*tree_indexes)); + if (tree_indexes == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } - afs_size = 1; - j = 0; - for (k = 0; k < num_sample_sets; k++) { - result_dims[k] = 1 + sample_set_sizes[k]; - afs_size *= result_dims[k]; - for (l = 0; l < sample_set_sizes[k]; l++) { - u = sample_sets[j]; - count_row = GET_2D_ROW(counts, K, u); - if (count_row[k] != 0) { - ret = TSK_ERR_DUPLICATE_SAMPLE; - goto out; - } - count_row[k] = 1; - j++; + tsk_memset(*tree_indexes, TSK_NULL, num_positions * sizeof(**tree_indexes)); + for (i = 0; i < num_positions; i++) { + while (ts->breakpoints[tree_index + 1] <= positions[i]) { + tree_index++; } + (*tree_indexes)[i] = tree_index; } - for (j = 0; j < self->num_samples; j++) { - u = self->samples[j]; - count_row = GET_2D_ROW(counts, K, u); - count_row[num_sample_sets] = 1; - } - result_dims[num_sample_sets] = (tsk_size_t) afs_size; - - tsk_memset(result, 0, num_windows * afs_size * sizeof(*result)); - if (stat_site) { - ret = tsk_treeseq_site_allele_frequency_spectrum(self, num_sample_sets, - sample_set_sizes, counts, num_windows, windows, result_dims, options, - result); - } else { - ret = tsk_treeseq_branch_allele_frequency_spectrum(self, num_sample_sets, counts, - num_windows, windows, result_dims, options, result); - } + tsk_bug_assert(tree_index <= (tsk_id_t)(num_trees - 1)); - if (options & TSK_STAT_SPAN_NORMALISE) { - span_normalise(num_windows, windows, afs_size, result); - } out: - tsk_safe_free(counts); - tsk_safe_free(result_dims); return ret; } -/*********************************** - * One way stats - ***********************************/ - static int -diversity_summary_func(tsk_size_t state_dim, const double *state, - tsk_size_t TSK_UNUSED(result_dim), double *result, void *params) +get_index_counts( + const tsk_id_t *indexes, tsk_size_t num_indexes, tsk_size_t **out_counts) { - sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; - const double *x = state; - double n; - tsk_size_t j; + int ret = 0; + tsk_id_t index = indexes[0]; + tsk_size_t count, i; + tsk_size_t *counts = tsk_calloc( + (tsk_size_t)(indexes[num_indexes ? num_indexes - 1 : 0] - indexes[0] + 1), + sizeof(*counts)); + if (counts == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } - for (j = 0; j < state_dim; j++) { - n = (double) args.sample_set_sizes[j]; - result[j] = x[j] * (n - x[j]) / (n * (n - 1)); + count = 1; + for (i = 1; i < num_indexes; i++) { + if (indexes[i] == indexes[i - 1]) { + count++; + } else { + counts[index - indexes[0]] = count; + count = 1; + index = indexes[i]; + } } - return 0; + counts[index - indexes[0]] = count; + *out_counts = counts; +out: + return ret; } -int -tsk_treeseq_diversity(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result) -{ - return tsk_treeseq_sample_count_stat(self, num_sample_sets, sample_set_sizes, - sample_sets, num_sample_sets, NULL, diversity_summary_func, num_windows, windows, - options, result); -} +typedef struct { + tsk_tree_t tree; + tsk_bit_array_t *node_samples; + tsk_id_t *parent; + tsk_id_t *edges_out; + tsk_id_t *edges_in; + double *branch_len; + tsk_size_t n_edges_out; + tsk_size_t n_edges_in; +} iter_state; static int -trait_covariance_summary_func(tsk_size_t state_dim, const double *state, - tsk_size_t TSK_UNUSED(result_dim), double *result, void *params) +iter_state_init(iter_state *self, const tsk_treeseq_t *ts, tsk_size_t state_dim) { - weight_stat_params_t args = *(weight_stat_params_t *) params; - const double n = (double) args.num_samples; - const double *x = state; - tsk_size_t j; + int ret = 0; + const tsk_size_t num_nodes = ts->tables->nodes.num_rows; - for (j = 0; j < state_dim; j++) { - result[j] = (x[j] * x[j]) / (2 * (n - 1) * (n - 1)); + ret = tsk_tree_init(&self->tree, ts, TSK_NO_SAMPLE_COUNTS); + if (ret != 0) { + goto out; } - return 0; + self->node_samples = tsk_calloc(1, sizeof(*self->node_samples)); + if (self->node_samples == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + ret = tsk_bit_array_init(self->node_samples, ts->num_samples, state_dim * num_nodes); + if (ret != 0) { + goto out; + } + self->parent = tsk_malloc(num_nodes * sizeof(*self->parent)); + self->edges_out = tsk_malloc(num_nodes * sizeof(*self->edges_out)); + self->edges_in = tsk_malloc(num_nodes * sizeof(*self->edges_in)); + self->branch_len = tsk_calloc(num_nodes, sizeof(*self->branch_len)); + if (self->parent == NULL || self->edges_out == NULL || self->edges_in == NULL + || self->branch_len == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } +out: + return ret; } -int -tsk_treeseq_trait_covariance(const tsk_treeseq_t *self, tsk_size_t num_weights, - const double *weights, tsk_size_t num_windows, const double *windows, - tsk_flags_t options, double *result) +static int +get_node_samples(const tsk_treeseq_t *ts, tsk_size_t state_dim, + const tsk_bit_array_t *sample_sets, tsk_bit_array_t *node_samples) { - tsk_size_t num_samples = self->num_samples; - tsk_size_t j, k; - int ret; - const double *row; - double *new_row; - double *means = tsk_calloc(num_weights, sizeof(double)); - double *new_weights = tsk_malloc((num_weights + 1) * num_samples * sizeof(double)); - weight_stat_params_t args = { num_samples = self->num_samples }; - - if (new_weights == NULL || means == NULL) { - ret = TSK_ERR_NO_MEMORY; + int ret = 0; + tsk_size_t n, k; + tsk_bit_array_t sample_set_row, node_samples_row; + tsk_size_t num_nodes = ts->tables->nodes.num_rows; + tsk_bit_array_value_t sample; + const tsk_id_t *restrict sample_index_map = ts->sample_index_map; + const tsk_flags_t *restrict flags = ts->tables->nodes.flags; + + ret = tsk_bit_array_init(node_samples, ts->num_samples, num_nodes * state_dim); + if (ret != 0) { goto out; } - - // center weights - for (j = 0; j < num_samples; j++) { - row = GET_2D_ROW(weights, num_weights, j); - for (k = 0; k < num_weights; k++) { - means[k] += row[k]; - } - } - for (k = 0; k < num_weights; k++) { - means[k] /= (double) num_samples; - } - for (j = 0; j < num_samples; j++) { - row = GET_2D_ROW(weights, num_weights, j); - new_row = GET_2D_ROW(new_weights, num_weights, j); - for (k = 0; k < num_weights; k++) { - new_row[k] = row[k] - means[k]; + for (k = 0; k < state_dim; k++) { + tsk_bit_array_get_row(sample_sets, k, &sample_set_row); + for (n = 0; n < num_nodes; n++) { + if (flags[n] & TSK_NODE_IS_SAMPLE) { + sample = (tsk_bit_array_value_t) sample_index_map[n]; + if (tsk_bit_array_contains(&sample_set_row, sample)) { + tsk_bit_array_get_row( + node_samples, (state_dim * n) + k, &node_samples_row); + tsk_bit_array_add_bit(&node_samples_row, sample); + } + } } } - - ret = tsk_treeseq_general_stat(self, num_weights, new_weights, num_weights, - trait_covariance_summary_func, &args, num_windows, windows, options, result); - out: - tsk_safe_free(means); - tsk_safe_free(new_weights); return ret; } -static int -trait_correlation_summary_func(tsk_size_t state_dim, const double *state, - tsk_size_t TSK_UNUSED(result_dim), double *result, void *params) +static void +iter_state_clear(iter_state *self, tsk_size_t state_dim, tsk_size_t num_nodes, + const tsk_bit_array_t *node_samples) { - weight_stat_params_t args = *(weight_stat_params_t *) params; - const double n = (double) args.num_samples; - const double *x = state; - double p; - tsk_size_t j; + self->n_edges_out = 0; + self->n_edges_in = 0; + tsk_tree_clear(&self->tree); + tsk_memset(self->parent, TSK_NULL, num_nodes * sizeof(*self->parent)); + tsk_memset(self->edges_out, TSK_NULL, num_nodes * sizeof(*self->edges_out)); + tsk_memset(self->edges_in, TSK_NULL, num_nodes * sizeof(*self->edges_in)); + tsk_memset(self->branch_len, 0, num_nodes * sizeof(*self->branch_len)); + tsk_memcpy(self->node_samples->data, node_samples->data, + node_samples->size * state_dim * num_nodes * sizeof(*node_samples->data)); +} - p = x[state_dim - 1]; - for (j = 0; j < state_dim - 1; j++) { - if ((p > 0.0) && (p < 1.0)) { - result[j] = (x[j] * x[j]) / (2 * (p * (1 - p)) * n * (n - 1)); +static void +iter_state_free(iter_state *self) +{ + tsk_tree_free(&self->tree); + tsk_bit_array_free(self->node_samples); + tsk_safe_free(self->node_samples); + tsk_safe_free(self->parent); + tsk_safe_free(self->edges_out); + tsk_safe_free(self->edges_in); + tsk_safe_free(self->branch_len); +} + +static int +advance_collect_edges(iter_state *s, tsk_id_t index) +{ + int ret = 0; + tsk_id_t j, e; + tsk_size_t i; + double left, right; + tsk_tree_position_t pos; + tsk_tree_t *tree = &s->tree; + const double *restrict edge_left = tree->tree_sequence->tables->edges.left; + const double *restrict edge_right = tree->tree_sequence->tables->edges.right; + + // Either we're seeking forward one step from some nonzero position in the tree, or + // from the beginning of the tree sequence. + if (tree->index != TSK_NULL || index == 0) { + ret = tsk_tree_next(tree); + if (ret < 0) { + goto out; + } + pos = tree->tree_pos; + i = 0; + for (j = pos.out.start; j != pos.out.stop; j++) { + s->edges_out[i] = pos.out.order[j]; + i++; + } + s->n_edges_out = i; + i = 0; + for (j = pos.in.start; j != pos.in.stop; j++) { + s->edges_in[i] = pos.in.order[j]; + i++; + } + s->n_edges_in = i; + } else { + // Seek from an arbitrary nonzero position from an uninitialized tree. + tsk_bug_assert(tree->index == -1); + ret = tsk_tree_seek_index(tree, index, 0); + if (ret < 0) { + goto out; + } + pos = tree->tree_pos; + i = 0; + if (pos.direction == TSK_DIR_FORWARD) { + left = pos.interval.left; + for (j = pos.in.start; j != pos.in.stop; j++) { + e = pos.in.order[j]; + if (edge_left[e] <= left && left < edge_right[e]) { + s->edges_in[i] = pos.in.order[j]; + i++; + } + } } else { - result[j] = 0.0; + right = pos.interval.right; + for (j = pos.in.start; j != pos.in.stop; j--) { + e = pos.in.order[j]; + if (edge_right[e] >= right && right > edge_left[e]) { + s->edges_in[i] = pos.in.order[j]; + i++; + } + } } + s->n_edges_out = 0; + s->n_edges_in = i; } - return 0; + ret = 0; +out: + return ret; } -int -tsk_treeseq_trait_correlation(const tsk_treeseq_t *self, tsk_size_t num_weights, - const double *weights, tsk_size_t num_windows, const double *windows, - tsk_flags_t options, double *result) +static int +compute_two_tree_branch_state_update(const tsk_treeseq_t *ts, tsk_id_t c, + const iter_state *A_state, const iter_state *B_state, tsk_size_t state_dim, + tsk_size_t result_dim, int sign, general_stat_func_t *f, + sample_count_stat_params_t *f_params, double *result) { - tsk_size_t num_samples = self->num_samples; - tsk_size_t j, k; - int ret; - double *means = tsk_calloc(num_weights, sizeof(double)); - double *meansqs = tsk_calloc(num_weights, sizeof(double)); - double *sds = tsk_calloc(num_weights, sizeof(double)); - const double *row; - double *new_row; - double *new_weights = tsk_malloc((num_weights + 1) * num_samples * sizeof(double)); - weight_stat_params_t args = { num_samples = self->num_samples }; - - if (new_weights == NULL || means == NULL || meansqs == NULL || sds == NULL) { - ret = TSK_ERR_NO_MEMORY; + int ret = 0; + double a_len, b_len; + double *restrict B_branch_len = B_state->branch_len; + double *weights = NULL, *weights_row, *result_tmp = NULL; + tsk_size_t n, k, a_row, b_row; + tsk_bit_array_t A_samples, B_samples, AB_samples, B_samples_tmp; + const double *restrict A_branch_len = A_state->branch_len; + const tsk_bit_array_t *restrict A_state_samples = A_state->node_samples; + const tsk_bit_array_t *restrict B_state_samples = B_state->node_samples; + tsk_size_t num_samples = ts->num_samples; + tsk_size_t num_nodes = ts->tables->nodes.num_rows; + b_len = B_branch_len[c] * sign; + if (b_len == 0) { + return ret; + } + + tsk_memset(&AB_samples, 0, sizeof(AB_samples)); + tsk_memset(&B_samples_tmp, 0, sizeof(B_samples_tmp)); + + weights = tsk_calloc(3 * state_dim, sizeof(*weights)); + result_tmp = tsk_calloc(result_dim, sizeof(*result_tmp)); + if (weights == NULL || result_tmp == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + ret = tsk_bit_array_init(&AB_samples, num_samples, 1); + if (ret != 0) { goto out; } - - if (num_weights < 1) { - ret = TSK_ERR_BAD_STATE_DIMS; + ret = tsk_bit_array_init(&B_samples_tmp, num_samples, 1); + if (ret != 0) { goto out; } - - // center and scale weights - for (j = 0; j < num_samples; j++) { - row = GET_2D_ROW(weights, num_weights, j); - for (k = 0; k < num_weights; k++) { - means[k] += row[k]; - meansqs[k] += row[k] * row[k]; + for (n = 0; n < num_nodes; n++) { + a_len = A_branch_len[n]; + if (a_len == 0) { + continue; + } + for (k = 0; k < state_dim; k++) { + a_row = (state_dim * n) + k; + b_row = (state_dim * (tsk_size_t) c) + k; + tsk_bit_array_get_row(A_state_samples, a_row, &A_samples); + tsk_bit_array_get_row(B_state_samples, b_row, &B_samples); + tsk_bit_array_intersect(&A_samples, &B_samples, &AB_samples); + weights_row = GET_2D_ROW(weights, 3, k); + weights_row[0] = (double) tsk_bit_array_count(&AB_samples); // w_AB + weights_row[1] + = (double) tsk_bit_array_count(&A_samples) - weights_row[0]; // w_Ab + weights_row[2] + = (double) tsk_bit_array_count(&B_samples) - weights_row[0]; // w_aB + } + ret = f(state_dim, weights, result_dim, result_tmp, f_params); + if (ret != 0) { + goto out; + } + for (k = 0; k < result_dim; k++) { + result[k] += result_tmp[k] * a_len * b_len; } } - for (k = 0; k < num_weights; k++) { - means[k] /= (double) num_samples; - meansqs[k] -= means[k] * means[k] * (double) num_samples; - meansqs[k] /= (double) (num_samples - 1); - sds[k] = sqrt(meansqs[k]); - } - for (j = 0; j < num_samples; j++) { - row = GET_2D_ROW(weights, num_weights, j); - new_row = GET_2D_ROW(new_weights, num_weights + 1, j); - for (k = 0; k < num_weights; k++) { - new_row[k] = (row[k] - means[k]) / sds[k]; - } - // set final row to 1/n to compute frequency - new_row[num_weights] = 1.0 / (double) num_samples; - } - - ret = tsk_treeseq_general_stat(self, num_weights + 1, new_weights, num_weights, - trait_correlation_summary_func, &args, num_windows, windows, options, result); - out: - tsk_safe_free(means); - tsk_safe_free(meansqs); - tsk_safe_free(sds); - tsk_safe_free(new_weights); + tsk_safe_free(weights); + tsk_safe_free(result_tmp); + tsk_bit_array_free(&AB_samples); + tsk_bit_array_free(&B_samples_tmp); return ret; } static int -trait_linear_model_summary_func(tsk_size_t state_dim, const double *state, - tsk_size_t result_dim, double *result, void *params) +compute_two_tree_branch_stat(const tsk_treeseq_t *ts, const iter_state *l_state, + iter_state *r_state, general_stat_func_t *f, sample_count_stat_params_t *f_params, + tsk_size_t result_dim, tsk_size_t state_dim, double *result) { - covariates_stat_params_t args = *(covariates_stat_params_t *) params; - const double num_samples = (double) args.num_samples; - const tsk_size_t k = args.num_covariates; - const double *V = args.V; - ; - const double *x = state; - const double *v; - double m, a, denom, z; - tsk_size_t i, j; - // x[0], ..., x[result_dim - 1] contains the traits, W - // x[result_dim], ..., x[state_dim - 2] contains the covariates, Z - // x[state_dim - 1] has the number of samples below the node - - m = x[state_dim - 1]; - for (i = 0; i < result_dim; i++) { - if ((m > 0.0) && (m < num_samples)) { - v = GET_2D_ROW(V, k, i); - a = x[i]; - denom = m; - for (j = 0; j < k; j++) { - z = x[result_dim + j]; - a -= z * v[j]; - denom -= z * z; + int ret = 0; + tsk_id_t e, c, ec, p, *updated_nodes = NULL; + tsk_size_t j, k, n_updates; + const double *restrict time = ts->tables->nodes.time; + const tsk_id_t *restrict edges_child = ts->tables->edges.child; + const tsk_id_t *restrict edges_parent = ts->tables->edges.parent; + const tsk_size_t num_nodes = ts->tables->nodes.num_rows; + tsk_bit_array_t updates, row, ec_row, *r_samples = r_state->node_samples; + + tsk_memset(&updates, 0, sizeof(updates)); + ret = tsk_bit_array_init(&updates, num_nodes, 1); + if (ret != 0) { + goto out; + } + updated_nodes = tsk_calloc(num_nodes, sizeof(*updated_nodes)); + if (updated_nodes == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + // Identify modified nodes both added and removed + for (j = 0; j < r_state->n_edges_out + r_state->n_edges_in; j++) { + e = j < r_state->n_edges_out ? r_state->edges_out[j] + : r_state->edges_in[j - r_state->n_edges_out]; + p = edges_parent[e]; + c = edges_child[e]; + // Identify affected nodes above child + while (p != TSK_NULL) { + tsk_bit_array_add_bit(&updates, (tsk_bit_array_value_t) c); + c = p; + p = r_state->parent[p]; + } + } + // Subtract the whole contribution from the child node + tsk_bit_array_get_items(&updates, updated_nodes, &n_updates); + while (n_updates != 0) { + n_updates--; + c = updated_nodes[n_updates]; + compute_two_tree_branch_state_update( + ts, c, l_state, r_state, state_dim, result_dim, -1, f, f_params, result); + } + // Remove samples under nodes from removed edges to parent nodes + for (j = 0; j < r_state->n_edges_out; j++) { + e = r_state->edges_out[j]; + p = edges_parent[e]; + ec = edges_child[e]; // edge child + while (p != TSK_NULL) { + for (k = 0; k < state_dim; k++) { + tsk_bit_array_get_row( + r_samples, (state_dim * (tsk_size_t) ec) + k, &ec_row); + tsk_bit_array_get_row(r_samples, (state_dim * (tsk_size_t) p) + k, &row); + tsk_bit_array_subtract(&row, &ec_row); } - // denom is the length of projection of the trait onto the subspace - // spanned by the covariates, so if it is zero then the system is - // singular and the solution is nonunique. This numerical tolerance - // could be smaller without hitting floating-point error, but being - // a tiny bit conservative about when the trait is almost in the - // span of the covariates is probably good. - if (denom < 1e-8) { - result[i] = 0.0; - } else { - result[i] = (a * a) / (2 * denom * denom); + p = r_state->parent[p]; + } + r_state->branch_len[ec] = 0; + r_state->parent[ec] = TSK_NULL; + } + // Add samples under nodes from added edges + for (j = 0; j < r_state->n_edges_in; j++) { + e = r_state->edges_in[j]; + p = edges_parent[e]; + ec = c = edges_child[e]; + r_state->branch_len[c] = time[p] - time[c]; + r_state->parent[c] = p; + while (p != TSK_NULL) { + tsk_bit_array_add_bit(&updates, (tsk_bit_array_value_t) c); + for (k = 0; k < state_dim; k++) { + tsk_bit_array_get_row( + r_samples, (state_dim * (tsk_size_t) ec) + k, &ec_row); + tsk_bit_array_get_row(r_samples, (state_dim * (tsk_size_t) p) + k, &row); + tsk_bit_array_add(&row, &ec_row); } - } else { - result[i] = 0.0; + c = p; + p = r_state->parent[p]; } } - return 0; + // Update all affected child nodes (fully subtracted, deferred from addition) + n_updates = 0; + tsk_bit_array_get_items(&updates, updated_nodes, &n_updates); + while (n_updates != 0) { + n_updates--; + c = updated_nodes[n_updates]; + compute_two_tree_branch_state_update( + ts, c, l_state, r_state, state_dim, result_dim, +1, f, f_params, result); + } +out: + tsk_safe_free(updated_nodes); + tsk_bit_array_free(&updates); + return ret; } -int -tsk_treeseq_trait_linear_model(const tsk_treeseq_t *self, tsk_size_t num_weights, - const double *weights, tsk_size_t num_covariates, const double *covariates, - tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result) +static int +tsk_treeseq_two_branch_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, + const tsk_bit_array_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, + sample_count_stat_params_t *f_params, norm_func_t *TSK_UNUSED(norm_f), + tsk_size_t n_rows, const double *row_positions, tsk_size_t n_cols, + const double *col_positions, tsk_flags_t TSK_UNUSED(options), double *result) { - tsk_size_t num_samples = self->num_samples; - tsk_size_t i, j, k; - int ret; - const double *w, *z; - double *v, *new_row; - double *V = tsk_calloc(num_covariates * num_weights, sizeof(double)); - double *new_weights - = tsk_malloc((num_weights + num_covariates + 1) * num_samples * sizeof(double)); - - covariates_stat_params_t args - = { .num_samples = self->num_samples, .num_covariates = num_covariates, .V = V }; - - // We assume that the covariates have been *already standardised*, - // so that (a) 1 is in the span of the columns, and - // (b) their crossproduct is the identity. - // We could do this instead here with gsl linalg. + int ret = 0; + int r, c; + tsk_id_t *row_indexes = NULL, *col_indexes = NULL; + tsk_size_t i, j, k, row, col, *row_repeats = NULL, *col_repeats = NULL; + tsk_bit_array_t node_samples; + iter_state l_state, r_state; + double *result_tmp = NULL, *result_row; + const tsk_size_t num_nodes = self->tables->nodes.num_rows; - if (new_weights == NULL || V == NULL) { - ret = TSK_ERR_NO_MEMORY; + tsk_memset(&node_samples, 0, sizeof(node_samples)); + tsk_memset(&l_state, 0, sizeof(l_state)); + tsk_memset(&r_state, 0, sizeof(r_state)); + result_tmp = tsk_malloc(result_dim * sizeof(*result_tmp)); + if (result_tmp == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } - - if (num_weights < 1) { - ret = TSK_ERR_BAD_STATE_DIMS; + ret = iter_state_init(&l_state, self, state_dim); + if (ret != 0) { goto out; } - - // V = weights^T (matrix mult) covariates - for (k = 0; k < num_samples; k++) { - w = GET_2D_ROW(weights, num_weights, k); - z = GET_2D_ROW(covariates, num_covariates, k); - for (i = 0; i < num_weights; i++) { - v = GET_2D_ROW(V, num_covariates, i); - for (j = 0; j < num_covariates; j++) { - v[j] += w[i] * z[j]; - } - } + ret = iter_state_init(&r_state, self, state_dim); + if (ret != 0) { + goto out; } - - for (k = 0; k < num_samples; k++) { - w = GET_2D_ROW(weights, num_weights, k); - z = GET_2D_ROW(covariates, num_covariates, k); - new_row = GET_2D_ROW(new_weights, num_covariates + num_weights + 1, k); - for (i = 0; i < num_weights; i++) { - new_row[i] = w[i]; + ret = positions_to_tree_indexes(self, row_positions, n_rows, &row_indexes); + if (ret != 0) { + goto out; + } + ret = positions_to_tree_indexes(self, col_positions, n_cols, &col_indexes); + if (ret != 0) { + goto out; + } + ret = get_index_counts(row_indexes, n_rows, &row_repeats); + if (ret != 0) { + goto out; + } + ret = get_index_counts(col_indexes, n_cols, &col_repeats); + if (ret != 0) { + goto out; + } + ret = get_node_samples(self, state_dim, sample_sets, &node_samples); + if (ret != 0) { + goto out; + } + iter_state_clear(&l_state, state_dim, num_nodes, &node_samples); + row = 0; + for (r = 0; r < (row_indexes[n_rows ? n_rows - 1U : 0] - row_indexes[0] + 1); r++) { + tsk_memset(result_tmp, 0, result_dim * sizeof(*result_tmp)); + iter_state_clear(&r_state, state_dim, num_nodes, &node_samples); + ret = advance_collect_edges(&l_state, (tsk_id_t) r + row_indexes[0]); + if (ret != 0) { + goto out; } - for (i = 0; i < num_covariates; i++) { - new_row[i + num_weights] = z[i]; + result_row = GET_2D_ROW(result, result_dim * n_cols, row); + ret = compute_two_tree_branch_stat( + self, &r_state, &l_state, f, f_params, result_dim, state_dim, result_tmp); + if (ret != 0) { + goto out; } - // set final row to 1 to count alleles - new_row[num_weights + num_covariates] = 1.0; + col = 0; + for (c = 0; c < (col_indexes[n_cols ? n_cols - 1 : 0] - col_indexes[0] + 1); + c++) { + ret = advance_collect_edges(&r_state, (tsk_id_t) c + col_indexes[0]); + if (ret != 0) { + goto out; + } + ret = compute_two_tree_branch_stat(self, &l_state, &r_state, f, f_params, + result_dim, state_dim, result_tmp); + if (ret != 0) { + goto out; + } + for (i = 0; i < row_repeats[r]; i++) { + for (j = 0; j < col_repeats[c]; j++) { + result_row = GET_2D_ROW(result, result_dim * n_cols, row + i); + for (k = 0; k < result_dim; k++) { + result_row[col + (j * result_dim) + k] = result_tmp[k]; + } + } + } + col += (col_repeats[c] * result_dim); + } + row += row_repeats[r]; } - - ret = tsk_treeseq_general_stat(self, num_weights + num_covariates + 1, new_weights, - num_weights, trait_linear_model_summary_func, &args, num_windows, windows, - options, result); - out: - tsk_safe_free(V); - tsk_safe_free(new_weights); + tsk_safe_free(result_tmp); + tsk_safe_free(row_indexes); + tsk_safe_free(col_indexes); + tsk_safe_free(row_repeats); + tsk_safe_free(col_repeats); + iter_state_free(&l_state); + iter_state_free(&r_state); + tsk_bit_array_free(&node_samples); return ret; } -static int -segregating_sites_summary_func(tsk_size_t state_dim, const double *state, - tsk_size_t TSK_UNUSED(result_dim), double *result, void *params) -{ - sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; - const double *x = state; - double n; - tsk_size_t j; - - // this works because sum_{i=1}^k (1-p_i) = k-1 - for (j = 0; j < state_dim; j++) { - n = (double) args.sample_set_sizes[j]; - result[j] = (x[j] > 0) * (1 - x[j] / n); - } - return 0; -} - int -tsk_treeseq_segregating_sites(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, +tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result) + tsk_size_t result_dim, const tsk_id_t *set_indexes, general_stat_func_t *f, + norm_func_t *norm_f, tsk_size_t out_rows, const tsk_id_t *row_sites, + const double *row_positions, tsk_size_t out_cols, const tsk_id_t *col_sites, + const double *col_positions, tsk_flags_t options, double *result) { - return tsk_treeseq_sample_count_stat(self, num_sample_sets, sample_set_sizes, - sample_sets, num_sample_sets, NULL, segregating_sites_summary_func, num_windows, - windows, options, result); -} + // TODO: generalize this function if we ever decide to do weighted two_locus stats. + // We only implement count stats and therefore we don't handle weights. + int ret = 0; + tsk_bit_array_t sample_sets_bits; + bool stat_site = !!(options & TSK_STAT_SITE); + bool stat_branch = !!(options & TSK_STAT_BRANCH); + tsk_size_t state_dim = num_sample_sets; + sample_count_stat_params_t f_params = { .sample_sets = sample_sets, + .num_sample_sets = num_sample_sets, + .sample_set_sizes = sample_set_sizes, + .set_indexes = set_indexes }; -static int -Y1_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state, - tsk_size_t result_dim, double *result, void *params) -{ - sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; - const double *x = state; - double ni, denom, numer; - tsk_size_t i; + tsk_memset(&sample_sets_bits, 0, sizeof(sample_sets_bits)); - for (i = 0; i < result_dim; i++) { - ni = (double) args.sample_set_sizes[i]; - denom = ni * (ni - 1) * (ni - 2); - numer = x[i] * (ni - x[i]) * (ni - x[i] - 1); - result[i] = numer / denom; + // We do not support two-locus node stats + if (!!(options & TSK_STAT_NODE)) { + ret = tsk_trace_error(TSK_ERR_UNSUPPORTED_STAT_MODE); + goto out; } - return 0; -} - -int -tsk_treeseq_Y1(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result) -{ - return tsk_treeseq_sample_count_stat(self, num_sample_sets, sample_set_sizes, - sample_sets, num_sample_sets, NULL, Y1_summary_func, num_windows, windows, - options, result); -} - -/*********************************** - * Two way stats - ***********************************/ - -static int -check_sample_stat_inputs(tsk_size_t num_sample_sets, tsk_size_t tuple_size, - tsk_size_t num_index_tuples, const tsk_id_t *index_tuples) -{ - int ret = 0; - - if (num_sample_sets < tuple_size) { - ret = TSK_ERR_INSUFFICIENT_SAMPLE_SETS; + // If no mode is specified, we default to site mode + if (!(stat_site || stat_branch)) { + stat_site = true; + } + // It's an error to specify more than one mode + if (stat_site + stat_branch > 1) { + ret = tsk_trace_error(TSK_ERR_MULTIPLE_STAT_MODES); goto out; } - if (num_index_tuples < 1) { - ret = TSK_ERR_INSUFFICIENT_INDEX_TUPLES; + ret = tsk_treeseq_check_sample_sets( + self, num_sample_sets, sample_set_sizes, sample_sets); + if (ret != 0) { goto out; } - ret = check_set_indexes( - num_sample_sets, tuple_size * num_index_tuples, index_tuples); + if (result_dim < 1) { + ret = tsk_trace_error(TSK_ERR_BAD_RESULT_DIMS); + goto out; + } + ret = sample_sets_to_bit_array( + self, sample_set_sizes, sample_sets, num_sample_sets, &sample_sets_bits); if (ret != 0) { goto out; } + + if (stat_site) { + ret = check_sites(row_sites, out_rows, self->tables->sites.num_rows); + if (ret != 0) { + goto out; + } + ret = check_sites(col_sites, out_cols, self->tables->sites.num_rows); + if (ret != 0) { + goto out; + } + ret = tsk_treeseq_two_site_count_stat(self, state_dim, &sample_sets_bits, + result_dim, f, &f_params, norm_f, out_rows, row_sites, out_cols, col_sites, + options, result); + } else if (stat_branch) { + ret = check_positions( + row_positions, out_rows, tsk_treeseq_get_sequence_length(self)); + if (ret != 0) { + goto out; + } + ret = check_positions( + col_positions, out_cols, tsk_treeseq_get_sequence_length(self)); + if (ret != 0) { + goto out; + } + ret = tsk_treeseq_two_branch_count_stat(self, state_dim, &sample_sets_bits, + result_dim, f, &f_params, norm_f, out_rows, row_positions, out_cols, + col_positions, options, result); + } + out: + tsk_bit_array_free(&sample_sets_bits); return ret; } -static int -divergence_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state, - tsk_size_t result_dim, double *result, void *params) +/*********************************** + * Allele frequency spectrum + ***********************************/ + +static inline void +fold(tsk_size_t *restrict coordinate, const tsk_size_t *restrict dims, + tsk_size_t num_dims) { - sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; - const double *x = state; - double ni, nj, denom; - tsk_id_t i, j; tsk_size_t k; + double n = 0; + int s = 0; - for (k = 0; k < result_dim; k++) { - i = args.set_indexes[2 * k]; - j = args.set_indexes[2 * k + 1]; - ni = (double) args.sample_set_sizes[i]; - nj = (double) args.sample_set_sizes[j]; - denom = ni * (nj - (i == j)); - result[k] = x[i] * (nj - x[j]) / denom; + for (k = 0; k < num_dims; k++) { + tsk_bug_assert(coordinate[k] < dims[k]); + n += (double) dims[k] - 1; + s += (int) coordinate[k]; + } + n /= 2; + k = num_dims; + while (s == n && k > 0) { + k--; + n -= ((double) (dims[k] - 1)) / 2; + s -= (int) coordinate[k]; + } + if (s > n) { + for (k = 0; k < num_dims; k++) { + s = (int) (dims[k] - 1 - coordinate[k]); + tsk_bug_assert(s >= 0); + coordinate[k] = (tsk_size_t) s; + } } - return 0; } -int -tsk_treeseq_divergence(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_windows, - const double *windows, tsk_flags_t options, double *result) +static int +tsk_treeseq_update_site_afs(const tsk_treeseq_t *self, const tsk_site_t *site, + const double *total_counts, const double *counts, tsk_size_t num_sample_sets, + tsk_size_t window_index, tsk_size_t *result_dims, tsk_flags_t options, + double *result) { int ret = 0; - ret = check_sample_stat_inputs(num_sample_sets, 2, num_index_tuples, index_tuples); + tsk_size_t afs_size; + tsk_size_t k, allele, num_alleles, all_samples; + double increment, *afs, *allele_counts, *allele_count; + tsk_size_t *coordinate = tsk_malloc(num_sample_sets * sizeof(*coordinate)); + bool polarised = !!(options & TSK_STAT_POLARISED); + const tsk_size_t K = num_sample_sets + 1; + + if (coordinate == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + ret = get_allele_weights( + site, counts, K, total_counts, &num_alleles, &allele_counts); if (ret != 0) { goto out; } - ret = tsk_treeseq_sample_count_stat(self, num_sample_sets, sample_set_sizes, - sample_sets, num_index_tuples, index_tuples, divergence_summary_func, - num_windows, windows, options, result); + + afs_size = result_dims[num_sample_sets]; + afs = result + afs_size * window_index; + + increment = polarised ? 1 : 0.5; + /* Sum over the allele weights. Skip the ancestral state if polarised. */ + for (allele = polarised ? 1 : 0; allele < num_alleles; allele++) { + allele_count = GET_2D_ROW(allele_counts, K, allele); + all_samples = (tsk_size_t) allele_count[num_sample_sets]; + if (all_samples > 0 && all_samples < self->num_samples) { + for (k = 0; k < num_sample_sets; k++) { + coordinate[k] = (tsk_size_t) allele_count[k]; + } + if (!polarised) { + fold(coordinate, result_dims, num_sample_sets); + } + increment_nd_array_value( + afs, num_sample_sets, result_dims, coordinate, increment); + } + } out: + tsk_safe_free(coordinate); + tsk_safe_free(allele_counts); return ret; } static int -genetic_relatedness_summary_func(tsk_size_t state_dim, const double *state, - tsk_size_t result_dim, double *result, void *params) +tsk_treeseq_site_allele_frequency_spectrum(const tsk_treeseq_t *self, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, double *counts, + tsk_size_t num_windows, const double *windows, tsk_size_t *result_dims, + tsk_flags_t options, double *result) { - sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; - const double *x = state; - tsk_id_t i, j; - tsk_size_t k; - double sumx = 0; - double sumn = 0; - double meanx, ni, nj; + int ret = 0; + tsk_id_t u, v; + tsk_size_t tree_site, tree_index, window_index; + tsk_size_t num_nodes = self->tables->nodes.num_rows; + const tsk_id_t num_edges = (tsk_id_t) self->tables->edges.num_rows; + const tsk_id_t *restrict I = self->tables->indexes.edge_insertion_order; + const tsk_id_t *restrict O = self->tables->indexes.edge_removal_order; + const double *restrict edge_left = self->tables->edges.left; + const double *restrict edge_right = self->tables->edges.right; + const tsk_id_t *restrict edge_parent = self->tables->edges.parent; + const tsk_id_t *restrict edge_child = self->tables->edges.child; + const double sequence_length = self->tables->sequence_length; + tsk_id_t *restrict parent = tsk_malloc(num_nodes * sizeof(*parent)); + tsk_site_t *site; + tsk_id_t tj, tk, h; + tsk_size_t j; + const tsk_size_t K = num_sample_sets + 1; + double t_left, t_right; + double *total_counts = tsk_malloc((1 + num_sample_sets) * sizeof(*total_counts)); - for (k = 0; k < state_dim; k++) { - sumx += x[k]; - sumn += (double) args.sample_set_sizes[k]; + if (parent == NULL || total_counts == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; } + tsk_memset(parent, 0xff, num_nodes * sizeof(*parent)); - meanx = sumx / sumn; - for (k = 0; k < result_dim; k++) { - i = args.set_indexes[2 * k]; - j = args.set_indexes[2 * k + 1]; - ni = (double) args.sample_set_sizes[i]; - nj = (double) args.sample_set_sizes[j]; - result[k] = (x[i] - ni * meanx) * (x[j] - nj * meanx) / 2; + for (j = 0; j < num_sample_sets; j++) { + total_counts[j] = (double) sample_set_sizes[j]; } - return 0; -} + total_counts[num_sample_sets] = (double) self->num_samples; -int -tsk_treeseq_genetic_relatedness(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_windows, - const double *windows, tsk_flags_t options, double *result) -{ - int ret = 0; - ret = check_sample_stat_inputs(num_sample_sets, 2, num_index_tuples, index_tuples); - if (ret != 0) { - goto out; + /* Iterate over the trees */ + tj = 0; + tk = 0; + t_left = 0; + tree_index = 0; + window_index = 0; + while (tj < num_edges || t_left < sequence_length) { + while (tk < num_edges && edge_right[O[tk]] == t_left) { + h = O[tk]; + tk++; + u = edge_child[h]; + v = edge_parent[h]; + while (v != TSK_NULL) { + update_state(counts, K, v, u, -1); + v = parent[v]; + } + parent[u] = TSK_NULL; + } + + while (tj < num_edges && edge_left[I[tj]] == t_left) { + h = I[tj]; + tj++; + u = edge_child[h]; + v = edge_parent[h]; + parent[u] = v; + while (v != TSK_NULL) { + update_state(counts, K, v, u, +1); + v = parent[v]; + } + } + t_right = sequence_length; + if (tj < num_edges) { + t_right = TSK_MIN(t_right, edge_left[I[tj]]); + } + if (tk < num_edges) { + t_right = TSK_MIN(t_right, edge_right[O[tk]]); + } + + /* Update the sites */ + for (tree_site = 0; tree_site < self->tree_sites_length[tree_index]; + tree_site++) { + site = self->tree_sites[tree_index] + tree_site; + while (windows[window_index + 1] <= site->position) { + window_index++; + tsk_bug_assert(window_index < num_windows); + } + ret = tsk_treeseq_update_site_afs(self, site, total_counts, counts, + num_sample_sets, window_index, result_dims, options, result); + if (ret != 0) { + goto out; + } + tsk_bug_assert(windows[window_index] <= site->position); + tsk_bug_assert(site->position < windows[window_index + 1]); + } + tree_index++; + t_left = t_right; } - ret = tsk_treeseq_sample_count_stat(self, num_sample_sets, sample_set_sizes, - sample_sets, num_index_tuples, index_tuples, genetic_relatedness_summary_func, - num_windows, windows, options, result); out: + /* Can't use msp_safe_free here because of restrict */ + if (parent != NULL) { + free(parent); + } + tsk_safe_free(total_counts); return ret; } -static int -Y2_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state, - tsk_size_t result_dim, double *result, void *params) +static void +tsk_treeseq_update_branch_afs(const tsk_treeseq_t *self, tsk_id_t u, double right, + double *restrict last_update, const double *restrict time, tsk_id_t *restrict parent, + tsk_size_t *restrict coordinate, const double *counts, tsk_size_t num_sample_sets, + tsk_size_t num_time_windows, const double *time_windows, tsk_size_t window_index, + const tsk_size_t *result_dims, tsk_flags_t options, double *result) { - sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; - const double *x = state; - double ni, nj, denom; - tsk_id_t i, j; + tsk_size_t afs_size; tsk_size_t k; - - for (k = 0; k < result_dim; k++) { - i = args.set_indexes[2 * k]; - j = args.set_indexes[2 * k + 1]; - ni = (double) args.sample_set_sizes[i]; - nj = (double) args.sample_set_sizes[j]; - denom = ni * nj * (nj - 1); - result[k] = x[i] * (nj - x[j]) * (nj - x[j] - 1) / denom; + tsk_size_t time_window_index; + double *afs; + bool polarised = !!(options & TSK_STAT_POLARISED); + const double *count_row = GET_2D_ROW(counts, num_sample_sets + 1, u); + double x = 0; + double t_u, t_v; + double tw_branch_length = 0; + const tsk_size_t all_samples = (tsk_size_t) count_row[num_sample_sets]; + if (parent[u] != TSK_NULL) { + t_u = time[u]; + t_v = time[parent[u]]; + if (0 < all_samples && all_samples < self->num_samples) { + time_window_index = 0; + afs_size = result_dims[num_sample_sets]; + while (time_window_index < num_time_windows + && time_windows[time_window_index] < t_v) { + afs = result + + afs_size * (window_index * num_time_windows + time_window_index); + for (k = 0; k < num_sample_sets; k++) { + coordinate[k] = (tsk_size_t) count_row[k]; + } + if (!polarised) { + fold(coordinate, result_dims, num_sample_sets); + } + tw_branch_length + = TSK_MAX(0.0, TSK_MIN(time_windows[time_window_index + 1], t_v) + - TSK_MAX(time_windows[time_window_index], t_u)); + x = (right - last_update[u]) * tw_branch_length; + increment_nd_array_value( + afs, num_sample_sets, result_dims, coordinate, x); + time_window_index++; + } + } } - return 0; + last_update[u] = right; } -int -tsk_treeseq_Y2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_windows, - const double *windows, tsk_flags_t options, double *result) +static int +tsk_treeseq_branch_allele_frequency_spectrum(const tsk_treeseq_t *self, + tsk_size_t num_sample_sets, double *counts, tsk_size_t num_windows, + const double *windows, tsk_size_t num_time_windows, const double *time_windows, + const tsk_size_t *result_dims, tsk_flags_t options, double *result) { int ret = 0; - ret = check_sample_stat_inputs(num_sample_sets, 2, num_index_tuples, index_tuples); - if (ret != 0) { - goto out; - } - ret = tsk_treeseq_sample_count_stat(self, num_sample_sets, sample_set_sizes, - sample_sets, num_index_tuples, index_tuples, Y2_summary_func, num_windows, - windows, options, result); -out: - return ret; -} - -static int -f2_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state, - tsk_size_t result_dim, double *result, void *params) -{ - sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; - const double *x = state; - double ni, nj, denom, numer; - tsk_id_t i, j; - tsk_size_t k; + tsk_id_t u, v; + tsk_size_t window_index; + tsk_size_t num_nodes = self->tables->nodes.num_rows; + const tsk_id_t num_edges = (tsk_id_t) self->tables->edges.num_rows; + const tsk_id_t *restrict I = self->tables->indexes.edge_insertion_order; + const tsk_id_t *restrict O = self->tables->indexes.edge_removal_order; + const double *restrict edge_left = self->tables->edges.left; + const double *restrict edge_right = self->tables->edges.right; + const tsk_id_t *restrict edge_parent = self->tables->edges.parent; + const tsk_id_t *restrict edge_child = self->tables->edges.child; + const double *restrict node_time = self->tables->nodes.time; + const double sequence_length = self->tables->sequence_length; + tsk_id_t *restrict parent = tsk_malloc(num_nodes * sizeof(*parent)); + double *restrict last_update = tsk_calloc(num_nodes, sizeof(*last_update)); + double *restrict branch_length = tsk_calloc(num_nodes, sizeof(*branch_length)); + tsk_size_t *restrict coordinate = tsk_malloc(num_sample_sets * sizeof(*coordinate)); + tsk_id_t tj, tk, h; + double t_left, t_right, w_right; + const tsk_size_t K = num_sample_sets + 1; - for (k = 0; k < result_dim; k++) { - i = args.set_indexes[2 * k]; - j = args.set_indexes[2 * k + 1]; - ni = (double) args.sample_set_sizes[i]; - nj = (double) args.sample_set_sizes[j]; - denom = ni * (ni - 1) * nj * (nj - 1); - numer = x[i] * (x[i] - 1) * (nj - x[j]) * (nj - x[j] - 1) - - x[i] * (ni - x[i]) * (nj - x[j]) * x[j]; - result[k] = numer / denom; + if (self->time_uncalibrated && !(options & TSK_STAT_ALLOW_TIME_UNCALIBRATED)) { + ret = tsk_trace_error(TSK_ERR_TIME_UNCALIBRATED); + goto out; } - return 0; + + if (parent == NULL || last_update == NULL || coordinate == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + tsk_memset(parent, 0xff, num_nodes * sizeof(*parent)); + + /* Iterate over the trees */ + tj = 0; + tk = 0; + t_left = 0; + window_index = 0; + while (tj < num_edges || t_left < sequence_length) { + tsk_bug_assert(window_index < num_windows); + while (tk < num_edges && edge_right[O[tk]] == t_left) { + h = O[tk]; + tk++; + u = edge_child[h]; + v = edge_parent[h]; + tsk_treeseq_update_branch_afs(self, u, t_left, last_update, node_time, + parent, coordinate, counts, num_sample_sets, num_time_windows, + time_windows, window_index, result_dims, options, result); + while (v != TSK_NULL) { + tsk_treeseq_update_branch_afs(self, v, t_left, last_update, node_time, + parent, coordinate, counts, num_sample_sets, num_time_windows, + time_windows, window_index, result_dims, options, result); + update_state(counts, K, v, u, -1); + v = parent[v]; + } + parent[u] = TSK_NULL; + branch_length[u] = 0; + } + + while (tj < num_edges && edge_left[I[tj]] == t_left) { + h = I[tj]; + tj++; + u = edge_child[h]; + v = edge_parent[h]; + parent[u] = v; + branch_length[u] = node_time[v] - node_time[u]; + while (v != TSK_NULL) { + tsk_treeseq_update_branch_afs(self, v, t_left, last_update, node_time, + parent, coordinate, counts, num_sample_sets, num_time_windows, + time_windows, window_index, result_dims, options, result); + update_state(counts, K, v, u, +1); + v = parent[v]; + } + } + + t_right = sequence_length; + if (tj < num_edges) { + t_right = TSK_MIN(t_right, edge_left[I[tj]]); + } + if (tk < num_edges) { + t_right = TSK_MIN(t_right, edge_right[O[tk]]); + } + + while (window_index < num_windows && windows[window_index + 1] <= t_right) { + w_right = windows[window_index + 1]; + /* Flush the contributions of all nodes to the current window */ + for (u = 0; u < (tsk_id_t) num_nodes; u++) { + tsk_bug_assert(last_update[u] < w_right); + tsk_treeseq_update_branch_afs(self, u, w_right, last_update, node_time, + parent, coordinate, counts, num_sample_sets, num_time_windows, + time_windows, window_index, result_dims, options, result); + } + window_index++; + } + + t_left = t_right; + } +out: + /* Can't use msp_safe_free here because of restrict */ + if (parent != NULL) { + free(parent); + } + if (last_update != NULL) { + free(last_update); + } + if (branch_length != NULL) { + free(branch_length); + } + if (coordinate != NULL) { + free(coordinate); + } + return ret; } int -tsk_treeseq_f2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_windows, - const double *windows, tsk_flags_t options, double *result) +tsk_treeseq_allele_frequency_spectrum(const tsk_treeseq_t *self, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t num_windows, const double *windows, + tsk_size_t num_time_windows, const double *time_windows, tsk_flags_t options, + double *result) { int ret = 0; - ret = check_sample_stat_inputs(num_sample_sets, 2, num_index_tuples, index_tuples); + bool stat_site = !!(options & TSK_STAT_SITE); + bool stat_branch = !!(options & TSK_STAT_BRANCH); + bool stat_node = !!(options & TSK_STAT_NODE); + const double default_windows[] = { 0, self->tables->sequence_length }; + const double default_time_windows[] = { 0, INFINITY }; + const tsk_size_t num_nodes = self->tables->nodes.num_rows; + const tsk_size_t K = num_sample_sets + 1; + tsk_size_t j, k, l, afs_size; + tsk_id_t u; + tsk_size_t *result_dims = NULL; + /* These counts should really be ints, but we use doubles so that we can + * reuse code from the general_stats code paths. */ + double *counts = NULL; + double *count_row; + if (stat_node) { + ret = tsk_trace_error(TSK_ERR_UNSUPPORTED_STAT_MODE); + goto out; + } + /* If no mode is specified, we default to site mode */ + if (!(stat_site || stat_branch)) { + stat_site = true; + } + /* It's an error to specify more than one mode */ + if (stat_site + stat_branch > 1) { + ret = tsk_trace_error(TSK_ERR_MULTIPLE_STAT_MODES); + goto out; + } + if (windows == NULL) { + num_windows = 1; + windows = default_windows; + } else { + ret = tsk_treeseq_check_windows( + self, num_windows, windows, TSK_REQUIRE_FULL_SPAN); + if (ret != 0) { + goto out; + } + } + if (time_windows == NULL) { + num_time_windows = 1; + time_windows = default_time_windows; + } else { + ret = tsk_treeseq_check_time_windows(num_time_windows, time_windows); + if (ret != 0) { + goto out; + } + // Site mode does not support time windows + if (stat_site && !(time_windows[0] == 0.0 && isinf((float) time_windows[1]))) { + ret = TSK_ERR_UNSUPPORTED_STAT_MODE; + goto out; + } + } + ret = tsk_treeseq_check_sample_sets( + self, num_sample_sets, sample_set_sizes, sample_sets); if (ret != 0) { goto out; } - ret = tsk_treeseq_sample_count_stat(self, num_sample_sets, sample_set_sizes, - sample_sets, num_index_tuples, index_tuples, f2_summary_func, num_windows, - windows, options, result); + + /* the last element of result_dims stores the total size of the dimensions */ + result_dims = tsk_malloc((num_sample_sets + 1) * sizeof(*result_dims)); + counts = tsk_calloc(num_nodes * K, sizeof(*counts)); + if (counts == NULL || result_dims == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + afs_size = 1; + j = 0; + for (k = 0; k < num_sample_sets; k++) { + result_dims[k] = 1 + sample_set_sizes[k]; + afs_size *= result_dims[k]; + for (l = 0; l < sample_set_sizes[k]; l++) { + u = sample_sets[j]; + count_row = GET_2D_ROW(counts, K, u); + if (count_row[k] != 0) { + ret = tsk_trace_error(TSK_ERR_DUPLICATE_SAMPLE); + goto out; + } + count_row[k] = 1; + j++; + } + } + for (j = 0; j < self->num_samples; j++) { + u = self->samples[j]; + count_row = GET_2D_ROW(counts, K, u); + count_row[num_sample_sets] = 1; + } + result_dims[num_sample_sets] = (tsk_size_t) afs_size; + tsk_memset(result, 0, num_windows * num_time_windows * afs_size * sizeof(*result)); + + if (stat_site) { + ret = tsk_treeseq_site_allele_frequency_spectrum(self, num_sample_sets, + sample_set_sizes, counts, num_windows, windows, result_dims, options, + result); + } else { + ret = tsk_treeseq_branch_allele_frequency_spectrum(self, num_sample_sets, counts, + num_windows, windows, num_time_windows, time_windows, result_dims, options, + result); + } + + if (options & TSK_STAT_SPAN_NORMALISE) { + span_normalise(num_windows, windows, afs_size * num_time_windows, result); + } out: + tsk_safe_free(counts); + tsk_safe_free(result_dims); return ret; } /*********************************** - * Three way stats + * One way stats ***********************************/ static int -Y3_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state, - tsk_size_t result_dim, double *result, void *params) +diversity_summary_func(tsk_size_t state_dim, const double *state, + tsk_size_t TSK_UNUSED(result_dim), double *result, void *params) { sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; const double *x = state; - double ni, nj, nk, denom, numer; - tsk_id_t i, j, k; - tsk_size_t tuple_index; + double n; + tsk_size_t j; - for (tuple_index = 0; tuple_index < result_dim; tuple_index++) { - i = args.set_indexes[3 * tuple_index]; - j = args.set_indexes[3 * tuple_index + 1]; - k = args.set_indexes[3 * tuple_index + 2]; - ni = (double) args.sample_set_sizes[i]; - nj = (double) args.sample_set_sizes[j]; - nk = (double) args.sample_set_sizes[k]; - denom = ni * nj * nk; - numer = x[i] * (nj - x[j]) * (nk - x[k]); - result[tuple_index] = numer / denom; + for (j = 0; j < state_dim; j++) { + n = (double) args.sample_set_sizes[j]; + result[j] = x[j] * (n - x[j]) / (n * (n - 1)); } return 0; } int -tsk_treeseq_Y3(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, +tsk_treeseq_diversity(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_windows, - const double *windows, tsk_flags_t options, double *result) + tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result) { - int ret = 0; - ret = check_sample_stat_inputs(num_sample_sets, 3, num_index_tuples, index_tuples); - if (ret != 0) { - goto out; - } - ret = tsk_treeseq_sample_count_stat(self, num_sample_sets, sample_set_sizes, - sample_sets, num_index_tuples, index_tuples, Y3_summary_func, num_windows, - windows, options, result); -out: - return ret; + return tsk_treeseq_sample_count_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_sample_sets, NULL, diversity_summary_func, num_windows, windows, + options, result); } static int -f3_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state, - tsk_size_t result_dim, double *result, void *params) +trait_covariance_summary_func(tsk_size_t state_dim, const double *state, + tsk_size_t TSK_UNUSED(result_dim), double *result, void *params) { - sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + weight_stat_params_t args = *(weight_stat_params_t *) params; + const double n = (double) args.num_samples; const double *x = state; - double ni, nj, nk, denom, numer; - tsk_id_t i, j, k; - tsk_size_t tuple_index; + tsk_size_t j; - for (tuple_index = 0; tuple_index < result_dim; tuple_index++) { - i = args.set_indexes[3 * tuple_index]; - j = args.set_indexes[3 * tuple_index + 1]; - k = args.set_indexes[3 * tuple_index + 2]; - ni = (double) args.sample_set_sizes[i]; - nj = (double) args.sample_set_sizes[j]; - nk = (double) args.sample_set_sizes[k]; - denom = ni * (ni - 1) * nj * nk; - numer = x[i] * (x[i] - 1) * (nj - x[j]) * (nk - x[k]) - - x[i] * (ni - x[i]) * (nj - x[j]) * x[k]; - result[tuple_index] = numer / denom; + for (j = 0; j < state_dim; j++) { + result[j] = (x[j] * x[j]) / (2 * (n - 1) * (n - 1)); } return 0; } int -tsk_treeseq_f3(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_windows, - const double *windows, tsk_flags_t options, double *result) +tsk_treeseq_trait_covariance(const tsk_treeseq_t *self, tsk_size_t num_weights, + const double *weights, tsk_size_t num_windows, const double *windows, + tsk_flags_t options, double *result) { - int ret = 0; - ret = check_sample_stat_inputs(num_sample_sets, 3, num_index_tuples, index_tuples); - if (ret != 0) { + tsk_size_t num_samples = self->num_samples; + tsk_size_t j, k; + int ret; + const double *row; + double *new_row; + double *means = tsk_calloc(num_weights, sizeof(double)); + double *new_weights = tsk_malloc((num_weights + 1) * num_samples * sizeof(double)); + weight_stat_params_t args = { num_samples = self->num_samples }; + + if (new_weights == NULL || means == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } - ret = tsk_treeseq_sample_count_stat(self, num_sample_sets, sample_set_sizes, - sample_sets, num_index_tuples, index_tuples, f3_summary_func, num_windows, - windows, options, result); + if (num_weights == 0) { + ret = tsk_trace_error(TSK_ERR_INSUFFICIENT_WEIGHTS); + goto out; + } + + // center weights + for (j = 0; j < num_samples; j++) { + row = GET_2D_ROW(weights, num_weights, j); + for (k = 0; k < num_weights; k++) { + means[k] += row[k]; + } + } + for (k = 0; k < num_weights; k++) { + means[k] /= (double) num_samples; + } + for (j = 0; j < num_samples; j++) { + row = GET_2D_ROW(weights, num_weights, j); + new_row = GET_2D_ROW(new_weights, num_weights, j); + for (k = 0; k < num_weights; k++) { + new_row[k] = row[k] - means[k]; + } + } + + ret = tsk_treeseq_general_stat(self, num_weights, new_weights, num_weights, + trait_covariance_summary_func, &args, num_windows, windows, options, result); + out: + tsk_safe_free(means); + tsk_safe_free(new_weights); return ret; } -/*********************************** - * Four way stats - ***********************************/ - static int -f4_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state, - tsk_size_t result_dim, double *result, void *params) +trait_correlation_summary_func(tsk_size_t state_dim, const double *state, + tsk_size_t TSK_UNUSED(result_dim), double *result, void *params) { - sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + weight_stat_params_t args = *(weight_stat_params_t *) params; + const double n = (double) args.num_samples; const double *x = state; - double ni, nj, nk, nl, denom, numer; - tsk_id_t i, j, k, l; - tsk_size_t tuple_index; + double p; + tsk_size_t j; - for (tuple_index = 0; tuple_index < result_dim; tuple_index++) { - i = args.set_indexes[4 * tuple_index]; - j = args.set_indexes[4 * tuple_index + 1]; - k = args.set_indexes[4 * tuple_index + 2]; - l = args.set_indexes[4 * tuple_index + 3]; - ni = (double) args.sample_set_sizes[i]; - nj = (double) args.sample_set_sizes[j]; - nk = (double) args.sample_set_sizes[k]; - nl = (double) args.sample_set_sizes[l]; - denom = ni * nj * nk * nl; - numer = x[i] * x[k] * (nj - x[j]) * (nl - x[l]) - - x[i] * x[l] * (nj - x[j]) * (nk - x[k]); - result[tuple_index] = numer / denom; + p = x[state_dim - 1]; + for (j = 0; j < state_dim - 1; j++) { + if ((p > 0.0) && (p < 1.0)) { + result[j] = (x[j] * x[j]) / (2 * (p * (1 - p)) * n * (n - 1)); + } else { + result[j] = 0.0; + } } return 0; } int -tsk_treeseq_f4(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_windows, - const double *windows, tsk_flags_t options, double *result) +tsk_treeseq_trait_correlation(const tsk_treeseq_t *self, tsk_size_t num_weights, + const double *weights, tsk_size_t num_windows, const double *windows, + tsk_flags_t options, double *result) { - int ret = 0; - ret = check_sample_stat_inputs(num_sample_sets, 4, num_index_tuples, index_tuples); - if (ret != 0) { + tsk_size_t num_samples = self->num_samples; + tsk_size_t j, k; + int ret; + double *means = tsk_calloc(num_weights, sizeof(double)); + double *meansqs = tsk_calloc(num_weights, sizeof(double)); + double *sds = tsk_calloc(num_weights, sizeof(double)); + const double *row; + double *new_row; + double *new_weights = tsk_malloc((num_weights + 1) * num_samples * sizeof(double)); + weight_stat_params_t args = { num_samples = self->num_samples }; + + if (new_weights == NULL || means == NULL || meansqs == NULL || sds == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } - ret = tsk_treeseq_sample_count_stat(self, num_sample_sets, sample_set_sizes, - sample_sets, num_index_tuples, index_tuples, f4_summary_func, num_windows, - windows, options, result); -out: - return ret; -} -/* Error-raising getter functions */ + if (num_weights < 1) { + ret = tsk_trace_error(TSK_ERR_INSUFFICIENT_WEIGHTS); + goto out; + } -int TSK_WARN_UNUSED -tsk_treeseq_get_node(const tsk_treeseq_t *self, tsk_id_t index, tsk_node_t *node) -{ - return tsk_node_table_get_row(&self->tables->nodes, index, node); -} + // center and scale weights + for (j = 0; j < num_samples; j++) { + row = GET_2D_ROW(weights, num_weights, j); + for (k = 0; k < num_weights; k++) { + means[k] += row[k]; + meansqs[k] += row[k] * row[k]; + } + } + for (k = 0; k < num_weights; k++) { + means[k] /= (double) num_samples; + meansqs[k] -= means[k] * means[k] * (double) num_samples; + meansqs[k] /= (double) (num_samples - 1); + sds[k] = sqrt(meansqs[k]); + } + for (j = 0; j < num_samples; j++) { + row = GET_2D_ROW(weights, num_weights, j); + new_row = GET_2D_ROW(new_weights, num_weights + 1, j); + for (k = 0; k < num_weights; k++) { + new_row[k] = (row[k] - means[k]) / sds[k]; + } + // set final row to 1/n to compute frequency + new_row[num_weights] = 1.0 / (double) num_samples; + } -int TSK_WARN_UNUSED -tsk_treeseq_get_edge(const tsk_treeseq_t *self, tsk_id_t index, tsk_edge_t *edge) -{ - return tsk_edge_table_get_row(&self->tables->edges, index, edge); -} + ret = tsk_treeseq_general_stat(self, num_weights + 1, new_weights, num_weights, + trait_correlation_summary_func, &args, num_windows, windows, options, result); -int TSK_WARN_UNUSED -tsk_treeseq_get_migration( - const tsk_treeseq_t *self, tsk_id_t index, tsk_migration_t *migration) -{ - return tsk_migration_table_get_row(&self->tables->migrations, index, migration); +out: + tsk_safe_free(means); + tsk_safe_free(meansqs); + tsk_safe_free(sds); + tsk_safe_free(new_weights); + return ret; } -int TSK_WARN_UNUSED -tsk_treeseq_get_mutation( - const tsk_treeseq_t *self, tsk_id_t index, tsk_mutation_t *mutation) +static int +trait_linear_model_summary_func(tsk_size_t state_dim, const double *state, + tsk_size_t result_dim, double *result, void *params) { - int ret = 0; + covariates_stat_params_t args = *(covariates_stat_params_t *) params; + const double num_samples = (double) args.num_samples; + const tsk_size_t k = args.num_covariates; + const double *V = args.V; + ; + const double *x = state; + const double *v; + double m, a, denom, z; + tsk_size_t i, j; + // x[0], ..., x[result_dim - 1] contains the traits, W + // x[result_dim], ..., x[state_dim - 2] contains the covariates, Z + // x[state_dim - 1] has the number of samples below the node - ret = tsk_mutation_table_get_row(&self->tables->mutations, index, mutation); - if (ret != 0) { - goto out; + m = x[state_dim - 1]; + for (i = 0; i < result_dim; i++) { + if ((m > 0.0) && (m < num_samples)) { + v = GET_2D_ROW(V, k, i); + a = x[i]; + denom = m; + for (j = 0; j < k; j++) { + z = x[result_dim + j]; + a -= z * v[j]; + denom -= z * z; + } + // denom is the length of projection of the trait onto the subspace + // spanned by the covariates, so if it is zero then the system is + // singular and the solution is nonunique. This numerical tolerance + // could be smaller without hitting floating-point error, but being + // a tiny bit conservative about when the trait is almost in the + // span of the covariates is probably good. + if (denom < 1e-8) { + result[i] = 0.0; + } else { + result[i] = (a * a) / (2 * denom * denom); + } + } else { + result[i] = 0.0; + } } - mutation->edge = self->site_mutations_mem[index].edge; -out: - return ret; + return 0; } -int TSK_WARN_UNUSED -tsk_treeseq_get_site(const tsk_treeseq_t *self, tsk_id_t index, tsk_site_t *site) +int +tsk_treeseq_trait_linear_model(const tsk_treeseq_t *self, tsk_size_t num_weights, + const double *weights, tsk_size_t num_covariates, const double *covariates, + tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result) { - int ret = 0; + tsk_size_t num_samples = self->num_samples; + tsk_size_t i, j, k; + int ret; + const double *w, *z; + double *v, *new_row; + double *V = tsk_calloc(num_covariates * num_weights, sizeof(double)); + double *new_weights + = tsk_malloc((num_weights + num_covariates + 1) * num_samples * sizeof(double)); - ret = tsk_site_table_get_row(&self->tables->sites, index, site); - if (ret != 0) { + covariates_stat_params_t args + = { .num_samples = self->num_samples, .num_covariates = num_covariates, .V = V }; + + // We assume that the covariates have been *already standardised*, + // so that (a) 1 is in the span of the columns, and + // (b) their crossproduct is the identity. + // We could do this instead here with gsl linalg. + + if (new_weights == NULL || V == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } - site->mutations = self->site_mutations[index]; - site->mutations_length = self->site_mutations_length[index]; -out: - return ret; -} -int TSK_WARN_UNUSED -tsk_treeseq_get_individual( - const tsk_treeseq_t *self, tsk_id_t index, tsk_individual_t *individual) -{ - int ret = 0; - - ret = tsk_individual_table_get_row(&self->tables->individuals, index, individual); - if (ret != 0) { + if (num_weights < 1) { + ret = tsk_trace_error(TSK_ERR_INSUFFICIENT_WEIGHTS); goto out; } - individual->nodes = self->individual_nodes[index]; - individual->nodes_length = self->individual_nodes_length[index]; + + // V = weights^T (matrix mult) covariates + for (k = 0; k < num_samples; k++) { + w = GET_2D_ROW(weights, num_weights, k); + z = GET_2D_ROW(covariates, num_covariates, k); + for (i = 0; i < num_weights; i++) { + v = GET_2D_ROW(V, num_covariates, i); + for (j = 0; j < num_covariates; j++) { + v[j] += w[i] * z[j]; + } + } + } + + for (k = 0; k < num_samples; k++) { + w = GET_2D_ROW(weights, num_weights, k); + z = GET_2D_ROW(covariates, num_covariates, k); + new_row = GET_2D_ROW(new_weights, num_covariates + num_weights + 1, k); + for (i = 0; i < num_weights; i++) { + new_row[i] = w[i]; + } + for (i = 0; i < num_covariates; i++) { + new_row[i + num_weights] = z[i]; + } + // set final row to 1 to count alleles + new_row[num_weights + num_covariates] = 1.0; + } + + ret = tsk_treeseq_general_stat(self, num_weights + num_covariates + 1, new_weights, + num_weights, trait_linear_model_summary_func, &args, num_windows, windows, + options, result); + out: + tsk_safe_free(V); + tsk_safe_free(new_weights); return ret; } -int TSK_WARN_UNUSED -tsk_treeseq_get_population( - const tsk_treeseq_t *self, tsk_id_t index, tsk_population_t *population) +static int +segregating_sites_summary_func(tsk_size_t state_dim, const double *state, + tsk_size_t TSK_UNUSED(result_dim), double *result, void *params) { - return tsk_population_table_get_row(&self->tables->populations, index, population); + sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + const double *x = state; + double n; + tsk_size_t j; + + // this works because sum_{i=1}^k (1-p_i) = k-1 + for (j = 0; j < state_dim; j++) { + n = (double) args.sample_set_sizes[j]; + result[j] = (x[j] > 0) * (1 - x[j] / n); + } + return 0; } -int TSK_WARN_UNUSED -tsk_treeseq_get_provenance( - const tsk_treeseq_t *self, tsk_id_t index, tsk_provenance_t *provenance) +int +tsk_treeseq_segregating_sites(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result) { - return tsk_provenance_table_get_row(&self->tables->provenances, index, provenance); + return tsk_treeseq_sample_count_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_sample_sets, NULL, segregating_sites_summary_func, num_windows, + windows, options, result); } -int TSK_WARN_UNUSED -tsk_treeseq_simplify(const tsk_treeseq_t *self, const tsk_id_t *samples, - tsk_size_t num_samples, tsk_flags_t options, tsk_treeseq_t *output, - tsk_id_t *node_map) +static int +Y1_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state, + tsk_size_t result_dim, double *result, void *params) { - int ret = 0; - tsk_table_collection_t *tables = tsk_malloc(sizeof(*tables)); + sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + const double *x = state; + double ni, denom, numer; + tsk_size_t i; - if (tables == NULL) { - ret = TSK_ERR_NO_MEMORY; - goto out; - } - ret = tsk_treeseq_copy_tables(self, tables, 0); - if (ret != 0) { - goto out; - } - ret = tsk_table_collection_simplify(tables, samples, num_samples, options, node_map); - if (ret != 0) { - goto out; - } - ret = tsk_treeseq_init( - output, tables, TSK_TS_INIT_BUILD_INDEXES | TSK_TAKE_OWNERSHIP); - /* Once tsk_tree_init has returned ownership of tables is transferred */ - tables = NULL; -out: - if (tables != NULL) { - tsk_table_collection_free(tables); - tsk_safe_free(tables); + for (i = 0; i < result_dim; i++) { + ni = (double) args.sample_set_sizes[i]; + denom = ni * (ni - 1) * (ni - 2); + numer = x[i] * (ni - x[i]) * (ni - x[i] - 1); + result[i] = numer / denom; } - return ret; + return 0; } -int TSK_WARN_UNUSED -tsk_treeseq_split_edges(const tsk_treeseq_t *self, double time, tsk_flags_t flags, - tsk_id_t population, const char *metadata, tsk_size_t metadata_length, - tsk_flags_t TSK_UNUSED(options), tsk_treeseq_t *output) +int +tsk_treeseq_Y1(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result) { - int ret = 0; - tsk_table_collection_t *tables = tsk_malloc(sizeof(*tables)); - const double *restrict node_time = self->tables->nodes.time; - const tsk_size_t num_edges = self->tables->edges.num_rows; - const tsk_size_t num_mutations = self->tables->mutations.num_rows; - tsk_id_t *split_edge = tsk_malloc(num_edges * sizeof(*split_edge)); - tsk_id_t j, u, mapped_node, ret_id; - double mutation_time; - tsk_edge_t edge; - tsk_mutation_t mutation; - tsk_bookmark_t sort_start; - - memset(output, 0, sizeof(*output)); - if (split_edge == NULL) { - ret = TSK_ERR_NO_MEMORY; - goto out; - } - ret = tsk_treeseq_copy_tables(self, tables, 0); - if (ret != 0) { - goto out; - } - if (tables->migrations.num_rows > 0) { - ret = TSK_ERR_MIGRATIONS_NOT_SUPPORTED; - goto out; - } - /* We could catch this below in add_row, but it's simpler to guarantee - * that we always catch the error in corner cases where the values - * aren't used. */ - if (population < -1 || population >= (tsk_id_t) self->tables->populations.num_rows) { - ret = TSK_ERR_POPULATION_OUT_OF_BOUNDS; - goto out; - } - if (!tsk_isfinite(time)) { - ret = TSK_ERR_TIME_NONFINITE; - goto out; - } - - tsk_edge_table_clear(&tables->edges); - tsk_memset(split_edge, TSK_NULL, num_edges * sizeof(*split_edge)); + return tsk_treeseq_sample_count_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_sample_sets, NULL, Y1_summary_func, num_windows, windows, + options, result); +} - for (j = 0; j < (tsk_id_t) num_edges; j++) { - /* Would prefer to use tsk_edge_table_get_row_unsafe, but it's - * currently static to tables.c */ - ret = tsk_edge_table_get_row(&self->tables->edges, j, &edge); - tsk_bug_assert(ret == 0); - if (node_time[edge.child] < time && time < node_time[edge.parent]) { - u = tsk_node_table_add_row(&tables->nodes, flags, time, population, TSK_NULL, - metadata, metadata_length); - if (u < 0) { - ret = (int) u; - goto out; - } - ret_id = tsk_edge_table_add_row(&tables->edges, edge.left, edge.right, u, - edge.child, edge.metadata, edge.metadata_length); - if (ret_id < 0) { - ret = (int) ret_id; - goto out; - } - edge.child = u; - split_edge[j] = u; - } - ret_id = tsk_edge_table_add_row(&tables->edges, edge.left, edge.right, - edge.parent, edge.child, edge.metadata, edge.metadata_length); - if (ret_id < 0) { - ret = (int) ret_id; - goto out; - } - } +static int +D_summary_func(tsk_size_t state_dim, const double *state, + tsk_size_t TSK_UNUSED(result_dim), double *result, void *params) +{ + sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + double n; + const double *state_row; + tsk_size_t j; - for (j = 0; j < (tsk_id_t) num_mutations; j++) { - /* Note: we could speed this up a bit by accessing the local - * memory for mutations directly. */ - ret = tsk_treeseq_get_mutation(self, j, &mutation); - tsk_bug_assert(ret == 0); - mapped_node = TSK_NULL; - if (mutation.edge != TSK_NULL) { - mapped_node = split_edge[mutation.edge]; - } - mutation_time = tsk_is_unknown_time(mutation.time) ? node_time[mutation.node] - : mutation.time; - if (mapped_node != TSK_NULL && mutation_time >= time) { - /* Update the column in-place to save a bit of time. */ - tables->mutations.node[j] = mapped_node; - } - } + for (j = 0; j < state_dim; j++) { + n = (double) args.sample_set_sizes[j]; + state_row = GET_2D_ROW(state, 3, j); + double p_AB = state_row[0] / n; + double p_Ab = state_row[1] / n; + double p_aB = state_row[2] / n; - /* Skip mutations and sites as they haven't been altered */ - /* Note we can probably optimise the edge sort a bit here also by - * reasoning about when the first edge gets altered in the table. - */ - memset(&sort_start, 0, sizeof(sort_start)); - sort_start.sites = tables->sites.num_rows; - sort_start.mutations = tables->mutations.num_rows; - ret = tsk_table_collection_sort(tables, &sort_start, 0); - if (ret != 0) { - goto out; + double p_A = p_AB + p_Ab; + double p_B = p_AB + p_aB; + result[j] = p_AB - (p_A * p_B); } - ret = tsk_treeseq_init( - output, tables, TSK_TS_INIT_BUILD_INDEXES | TSK_TAKE_OWNERSHIP); - tables = NULL; -out: - if (tables != NULL) { - tsk_table_collection_free(tables); - tsk_safe_free(tables); - } - tsk_safe_free(split_edge); - return ret; + return 0; } -/* ======================================================== * - * Tree - * ======================================================== */ +int +tsk_treeseq_D(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result) +{ + options |= TSK_STAT_POLARISED; // TODO: allow user to pick? + return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_sample_sets, NULL, D_summary_func, norm_total_weighted, + num_rows, row_sites, row_positions, num_cols, col_sites, col_positions, options, + result); +} -int TSK_WARN_UNUSED -tsk_tree_init(tsk_tree_t *self, const tsk_treeseq_t *tree_sequence, tsk_flags_t options) +static int +D2_summary_func(tsk_size_t state_dim, const double *state, + tsk_size_t TSK_UNUSED(result_dim), double *result, void *params) { - int ret = TSK_ERR_NO_MEMORY; - tsk_size_t num_samples, num_nodes, N; + sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + double n; + const double *state_row; + tsk_size_t j; - tsk_memset(self, 0, sizeof(tsk_tree_t)); - if (tree_sequence == NULL) { - ret = TSK_ERR_BAD_PARAM_VALUE; - goto out; + for (j = 0; j < state_dim; j++) { + n = (double) args.sample_set_sizes[j]; + state_row = GET_2D_ROW(state, 3, j); + double p_AB = state_row[0] / n; + double p_Ab = state_row[1] / n; + double p_aB = state_row[2] / n; + + double p_A = p_AB + p_Ab; + double p_B = p_AB + p_aB; + result[j] = p_AB - (p_A * p_B); + result[j] *= result[j]; } - num_nodes = tree_sequence->tables->nodes.num_rows; - num_samples = tree_sequence->num_samples; - self->num_nodes = num_nodes; - self->virtual_root = (tsk_id_t) num_nodes; - self->tree_sequence = tree_sequence; - self->samples = tree_sequence->samples; - self->options = options; - self->root_threshold = 1; - /* Allocate space in the quintuply linked tree for the virtual root */ - N = num_nodes + 1; - self->parent = tsk_malloc(N * sizeof(*self->parent)); - self->left_child = tsk_malloc(N * sizeof(*self->left_child)); - self->right_child = tsk_malloc(N * sizeof(*self->right_child)); - self->left_sib = tsk_malloc(N * sizeof(*self->left_sib)); - self->right_sib = tsk_malloc(N * sizeof(*self->right_sib)); - self->num_children = tsk_calloc(N, sizeof(*self->num_children)); - self->edge = tsk_malloc(N * sizeof(*self->edge)); - if (self->parent == NULL || self->left_child == NULL || self->right_child == NULL - || self->left_sib == NULL || self->right_sib == NULL - || self->num_children == NULL || self->edge == NULL) { - goto out; + return 0; +} + +int +tsk_treeseq_D2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result) +{ + return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_sample_sets, NULL, D2_summary_func, norm_total_weighted, + num_rows, row_sites, row_positions, num_cols, col_sites, col_positions, options, + result); +} + +static int +r2_summary_func(tsk_size_t state_dim, const double *state, + tsk_size_t TSK_UNUSED(result_dim), double *result, void *params) +{ + sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + double n; + const double *state_row; + tsk_size_t j; + + for (j = 0; j < state_dim; j++) { + n = (double) args.sample_set_sizes[j]; + state_row = GET_2D_ROW(state, 3, j); + double p_AB = state_row[0] / n; + double p_Ab = state_row[1] / n; + double p_aB = state_row[2] / n; + + double p_A = p_AB + p_Ab; + double p_B = p_AB + p_aB; + + double D = p_AB - (p_A * p_B); + double denom = p_A * p_B * (1 - p_A) * (1 - p_B); + + result[j] = (D * D) / denom; } - if (!(self->options & TSK_NO_SAMPLE_COUNTS)) { - self->num_samples = tsk_calloc(N, sizeof(*self->num_samples)); - self->num_tracked_samples = tsk_calloc(N, sizeof(*self->num_tracked_samples)); - if (self->num_samples == NULL || self->num_tracked_samples == NULL) { - goto out; + return 0; +} + +int +tsk_treeseq_r2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result) +{ + return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_sample_sets, NULL, r2_summary_func, norm_hap_weighted, num_rows, + row_sites, row_positions, num_cols, col_sites, col_positions, options, result); +} + +static int +D_prime_summary_func(tsk_size_t state_dim, const double *state, + tsk_size_t TSK_UNUSED(result_dim), double *result, void *params) +{ + sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + double n; + const double *state_row; + tsk_size_t j; + + for (j = 0; j < state_dim; j++) { + n = (double) args.sample_set_sizes[j]; + state_row = GET_2D_ROW(state, 3, j); + double p_AB = state_row[0] / n; + double p_Ab = state_row[1] / n; + double p_aB = state_row[2] / n; + + double p_A = p_AB + p_Ab; + double p_B = p_AB + p_aB; + + double D = p_AB - (p_A * p_B); + + if (D >= 0) { + result[j] = D / TSK_MIN(p_A * (1 - p_B), (1 - p_A) * p_B); + } else if (D < 0) { + result[j] = D / TSK_MIN(p_A * p_B, (1 - p_A) * (1 - p_B)); } } - if (self->options & TSK_SAMPLE_LISTS) { - self->left_sample = tsk_malloc(N * sizeof(*self->left_sample)); - self->right_sample = tsk_malloc(N * sizeof(*self->right_sample)); - self->next_sample = tsk_malloc(num_samples * sizeof(*self->next_sample)); - if (self->left_sample == NULL || self->right_sample == NULL - || self->next_sample == NULL) { - goto out; - } + return 0; +} + +int +tsk_treeseq_D_prime(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result) +{ + options |= TSK_STAT_POLARISED; // TODO: allow user to pick? + return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_sample_sets, NULL, D_prime_summary_func, norm_total_weighted, + num_rows, row_sites, row_positions, num_cols, col_sites, col_positions, options, + result); +} + +static int +r_summary_func(tsk_size_t state_dim, const double *state, + tsk_size_t TSK_UNUSED(result_dim), double *result, void *params) +{ + sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + double n; + const double *state_row; + tsk_size_t j; + + for (j = 0; j < state_dim; j++) { + n = (double) args.sample_set_sizes[j]; + state_row = GET_2D_ROW(state, 3, j); + double p_AB = state_row[0] / n; + double p_Ab = state_row[1] / n; + double p_aB = state_row[2] / n; + + double p_A = p_AB + p_Ab; + double p_B = p_AB + p_aB; + + double D = p_AB - (p_A * p_B); + double denom = p_A * p_B * (1 - p_A) * (1 - p_B); + + result[j] = D / sqrt(denom); } - ret = tsk_tree_clear(self); -out: - return ret; + return 0; } int -tsk_tree_set_root_threshold(tsk_tree_t *self, tsk_size_t root_threshold) +tsk_treeseq_r(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result) { - int ret = 0; + options |= TSK_STAT_POLARISED; // TODO: allow user to pick? + return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_sample_sets, NULL, r_summary_func, norm_total_weighted, + num_rows, row_sites, row_positions, num_cols, col_sites, col_positions, options, + result); +} - if (root_threshold == 0) { - ret = TSK_ERR_BAD_PARAM_VALUE; - goto out; +static int +Dz_summary_func(tsk_size_t state_dim, const double *state, + tsk_size_t TSK_UNUSED(result_dim), double *result, void *params) +{ + sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + double n; + const double *state_row; + tsk_size_t j; + + for (j = 0; j < state_dim; j++) { + n = (double) args.sample_set_sizes[j]; + state_row = GET_2D_ROW(state, 3, j); + double p_AB = state_row[0] / n; + double p_Ab = state_row[1] / n; + double p_aB = state_row[2] / n; + + double p_A = p_AB + p_Ab; + double p_B = p_AB + p_aB; + + double D = p_AB - (p_A * p_B); + + result[j] = D * (1 - 2 * p_A) * (1 - 2 * p_B); } - /* Don't allow the value to be set when the tree is out of the null - * state */ - if (self->index != -1) { - ret = TSK_ERR_UNSUPPORTED_OPERATION; - goto out; + return 0; +} + +int +tsk_treeseq_Dz(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result) +{ + return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_sample_sets, NULL, Dz_summary_func, norm_total_weighted, + num_rows, row_sites, row_positions, num_cols, col_sites, col_positions, options, + result); +} + +static int +pi2_summary_func(tsk_size_t state_dim, const double *state, + tsk_size_t TSK_UNUSED(result_dim), double *result, void *params) +{ + sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + double n; + const double *state_row; + tsk_size_t j; + + for (j = 0; j < state_dim; j++) { + n = (double) args.sample_set_sizes[j]; + state_row = GET_2D_ROW(state, 3, j); + double p_AB = state_row[0] / n; + double p_Ab = state_row[1] / n; + double p_aB = state_row[2] / n; + + double p_A = p_AB + p_Ab; + double p_B = p_AB + p_aB; + result[j] = p_A * (1 - p_A) * p_B * (1 - p_B); } - self->root_threshold = root_threshold; - /* Reset the roots */ - ret = tsk_tree_clear(self); -out: - return ret; + return 0; } -tsk_size_t -tsk_tree_get_root_threshold(const tsk_tree_t *self) +int +tsk_treeseq_pi2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result) { - return self->root_threshold; + return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_sample_sets, NULL, pi2_summary_func, norm_total_weighted, + num_rows, row_sites, row_positions, num_cols, col_sites, col_positions, options, + result); +} + +static int +D2_unbiased_summary_func(tsk_size_t state_dim, const double *state, + tsk_size_t TSK_UNUSED(result_dim), double *result, void *params) +{ + sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + double n; + const double *state_row; + tsk_size_t j; + + for (j = 0; j < state_dim; j++) { + n = (double) args.sample_set_sizes[j]; + state_row = GET_2D_ROW(state, 3, j); + double w_AB = state_row[0]; + double w_Ab = state_row[1]; + double w_aB = state_row[2]; + double w_ab = n - (w_AB + w_Ab + w_aB); + result[j] = (1 / (n * (n - 1) * (n - 2) * (n - 3))) + * ((w_aB * w_aB * (w_Ab - 1) * w_Ab) + + ((w_ab - 1) * w_ab * (w_AB - 1) * w_AB) + - (w_aB * w_Ab * (w_Ab + (2 * w_ab * w_AB) - 1))); + } + return 0; } int -tsk_tree_free(tsk_tree_t *self) +tsk_treeseq_D2_unbiased(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result) { - tsk_safe_free(self->parent); - tsk_safe_free(self->left_child); - tsk_safe_free(self->right_child); - tsk_safe_free(self->left_sib); - tsk_safe_free(self->right_sib); - tsk_safe_free(self->num_samples); - tsk_safe_free(self->num_tracked_samples); - tsk_safe_free(self->left_sample); - tsk_safe_free(self->right_sample); - tsk_safe_free(self->next_sample); - tsk_safe_free(self->num_children); - tsk_safe_free(self->edge); + return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_sample_sets, NULL, D2_unbiased_summary_func, + norm_total_weighted, num_rows, row_sites, row_positions, num_cols, col_sites, + col_positions, options, result); +} + +static int +Dz_unbiased_summary_func(tsk_size_t state_dim, const double *state, + tsk_size_t TSK_UNUSED(result_dim), double *result, void *params) +{ + sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + double n; + const double *state_row; + tsk_size_t j; + + for (j = 0; j < state_dim; j++) { + n = (double) args.sample_set_sizes[j]; + state_row = GET_2D_ROW(state, 3, j); + double w_AB = state_row[0]; + double w_Ab = state_row[1]; + double w_aB = state_row[2]; + double w_ab = n - (w_AB + w_Ab + w_aB); + result[j] = (1 / (n * (n - 1) * (n - 2) * (n - 3))) + * ((((w_AB * w_ab) - (w_Ab * w_aB)) * (w_aB + w_ab - w_AB - w_Ab) + * (w_Ab + w_ab - w_AB - w_aB)) + - ((w_AB * w_ab) * (w_AB + w_ab - w_Ab - w_aB - 2)) + - ((w_Ab * w_aB) * (w_Ab + w_aB - w_AB - w_ab - 2))); + } return 0; } -bool -tsk_tree_has_sample_lists(const tsk_tree_t *self) +int +tsk_treeseq_Dz_unbiased(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result) { - return !!(self->options & TSK_SAMPLE_LISTS); + return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_sample_sets, NULL, Dz_unbiased_summary_func, + norm_total_weighted, num_rows, row_sites, row_positions, num_cols, col_sites, + col_positions, options, result); } -bool -tsk_tree_has_sample_counts(const tsk_tree_t *self) +static int +pi2_unbiased_summary_func(tsk_size_t state_dim, const double *state, + tsk_size_t TSK_UNUSED(result_dim), double *result, void *params) { - return !(self->options & TSK_NO_SAMPLE_COUNTS); + sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + double n; + const double *state_row; + tsk_size_t j; + + for (j = 0; j < state_dim; j++) { + n = (double) args.sample_set_sizes[j]; + state_row = GET_2D_ROW(state, 3, j); + double w_AB = state_row[0]; + double w_Ab = state_row[1]; + double w_aB = state_row[2]; + double w_ab = n - (w_AB + w_Ab + w_aB); + result[j] + = (1 / (n * (n - 1) * (n - 2) * (n - 3))) + * (((w_AB + w_Ab) * (w_aB + w_ab) * (w_AB + w_aB) * (w_Ab + w_ab)) + - ((w_AB * w_ab) * (w_AB + w_ab + (3 * w_Ab) + (3 * w_aB) - 1)) + - ((w_Ab * w_aB) * (w_Ab + w_aB + (3 * w_AB) + (3 * w_ab) - 1))); + } + return 0; } -static int TSK_WARN_UNUSED -tsk_tree_reset_tracked_samples(tsk_tree_t *self) +int +tsk_treeseq_pi2_unbiased(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result) +{ + return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_sample_sets, NULL, pi2_unbiased_summary_func, + norm_total_weighted, num_rows, row_sites, row_positions, num_cols, col_sites, + col_positions, options, result); +} + +/*********************************** + * Two way stats + ***********************************/ + +static int +check_sample_stat_inputs(tsk_size_t num_sample_sets, tsk_size_t tuple_size, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples) { int ret = 0; - if (!tsk_tree_has_sample_counts(self)) { - ret = TSK_ERR_UNSUPPORTED_OPERATION; + if (num_sample_sets < 1) { + ret = tsk_trace_error(TSK_ERR_INSUFFICIENT_SAMPLE_SETS); + goto out; + } + if (num_index_tuples < 1) { + ret = tsk_trace_error(TSK_ERR_INSUFFICIENT_INDEX_TUPLES); + goto out; + } + ret = check_set_indexes( + num_sample_sets, tuple_size * num_index_tuples, index_tuples); + if (ret != 0) { goto out; } - tsk_memset(self->num_tracked_samples, 0, - (self->num_nodes + 1) * sizeof(*self->num_tracked_samples)); out: return ret; } -int TSK_WARN_UNUSED -tsk_tree_set_tracked_samples( - tsk_tree_t *self, tsk_size_t num_tracked_samples, const tsk_id_t *tracked_samples) +static int +divergence_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state, + tsk_size_t result_dim, double *result, void *params) { - int ret = TSK_ERR_GENERIC; - tsk_size_t *tree_num_tracked_samples = self->num_tracked_samples; - const tsk_id_t *parent = self->parent; - tsk_size_t j; - tsk_id_t u; + sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + const double *x = state; + double ni, nj, denom; + tsk_id_t i, j; + tsk_size_t k; - /* TODO This is not needed when the tree is new. We should use the - * state machine to check and only reset the tracked samples when needed. - */ - ret = tsk_tree_reset_tracked_samples(self); - if (ret != 0) { - goto out; - } - self->num_tracked_samples[self->virtual_root] = num_tracked_samples; - for (j = 0; j < num_tracked_samples; j++) { - u = tracked_samples[j]; - if (u < 0 || u >= (tsk_id_t) self->num_nodes) { - ret = TSK_ERR_NODE_OUT_OF_BOUNDS; - goto out; - } - if (!tsk_treeseq_is_sample(self->tree_sequence, u)) { - ret = TSK_ERR_BAD_SAMPLES; - goto out; - } - if (self->num_tracked_samples[u] != 0) { - ret = TSK_ERR_DUPLICATE_SAMPLE; - goto out; - } - /* Propagate this upwards */ - while (u != TSK_NULL) { - tree_num_tracked_samples[u]++; - u = parent[u]; - } + for (k = 0; k < result_dim; k++) { + i = args.set_indexes[2 * k]; + j = args.set_indexes[2 * k + 1]; + ni = (double) args.sample_set_sizes[i]; + nj = (double) args.sample_set_sizes[j]; + denom = ni * (nj - (i == j)); + result[k] = x[i] * (nj - x[j]) / denom; } -out: - return ret; + return 0; } -int TSK_WARN_UNUSED -tsk_tree_track_descendant_samples(tsk_tree_t *self, tsk_id_t node) +int +tsk_treeseq_divergence(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_windows, + const double *windows, tsk_flags_t options, double *result) { int ret = 0; - tsk_id_t *nodes = tsk_malloc(tsk_tree_get_size_bound(self) * sizeof(*nodes)); - const tsk_id_t *restrict parent = self->parent; - const tsk_id_t *restrict left_child = self->left_child; - const tsk_id_t *restrict right_sib = self->right_sib; - const tsk_flags_t *restrict flags = self->tree_sequence->tables->nodes.flags; - tsk_size_t *num_tracked_samples = self->num_tracked_samples; - tsk_size_t n, j, num_nodes; - tsk_id_t u, v; - - if (nodes == NULL) { - ret = TSK_ERR_NO_MEMORY; - goto out; - } - ret = tsk_tree_postorder_from(self, node, nodes, &num_nodes); - if (ret != 0) { - goto out; - } - ret = tsk_tree_reset_tracked_samples(self); + ret = check_sample_stat_inputs(num_sample_sets, 2, num_index_tuples, index_tuples); if (ret != 0) { goto out; } - u = 0; /* keep the compiler happy */ - for (j = 0; j < num_nodes; j++) { - u = nodes[j]; - for (v = left_child[u]; v != TSK_NULL; v = right_sib[v]) { - num_tracked_samples[u] += num_tracked_samples[v]; - } - num_tracked_samples[u] += flags[u] & TSK_NODE_IS_SAMPLE ? 1 : 0; - } - n = num_tracked_samples[u]; - u = parent[u]; - while (u != TSK_NULL) { - num_tracked_samples[u] = n; - u = parent[u]; - } - num_tracked_samples[self->virtual_root] = n; + ret = tsk_treeseq_sample_count_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_index_tuples, index_tuples, divergence_summary_func, + num_windows, windows, options, result); out: - tsk_safe_free(nodes); return ret; } -int TSK_WARN_UNUSED -tsk_tree_copy(const tsk_tree_t *self, tsk_tree_t *dest, tsk_flags_t options) +static int +genetic_relatedness_summary_func(tsk_size_t state_dim, const double *state, + tsk_size_t result_dim, double *result, void *params) { - int ret = TSK_ERR_GENERIC; - tsk_size_t N = self->num_nodes + 1; + sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + const double *x = state; + tsk_id_t i, j; + tsk_size_t k; + double sumx = 0; + double meanx, ni, nj; - if (!(options & TSK_NO_INIT)) { - ret = tsk_tree_init(dest, self->tree_sequence, options); - if (ret != 0) { - goto out; - } - } - if (self->tree_sequence != dest->tree_sequence) { - ret = TSK_ERR_BAD_PARAM_VALUE; - goto out; + for (k = 0; k < state_dim; k++) { + sumx += x[k] / (double) args.sample_set_sizes[k]; } - dest->interval = self->interval; - dest->left_index = self->left_index; - dest->right_index = self->right_index; - dest->direction = self->direction; - dest->index = self->index; - dest->sites = self->sites; - dest->sites_length = self->sites_length; - dest->root_threshold = self->root_threshold; - dest->num_edges = self->num_edges; - tsk_memcpy(dest->parent, self->parent, N * sizeof(*self->parent)); - tsk_memcpy(dest->left_child, self->left_child, N * sizeof(*self->left_child)); - tsk_memcpy(dest->right_child, self->right_child, N * sizeof(*self->right_child)); - tsk_memcpy(dest->left_sib, self->left_sib, N * sizeof(*self->left_sib)); - tsk_memcpy(dest->right_sib, self->right_sib, N * sizeof(*self->right_sib)); - tsk_memcpy(dest->num_children, self->num_children, N * sizeof(*self->num_children)); - tsk_memcpy(dest->edge, self->edge, N * sizeof(*self->edge)); - if (!(dest->options & TSK_NO_SAMPLE_COUNTS)) { - if (self->options & TSK_NO_SAMPLE_COUNTS) { - ret = TSK_ERR_UNSUPPORTED_OPERATION; - goto out; - } - tsk_memcpy(dest->num_samples, self->num_samples, N * sizeof(*self->num_samples)); - tsk_memcpy(dest->num_tracked_samples, self->num_tracked_samples, - N * sizeof(*self->num_tracked_samples)); - } - if (dest->options & TSK_SAMPLE_LISTS) { - if (!(self->options & TSK_SAMPLE_LISTS)) { - ret = TSK_ERR_UNSUPPORTED_OPERATION; - goto out; - } - tsk_memcpy(dest->left_sample, self->left_sample, N * sizeof(*self->left_sample)); - tsk_memcpy( - dest->right_sample, self->right_sample, N * sizeof(*self->right_sample)); - tsk_memcpy(dest->next_sample, self->next_sample, - self->tree_sequence->num_samples * sizeof(*self->next_sample)); + meanx = sumx / (double) state_dim; + for (k = 0; k < result_dim; k++) { + i = args.set_indexes[2 * k]; + j = args.set_indexes[2 * k + 1]; + ni = (double) args.sample_set_sizes[i]; + nj = (double) args.sample_set_sizes[j]; + result[k] = (x[i] / ni - meanx) * (x[j] / nj - meanx); } - ret = 0; -out: - return ret; + return 0; } -bool TSK_WARN_UNUSED -tsk_tree_equals(const tsk_tree_t *self, const tsk_tree_t *other) +static int +genetic_relatedness_noncentred_summary_func(tsk_size_t TSK_UNUSED(state_dim), + const double *state, tsk_size_t result_dim, double *result, void *params) { - bool ret = false; + sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + const double *x = state; + tsk_id_t i, j; + tsk_size_t k; + double ni, nj; - if (self->tree_sequence == other->tree_sequence) { - ret = self->index == other->index; + for (k = 0; k < result_dim; k++) { + i = args.set_indexes[2 * k]; + j = args.set_indexes[2 * k + 1]; + ni = (double) args.sample_set_sizes[i]; + nj = (double) args.sample_set_sizes[j]; + result[k] = x[i] * x[j] / (ni * nj); } - return ret; + return 0; } -static int -tsk_tree_check_node(const tsk_tree_t *self, tsk_id_t u) +int +tsk_treeseq_genetic_relatedness(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_windows, + const double *windows, tsk_flags_t options, double *result) { int ret = 0; - if (u < 0 || u > (tsk_id_t) self->num_nodes) { - ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + ret = check_sample_stat_inputs(num_sample_sets, 2, num_index_tuples, index_tuples); + if (ret != 0) { + goto out; + } + if (!(options & TSK_STAT_NONCENTRED)) { + ret = tsk_treeseq_sample_count_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_index_tuples, index_tuples, + genetic_relatedness_summary_func, num_windows, windows, options, result); + } else { + ret = tsk_treeseq_sample_count_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_index_tuples, index_tuples, + genetic_relatedness_noncentred_summary_func, num_windows, windows, options, + result); } +out: return ret; } -bool -tsk_tree_is_descendant(const tsk_tree_t *self, tsk_id_t u, tsk_id_t v) +static int +genetic_relatedness_weighted_summary_func(tsk_size_t state_dim, const double *state, + tsk_size_t result_dim, double *result, void *params) { - bool ret = false; - tsk_id_t w = u; - tsk_id_t *restrict parent = self->parent; + indexed_weight_stat_params_t args = *(indexed_weight_stat_params_t *) params; + const double *x = state; + tsk_id_t i, j; + tsk_size_t k; + double pn, ni, nj; - if (tsk_tree_check_node(self, u) == 0 && tsk_tree_check_node(self, v) == 0) { - while (w != v && w != TSK_NULL) { - w = parent[w]; - } - ret = w == v; + pn = state[state_dim - 1]; + for (k = 0; k < result_dim; k++) { + i = args.index_tuples[2 * k]; + j = args.index_tuples[2 * k + 1]; + ni = args.total_weights[i]; + nj = args.total_weights[j]; + result[k] = (x[i] - ni * pn) * (x[j] - nj * pn); } - return ret; + return 0; } -int TSK_WARN_UNUSED -tsk_tree_get_mrca(const tsk_tree_t *self, tsk_id_t u, tsk_id_t v, tsk_id_t *mrca) +static int +genetic_relatedness_weighted_noncentred_summary_func(tsk_size_t TSK_UNUSED(state_dim), + const double *state, tsk_size_t result_dim, double *result, void *params) +{ + indexed_weight_stat_params_t args = *(indexed_weight_stat_params_t *) params; + const double *x = state; + tsk_id_t i, j; + tsk_size_t k; + + for (k = 0; k < result_dim; k++) { + i = args.index_tuples[2 * k]; + j = args.index_tuples[2 * k + 1]; + result[k] = x[i] * x[j]; + } + return 0; +} + +int +tsk_treeseq_genetic_relatedness_weighted(const tsk_treeseq_t *self, + tsk_size_t num_weights, const double *weights, tsk_size_t num_index_tuples, + const tsk_id_t *index_tuples, tsk_size_t num_windows, const double *windows, + double *result, tsk_flags_t options) { int ret = 0; - double tu, tv; - const tsk_id_t *restrict parent = self->parent; - const double *restrict time = self->tree_sequence->tables->nodes.time; + tsk_size_t num_samples = self->num_samples; + size_t j, k; + indexed_weight_stat_params_t args; + const double *row; + double *new_row; + double *total_weights = tsk_calloc((num_weights + 1), sizeof(*total_weights)); + double *new_weights + = tsk_malloc((num_weights + 1) * num_samples * sizeof(*new_weights)); - ret = tsk_tree_check_node(self, u); - if (ret != 0) { + if (total_weights == NULL || new_weights == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } - ret = tsk_tree_check_node(self, v); - if (ret != 0) { + if (num_weights == 0) { + ret = tsk_trace_error(TSK_ERR_INSUFFICIENT_WEIGHTS); goto out; } - /* Simplest to make the virtual_root a special case here to avoid - * doing the time lookup. */ - if (u == self->virtual_root || v == self->virtual_root) { - *mrca = self->virtual_root; - return 0; + // Add a column of ones to W + for (j = 0; j < num_samples; j++) { + row = GET_2D_ROW(weights, num_weights, j); + new_row = GET_2D_ROW(new_weights, num_weights + 1, j); + for (k = 0; k < num_weights; k++) { + new_row[k] = row[k]; + total_weights[k] += row[k]; + } + new_row[num_weights] = 1.0 / (double) num_samples; } + total_weights[num_weights] = 1.0; - tu = time[u]; - tv = time[v]; - while (u != v) { - if (tu < tv) { - u = parent[u]; - if (u == TSK_NULL) { - break; - } - tu = time[u]; - } else { - v = parent[v]; - if (v == TSK_NULL) { - break; - } - tv = time[v]; + args.total_weights = total_weights; + args.index_tuples = index_tuples; + if (!(options & TSK_STAT_NONCENTRED)) { + ret = tsk_treeseq_general_stat(self, num_weights + 1, new_weights, + num_index_tuples, genetic_relatedness_weighted_summary_func, &args, + num_windows, windows, options, result); + if (ret != 0) { + goto out; + } + } else { + ret = tsk_treeseq_general_stat(self, num_weights + 1, new_weights, + num_index_tuples, genetic_relatedness_weighted_noncentred_summary_func, + &args, num_windows, windows, options, result); + if (ret != 0) { + goto out; } } - *mrca = u == v ? u : TSK_NULL; + out: + tsk_safe_free(total_weights); + tsk_safe_free(new_weights); return ret; } static int -tsk_tree_get_num_samples_by_traversal( - const tsk_tree_t *self, tsk_id_t u, tsk_size_t *num_samples) +Y2_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state, + tsk_size_t result_dim, double *result, void *params) { - int ret = 0; - tsk_size_t num_nodes, j; - tsk_size_t count = 0; - const tsk_flags_t *restrict flags = self->tree_sequence->tables->nodes.flags; - tsk_id_t *nodes = tsk_malloc(tsk_tree_get_size_bound(self) * sizeof(*nodes)); - tsk_id_t v; + sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + const double *x = state; + double ni, nj, denom; + tsk_id_t i, j; + tsk_size_t k; - if (nodes == NULL) { - ret = TSK_ERR_NO_MEMORY; - goto out; + for (k = 0; k < result_dim; k++) { + i = args.set_indexes[2 * k]; + j = args.set_indexes[2 * k + 1]; + ni = (double) args.sample_set_sizes[i]; + nj = (double) args.sample_set_sizes[j]; + denom = ni * nj * (nj - 1); + result[k] = x[i] * (nj - x[j]) * (nj - x[j] - 1) / denom; } - ret = tsk_tree_preorder_from(self, u, nodes, &num_nodes); + return 0; +} + +int +tsk_treeseq_Y2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_windows, + const double *windows, tsk_flags_t options, double *result) +{ + int ret = 0; + ret = check_sample_stat_inputs(num_sample_sets, 2, num_index_tuples, index_tuples); if (ret != 0) { goto out; } - for (j = 0; j < num_nodes; j++) { - v = nodes[j]; - if (flags[v] & TSK_NODE_IS_SAMPLE) { - count++; - } - } - *num_samples = count; + ret = tsk_treeseq_sample_count_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_index_tuples, index_tuples, Y2_summary_func, num_windows, + windows, options, result); out: - tsk_safe_free(nodes); return ret; } -int TSK_WARN_UNUSED -tsk_tree_get_num_samples(const tsk_tree_t *self, tsk_id_t u, tsk_size_t *num_samples) +static int +f2_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state, + tsk_size_t result_dim, double *result, void *params) { - int ret = 0; - - ret = tsk_tree_check_node(self, u); - if (ret != 0) { - goto out; - } + sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + const double *x = state; + double ni, nj, denom, numer; + tsk_id_t i, j; + tsk_size_t k; - if (!(self->options & TSK_NO_SAMPLE_COUNTS)) { - *num_samples = (tsk_size_t) self->num_samples[u]; - } else { - ret = tsk_tree_get_num_samples_by_traversal(self, u, num_samples); + for (k = 0; k < result_dim; k++) { + i = args.set_indexes[2 * k]; + j = args.set_indexes[2 * k + 1]; + ni = (double) args.sample_set_sizes[i]; + nj = (double) args.sample_set_sizes[j]; + denom = ni * (ni - 1) * nj * (nj - 1); + numer = x[i] * (x[i] - 1) * (nj - x[j]) * (nj - x[j] - 1) + - x[i] * (ni - x[i]) * (nj - x[j]) * x[j]; + result[k] = numer / denom; } -out: - return ret; + return 0; } -int TSK_WARN_UNUSED -tsk_tree_get_num_tracked_samples( - const tsk_tree_t *self, tsk_id_t u, tsk_size_t *num_tracked_samples) +int +tsk_treeseq_f2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_windows, + const double *windows, tsk_flags_t options, double *result) { int ret = 0; - - ret = tsk_tree_check_node(self, u); + ret = check_sample_stat_inputs(num_sample_sets, 2, num_index_tuples, index_tuples); if (ret != 0) { goto out; } - if (self->options & TSK_NO_SAMPLE_COUNTS) { - ret = TSK_ERR_UNSUPPORTED_OPERATION; - goto out; - } - *num_tracked_samples = self->num_tracked_samples[u]; + ret = tsk_treeseq_sample_count_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_index_tuples, index_tuples, f2_summary_func, num_windows, + windows, options, result); out: return ret; } -bool -tsk_tree_is_sample(const tsk_tree_t *self, tsk_id_t u) +static int +D2_ij_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state, + tsk_size_t result_dim, double *result, void *params) { - return tsk_treeseq_is_sample(self->tree_sequence, u); -} + sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + const double *state_row; + double n; + tsk_size_t k; + tsk_id_t i, j; + double p_A, p_B, p_AB, p_Ab, p_aB, D_i, D_j; -tsk_id_t -tsk_tree_get_left_root(const tsk_tree_t *self) -{ - return self->left_child[self->virtual_root]; -} + for (k = 0; k < result_dim; k++) { + i = args.set_indexes[2 * k]; + j = args.set_indexes[2 * k + 1]; -tsk_id_t -tsk_tree_get_right_root(const tsk_tree_t *self) -{ - return self->right_child[self->virtual_root]; -} + n = (double) args.sample_set_sizes[i]; + state_row = GET_2D_ROW(state, 3, i); + p_AB = state_row[0] / n; + p_Ab = state_row[1] / n; + p_aB = state_row[2] / n; + p_A = p_AB + p_Ab; + p_B = p_AB + p_aB; + D_i = p_AB - (p_A * p_B); -tsk_size_t -tsk_tree_get_num_roots(const tsk_tree_t *self) -{ - return (tsk_size_t) self->num_children[self->virtual_root]; + n = (double) args.sample_set_sizes[j]; + state_row = GET_2D_ROW(state, 3, j); + p_AB = state_row[0] / n; + p_Ab = state_row[1] / n; + p_aB = state_row[2] / n; + p_A = p_AB + p_Ab; + p_B = p_AB + p_aB; + D_j = p_AB - (p_A * p_B); + + result[k] = D_i * D_j; + } + + return 0; } -int TSK_WARN_UNUSED -tsk_tree_get_parent(const tsk_tree_t *self, tsk_id_t u, tsk_id_t *parent) +int +tsk_treeseq_D2_ij(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result) { int ret = 0; - - ret = tsk_tree_check_node(self, u); + ret = check_sample_stat_inputs(num_sample_sets, 2, num_index_tuples, index_tuples); if (ret != 0) { goto out; } - *parent = self->parent[u]; + ret = tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_index_tuples, index_tuples, D2_ij_summary_func, + norm_total_weighted, num_rows, row_sites, row_positions, num_cols, col_sites, + col_positions, options, result); out: return ret; } -int TSK_WARN_UNUSED -tsk_tree_get_time(const tsk_tree_t *self, tsk_id_t u, double *t) +static int +D2_ij_unbiased_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state, + tsk_size_t result_dim, double *result, void *params) { - int ret = 0; - tsk_node_t node; + sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + const double *state_row; + tsk_size_t k; + tsk_id_t i, j; + double n_i, n_j; + double w_AB_i, w_Ab_i, w_aB_i, w_ab_i; + double w_AB_j, w_Ab_j, w_aB_j, w_ab_j; - if (u == self->virtual_root) { - *t = INFINITY; - } else { - ret = tsk_treeseq_get_node(self->tree_sequence, u, &node); - if (ret != 0) { - goto out; + for (k = 0; k < result_dim; k++) { + i = args.set_indexes[2 * k]; + j = args.set_indexes[2 * k + 1]; + if (i == j) { + // We require disjoint sample sets because we test equality here + n_i = (double) args.sample_set_sizes[i]; + state_row = GET_2D_ROW(state, 3, i); + w_AB_i = state_row[0]; + w_Ab_i = state_row[1]; + w_aB_i = state_row[2]; + w_ab_i = n_i - (w_AB_i + w_Ab_i + w_aB_i); + result[k] = (w_AB_i * (w_AB_i - 1) * w_ab_i * (w_ab_i - 1) + + w_Ab_i * (w_Ab_i - 1) * w_aB_i * (w_aB_i - 1) + - 2 * w_AB_i * w_Ab_i * w_aB_i * w_ab_i) + / n_i / (n_i - 1) / (n_i - 2) / (n_i - 3); + } + + else { + n_i = (double) args.sample_set_sizes[i]; + state_row = GET_2D_ROW(state, 3, i); + w_AB_i = state_row[0]; + w_Ab_i = state_row[1]; + w_aB_i = state_row[2]; + w_ab_i = n_i - (w_AB_i + w_Ab_i + w_aB_i); + + n_j = (double) args.sample_set_sizes[j]; + state_row = GET_2D_ROW(state, 3, j); + w_AB_j = state_row[0]; + w_Ab_j = state_row[1]; + w_aB_j = state_row[2]; + w_ab_j = n_j - (w_AB_j + w_Ab_j + w_aB_j); + + result[k] = (w_Ab_i * w_aB_i - w_AB_i * w_ab_i) + * (w_Ab_j * w_aB_j - w_AB_j * w_ab_j) / n_i / (n_i - 1) / n_j + / (n_j - 1); } - *t = node.time; } -out: - return ret; -} - -static inline double -tsk_tree_get_branch_length_unsafe(const tsk_tree_t *self, tsk_id_t u) -{ - const double *times = self->tree_sequence->tables->nodes.time; - const tsk_id_t parent = self->parent[u]; - return parent == TSK_NULL ? 0 : times[parent] - times[u]; + return 0; } -int TSK_WARN_UNUSED -tsk_tree_get_branch_length(const tsk_tree_t *self, tsk_id_t u, double *ret_branch_length) +int +tsk_treeseq_D2_ij_unbiased(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result) { int ret = 0; - - ret = tsk_tree_check_node(self, u); + ret = check_sample_stat_inputs(num_sample_sets, 2, num_index_tuples, index_tuples); if (ret != 0) { goto out; } - *ret_branch_length = tsk_tree_get_branch_length_unsafe(self, u); + ret = tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_index_tuples, index_tuples, D2_ij_unbiased_summary_func, + norm_total_weighted, num_rows, row_sites, row_positions, num_cols, col_sites, + col_positions, options, result); out: return ret; } -int -tsk_tree_get_total_branch_length(const tsk_tree_t *self, tsk_id_t node, double *ret_tbl) +static int +r2_ij_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state, + tsk_size_t result_dim, double *result, void *params) { - int ret = 0; - tsk_size_t j, num_nodes; - tsk_id_t u, v; - const tsk_id_t *restrict parent = self->parent; - const double *restrict time = self->tree_sequence->tables->nodes.time; - tsk_id_t *nodes = tsk_malloc(tsk_tree_get_size_bound(self) * sizeof(*nodes)); - double sum = 0; + sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + const double *state_row; + tsk_size_t k; + tsk_id_t i, j; + double n, pAB, pAb, paB, pA, pB, D_i, D_j, denom_i, denom_j; - if (nodes == NULL) { - ret = TSK_ERR_NO_MEMORY; - goto out; - } - ret = tsk_tree_preorder_from(self, node, nodes, &num_nodes); - if (ret != 0) { - goto out; - } - /* We always skip the first node because we don't return the branch length - * over the input node. */ - for (j = 1; j < num_nodes; j++) { - u = nodes[j]; - v = parent[u]; - if (v != TSK_NULL) { - sum += time[v] - time[u]; - } - } - *ret_tbl = sum; -out: - tsk_safe_free(nodes); - return ret; -} + for (k = 0; k < result_dim; k++) { + i = args.set_indexes[2 * k]; + j = args.set_indexes[2 * k + 1]; -int TSK_WARN_UNUSED -tsk_tree_get_sites( - const tsk_tree_t *self, const tsk_site_t **sites, tsk_size_t *sites_length) -{ - *sites = self->sites; - *sites_length = self->sites_length; - return 0; -} + n = (double) args.sample_set_sizes[i]; + state_row = GET_2D_ROW(state, 3, i); + pAB = state_row[0] / n; + pAb = state_row[1] / n; + paB = state_row[2] / n; + pA = pAB + pAb; + pB = pAB + paB; + D_i = pAB - (pA * pB); + denom_i = sqrt(pA * (1 - pA) * pB * (1 - pB)); -/* u must be a valid node in the tree. For internal use */ -static int -tsk_tree_get_depth_unsafe(const tsk_tree_t *self, tsk_id_t u) -{ - tsk_id_t v; - const tsk_id_t *restrict parent = self->parent; - int depth = 0; + n = (double) args.sample_set_sizes[j]; + state_row = GET_2D_ROW(state, 3, j); + pAB = state_row[0] / n; + pAb = state_row[1] / n; + paB = state_row[2] / n; + pA = pAB + pAb; + pB = pAB + paB; + D_j = pAB - (pA * pB); + denom_j = sqrt(pA * (1 - pA) * pB * (1 - pB)); - if (u == self->virtual_root) { - return -1; - } - for (v = parent[u]; v != TSK_NULL; v = parent[v]) { - depth++; + result[k] = (D_i * D_j) / (denom_i * denom_j); } - return depth; + return 0; } -int TSK_WARN_UNUSED -tsk_tree_get_depth(const tsk_tree_t *self, tsk_id_t u, int *depth_ret) +int +tsk_treeseq_r2_ij(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result) { int ret = 0; - - ret = tsk_tree_check_node(self, u); + ret = check_sample_stat_inputs(num_sample_sets, 2, num_index_tuples, index_tuples); if (ret != 0) { goto out; } - - *depth_ret = tsk_tree_get_depth_unsafe(self, u); + ret = tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_index_tuples, index_tuples, r2_ij_summary_func, + norm_hap_weighted_ij, num_rows, row_sites, row_positions, num_cols, col_sites, + col_positions, options, result); out: return ret; } -static tsk_id_t -tsk_tree_node_root(tsk_tree_t *self, tsk_id_t u) -{ - tsk_id_t v = u; - while (self->parent[v] != TSK_NULL) { - v = self->parent[v]; - } +/*********************************** + * Three way stats + ***********************************/ + +static int +Y3_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state, + tsk_size_t result_dim, double *result, void *params) +{ + sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + const double *x = state; + double ni, nj, nk, denom, numer; + tsk_id_t i, j, k; + tsk_size_t tuple_index; + + for (tuple_index = 0; tuple_index < result_dim; tuple_index++) { + i = args.set_indexes[3 * tuple_index]; + j = args.set_indexes[3 * tuple_index + 1]; + k = args.set_indexes[3 * tuple_index + 2]; + ni = (double) args.sample_set_sizes[i]; + nj = (double) args.sample_set_sizes[j]; + nk = (double) args.sample_set_sizes[k]; + denom = ni * nj * nk; + numer = x[i] * (nj - x[j]) * (nk - x[k]); + result[tuple_index] = numer / denom; + } + return 0; +} + +int +tsk_treeseq_Y3(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_windows, + const double *windows, tsk_flags_t options, double *result) +{ + int ret = 0; + ret = check_sample_stat_inputs(num_sample_sets, 3, num_index_tuples, index_tuples); + if (ret != 0) { + goto out; + } + ret = tsk_treeseq_sample_count_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_index_tuples, index_tuples, Y3_summary_func, num_windows, + windows, options, result); +out: + return ret; +} + +static int +f3_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state, + tsk_size_t result_dim, double *result, void *params) +{ + sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + const double *x = state; + double ni, nj, nk, denom, numer; + tsk_id_t i, j, k; + tsk_size_t tuple_index; + + for (tuple_index = 0; tuple_index < result_dim; tuple_index++) { + i = args.set_indexes[3 * tuple_index]; + j = args.set_indexes[3 * tuple_index + 1]; + k = args.set_indexes[3 * tuple_index + 2]; + ni = (double) args.sample_set_sizes[i]; + nj = (double) args.sample_set_sizes[j]; + nk = (double) args.sample_set_sizes[k]; + denom = ni * (ni - 1) * nj * nk; + numer = x[i] * (x[i] - 1) * (nj - x[j]) * (nk - x[k]) + - x[i] * (ni - x[i]) * (nj - x[j]) * x[k]; + result[tuple_index] = numer / denom; + } + return 0; +} + +int +tsk_treeseq_f3(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_windows, + const double *windows, tsk_flags_t options, double *result) +{ + int ret = 0; + ret = check_sample_stat_inputs(num_sample_sets, 3, num_index_tuples, index_tuples); + if (ret != 0) { + goto out; + } + ret = tsk_treeseq_sample_count_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_index_tuples, index_tuples, f3_summary_func, num_windows, + windows, options, result); +out: + return ret; +} + +/*********************************** + * Four way stats + ***********************************/ + +static int +f4_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state, + tsk_size_t result_dim, double *result, void *params) +{ + sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + const double *x = state; + double ni, nj, nk, nl, denom, numer; + tsk_id_t i, j, k, l; + tsk_size_t tuple_index; + + for (tuple_index = 0; tuple_index < result_dim; tuple_index++) { + i = args.set_indexes[4 * tuple_index]; + j = args.set_indexes[4 * tuple_index + 1]; + k = args.set_indexes[4 * tuple_index + 2]; + l = args.set_indexes[4 * tuple_index + 3]; + ni = (double) args.sample_set_sizes[i]; + nj = (double) args.sample_set_sizes[j]; + nk = (double) args.sample_set_sizes[k]; + nl = (double) args.sample_set_sizes[l]; + denom = ni * nj * nk * nl; + numer = x[i] * x[k] * (nj - x[j]) * (nl - x[l]) + - x[i] * x[l] * (nj - x[j]) * (nk - x[k]); + result[tuple_index] = numer / denom; + } + return 0; +} + +int +tsk_treeseq_f4(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_windows, + const double *windows, tsk_flags_t options, double *result) +{ + int ret = 0; + ret = check_sample_stat_inputs(num_sample_sets, 4, num_index_tuples, index_tuples); + if (ret != 0) { + goto out; + } + ret = tsk_treeseq_sample_count_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_index_tuples, index_tuples, f4_summary_func, num_windows, + windows, options, result); +out: + return ret; +} + +/* Error-raising getter functions */ + +int TSK_WARN_UNUSED +tsk_treeseq_get_node(const tsk_treeseq_t *self, tsk_id_t index, tsk_node_t *node) +{ + return tsk_node_table_get_row(&self->tables->nodes, index, node); +} + +int TSK_WARN_UNUSED +tsk_treeseq_get_edge(const tsk_treeseq_t *self, tsk_id_t index, tsk_edge_t *edge) +{ + return tsk_edge_table_get_row(&self->tables->edges, index, edge); +} + +int TSK_WARN_UNUSED +tsk_treeseq_get_migration( + const tsk_treeseq_t *self, tsk_id_t index, tsk_migration_t *migration) +{ + return tsk_migration_table_get_row(&self->tables->migrations, index, migration); +} + +int TSK_WARN_UNUSED +tsk_treeseq_get_mutation( + const tsk_treeseq_t *self, tsk_id_t index, tsk_mutation_t *mutation) +{ + int ret = 0; + + ret = tsk_mutation_table_get_row(&self->tables->mutations, index, mutation); + if (ret != 0) { + goto out; + } + mutation->edge = self->site_mutations_mem[index].edge; + mutation->inherited_state = self->site_mutations_mem[index].inherited_state; + mutation->inherited_state_length + = self->site_mutations_mem[index].inherited_state_length; +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_treeseq_get_site(const tsk_treeseq_t *self, tsk_id_t index, tsk_site_t *site) +{ + int ret = 0; + + ret = tsk_site_table_get_row(&self->tables->sites, index, site); + if (ret != 0) { + goto out; + } + site->mutations = self->site_mutations[index]; + site->mutations_length = self->site_mutations_length[index]; +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_treeseq_get_individual( + const tsk_treeseq_t *self, tsk_id_t index, tsk_individual_t *individual) +{ + int ret = 0; + + ret = tsk_individual_table_get_row(&self->tables->individuals, index, individual); + if (ret != 0) { + goto out; + } + individual->nodes = self->individual_nodes[index]; + individual->nodes_length = self->individual_nodes_length[index]; +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_treeseq_get_population( + const tsk_treeseq_t *self, tsk_id_t index, tsk_population_t *population) +{ + return tsk_population_table_get_row(&self->tables->populations, index, population); +} + +int TSK_WARN_UNUSED +tsk_treeseq_get_provenance( + const tsk_treeseq_t *self, tsk_id_t index, tsk_provenance_t *provenance) +{ + return tsk_provenance_table_get_row(&self->tables->provenances, index, provenance); +} + +int TSK_WARN_UNUSED +tsk_treeseq_simplify(const tsk_treeseq_t *self, const tsk_id_t *samples, + tsk_size_t num_samples, tsk_flags_t options, tsk_treeseq_t *output, + tsk_id_t *node_map) +{ + int ret = 0; + tsk_table_collection_t *tables = tsk_malloc(sizeof(*tables)); + + if (tables == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + ret = tsk_treeseq_copy_tables(self, tables, 0); + if (ret != 0) { + goto out; + } + ret = tsk_table_collection_simplify(tables, samples, num_samples, options, node_map); + if (ret != 0) { + goto out; + } + ret = tsk_treeseq_init( + output, tables, TSK_TS_INIT_BUILD_INDEXES | TSK_TAKE_OWNERSHIP); + /* Once tsk_treeseq_init has returned ownership of tables is transferred */ + tables = NULL; +out: + if (tables != NULL) { + tsk_table_collection_free(tables); + tsk_safe_free(tables); + } + return ret; +} + +int TSK_WARN_UNUSED +tsk_treeseq_split_edges(const tsk_treeseq_t *self, double time, tsk_flags_t flags, + tsk_id_t population, const char *metadata, tsk_size_t metadata_length, + tsk_flags_t TSK_UNUSED(options), tsk_treeseq_t *output) +{ + int ret = 0; + tsk_table_collection_t *tables = tsk_malloc(sizeof(*tables)); + const double *restrict node_time = self->tables->nodes.time; + const tsk_size_t num_edges = self->tables->edges.num_rows; + const tsk_size_t num_mutations = self->tables->mutations.num_rows; + tsk_id_t *split_edge = tsk_malloc(num_edges * sizeof(*split_edge)); + tsk_id_t j, u, mapped_node, ret_id; + double mutation_time; + tsk_edge_t edge; + tsk_mutation_t mutation; + tsk_bookmark_t sort_start; + + memset(output, 0, sizeof(*output)); + if (split_edge == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + ret = tsk_treeseq_copy_tables(self, tables, 0); + if (ret != 0) { + goto out; + } + if (tables->migrations.num_rows > 0) { + ret = tsk_trace_error(TSK_ERR_MIGRATIONS_NOT_SUPPORTED); + goto out; + } + /* We could catch this below in add_row, but it's simpler to guarantee + * that we always catch the error in corner cases where the values + * aren't used. */ + if (population < -1 || population >= (tsk_id_t) self->tables->populations.num_rows) { + ret = tsk_trace_error(TSK_ERR_POPULATION_OUT_OF_BOUNDS); + goto out; + } + if (!tsk_isfinite(time)) { + ret = tsk_trace_error(TSK_ERR_TIME_NONFINITE); + goto out; + } + + tsk_edge_table_clear(&tables->edges); + tsk_memset(split_edge, TSK_NULL, num_edges * sizeof(*split_edge)); + + for (j = 0; j < (tsk_id_t) num_edges; j++) { + /* Would prefer to use tsk_edge_table_get_row_unsafe, but it's + * currently static to tables.c */ + ret = tsk_edge_table_get_row(&self->tables->edges, j, &edge); + tsk_bug_assert(ret == 0); + if (node_time[edge.child] < time && time < node_time[edge.parent]) { + u = tsk_node_table_add_row(&tables->nodes, flags, time, population, TSK_NULL, + metadata, metadata_length); + if (u < 0) { + ret = (int) u; + goto out; + } + ret_id = tsk_edge_table_add_row(&tables->edges, edge.left, edge.right, u, + edge.child, edge.metadata, edge.metadata_length); + if (ret_id < 0) { + ret = (int) ret_id; + goto out; + } + edge.child = u; + split_edge[j] = u; + } + ret_id = tsk_edge_table_add_row(&tables->edges, edge.left, edge.right, + edge.parent, edge.child, edge.metadata, edge.metadata_length); + if (ret_id < 0) { + ret = (int) ret_id; + goto out; + } + } + + for (j = 0; j < (tsk_id_t) num_mutations; j++) { + /* Note: we could speed this up a bit by accessing the local + * memory for mutations directly. */ + ret = tsk_treeseq_get_mutation(self, j, &mutation); + tsk_bug_assert(ret == 0); + mapped_node = TSK_NULL; + if (mutation.edge != TSK_NULL) { + mapped_node = split_edge[mutation.edge]; + } + mutation_time = tsk_is_unknown_time(mutation.time) ? node_time[mutation.node] + : mutation.time; + if (mapped_node != TSK_NULL && mutation_time >= time) { + /* Update the column in-place to save a bit of time. */ + tables->mutations.node[j] = mapped_node; + } + } + + /* Skip mutations and sites as they haven't been altered */ + /* Note we can probably optimise the edge sort a bit here also by + * reasoning about when the first edge gets altered in the table. + */ + memset(&sort_start, 0, sizeof(sort_start)); + sort_start.sites = tables->sites.num_rows; + sort_start.mutations = tables->mutations.num_rows; + ret = tsk_table_collection_sort(tables, &sort_start, 0); + if (ret != 0) { + goto out; + } + + ret = tsk_treeseq_init( + output, tables, TSK_TS_INIT_BUILD_INDEXES | TSK_TAKE_OWNERSHIP); + tables = NULL; +out: + if (tables != NULL) { + tsk_table_collection_free(tables); + tsk_safe_free(tables); + } + tsk_safe_free(split_edge); + return ret; +} + +/* ======================================================== * + * tree_position + * ======================================================== */ + +static void +tsk_tree_position_set_null(tsk_tree_position_t *self) +{ + self->index = -1; + self->interval.left = 0; + self->interval.right = 0; +} + +int +tsk_tree_position_init(tsk_tree_position_t *self, const tsk_treeseq_t *tree_sequence, + tsk_flags_t TSK_UNUSED(options)) +{ + memset(self, 0, sizeof(*self)); + self->tree_sequence = tree_sequence; + tsk_tree_position_set_null(self); + return 0; +} + +int +tsk_tree_position_free(tsk_tree_position_t *TSK_UNUSED(self)) +{ + return 0; +} + +int +tsk_tree_position_print_state(const tsk_tree_position_t *self, FILE *out) +{ + fprintf(out, "Tree position state\n"); + fprintf(out, "index = %d\n", (int) self->index); + fprintf(out, "interval = [%f,\t%f)\n", self->interval.left, self->interval.right); + fprintf( + out, "out = start=%d\tstop=%d\n", (int) self->out.start, (int) self->out.stop); + fprintf( + out, "in = start=%d\tstop=%d\n", (int) self->in.start, (int) self->in.stop); + return 0; +} + +bool +tsk_tree_position_next(tsk_tree_position_t *self) +{ + const tsk_table_collection_t *tables = self->tree_sequence->tables; + const tsk_id_t M = (tsk_id_t) tables->edges.num_rows; + const tsk_id_t num_trees = (tsk_id_t) self->tree_sequence->num_trees; + const double *restrict left_coords = tables->edges.left; + const tsk_id_t *restrict left_order = tables->indexes.edge_insertion_order; + const double *restrict right_coords = tables->edges.right; + const tsk_id_t *restrict right_order = tables->indexes.edge_removal_order; + const double *restrict breakpoints = self->tree_sequence->breakpoints; + tsk_id_t j, left_current_index, right_current_index; + double left; + + if (self->index == -1) { + self->interval.right = 0; + self->in.stop = 0; + self->out.stop = 0; + self->direction = TSK_DIR_FORWARD; + } + + if (self->direction == TSK_DIR_FORWARD) { + left_current_index = self->in.stop; + right_current_index = self->out.stop; + } else { + left_current_index = self->out.stop + 1; + right_current_index = self->in.stop + 1; + } + + left = self->interval.right; + + j = right_current_index; + self->out.start = j; + while (j < M && right_coords[right_order[j]] == left) { + j++; + } + self->out.stop = j; + self->out.order = right_order; + + j = left_current_index; + self->in.start = j; + while (j < M && left_coords[left_order[j]] == left) { + j++; + } + self->in.stop = j; + self->in.order = left_order; + + self->direction = TSK_DIR_FORWARD; + self->index++; + if (self->index == num_trees) { + tsk_tree_position_set_null(self); + } else { + self->interval.left = left; + self->interval.right = breakpoints[self->index + 1]; + } + return self->index != -1; +} + +bool +tsk_tree_position_prev(tsk_tree_position_t *self) +{ + const tsk_table_collection_t *tables = self->tree_sequence->tables; + const tsk_id_t M = (tsk_id_t) tables->edges.num_rows; + const double sequence_length = tables->sequence_length; + const tsk_id_t num_trees = (tsk_id_t) self->tree_sequence->num_trees; + const double *restrict left_coords = tables->edges.left; + const tsk_id_t *restrict left_order = tables->indexes.edge_insertion_order; + const double *restrict right_coords = tables->edges.right; + const tsk_id_t *restrict right_order = tables->indexes.edge_removal_order; + const double *restrict breakpoints = self->tree_sequence->breakpoints; + tsk_id_t j, left_current_index, right_current_index; + double right; + + if (self->index == -1) { + self->index = num_trees; + self->interval.left = sequence_length; + self->in.stop = M - 1; + self->out.stop = M - 1; + self->direction = TSK_DIR_REVERSE; + } + + if (self->direction == TSK_DIR_REVERSE) { + left_current_index = self->out.stop; + right_current_index = self->in.stop; + } else { + left_current_index = self->in.stop - 1; + right_current_index = self->out.stop - 1; + } + + right = self->interval.left; + + j = left_current_index; + self->out.start = j; + while (j >= 0 && left_coords[left_order[j]] == right) { + j--; + } + self->out.stop = j; + self->out.order = left_order; + + j = right_current_index; + self->in.start = j; + while (j >= 0 && right_coords[right_order[j]] == right) { + j--; + } + self->in.stop = j; + self->in.order = right_order; + + self->index--; + self->direction = TSK_DIR_REVERSE; + if (self->index == -1) { + tsk_tree_position_set_null(self); + } else { + self->interval.left = breakpoints[self->index]; + self->interval.right = right; + } + return self->index != -1; +} + +int TSK_WARN_UNUSED +tsk_tree_position_seek_forward(tsk_tree_position_t *self, tsk_id_t index) +{ + int ret = 0; + const tsk_table_collection_t *tables = self->tree_sequence->tables; + const tsk_id_t M = (tsk_id_t) tables->edges.num_rows; + const tsk_id_t num_trees = (tsk_id_t) self->tree_sequence->num_trees; + const double *restrict left_coords = tables->edges.left; + const tsk_id_t *restrict left_order = tables->indexes.edge_insertion_order; + const double *restrict right_coords = tables->edges.right; + const tsk_id_t *restrict right_order = tables->indexes.edge_removal_order; + const double *restrict breakpoints = self->tree_sequence->breakpoints; + tsk_id_t j, left_current_index, right_current_index; + double left; + + tsk_bug_assert(index >= self->index && index < num_trees); + + if (self->index == -1) { + self->interval.right = 0; + self->in.stop = 0; + self->out.stop = 0; + self->direction = TSK_DIR_FORWARD; + } + + if (self->direction == TSK_DIR_FORWARD) { + left_current_index = self->in.stop; + right_current_index = self->out.stop; + } else { + left_current_index = self->out.stop + 1; + right_current_index = self->in.stop + 1; + } + + self->direction = TSK_DIR_FORWARD; + left = breakpoints[index]; + + j = right_current_index; + self->out.start = j; + while (j < M && right_coords[right_order[j]] <= left) { + j++; + } + self->out.stop = j; + + if (self->index == -1) { + self->out.start = self->out.stop; + } + + j = left_current_index; + while (j < M && right_coords[left_order[j]] <= left) { + j++; + } + self->in.start = j; + while (j < M && left_coords[left_order[j]] <= left) { + j++; + } + self->in.stop = j; + + self->interval.left = left; + self->interval.right = breakpoints[index + 1]; + self->out.order = right_order; + self->in.order = left_order; + self->index = index; + return ret; +} + +int TSK_WARN_UNUSED +tsk_tree_position_seek_backward(tsk_tree_position_t *self, tsk_id_t index) +{ + int ret = 0; + const tsk_table_collection_t *tables = self->tree_sequence->tables; + const tsk_id_t M = (tsk_id_t) tables->edges.num_rows; + const double sequence_length = tables->sequence_length; + const tsk_id_t num_trees = (tsk_id_t) self->tree_sequence->num_trees; + const double *restrict left_coords = tables->edges.left; + const tsk_id_t *restrict left_order = tables->indexes.edge_insertion_order; + const double *restrict right_coords = tables->edges.right; + const tsk_id_t *restrict right_order = tables->indexes.edge_removal_order; + const double *restrict breakpoints = self->tree_sequence->breakpoints; + tsk_id_t j, left_current_index, right_current_index; + double right; + + if (self->index == -1) { + self->index = num_trees; + self->interval.left = sequence_length; + self->in.stop = M - 1; + self->out.stop = M - 1; + self->direction = TSK_DIR_REVERSE; + } + tsk_bug_assert(index <= self->index); + + if (self->direction == TSK_DIR_REVERSE) { + left_current_index = self->out.stop; + right_current_index = self->in.stop; + } else { + left_current_index = self->in.stop - 1; + right_current_index = self->out.stop - 1; + } + + self->direction = TSK_DIR_REVERSE; + right = breakpoints[index + 1]; + + j = left_current_index; + self->out.start = j; + while (j >= 0 && left_coords[left_order[j]] >= right) { + j--; + } + self->out.stop = j; + + if (self->index == num_trees) { + self->out.start = self->out.stop; + } + + j = right_current_index; + while (j >= 0 && left_coords[right_order[j]] >= right) { + j--; + } + self->in.start = j; + while (j >= 0 && right_coords[right_order[j]] >= right) { + j--; + } + self->in.stop = j; + + self->interval.right = right; + self->interval.left = breakpoints[index]; + self->out.order = left_order; + self->in.order = right_order; + self->index = index; + + return ret; +} + +/* ======================================================== * + * Tree + * ======================================================== */ + +/* Return the root for the specified node. + * NOTE: no bounds checking is done here. + */ +static tsk_id_t +tsk_tree_get_node_root(const tsk_tree_t *self, tsk_id_t u) +{ + const tsk_id_t *restrict parent = self->parent; + + while (parent[u] != TSK_NULL) { + u = parent[u]; + } + return u; +} + +int TSK_WARN_UNUSED +tsk_tree_init(tsk_tree_t *self, const tsk_treeseq_t *tree_sequence, tsk_flags_t options) +{ + int ret = 0; + tsk_size_t num_samples, num_nodes, N; + + tsk_memset(self, 0, sizeof(tsk_tree_t)); + if (tree_sequence == NULL) { + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); + goto out; + } + num_nodes = tree_sequence->tables->nodes.num_rows; + num_samples = tree_sequence->num_samples; + self->num_nodes = num_nodes; + self->virtual_root = (tsk_id_t) num_nodes; + self->tree_sequence = tree_sequence; + self->samples = tree_sequence->samples; + self->options = options; + self->root_threshold = 1; + + /* Allocate space in the quintuply linked tree for the virtual root */ + N = num_nodes + 1; + self->parent = tsk_malloc(N * sizeof(*self->parent)); + self->left_child = tsk_malloc(N * sizeof(*self->left_child)); + self->right_child = tsk_malloc(N * sizeof(*self->right_child)); + self->left_sib = tsk_malloc(N * sizeof(*self->left_sib)); + self->right_sib = tsk_malloc(N * sizeof(*self->right_sib)); + self->num_children = tsk_calloc(N, sizeof(*self->num_children)); + self->edge = tsk_malloc(N * sizeof(*self->edge)); + if (self->parent == NULL || self->left_child == NULL || self->right_child == NULL + || self->left_sib == NULL || self->right_sib == NULL + || self->num_children == NULL || self->edge == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + if (!(self->options & TSK_NO_SAMPLE_COUNTS)) { + self->num_samples = tsk_calloc(N, sizeof(*self->num_samples)); + self->num_tracked_samples = tsk_calloc(N, sizeof(*self->num_tracked_samples)); + if (self->num_samples == NULL || self->num_tracked_samples == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + } + if (self->options & TSK_SAMPLE_LISTS) { + self->left_sample = tsk_malloc(N * sizeof(*self->left_sample)); + self->right_sample = tsk_malloc(N * sizeof(*self->right_sample)); + self->next_sample = tsk_malloc(num_samples * sizeof(*self->next_sample)); + if (self->left_sample == NULL || self->right_sample == NULL + || self->next_sample == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + } + + ret = tsk_tree_position_init(&self->tree_pos, tree_sequence, 0); + if (ret != 0) { + goto out; + } + ret = tsk_tree_clear(self); +out: + return ret; +} + +int +tsk_tree_set_root_threshold(tsk_tree_t *self, tsk_size_t root_threshold) +{ + int ret = 0; + + if (root_threshold == 0) { + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); + goto out; + } + /* Don't allow the value to be set when the tree is out of the null + * state */ + if (self->index != -1) { + ret = tsk_trace_error(TSK_ERR_UNSUPPORTED_OPERATION); + goto out; + } + self->root_threshold = root_threshold; + /* Reset the roots */ + ret = tsk_tree_clear(self); +out: + return ret; +} + +tsk_size_t +tsk_tree_get_root_threshold(const tsk_tree_t *self) +{ + return self->root_threshold; +} + +int +tsk_tree_free(tsk_tree_t *self) +{ + tsk_safe_free(self->parent); + tsk_safe_free(self->left_child); + tsk_safe_free(self->right_child); + tsk_safe_free(self->left_sib); + tsk_safe_free(self->right_sib); + tsk_safe_free(self->num_samples); + tsk_safe_free(self->num_tracked_samples); + tsk_safe_free(self->left_sample); + tsk_safe_free(self->right_sample); + tsk_safe_free(self->next_sample); + tsk_safe_free(self->num_children); + tsk_safe_free(self->edge); + tsk_tree_position_free(&self->tree_pos); + return 0; +} + +bool +tsk_tree_has_sample_lists(const tsk_tree_t *self) +{ + return !!(self->options & TSK_SAMPLE_LISTS); +} + +bool +tsk_tree_has_sample_counts(const tsk_tree_t *self) +{ + return !(self->options & TSK_NO_SAMPLE_COUNTS); +} + +static int TSK_WARN_UNUSED +tsk_tree_reset_tracked_samples(tsk_tree_t *self) +{ + int ret = 0; + + if (!tsk_tree_has_sample_counts(self)) { + ret = tsk_trace_error(TSK_ERR_UNSUPPORTED_OPERATION); + goto out; + } + tsk_memset(self->num_tracked_samples, 0, + (self->num_nodes + 1) * sizeof(*self->num_tracked_samples)); +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_tree_set_tracked_samples( + tsk_tree_t *self, tsk_size_t num_tracked_samples, const tsk_id_t *tracked_samples) +{ + int ret = TSK_ERR_GENERIC; + tsk_size_t *tree_num_tracked_samples = self->num_tracked_samples; + const tsk_id_t *parent = self->parent; + tsk_size_t j; + tsk_id_t u; + + /* TODO This is not needed when the tree is new. We should use the + * state machine to check and only reset the tracked samples when needed. + */ + ret = tsk_tree_reset_tracked_samples(self); + if (ret != 0) { + goto out; + } + self->num_tracked_samples[self->virtual_root] = num_tracked_samples; + for (j = 0; j < num_tracked_samples; j++) { + u = tracked_samples[j]; + if (u < 0 || u >= (tsk_id_t) self->num_nodes) { + ret = tsk_trace_error(TSK_ERR_NODE_OUT_OF_BOUNDS); + goto out; + } + if (!tsk_treeseq_is_sample(self->tree_sequence, u)) { + ret = tsk_trace_error(TSK_ERR_BAD_SAMPLES); + goto out; + } + if (self->num_tracked_samples[u] != 0) { + ret = tsk_trace_error(TSK_ERR_DUPLICATE_SAMPLE); + goto out; + } + /* Propagate this upwards */ + while (u != TSK_NULL) { + tree_num_tracked_samples[u]++; + u = parent[u]; + } + } +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_tree_track_descendant_samples(tsk_tree_t *self, tsk_id_t node) +{ + int ret = 0; + tsk_id_t *nodes = tsk_malloc(tsk_tree_get_size_bound(self) * sizeof(*nodes)); + const tsk_id_t *restrict parent = self->parent; + const tsk_id_t *restrict left_child = self->left_child; + const tsk_id_t *restrict right_sib = self->right_sib; + const tsk_flags_t *restrict flags = self->tree_sequence->tables->nodes.flags; + tsk_size_t *num_tracked_samples = self->num_tracked_samples; + tsk_size_t n, j, num_nodes; + tsk_id_t u, v; + + if (nodes == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + ret = tsk_tree_postorder_from(self, node, nodes, &num_nodes); + if (ret != 0) { + goto out; + } + ret = tsk_tree_reset_tracked_samples(self); + if (ret != 0) { + goto out; + } + u = 0; /* keep the compiler happy */ + for (j = 0; j < num_nodes; j++) { + u = nodes[j]; + for (v = left_child[u]; v != TSK_NULL; v = right_sib[v]) { + num_tracked_samples[u] += num_tracked_samples[v]; + } + num_tracked_samples[u] += flags[u] & TSK_NODE_IS_SAMPLE ? 1 : 0; + } + n = num_tracked_samples[u]; + u = parent[u]; + while (u != TSK_NULL) { + num_tracked_samples[u] = n; + u = parent[u]; + } + num_tracked_samples[self->virtual_root] = n; +out: + tsk_safe_free(nodes); + return ret; +} + +int TSK_WARN_UNUSED +tsk_tree_copy(const tsk_tree_t *self, tsk_tree_t *dest, tsk_flags_t options) +{ + int ret = TSK_ERR_GENERIC; + tsk_size_t N = self->num_nodes + 1; + + if (!(options & TSK_NO_INIT)) { + ret = tsk_tree_init(dest, self->tree_sequence, options); + if (ret != 0) { + goto out; + } + } + if (self->tree_sequence != dest->tree_sequence) { + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); + goto out; + } + dest->interval = self->interval; + dest->left_index = self->left_index; + dest->right_index = self->right_index; + dest->direction = self->direction; + dest->index = self->index; + dest->sites = self->sites; + dest->sites_length = self->sites_length; + dest->root_threshold = self->root_threshold; + dest->num_edges = self->num_edges; + dest->tree_pos = self->tree_pos; + + tsk_memcpy(dest->parent, self->parent, N * sizeof(*self->parent)); + tsk_memcpy(dest->left_child, self->left_child, N * sizeof(*self->left_child)); + tsk_memcpy(dest->right_child, self->right_child, N * sizeof(*self->right_child)); + tsk_memcpy(dest->left_sib, self->left_sib, N * sizeof(*self->left_sib)); + tsk_memcpy(dest->right_sib, self->right_sib, N * sizeof(*self->right_sib)); + tsk_memcpy(dest->num_children, self->num_children, N * sizeof(*self->num_children)); + tsk_memcpy(dest->edge, self->edge, N * sizeof(*self->edge)); + if (!(dest->options & TSK_NO_SAMPLE_COUNTS)) { + if (self->options & TSK_NO_SAMPLE_COUNTS) { + ret = tsk_trace_error(TSK_ERR_UNSUPPORTED_OPERATION); + goto out; + } + tsk_memcpy(dest->num_samples, self->num_samples, N * sizeof(*self->num_samples)); + tsk_memcpy(dest->num_tracked_samples, self->num_tracked_samples, + N * sizeof(*self->num_tracked_samples)); + } + if (dest->options & TSK_SAMPLE_LISTS) { + if (!(self->options & TSK_SAMPLE_LISTS)) { + ret = tsk_trace_error(TSK_ERR_UNSUPPORTED_OPERATION); + goto out; + } + tsk_memcpy(dest->left_sample, self->left_sample, N * sizeof(*self->left_sample)); + tsk_memcpy( + dest->right_sample, self->right_sample, N * sizeof(*self->right_sample)); + tsk_memcpy(dest->next_sample, self->next_sample, + self->tree_sequence->num_samples * sizeof(*self->next_sample)); + } + ret = 0; +out: + return ret; +} + +bool TSK_WARN_UNUSED +tsk_tree_equals(const tsk_tree_t *self, const tsk_tree_t *other) +{ + bool ret = false; + + if (self->tree_sequence == other->tree_sequence) { + ret = self->index == other->index; + } + return ret; +} + +static int +tsk_tree_check_node(const tsk_tree_t *self, tsk_id_t u) +{ + int ret = 0; + if (u < 0 || u > (tsk_id_t) self->num_nodes) { + ret = tsk_trace_error(TSK_ERR_NODE_OUT_OF_BOUNDS); + } + return ret; +} + +bool +tsk_tree_is_descendant(const tsk_tree_t *self, tsk_id_t u, tsk_id_t v) +{ + bool ret = false; + tsk_id_t w = u; + tsk_id_t *restrict parent = self->parent; + + if (tsk_tree_check_node(self, u) == 0 && tsk_tree_check_node(self, v) == 0) { + while (w != v && w != TSK_NULL) { + w = parent[w]; + } + ret = w == v; + } + return ret; +} + +int TSK_WARN_UNUSED +tsk_tree_get_mrca(const tsk_tree_t *self, tsk_id_t u, tsk_id_t v, tsk_id_t *mrca) +{ + int ret = 0; + double tu, tv; + const tsk_id_t *restrict parent = self->parent; + const double *restrict time = self->tree_sequence->tables->nodes.time; + + ret = tsk_tree_check_node(self, u); + if (ret != 0) { + goto out; + } + ret = tsk_tree_check_node(self, v); + if (ret != 0) { + goto out; + } + + /* Simplest to make the virtual_root a special case here to avoid + * doing the time lookup. */ + if (u == self->virtual_root || v == self->virtual_root) { + *mrca = self->virtual_root; + return 0; + } + + tu = time[u]; + tv = time[v]; + while (u != v) { + if (tu < tv) { + u = parent[u]; + if (u == TSK_NULL) { + break; + } + tu = time[u]; + } else { + v = parent[v]; + if (v == TSK_NULL) { + break; + } + tv = time[v]; + } + } + *mrca = u == v ? u : TSK_NULL; +out: + return ret; +} + +static int +tsk_tree_get_num_samples_by_traversal( + const tsk_tree_t *self, tsk_id_t u, tsk_size_t *num_samples) +{ + int ret = 0; + tsk_size_t num_nodes, j; + tsk_size_t count = 0; + const tsk_flags_t *restrict flags = self->tree_sequence->tables->nodes.flags; + tsk_id_t *nodes = tsk_malloc(tsk_tree_get_size_bound(self) * sizeof(*nodes)); + tsk_id_t v; + + if (nodes == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + ret = tsk_tree_preorder_from(self, u, nodes, &num_nodes); + if (ret != 0) { + goto out; + } + for (j = 0; j < num_nodes; j++) { + v = nodes[j]; + if (flags[v] & TSK_NODE_IS_SAMPLE) { + count++; + } + } + *num_samples = count; +out: + tsk_safe_free(nodes); + return ret; +} + +int TSK_WARN_UNUSED +tsk_tree_get_num_samples(const tsk_tree_t *self, tsk_id_t u, tsk_size_t *num_samples) +{ + int ret = 0; + + ret = tsk_tree_check_node(self, u); + if (ret != 0) { + goto out; + } + + if (!(self->options & TSK_NO_SAMPLE_COUNTS)) { + *num_samples = (tsk_size_t) self->num_samples[u]; + } else { + ret = tsk_tree_get_num_samples_by_traversal(self, u, num_samples); + } +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_tree_get_num_tracked_samples( + const tsk_tree_t *self, tsk_id_t u, tsk_size_t *num_tracked_samples) +{ + int ret = 0; + + ret = tsk_tree_check_node(self, u); + if (ret != 0) { + goto out; + } + if (self->options & TSK_NO_SAMPLE_COUNTS) { + ret = tsk_trace_error(TSK_ERR_UNSUPPORTED_OPERATION); + goto out; + } + *num_tracked_samples = self->num_tracked_samples[u]; +out: + return ret; +} + +bool +tsk_tree_is_sample(const tsk_tree_t *self, tsk_id_t u) +{ + return tsk_treeseq_is_sample(self->tree_sequence, u); +} + +tsk_id_t +tsk_tree_get_left_root(const tsk_tree_t *self) +{ + return self->left_child[self->virtual_root]; +} + +tsk_id_t +tsk_tree_get_right_root(const tsk_tree_t *self) +{ + return self->right_child[self->virtual_root]; +} + +tsk_size_t +tsk_tree_get_num_roots(const tsk_tree_t *self) +{ + return (tsk_size_t) self->num_children[self->virtual_root]; +} + +int TSK_WARN_UNUSED +tsk_tree_get_parent(const tsk_tree_t *self, tsk_id_t u, tsk_id_t *parent) +{ + int ret = 0; + + ret = tsk_tree_check_node(self, u); + if (ret != 0) { + goto out; + } + *parent = self->parent[u]; +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_tree_get_time(const tsk_tree_t *self, tsk_id_t u, double *t) +{ + int ret = 0; + tsk_node_t node; + + if (u == self->virtual_root) { + *t = INFINITY; + } else { + ret = tsk_treeseq_get_node(self->tree_sequence, u, &node); + if (ret != 0) { + goto out; + } + *t = node.time; + } +out: + return ret; +} + +static inline double +tsk_tree_get_branch_length_unsafe(const tsk_tree_t *self, tsk_id_t u) +{ + const double *times = self->tree_sequence->tables->nodes.time; + const tsk_id_t parent = self->parent[u]; + + return parent == TSK_NULL ? 0 : times[parent] - times[u]; +} + +int TSK_WARN_UNUSED +tsk_tree_get_branch_length(const tsk_tree_t *self, tsk_id_t u, double *ret_branch_length) +{ + int ret = 0; + + ret = tsk_tree_check_node(self, u); + if (ret != 0) { + goto out; + } + *ret_branch_length = tsk_tree_get_branch_length_unsafe(self, u); +out: + return ret; +} + +int +tsk_tree_get_total_branch_length(const tsk_tree_t *self, tsk_id_t node, double *ret_tbl) +{ + int ret = 0; + tsk_size_t j, num_nodes; + tsk_id_t u, v; + const tsk_id_t *restrict parent = self->parent; + const double *restrict time = self->tree_sequence->tables->nodes.time; + tsk_id_t *nodes = tsk_malloc(tsk_tree_get_size_bound(self) * sizeof(*nodes)); + double sum = 0; + + if (nodes == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + ret = tsk_tree_preorder_from(self, node, nodes, &num_nodes); + if (ret != 0) { + goto out; + } + /* We always skip the first node because we don't return the branch length + * over the input node. */ + for (j = 1; j < num_nodes; j++) { + u = nodes[j]; + v = parent[u]; + if (v != TSK_NULL) { + sum += time[v] - time[u]; + } + } + *ret_tbl = sum; +out: + tsk_safe_free(nodes); + return ret; +} + +int TSK_WARN_UNUSED +tsk_tree_get_sites( + const tsk_tree_t *self, const tsk_site_t **sites, tsk_size_t *sites_length) +{ + *sites = self->sites; + *sites_length = self->sites_length; + return 0; +} + +/* u must be a valid node in the tree. For internal use */ +static int +tsk_tree_get_depth_unsafe(const tsk_tree_t *self, tsk_id_t u) +{ + tsk_id_t v; + const tsk_id_t *restrict parent = self->parent; + int depth = 0; + + if (u == self->virtual_root) { + return -1; + } + for (v = parent[u]; v != TSK_NULL; v = parent[v]) { + depth++; + } + return depth; +} + +int TSK_WARN_UNUSED +tsk_tree_get_depth(const tsk_tree_t *self, tsk_id_t u, int *depth_ret) +{ + int ret = 0; + + ret = tsk_tree_check_node(self, u); + if (ret != 0) { + goto out; + } + + *depth_ret = tsk_tree_get_depth_unsafe(self, u); +out: + return ret; +} + +static tsk_id_t +tsk_tree_node_root(tsk_tree_t *self, tsk_id_t u) +{ + tsk_id_t v = u; + while (self->parent[v] != TSK_NULL) { + v = self->parent[v]; + } + + return v; +} + +static void +tsk_tree_check_state(const tsk_tree_t *self) +{ + tsk_id_t u, v; + tsk_size_t j, num_samples; + int err, c; + tsk_site_t site; + tsk_id_t *children = tsk_malloc(self->num_nodes * sizeof(tsk_id_t)); + bool *is_root = tsk_calloc(self->num_nodes, sizeof(bool)); + + tsk_bug_assert(children != NULL); + + /* Check the virtual root properties */ + tsk_bug_assert(self->parent[self->virtual_root] == TSK_NULL); + tsk_bug_assert(self->left_sib[self->virtual_root] == TSK_NULL); + tsk_bug_assert(self->right_sib[self->virtual_root] == TSK_NULL); + + for (j = 0; j < self->tree_sequence->num_samples; j++) { + u = self->samples[j]; + while (self->parent[u] != TSK_NULL) { + u = self->parent[u]; + } + is_root[u] = true; + } + if (self->tree_sequence->num_samples == 0) { + tsk_bug_assert(self->left_child[self->virtual_root] == TSK_NULL); + } + + /* Iterate over the roots and make sure they are set */ + for (u = tsk_tree_get_left_root(self); u != TSK_NULL; u = self->right_sib[u]) { + tsk_bug_assert(is_root[u]); + is_root[u] = false; + } + for (u = 0; u < (tsk_id_t) self->num_nodes; u++) { + tsk_bug_assert(!is_root[u]); + c = 0; + for (v = self->left_child[u]; v != TSK_NULL; v = self->right_sib[v]) { + tsk_bug_assert(self->parent[v] == u); + children[c] = v; + c++; + } + for (v = self->right_child[u]; v != TSK_NULL; v = self->left_sib[v]) { + tsk_bug_assert(c > 0); + c--; + tsk_bug_assert(v == children[c]); + } + } + for (j = 0; j < self->sites_length; j++) { + site = self->sites[j]; + tsk_bug_assert(self->interval.left <= site.position); + tsk_bug_assert(site.position < self->interval.right); + } + + if (!(self->options & TSK_NO_SAMPLE_COUNTS)) { + tsk_bug_assert(self->num_samples != NULL); + tsk_bug_assert(self->num_tracked_samples != NULL); + for (u = 0; u < (tsk_id_t) self->num_nodes; u++) { + err = tsk_tree_get_num_samples_by_traversal(self, u, &num_samples); + tsk_bug_assert(err == 0); + tsk_bug_assert(num_samples == (tsk_size_t) self->num_samples[u]); + } + } else { + tsk_bug_assert(self->num_samples == NULL); + tsk_bug_assert(self->num_tracked_samples == NULL); + } + if (self->options & TSK_SAMPLE_LISTS) { + tsk_bug_assert(self->right_sample != NULL); + tsk_bug_assert(self->left_sample != NULL); + tsk_bug_assert(self->next_sample != NULL); + } else { + tsk_bug_assert(self->right_sample == NULL); + tsk_bug_assert(self->left_sample == NULL); + tsk_bug_assert(self->next_sample == NULL); + } + + free(children); + free(is_root); +} + +void +tsk_tree_print_state(const tsk_tree_t *self, FILE *out) +{ + tsk_size_t j; + tsk_site_t site; + + fprintf(out, "Tree state:\n"); + fprintf(out, "options = %d\n", self->options); + fprintf(out, "root_threshold = %lld\n", (long long) self->root_threshold); + fprintf(out, "left = %f\n", self->interval.left); + fprintf(out, "right = %f\n", self->interval.right); + fprintf(out, "index = %lld\n", (long long) self->index); + fprintf(out, "num_edges = %d\n", (int) self->num_edges); + fprintf(out, "node\tedge\tparent\tlchild\trchild\tlsib\trsib"); + if (self->options & TSK_SAMPLE_LISTS) { + fprintf(out, "\thead\ttail"); + } + fprintf(out, "\n"); + + for (j = 0; j < self->num_nodes + 1; j++) { + fprintf(out, "%lld\t%lld\t%lld\t%lld\t%lld\t%lld\t%lld", (long long) j, + (long long) self->edge[j], (long long) self->parent[j], + (long long) self->left_child[j], (long long) self->right_child[j], + (long long) self->left_sib[j], (long long) self->right_sib[j]); + if (self->options & TSK_SAMPLE_LISTS) { + fprintf(out, "\t%lld\t%lld\t", (long long) self->left_sample[j], + (long long) self->right_sample[j]); + } + if (!(self->options & TSK_NO_SAMPLE_COUNTS)) { + fprintf(out, "\t%lld\t%lld", (long long) self->num_samples[j], + (long long) self->num_tracked_samples[j]); + } + fprintf(out, "\n"); + } + fprintf(out, "sites = \n"); + for (j = 0; j < self->sites_length; j++) { + site = self->sites[j]; + fprintf(out, "\t%lld\t%f\n", (long long) site.id, site.position); + } + tsk_tree_check_state(self); +} + +/* Methods for positioning the tree along the sequence */ + +/* The following methods are performance sensitive and so we use a + * lot of restrict pointers. Because we are saying that we don't have + * any aliases to these pointers, we pass around the reference to parent + * since it's used in all the functions. */ +static inline void +tsk_tree_update_sample_lists( + tsk_tree_t *self, tsk_id_t node, const tsk_id_t *restrict parent) +{ + tsk_id_t u, v, sample_index; + tsk_id_t *restrict left_child = self->left_child; + tsk_id_t *restrict right_sib = self->right_sib; + tsk_id_t *restrict left = self->left_sample; + tsk_id_t *restrict right = self->right_sample; + tsk_id_t *restrict next = self->next_sample; + const tsk_id_t *restrict sample_index_map = self->tree_sequence->sample_index_map; + + for (u = node; u != TSK_NULL; u = parent[u]) { + sample_index = sample_index_map[u]; + if (sample_index != TSK_NULL) { + right[u] = left[u]; + } else { + left[u] = TSK_NULL; + right[u] = TSK_NULL; + } + for (v = left_child[u]; v != TSK_NULL; v = right_sib[v]) { + if (left[v] != TSK_NULL) { + tsk_bug_assert(right[v] != TSK_NULL); + if (left[u] == TSK_NULL) { + left[u] = left[v]; + right[u] = right[v]; + } else { + next[right[u]] = left[v]; + right[u] = right[v]; + } + } + } + } +} + +static inline void +tsk_tree_remove_branch( + tsk_tree_t *self, tsk_id_t p, tsk_id_t c, tsk_id_t *restrict parent) +{ + tsk_id_t *restrict left_child = self->left_child; + tsk_id_t *restrict right_child = self->right_child; + tsk_id_t *restrict left_sib = self->left_sib; + tsk_id_t *restrict right_sib = self->right_sib; + tsk_id_t *restrict num_children = self->num_children; + tsk_id_t lsib = left_sib[c]; + tsk_id_t rsib = right_sib[c]; + + if (lsib == TSK_NULL) { + left_child[p] = rsib; + } else { + right_sib[lsib] = rsib; + } + if (rsib == TSK_NULL) { + right_child[p] = lsib; + } else { + left_sib[rsib] = lsib; + } + parent[c] = TSK_NULL; + left_sib[c] = TSK_NULL; + right_sib[c] = TSK_NULL; + num_children[p]--; +} + +static inline void +tsk_tree_insert_branch( + tsk_tree_t *self, tsk_id_t p, tsk_id_t c, tsk_id_t *restrict parent) +{ + tsk_id_t *restrict left_child = self->left_child; + tsk_id_t *restrict right_child = self->right_child; + tsk_id_t *restrict left_sib = self->left_sib; + tsk_id_t *restrict right_sib = self->right_sib; + tsk_id_t *restrict num_children = self->num_children; + tsk_id_t u; + + parent[c] = p; + u = right_child[p]; + if (u == TSK_NULL) { + left_child[p] = c; + left_sib[c] = TSK_NULL; + right_sib[c] = TSK_NULL; + } else { + right_sib[u] = c; + left_sib[c] = u; + right_sib[c] = TSK_NULL; + } + right_child[p] = c; + num_children[p]++; +} + +static inline void +tsk_tree_insert_root(tsk_tree_t *self, tsk_id_t root, tsk_id_t *restrict parent) +{ + tsk_tree_insert_branch(self, self->virtual_root, root, parent); + parent[root] = TSK_NULL; +} + +static inline void +tsk_tree_remove_root(tsk_tree_t *self, tsk_id_t root, tsk_id_t *restrict parent) +{ + tsk_tree_remove_branch(self, self->virtual_root, root, parent); +} + +static void +tsk_tree_remove_edge( + tsk_tree_t *self, tsk_id_t p, tsk_id_t c, tsk_id_t TSK_UNUSED(edge_id)) +{ + tsk_id_t *restrict parent = self->parent; + tsk_size_t *restrict num_samples = self->num_samples; + tsk_size_t *restrict num_tracked_samples = self->num_tracked_samples; + tsk_id_t *restrict edge = self->edge; + const tsk_size_t root_threshold = self->root_threshold; + tsk_id_t u; + tsk_id_t path_end = TSK_NULL; + bool path_end_was_root = false; + +#define POTENTIAL_ROOT(U) (num_samples[U] >= root_threshold) + + tsk_tree_remove_branch(self, p, c, parent); + self->num_edges--; + edge[c] = TSK_NULL; + + if (!(self->options & TSK_NO_SAMPLE_COUNTS)) { + u = p; + while (u != TSK_NULL) { + path_end = u; + path_end_was_root = POTENTIAL_ROOT(u); + num_samples[u] -= num_samples[c]; + num_tracked_samples[u] -= num_tracked_samples[c]; + u = parent[u]; + } + + if (path_end_was_root && !POTENTIAL_ROOT(path_end)) { + tsk_tree_remove_root(self, path_end, parent); + } + if (POTENTIAL_ROOT(c)) { + tsk_tree_insert_root(self, c, parent); + } + } + + if (self->options & TSK_SAMPLE_LISTS) { + tsk_tree_update_sample_lists(self, p, parent); + } +} + +static void +tsk_tree_insert_edge(tsk_tree_t *self, tsk_id_t p, tsk_id_t c, tsk_id_t edge_id) +{ + tsk_id_t *restrict parent = self->parent; + tsk_size_t *restrict num_samples = self->num_samples; + tsk_size_t *restrict num_tracked_samples = self->num_tracked_samples; + tsk_id_t *restrict edge = self->edge; + const tsk_size_t root_threshold = self->root_threshold; + tsk_id_t u; + tsk_id_t path_end = TSK_NULL; + bool path_end_was_root = false; + +#define POTENTIAL_ROOT(U) (num_samples[U] >= root_threshold) + + if (!(self->options & TSK_NO_SAMPLE_COUNTS)) { + u = p; + while (u != TSK_NULL) { + path_end = u; + path_end_was_root = POTENTIAL_ROOT(u); + num_samples[u] += num_samples[c]; + num_tracked_samples[u] += num_tracked_samples[c]; + u = parent[u]; + } + + if (POTENTIAL_ROOT(c)) { + tsk_tree_remove_root(self, c, parent); + } + if (POTENTIAL_ROOT(path_end) && !path_end_was_root) { + tsk_tree_insert_root(self, path_end, parent); + } + } + + tsk_tree_insert_branch(self, p, c, parent); + self->num_edges++; + edge[c] = edge_id; + + if (self->options & TSK_SAMPLE_LISTS) { + tsk_tree_update_sample_lists(self, p, parent); + } +} + +int TSK_WARN_UNUSED +tsk_tree_first(tsk_tree_t *self) +{ + int ret = TSK_TREE_OK; + + ret = tsk_tree_clear(self); + if (ret != 0) { + goto out; + } + ret = tsk_tree_next(self); +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_tree_last(tsk_tree_t *self) +{ + int ret = TSK_TREE_OK; + + ret = tsk_tree_clear(self); + if (ret != 0) { + goto out; + } + ret = tsk_tree_prev(self); +out: + return ret; +} + +static void +tsk_tree_update_index_and_interval(tsk_tree_t *self) +{ + tsk_table_collection_t *tables = self->tree_sequence->tables; + + self->index = self->tree_pos.index; + self->interval.left = self->tree_pos.interval.left; + self->interval.right = self->tree_pos.interval.right; + + if (tables->sites.num_rows > 0) { + self->sites = self->tree_sequence->tree_sites[self->index]; + self->sites_length = self->tree_sequence->tree_sites_length[self->index]; + } +} + +int TSK_WARN_UNUSED +tsk_tree_next(tsk_tree_t *self) +{ + int ret = 0; + tsk_table_collection_t *tables = self->tree_sequence->tables; + const tsk_id_t *restrict edge_parent = tables->edges.parent; + const tsk_id_t *restrict edge_child = tables->edges.child; + tsk_id_t j, e; + tsk_tree_position_t tree_pos; + bool valid; + + valid = tsk_tree_position_next(&self->tree_pos); + tree_pos = self->tree_pos; + + if (valid) { + for (j = tree_pos.out.start; j != tree_pos.out.stop; j++) { + e = tree_pos.out.order[j]; + tsk_tree_remove_edge(self, edge_parent[e], edge_child[e], e); + } + + for (j = tree_pos.in.start; j != tree_pos.in.stop; j++) { + e = tree_pos.in.order[j]; + tsk_tree_insert_edge(self, edge_parent[e], edge_child[e], e); + } + ret = TSK_TREE_OK; + tsk_tree_update_index_and_interval(self); + } else { + ret = tsk_tree_clear(self); + } + return ret; +} + +int TSK_WARN_UNUSED +tsk_tree_prev(tsk_tree_t *self) +{ + int ret = 0; + tsk_table_collection_t *tables = self->tree_sequence->tables; + const tsk_id_t *restrict edge_parent = tables->edges.parent; + const tsk_id_t *restrict edge_child = tables->edges.child; + tsk_id_t j, e; + tsk_tree_position_t tree_pos; + bool valid; + + valid = tsk_tree_position_prev(&self->tree_pos); + tree_pos = self->tree_pos; + + if (valid) { + for (j = tree_pos.out.start; j != tree_pos.out.stop; j--) { + e = tree_pos.out.order[j]; + tsk_tree_remove_edge(self, edge_parent[e], edge_child[e], e); + } + + for (j = tree_pos.in.start; j != tree_pos.in.stop; j--) { + e = tree_pos.in.order[j]; + tsk_tree_insert_edge(self, edge_parent[e], edge_child[e], e); + } + ret = TSK_TREE_OK; + tsk_tree_update_index_and_interval(self); + } else { + ret = tsk_tree_clear(self); + } + return ret; +} + +static inline bool +tsk_tree_position_in_interval(const tsk_tree_t *self, double x) +{ + return self->interval.left <= x && x < self->interval.right; +} + +static int +tsk_tree_seek_from_null(tsk_tree_t *self, double x, tsk_flags_t TSK_UNUSED(options)) +{ + int ret = 0; + tsk_table_collection_t *tables = self->tree_sequence->tables; + const tsk_id_t *restrict edge_parent = tables->edges.parent; + const tsk_id_t *restrict edge_child = tables->edges.child; + const double *restrict edge_left = tables->edges.left; + const double *restrict edge_right = tables->edges.right; + double interval_left, interval_right; + const double *restrict breakpoints = self->tree_sequence->breakpoints; + const tsk_size_t num_trees = self->tree_sequence->num_trees; + const double L = tsk_treeseq_get_sequence_length(self->tree_sequence); + tsk_id_t j, e, index; + tsk_tree_position_t tree_pos; + + index = (tsk_id_t) tsk_search_sorted(breakpoints, num_trees + 1, x); + if (breakpoints[index] > x) { + index--; + } + + if (x <= L / 2.0) { + ret = tsk_tree_position_seek_forward(&self->tree_pos, index); + if (ret != 0) { + goto out; + } + // Since we are seeking from null, there are no edges to remove + tree_pos = self->tree_pos; + interval_left = tree_pos.interval.left; + for (j = tree_pos.in.start; j != tree_pos.in.stop; j++) { + e = tree_pos.in.order[j]; + if (edge_left[e] <= interval_left && interval_left < edge_right[e]) { + tsk_tree_insert_edge(self, edge_parent[e], edge_child[e], e); + } + } + } else { + ret = tsk_tree_position_seek_backward(&self->tree_pos, index); + if (ret != 0) { + goto out; + } + tree_pos = self->tree_pos; + interval_right = tree_pos.interval.right; + for (j = tree_pos.in.start; j != tree_pos.in.stop; j--) { + e = tree_pos.in.order[j]; + if (edge_right[e] >= interval_right && interval_right > edge_left[e]) { + tsk_tree_insert_edge(self, edge_parent[e], edge_child[e], e); + } + } + } + tsk_tree_update_index_and_interval(self); +out: + return ret; +} + +static int TSK_WARN_UNUSED +tsk_tree_seek_forward(tsk_tree_t *self, tsk_id_t index) +{ + int ret = 0; + tsk_table_collection_t *tables = self->tree_sequence->tables; + const tsk_id_t *restrict edge_parent = tables->edges.parent; + const tsk_id_t *restrict edge_child = tables->edges.child; + const double *restrict edge_left = tables->edges.left; + const double *restrict edge_right = tables->edges.right; + double interval_left, e_left; + const double old_right = self->interval.right; + tsk_id_t j, e; + tsk_tree_position_t tree_pos; + + ret = tsk_tree_position_seek_forward(&self->tree_pos, index); + if (ret != 0) { + goto out; + } + tree_pos = self->tree_pos; + interval_left = tree_pos.interval.left; + + for (j = tree_pos.out.start; j != tree_pos.out.stop; j++) { + e = tree_pos.out.order[j]; + e_left = edge_left[e]; + if (e_left < old_right) { + tsk_bug_assert(edge_parent[e] != TSK_NULL); + tsk_tree_remove_edge(self, edge_parent[e], edge_child[e], e); + } + tsk_bug_assert(e_left < interval_left); + } + + for (j = tree_pos.in.start; j != tree_pos.in.stop; j++) { + e = tree_pos.in.order[j]; + if (edge_left[e] <= interval_left && interval_left < edge_right[e]) { + tsk_tree_insert_edge(self, edge_parent[e], edge_child[e], e); + } + } + tsk_tree_update_index_and_interval(self); +out: + return ret; +} + +static int TSK_WARN_UNUSED +tsk_tree_seek_backward(tsk_tree_t *self, tsk_id_t index) +{ + int ret = 0; + tsk_table_collection_t *tables = self->tree_sequence->tables; + const tsk_id_t *restrict edge_parent = tables->edges.parent; + const tsk_id_t *restrict edge_child = tables->edges.child; + const double *restrict edge_left = tables->edges.left; + const double *restrict edge_right = tables->edges.right; + double interval_right, e_right; + const double old_right = self->interval.right; + tsk_id_t j, e; + tsk_tree_position_t tree_pos; + + ret = tsk_tree_position_seek_backward(&self->tree_pos, index); + if (ret != 0) { + goto out; + } + tree_pos = self->tree_pos; + interval_right = tree_pos.interval.right; + + for (j = tree_pos.out.start; j != tree_pos.out.stop; j--) { + e = tree_pos.out.order[j]; + e_right = edge_right[e]; + if (e_right >= old_right) { + tsk_bug_assert(edge_parent[e] != TSK_NULL); + tsk_tree_remove_edge(self, edge_parent[e], edge_child[e], e); + } + tsk_bug_assert(e_right > interval_right); + } + + for (j = tree_pos.in.start; j != tree_pos.in.stop; j--) { + e = tree_pos.in.order[j]; + if (edge_right[e] >= interval_right && interval_right > edge_left[e]) { + tsk_tree_insert_edge(self, edge_parent[e], edge_child[e], e); + } + } + tsk_tree_update_index_and_interval(self); +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_tree_seek_index(tsk_tree_t *self, tsk_id_t tree, tsk_flags_t options) +{ + int ret = 0; + double x; + + if (tree < 0 || tree >= (tsk_id_t) self->tree_sequence->num_trees) { + ret = tsk_trace_error(TSK_ERR_SEEK_OUT_OF_BOUNDS); + goto out; + } + x = self->tree_sequence->breakpoints[tree]; + ret = tsk_tree_seek(self, x, options); +out: + return ret; +} + +static int TSK_WARN_UNUSED +tsk_tree_seek_linear(tsk_tree_t *self, double x) +{ + const double L = tsk_treeseq_get_sequence_length(self->tree_sequence); + const double t_l = self->interval.left; + const double t_r = self->interval.right; + int ret = 0; + double distance_left, distance_right; + + if (x < t_l) { + /* |-----|-----|========|---------| */ + /* 0 x t_l t_r L */ + distance_left = t_l - x; + distance_right = L - t_r + x; + } else { + /* |------|========|------|-------| */ + /* 0 t_l t_r x L */ + distance_right = x - t_r; + distance_left = t_l + L - x; + } + if (distance_right <= distance_left) { + while (!tsk_tree_position_in_interval(self, x)) { + ret = tsk_tree_next(self); + if (ret < 0) { + goto out; + } + } + } else { + while (!tsk_tree_position_in_interval(self, x)) { + ret = tsk_tree_prev(self); + if (ret < 0) { + goto out; + } + } + } + ret = 0; +out: + return ret; +} + +static int TSK_WARN_UNUSED +tsk_tree_seek_skip(tsk_tree_t *self, double x) +{ + const double t_l = self->interval.left; + int ret = 0; + tsk_id_t index; + const tsk_size_t num_trees = self->tree_sequence->num_trees; + const double *restrict breakpoints = self->tree_sequence->breakpoints; + + index = (tsk_id_t) tsk_search_sorted(breakpoints, num_trees + 1, x); + if (breakpoints[index] > x) { + index--; + } + + if (x < t_l) { + ret = tsk_tree_seek_backward(self, index); + } else { + ret = tsk_tree_seek_forward(self, index); + } + tsk_bug_assert(tsk_tree_position_in_interval(self, x)); + return ret; +} + +int TSK_WARN_UNUSED +tsk_tree_seek(tsk_tree_t *self, double x, tsk_flags_t options) +{ + int ret = 0; + const double L = tsk_treeseq_get_sequence_length(self->tree_sequence); + + if (x < 0 || x >= L) { + ret = tsk_trace_error(TSK_ERR_SEEK_OUT_OF_BOUNDS); + goto out; + } + + if (self->index == -1) { + ret = tsk_tree_seek_from_null(self, x, options); + } else { + if (options & TSK_SEEK_SKIP) { + ret = tsk_tree_seek_skip(self, x); + } else { + ret = tsk_tree_seek_linear(self, x); + } + } + +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_tree_clear(tsk_tree_t *self) +{ + int ret = 0; + tsk_size_t j; + tsk_id_t u; + const tsk_size_t N = self->num_nodes + 1; + const tsk_size_t num_samples = self->tree_sequence->num_samples; + const bool sample_counts = !(self->options & TSK_NO_SAMPLE_COUNTS); + const bool sample_lists = !!(self->options & TSK_SAMPLE_LISTS); + const tsk_flags_t *flags = self->tree_sequence->tables->nodes.flags; + + self->interval.left = 0; + self->interval.right = 0; + self->num_edges = 0; + self->index = -1; + tsk_tree_position_set_null(&self->tree_pos); + /* TODO we should profile this method to see if just doing a single loop over + * the nodes would be more efficient than multiple memsets. + */ + tsk_memset(self->parent, 0xff, N * sizeof(*self->parent)); + tsk_memset(self->left_child, 0xff, N * sizeof(*self->left_child)); + tsk_memset(self->right_child, 0xff, N * sizeof(*self->right_child)); + tsk_memset(self->left_sib, 0xff, N * sizeof(*self->left_sib)); + tsk_memset(self->right_sib, 0xff, N * sizeof(*self->right_sib)); + tsk_memset(self->num_children, 0, N * sizeof(*self->num_children)); + tsk_memset(self->edge, 0xff, N * sizeof(*self->edge)); + + if (sample_counts) { + tsk_memset(self->num_samples, 0, N * sizeof(*self->num_samples)); + /* We can't reset the tracked samples via memset because we don't + * know where the tracked samples are. + */ + for (j = 0; j < self->num_nodes; j++) { + if (!(flags[j] & TSK_NODE_IS_SAMPLE)) { + self->num_tracked_samples[j] = 0; + } + } + /* The total tracked_samples gets set in set_tracked_samples */ + self->num_samples[self->virtual_root] = num_samples; + } + if (sample_lists) { + tsk_memset(self->left_sample, 0xff, N * sizeof(tsk_id_t)); + tsk_memset(self->right_sample, 0xff, N * sizeof(tsk_id_t)); + tsk_memset(self->next_sample, 0xff, num_samples * sizeof(tsk_id_t)); + } + /* Set the sample attributes */ + for (j = 0; j < num_samples; j++) { + u = self->samples[j]; + if (sample_counts) { + self->num_samples[u] = 1; + } + if (sample_lists) { + /* We are mapping to *indexes* into the list of samples here */ + self->left_sample[u] = (tsk_id_t) j; + self->right_sample[u] = (tsk_id_t) j; + } + } + if (sample_counts && self->root_threshold == 1 && num_samples > 0) { + for (j = 0; j < num_samples; j++) { + /* Set initial roots */ + if (self->root_threshold == 1) { + tsk_tree_insert_root(self, self->samples[j], self->parent); + } + } + } + return ret; +} + +tsk_size_t +tsk_tree_get_size_bound(const tsk_tree_t *self) +{ + tsk_size_t bound = 0; + + if (self->tree_sequence != NULL) { + /* This is a safe upper bound which can be computed cheaply. + * We have at most n roots and each edge adds at most one new + * node to the tree. We also allow space for the virtual root, + * to simplify client code. + * + * In the common case of a binary tree with a single root, we have + * 2n - 1 nodes in total, and 2n - 2 edges. Therefore, we return + * 3n - 1, which is an over-estimate of 1/2 and we allocate + * 1.5 times as much memory as we need. + * + * Since tracking the exact number of nodes in the tree would require + * storing the number of nodes beneath every node and complicate + * the tree transition method, this seems like a good compromise + * and will result in less memory usage overall in nearly all cases. + */ + bound = 1 + self->tree_sequence->num_samples + self->num_edges; + } + return bound; +} + +/* Traversal orders */ +static tsk_id_t * +tsk_tree_alloc_node_stack(const tsk_tree_t *self) +{ + return tsk_malloc(tsk_tree_get_size_bound(self) * sizeof(tsk_id_t)); +} + +int +tsk_tree_preorder(const tsk_tree_t *self, tsk_id_t *nodes, tsk_size_t *num_nodes_ret) +{ + return tsk_tree_preorder_from(self, -1, nodes, num_nodes_ret); +} + +int +tsk_tree_preorder_from( + const tsk_tree_t *self, tsk_id_t root, tsk_id_t *nodes, tsk_size_t *num_nodes_ret) +{ + int ret = 0; + const tsk_id_t *restrict right_child = self->right_child; + const tsk_id_t *restrict left_sib = self->left_sib; + tsk_id_t *stack = tsk_tree_alloc_node_stack(self); + tsk_size_t num_nodes = 0; + tsk_id_t u, v; + int stack_top; + + if (stack == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + + if ((root == -1 || root == self->virtual_root) + && !tsk_tree_has_sample_counts(self)) { + ret = tsk_trace_error(TSK_ERR_UNSUPPORTED_OPERATION); + goto out; + } + if (root == -1) { + stack_top = -1; + for (u = right_child[self->virtual_root]; u != TSK_NULL; u = left_sib[u]) { + stack_top++; + stack[stack_top] = u; + } + } else { + ret = tsk_tree_check_node(self, root); + if (ret != 0) { + goto out; + } + stack_top = 0; + stack[stack_top] = root; + } + + while (stack_top >= 0) { + u = stack[stack_top]; + stack_top--; + nodes[num_nodes] = u; + num_nodes++; + for (v = right_child[u]; v != TSK_NULL; v = left_sib[v]) { + stack_top++; + stack[stack_top] = v; + } + } + *num_nodes_ret = num_nodes; +out: + tsk_safe_free(stack); + return ret; +} + +/* We could implement this using the preorder function, but since it's + * going to be performance critical we want to avoid the overhead + * of mallocing the intermediate node list (which will be bigger than + * the number of samples). */ +int +tsk_tree_preorder_samples_from( + const tsk_tree_t *self, tsk_id_t root, tsk_id_t *nodes, tsk_size_t *num_nodes_ret) +{ + int ret = 0; + const tsk_id_t *restrict right_child = self->right_child; + const tsk_id_t *restrict left_sib = self->left_sib; + const tsk_flags_t *restrict flags = self->tree_sequence->tables->nodes.flags; + tsk_id_t *stack = tsk_tree_alloc_node_stack(self); + tsk_size_t num_nodes = 0; + tsk_id_t u, v; + int stack_top; + + if (stack == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + + /* We could push the virtual_root onto the stack directly to simplify + * the code a little, but then we'd have to check put an extra check + * when looking up the flags array (which isn't defined for virtual_root). + */ + if (root == -1 || root == self->virtual_root) { + if (!tsk_tree_has_sample_counts(self)) { + ret = tsk_trace_error(TSK_ERR_UNSUPPORTED_OPERATION); + goto out; + } + stack_top = -1; + for (u = right_child[self->virtual_root]; u != TSK_NULL; u = left_sib[u]) { + stack_top++; + stack[stack_top] = u; + } + } else { + ret = tsk_tree_check_node(self, root); + if (ret != 0) { + goto out; + } + stack_top = 0; + stack[stack_top] = root; + } + + while (stack_top >= 0) { + u = stack[stack_top]; + stack_top--; + if (flags[u] & TSK_NODE_IS_SAMPLE) { + nodes[num_nodes] = u; + num_nodes++; + } + for (v = right_child[u]; v != TSK_NULL; v = left_sib[v]) { + stack_top++; + stack[stack_top] = v; + } + } + *num_nodes_ret = num_nodes; +out: + tsk_safe_free(stack); + return ret; +} + +int +tsk_tree_postorder(const tsk_tree_t *self, tsk_id_t *nodes, tsk_size_t *num_nodes_ret) +{ + return tsk_tree_postorder_from(self, -1, nodes, num_nodes_ret); +} +int +tsk_tree_postorder_from( + const tsk_tree_t *self, tsk_id_t root, tsk_id_t *nodes, tsk_size_t *num_nodes_ret) +{ + int ret = 0; + const tsk_id_t *restrict right_child = self->right_child; + const tsk_id_t *restrict left_sib = self->left_sib; + const tsk_id_t *restrict parent = self->parent; + tsk_id_t *stack = tsk_tree_alloc_node_stack(self); + tsk_size_t num_nodes = 0; + tsk_id_t u, v, postorder_parent; + int stack_top; + bool is_virtual_root = root == self->virtual_root; + + if (stack == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + + if (root == -1 || is_virtual_root) { + if (!tsk_tree_has_sample_counts(self)) { + ret = tsk_trace_error(TSK_ERR_UNSUPPORTED_OPERATION); + goto out; + } + stack_top = -1; + for (u = right_child[self->virtual_root]; u != TSK_NULL; u = left_sib[u]) { + stack_top++; + stack[stack_top] = u; + } + } else { + ret = tsk_tree_check_node(self, root); + if (ret != 0) { + goto out; + } + stack_top = 0; + stack[stack_top] = root; + } + + postorder_parent = TSK_NULL; + while (stack_top >= 0) { + u = stack[stack_top]; + if (right_child[u] != TSK_NULL && u != postorder_parent) { + for (v = right_child[u]; v != TSK_NULL; v = left_sib[v]) { + stack_top++; + stack[stack_top] = v; + } + } else { + stack_top--; + postorder_parent = parent[u]; + nodes[num_nodes] = u; + num_nodes++; + } + } + if (is_virtual_root) { + nodes[num_nodes] = root; + num_nodes++; + } + *num_nodes_ret = num_nodes; +out: + tsk_safe_free(stack); + return ret; +} + +/* Balance/imbalance metrics */ + +/* Result is a tsk_size_t value here because we could imagine the total + * depth overflowing a 32bit integer for a large tree. */ +int +tsk_tree_sackin_index(const tsk_tree_t *self, tsk_size_t *result) +{ + /* Keep the size of the stack elements to 8 bytes in total in the + * standard case. A tsk_id_t depth value is always safe, since + * depth counts the number of nodes encountered on a path. + */ + struct stack_elem { + tsk_id_t node; + tsk_id_t depth; + }; + int ret = 0; + const tsk_id_t *restrict right_child = self->right_child; + const tsk_id_t *restrict left_sib = self->left_sib; + struct stack_elem *stack + = tsk_malloc(tsk_tree_get_size_bound(self) * sizeof(*stack)); + int stack_top; + tsk_size_t total_depth; + tsk_id_t u; + struct stack_elem s = { .node = TSK_NULL, .depth = 0 }; + + if (stack == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + + stack_top = -1; + for (u = right_child[self->virtual_root]; u != TSK_NULL; u = left_sib[u]) { + stack_top++; + s.node = u; + stack[stack_top] = s; + } + total_depth = 0; + while (stack_top >= 0) { + s = stack[stack_top]; + stack_top--; + u = right_child[s.node]; + if (u == TSK_NULL) { + total_depth += (tsk_size_t) s.depth; + } else { + s.depth++; + while (u != TSK_NULL) { + stack_top++; + s.node = u; + stack[stack_top] = s; + u = left_sib[u]; + } + } + } + *result = total_depth; +out: + tsk_safe_free(stack); + return ret; +} - return v; +int +tsk_tree_colless_index(const tsk_tree_t *self, tsk_size_t *result) +{ + int ret = 0; + const tsk_id_t *restrict right_child = self->right_child; + const tsk_id_t *restrict left_sib = self->left_sib; + tsk_id_t *nodes = tsk_malloc(tsk_tree_get_size_bound(self) * sizeof(*nodes)); + tsk_id_t *num_leaves = tsk_calloc(self->num_nodes, sizeof(*num_leaves)); + tsk_size_t j, num_nodes, total; + tsk_id_t num_children, u, v; + + if (nodes == NULL || num_leaves == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + if (tsk_tree_get_num_roots(self) != 1) { + ret = tsk_trace_error(TSK_ERR_UNDEFINED_MULTIROOT); + goto out; + } + ret = tsk_tree_postorder(self, nodes, &num_nodes); + if (ret != 0) { + goto out; + } + + total = 0; + for (j = 0; j < num_nodes; j++) { + u = nodes[j]; + /* Cheaper to compute this on the fly than to access the num_children array. + * since we're already iterating over the children. */ + num_children = 0; + for (v = right_child[u]; v != TSK_NULL; v = left_sib[v]) { + num_children++; + num_leaves[u] += num_leaves[v]; + } + if (num_children == 0) { + num_leaves[u] = 1; + } else if (num_children == 2) { + v = right_child[u]; + total += (tsk_size_t) llabs(num_leaves[v] - num_leaves[left_sib[v]]); + } else { + ret = tsk_trace_error(TSK_ERR_UNDEFINED_NONBINARY); + goto out; + } + } + *result = total; +out: + tsk_safe_free(nodes); + tsk_safe_free(num_leaves); + return ret; } -static void -tsk_tree_check_state(const tsk_tree_t *self) +int +tsk_tree_b1_index(const tsk_tree_t *self, double *result) +{ + int ret = 0; + const tsk_id_t *restrict parent = self->parent; + const tsk_id_t *restrict right_child = self->right_child; + const tsk_id_t *restrict left_sib = self->left_sib; + tsk_id_t *nodes = tsk_malloc(tsk_tree_get_size_bound(self) * sizeof(*nodes)); + tsk_size_t *max_path_length = tsk_calloc(self->num_nodes, sizeof(*max_path_length)); + tsk_size_t j, num_nodes, mpl; + double total = 0.0; + tsk_id_t u, v; + + if (nodes == NULL || max_path_length == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + ret = tsk_tree_postorder(self, nodes, &num_nodes); + if (ret != 0) { + goto out; + } + + for (j = 0; j < num_nodes; j++) { + u = nodes[j]; + if (parent[u] != TSK_NULL && right_child[u] != TSK_NULL) { + mpl = 0; + for (v = right_child[u]; v != TSK_NULL; v = left_sib[v]) { + mpl = TSK_MAX(mpl, max_path_length[v]); + } + max_path_length[u] = mpl + 1; + total += 1 / (double) max_path_length[u]; + } + } + *result = total; +out: + tsk_safe_free(nodes); + tsk_safe_free(max_path_length); + return ret; +} + +static double +general_log(double x, double base) +{ + return log(x) / log(base); +} + +int +tsk_tree_b2_index(const tsk_tree_t *self, double base, double *result) +{ + struct stack_elem { + tsk_id_t node; + double path_product; + }; + int ret = 0; + const tsk_id_t *restrict right_child = self->right_child; + const tsk_id_t *restrict left_sib = self->left_sib; + struct stack_elem *stack + = tsk_malloc(tsk_tree_get_size_bound(self) * sizeof(*stack)); + int stack_top; + double total_proba = 0; + double num_children; + tsk_id_t u; + struct stack_elem s = { .node = TSK_NULL, .path_product = 1 }; + + if (stack == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + if (tsk_tree_get_num_roots(self) != 1) { + ret = tsk_trace_error(TSK_ERR_UNDEFINED_MULTIROOT); + goto out; + } + + stack_top = 0; + s.node = tsk_tree_get_left_root(self); + stack[stack_top] = s; + + while (stack_top >= 0) { + s = stack[stack_top]; + stack_top--; + u = right_child[s.node]; + if (u == TSK_NULL) { + total_proba -= s.path_product * general_log(s.path_product, base); + } else { + num_children = 0; + for (; u != TSK_NULL; u = left_sib[u]) { + num_children++; + } + s.path_product *= 1 / num_children; + for (u = right_child[s.node]; u != TSK_NULL; u = left_sib[u]) { + stack_top++; + s.node = u; + stack[stack_top] = s; + } + } + } + *result = total_proba; +out: + tsk_safe_free(stack); + return ret; +} + +int +tsk_tree_num_lineages(const tsk_tree_t *self, double t, tsk_size_t *result) +{ + int ret = 0; + const tsk_id_t *restrict right_child = self->right_child; + const tsk_id_t *restrict left_sib = self->left_sib; + const double *restrict time = self->tree_sequence->tables->nodes.time; + tsk_id_t *stack = tsk_tree_alloc_node_stack(self); + tsk_size_t num_lineages = 0; + int stack_top; + tsk_id_t u, v; + double child_time, parent_time; + + if (stack == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + if (!tsk_isfinite(t)) { + ret = tsk_trace_error(TSK_ERR_TIME_NONFINITE); + goto out; + } + /* Push the roots onto the stack */ + stack_top = -1; + for (u = right_child[self->virtual_root]; u != TSK_NULL; u = left_sib[u]) { + stack_top++; + stack[stack_top] = u; + } + + while (stack_top >= 0) { + u = stack[stack_top]; + parent_time = time[u]; + stack_top--; + for (v = right_child[u]; v != TSK_NULL; v = left_sib[v]) { + child_time = time[v]; + /* Only traverse down the tree as far as we need to */ + if (child_time > t) { + stack_top++; + stack[stack_top] = v; + } else if (t < parent_time) { + num_lineages++; + } + } + } + *result = num_lineages; +out: + tsk_safe_free(stack); + return ret; +} + +/* Parsimony methods */ + +static inline uint64_t +set_bit(uint64_t value, int32_t bit) +{ + return value | (1ULL << bit); +} + +static inline bool +bit_is_set(uint64_t value, int32_t bit) +{ + return (value & (1ULL << bit)) != 0; +} + +static inline int8_t +get_smallest_set_bit(uint64_t v) +{ + /* This is an inefficient implementation, there are several better + * approaches. On GCC we can use + * return (uint8_t) (__builtin_ffsll((long long) v) - 1); + */ + uint64_t t = 1; + int8_t r = 0; + + assert(v != 0); + while ((v & t) == 0) { + t <<= 1; + r++; + } + return r; +} + +#define HARTIGAN_MAX_ALLELES 64 + +/* This interface is experimental. In the future, we should provide the option to + * use a general cost matrix, in which case we'll use the Sankoff algorithm. For + * now this is unused. + * + * We should also vectorise the function so that several sites can be processed + * at once. + * + * The algorithm used here is Hartigan parsimony, "Minimum Mutation Fits to a + * Given Tree", Biometrics 1973. + */ +int TSK_WARN_UNUSED +tsk_tree_map_mutations(tsk_tree_t *self, int32_t *genotypes, + double *TSK_UNUSED(cost_matrix), tsk_flags_t options, int32_t *r_ancestral_state, + tsk_size_t *r_num_transitions, tsk_state_transition_t **r_transitions) { + int ret = 0; + struct stack_elem { + tsk_id_t node; + tsk_id_t transition_parent; + int32_t state; + }; + const tsk_size_t num_samples = self->tree_sequence->num_samples; + const tsk_id_t *restrict left_child = self->left_child; + const tsk_id_t *restrict right_sib = self->right_sib; + const tsk_size_t N = tsk_treeseq_get_num_nodes(self->tree_sequence); + const tsk_flags_t *restrict node_flags = self->tree_sequence->tables->nodes.flags; + tsk_id_t *nodes = tsk_malloc(tsk_tree_get_size_bound(self) * sizeof(*nodes)); + /* Note: to use less memory here and to improve cache performance we should + * probably change to allocating exactly the number of nodes returned by + * a preorder traversal, and then lay the memory out in this order. So, we'd + * need a map from node ID to its index in the preorder traversal, but this + * is trivial to compute. Probably doesn't matter so much at the moment + * when we're doing a single site, but it would make a big difference if + * we were vectorising over lots of sites. */ + uint64_t *restrict optimal_set = tsk_calloc(N + 1, sizeof(*optimal_set)); + struct stack_elem *restrict preorder_stack + = tsk_malloc(tsk_tree_get_size_bound(self) * sizeof(*preorder_stack)); tsk_id_t u, v; - tsk_size_t j, num_samples; - int err, c; - tsk_site_t site; - tsk_id_t *children = tsk_malloc(self->num_nodes * sizeof(tsk_id_t)); - bool *is_root = tsk_calloc(self->num_nodes, sizeof(bool)); - - tsk_bug_assert(children != NULL); - - /* Check the virtual root properties */ - tsk_bug_assert(self->parent[self->virtual_root] == TSK_NULL); - tsk_bug_assert(self->left_sib[self->virtual_root] == TSK_NULL); - tsk_bug_assert(self->right_sib[self->virtual_root] == TSK_NULL); + /* The largest possible number of transitions is one over every sample */ + tsk_state_transition_t *transitions = tsk_malloc(num_samples * sizeof(*transitions)); + int32_t allele, ancestral_state; + int stack_top; + struct stack_elem s; + tsk_size_t j, num_transitions, max_allele_count, num_nodes; + tsk_size_t allele_count[HARTIGAN_MAX_ALLELES]; + tsk_size_t non_missing = 0; + int32_t num_alleles = 0; - for (j = 0; j < self->tree_sequence->num_samples; j++) { - u = self->samples[j]; - while (self->parent[u] != TSK_NULL) { - u = self->parent[u]; - } - is_root[u] = true; + if (optimal_set == NULL || preorder_stack == NULL || transitions == NULL + || nodes == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; } - if (self->tree_sequence->num_samples == 0) { - tsk_bug_assert(self->left_child[self->virtual_root] == TSK_NULL); + for (j = 0; j < num_samples; j++) { + if (genotypes[j] >= HARTIGAN_MAX_ALLELES || genotypes[j] < TSK_MISSING_DATA) { + ret = tsk_trace_error(TSK_ERR_BAD_GENOTYPE); + goto out; + } + u = self->tree_sequence->samples[j]; + if (genotypes[j] == TSK_MISSING_DATA) { + /* All bits set */ + optimal_set[u] = UINT64_MAX; + } else { + optimal_set[u] = set_bit(optimal_set[u], genotypes[j]); + num_alleles = TSK_MAX(genotypes[j], num_alleles); + non_missing++; + } } - /* Iterate over the roots and make sure they are set */ - for (u = tsk_tree_get_left_root(self); u != TSK_NULL; u = self->right_sib[u]) { - tsk_bug_assert(is_root[u]); - is_root[u] = false; + if (non_missing == 0) { + ret = tsk_trace_error(TSK_ERR_GENOTYPES_ALL_MISSING); + goto out; } - for (u = 0; u < (tsk_id_t) self->num_nodes; u++) { - tsk_bug_assert(!is_root[u]); - c = 0; - for (v = self->left_child[u]; v != TSK_NULL; v = self->right_sib[v]) { - tsk_bug_assert(self->parent[v] == u); - children[c] = v; - c++; - } - for (v = self->right_child[u]; v != TSK_NULL; v = self->left_sib[v]) { - tsk_bug_assert(c > 0); - c--; - tsk_bug_assert(v == children[c]); + num_alleles++; + + ancestral_state = 0; /* keep compiler happy */ + if (options & TSK_MM_FIXED_ANCESTRAL_STATE) { + ancestral_state = *r_ancestral_state; + if ((ancestral_state < 0) || (ancestral_state >= HARTIGAN_MAX_ALLELES)) { + ret = tsk_trace_error(TSK_ERR_BAD_ANCESTRAL_STATE); + goto out; + } else if (ancestral_state >= num_alleles) { + num_alleles = (int32_t)(ancestral_state + 1); } } - for (j = 0; j < self->sites_length; j++) { - site = self->sites[j]; - tsk_bug_assert(self->interval.left <= site.position); - tsk_bug_assert(site.position < self->interval.right); - } - if (!(self->options & TSK_NO_SAMPLE_COUNTS)) { - tsk_bug_assert(self->num_samples != NULL); - tsk_bug_assert(self->num_tracked_samples != NULL); - for (u = 0; u < (tsk_id_t) self->num_nodes; u++) { - err = tsk_tree_get_num_samples_by_traversal(self, u, &num_samples); - tsk_bug_assert(err == 0); - tsk_bug_assert(num_samples == (tsk_size_t) self->num_samples[u]); + ret = tsk_tree_postorder_from(self, self->virtual_root, nodes, &num_nodes); + if (ret != 0) { + goto out; + } + for (j = 0; j < num_nodes; j++) { + u = nodes[j]; + tsk_memset(allele_count, 0, ((size_t) num_alleles) * sizeof(*allele_count)); + for (v = left_child[u]; v != TSK_NULL; v = right_sib[v]) { + for (allele = 0; allele < num_alleles; allele++) { + allele_count[allele] += bit_is_set(optimal_set[v], allele); + } + } + /* the virtual root has no flags defined */ + if (u == (tsk_id_t) N || !(node_flags[u] & TSK_NODE_IS_SAMPLE)) { + max_allele_count = 0; + for (allele = 0; allele < num_alleles; allele++) { + max_allele_count = TSK_MAX(max_allele_count, allele_count[allele]); + } + for (allele = 0; allele < num_alleles; allele++) { + if (allele_count[allele] == max_allele_count) { + optimal_set[u] = set_bit(optimal_set[u], allele); + } + } } - } else { - tsk_bug_assert(self->num_samples == NULL); - tsk_bug_assert(self->num_tracked_samples == NULL); } - if (self->options & TSK_SAMPLE_LISTS) { - tsk_bug_assert(self->right_sample != NULL); - tsk_bug_assert(self->left_sample != NULL); - tsk_bug_assert(self->next_sample != NULL); + if (!(options & TSK_MM_FIXED_ANCESTRAL_STATE)) { + ancestral_state = get_smallest_set_bit(optimal_set[self->virtual_root]); } else { - tsk_bug_assert(self->right_sample == NULL); - tsk_bug_assert(self->left_sample == NULL); - tsk_bug_assert(self->next_sample == NULL); + optimal_set[self->virtual_root] = UINT64_MAX; } - free(children); - free(is_root); -} - -void -tsk_tree_print_state(const tsk_tree_t *self, FILE *out) -{ - tsk_size_t j; - tsk_site_t site; + num_transitions = 0; - fprintf(out, "Tree state:\n"); - fprintf(out, "options = %d\n", self->options); - fprintf(out, "root_threshold = %lld\n", (long long) self->root_threshold); - fprintf(out, "left = %f\n", self->interval.left); - fprintf(out, "right = %f\n", self->interval.right); - fprintf(out, "index = %lld\n", (long long) self->index); - fprintf(out, "node\tparent\tlchild\trchild\tlsib\trsib"); - if (self->options & TSK_SAMPLE_LISTS) { - fprintf(out, "\thead\ttail"); - } - fprintf(out, "\n"); + /* Do a preorder traversal */ + preorder_stack[0].node = self->virtual_root; + preorder_stack[0].state = ancestral_state; + preorder_stack[0].transition_parent = TSK_NULL; + stack_top = 0; + while (stack_top >= 0) { + s = preorder_stack[stack_top]; + stack_top--; - for (j = 0; j < self->num_nodes + 1; j++) { - fprintf(out, "%lld\t%lld\t%lld\t%lld\t%lld\t%lld", (long long) j, - (long long) self->parent[j], (long long) self->left_child[j], - (long long) self->right_child[j], (long long) self->left_sib[j], - (long long) self->right_sib[j]); - if (self->options & TSK_SAMPLE_LISTS) { - fprintf(out, "\t%lld\t%lld\t", (long long) self->left_sample[j], - (long long) self->right_sample[j]); + if (!bit_is_set(optimal_set[s.node], s.state)) { + s.state = get_smallest_set_bit(optimal_set[s.node]); + transitions[num_transitions].node = s.node; + transitions[num_transitions].parent = s.transition_parent; + transitions[num_transitions].state = s.state; + s.transition_parent = (tsk_id_t) num_transitions; + num_transitions++; } - if (!(self->options & TSK_NO_SAMPLE_COUNTS)) { - fprintf(out, "\t%lld\t%lld", (long long) self->num_samples[j], - (long long) self->num_tracked_samples[j]); + for (v = left_child[s.node]; v != TSK_NULL; v = right_sib[v]) { + stack_top++; + s.node = v; + preorder_stack[stack_top] = s; } - fprintf(out, "\n"); } - fprintf(out, "sites = \n"); - for (j = 0; j < self->sites_length; j++) { - site = self->sites[j]; - fprintf(out, "\t%lld\t%f\n", (long long) site.id, site.position); + + *r_transitions = transitions; + *r_num_transitions = num_transitions; + *r_ancestral_state = ancestral_state; + transitions = NULL; +out: + tsk_safe_free(transitions); + /* Cannot safe_free because of 'restrict' */ + if (optimal_set != NULL) { + free(optimal_set); } - tsk_tree_check_state(self); + if (preorder_stack != NULL) { + free(preorder_stack); + } + if (nodes != NULL) { + free(nodes); + } + return ret; } -/* Methods for positioning the tree along the sequence */ +/* ======================================================== * + * KC Distance + * ======================================================== */ -/* The following methods are performance sensitive and so we use a - * lot of restrict pointers. Because we are saying that we don't have - * any aliases to these pointers, we pass around the reference to parent - * since it's used in all the functions. */ -static inline void -tsk_tree_update_sample_lists( - tsk_tree_t *self, tsk_id_t node, const tsk_id_t *restrict parent) +typedef struct { + tsk_size_t *m; + double *M; + tsk_id_t n; + tsk_id_t N; +} kc_vectors; + +static int +kc_vectors_alloc(kc_vectors *self, tsk_id_t n) { - tsk_id_t u, v, sample_index; - tsk_id_t *restrict left_child = self->left_child; - tsk_id_t *restrict right_sib = self->right_sib; - tsk_id_t *restrict left = self->left_sample; - tsk_id_t *restrict right = self->right_sample; - tsk_id_t *restrict next = self->next_sample; - const tsk_id_t *restrict sample_index_map = self->tree_sequence->sample_index_map; + int ret = 0; - for (u = node; u != TSK_NULL; u = parent[u]) { - sample_index = sample_index_map[u]; - if (sample_index != TSK_NULL) { - right[u] = left[u]; - } else { - left[u] = TSK_NULL; - right[u] = TSK_NULL; - } - for (v = left_child[u]; v != TSK_NULL; v = right_sib[v]) { - if (left[v] != TSK_NULL) { - tsk_bug_assert(right[v] != TSK_NULL); - if (left[u] == TSK_NULL) { - left[u] = left[v]; - right[u] = right[v]; - } else { - next[right[u]] = left[v]; - right[u] = right[v]; - } - } - } + self->n = n; + self->N = (n * (n - 1)) / 2; + self->m = tsk_calloc((size_t)(self->N + self->n), sizeof(*self->m)); + self->M = tsk_calloc((size_t)(self->N + self->n), sizeof(*self->M)); + if (self->m == NULL || self->M == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; } + +out: + return ret; } -static inline void -tsk_tree_remove_branch( - tsk_tree_t *self, tsk_id_t p, tsk_id_t c, tsk_id_t *restrict parent) +static void +kc_vectors_free(kc_vectors *self) { - tsk_id_t *restrict left_child = self->left_child; - tsk_id_t *restrict right_child = self->right_child; - tsk_id_t *restrict left_sib = self->left_sib; - tsk_id_t *restrict right_sib = self->right_sib; - tsk_id_t *restrict num_children = self->num_children; - tsk_id_t lsib = left_sib[c]; - tsk_id_t rsib = right_sib[c]; - - if (lsib == TSK_NULL) { - left_child[p] = rsib; - } else { - right_sib[lsib] = rsib; - } - if (rsib == TSK_NULL) { - right_child[p] = lsib; - } else { - left_sib[rsib] = lsib; - } - parent[c] = TSK_NULL; - left_sib[c] = TSK_NULL; - right_sib[c] = TSK_NULL; - num_children[p]--; + tsk_safe_free(self->m); + tsk_safe_free(self->M); } static inline void -tsk_tree_insert_branch( - tsk_tree_t *self, tsk_id_t p, tsk_id_t c, tsk_id_t *restrict parent) +update_kc_vectors_single_sample( + const tsk_treeseq_t *ts, kc_vectors *kc_vecs, tsk_id_t u, double time) { - tsk_id_t *restrict left_child = self->left_child; - tsk_id_t *restrict right_child = self->right_child; - tsk_id_t *restrict left_sib = self->left_sib; - tsk_id_t *restrict right_sib = self->right_sib; - tsk_id_t *restrict num_children = self->num_children; - tsk_id_t u; + const tsk_id_t *sample_index_map = ts->sample_index_map; + tsk_id_t u_index = sample_index_map[u]; - parent[c] = p; - u = right_child[p]; - if (u == TSK_NULL) { - left_child[p] = c; - left_sib[c] = TSK_NULL; - right_sib[c] = TSK_NULL; - } else { - right_sib[u] = c; - left_sib[c] = u; - right_sib[c] = TSK_NULL; - } - right_child[p] = c; - num_children[p]++; + kc_vecs->m[kc_vecs->N + u_index] = 1; + kc_vecs->M[kc_vecs->N + u_index] = time; } static inline void -tsk_tree_insert_root(tsk_tree_t *self, tsk_id_t root, tsk_id_t *restrict parent) +update_kc_vectors_all_pairs(const tsk_tree_t *tree, kc_vectors *kc_vecs, tsk_id_t u, + tsk_id_t v, tsk_size_t depth, double time) { - tsk_tree_insert_branch(self, self->virtual_root, root, parent); - parent[root] = TSK_NULL; -} + tsk_id_t sample1_index, sample2_index, n1, n2, tmp, pair_index; + const tsk_id_t *restrict left_sample = tree->left_sample; + const tsk_id_t *restrict right_sample = tree->right_sample; + const tsk_id_t *restrict next_sample = tree->next_sample; + tsk_size_t *restrict kc_m = kc_vecs->m; + double *restrict kc_M = kc_vecs->M; -static inline void -tsk_tree_remove_root(tsk_tree_t *self, tsk_id_t root, tsk_id_t *restrict parent) -{ - tsk_tree_remove_branch(self, self->virtual_root, root, parent); + sample1_index = left_sample[u]; + while (sample1_index != TSK_NULL) { + sample2_index = left_sample[v]; + while (sample2_index != TSK_NULL) { + n1 = sample1_index; + n2 = sample2_index; + if (n1 > n2) { + tmp = n1; + n1 = n2; + n2 = tmp; + } + + /* We spend ~40% of our time here because these accesses + * are not in order and gets very poor cache behavior */ + pair_index = n2 - n1 - 1 + (-1 * n1 * (n1 - 2 * kc_vecs->n + 1)) / 2; + kc_m[pair_index] = depth; + kc_M[pair_index] = time; + + if (sample2_index == right_sample[v]) { + break; + } + sample2_index = next_sample[sample2_index]; + } + if (sample1_index == right_sample[u]) { + break; + } + sample1_index = next_sample[sample1_index]; + } } -static void -tsk_tree_remove_edge(tsk_tree_t *self, tsk_id_t p, tsk_id_t c) +struct kc_stack_elmt { + tsk_id_t node; + tsk_size_t depth; +}; + +static int +fill_kc_vectors(const tsk_tree_t *t, kc_vectors *kc_vecs) { - tsk_id_t *restrict parent = self->parent; - tsk_size_t *restrict num_samples = self->num_samples; - tsk_size_t *restrict num_tracked_samples = self->num_tracked_samples; - tsk_id_t *restrict edge = self->edge; - const tsk_size_t root_threshold = self->root_threshold; - tsk_id_t u; - tsk_id_t path_end = TSK_NULL; - bool path_end_was_root = false; + int stack_top; + tsk_size_t depth; + double time; + const double *times; + struct kc_stack_elmt *stack; + tsk_id_t root, u, c1, c2; + int ret = 0; + const tsk_treeseq_t *ts = t->tree_sequence; -#define POTENTIAL_ROOT(U) (num_samples[U] >= root_threshold) + stack = tsk_malloc(tsk_tree_get_size_bound(t) * sizeof(*stack)); + if (stack == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } - tsk_tree_remove_branch(self, p, c, parent); - self->num_edges--; - edge[c] = TSK_NULL; + times = t->tree_sequence->tables->nodes.time; - if (!(self->options & TSK_NO_SAMPLE_COUNTS)) { - u = p; - while (u != TSK_NULL) { - path_end = u; - path_end_was_root = POTENTIAL_ROOT(u); - num_samples[u] -= num_samples[c]; - num_tracked_samples[u] -= num_tracked_samples[c]; - u = parent[u]; - } + for (root = tsk_tree_get_left_root(t); root != TSK_NULL; root = t->right_sib[root]) { + stack_top = 0; + stack[stack_top].node = root; + stack[stack_top].depth = 0; + while (stack_top >= 0) { + u = stack[stack_top].node; + depth = stack[stack_top].depth; + stack_top--; - if (path_end_was_root && !POTENTIAL_ROOT(path_end)) { - tsk_tree_remove_root(self, path_end, parent); - } - if (POTENTIAL_ROOT(c)) { - tsk_tree_insert_root(self, c, parent); + if (tsk_tree_is_sample(t, u)) { + time = tsk_tree_get_branch_length_unsafe(t, u); + update_kc_vectors_single_sample(ts, kc_vecs, u, time); + } + + /* Don't bother going deeper if there are no samples under this node */ + if (t->left_sample[u] != TSK_NULL) { + for (c1 = t->left_child[u]; c1 != TSK_NULL; c1 = t->right_sib[c1]) { + stack_top++; + stack[stack_top].node = c1; + stack[stack_top].depth = depth + 1; + + for (c2 = t->right_sib[c1]; c2 != TSK_NULL; c2 = t->right_sib[c2]) { + time = times[root] - times[u]; + update_kc_vectors_all_pairs(t, kc_vecs, c1, c2, depth, time); + } + } + } } } - if (self->options & TSK_SAMPLE_LISTS) { - tsk_tree_update_sample_lists(self, p, parent); - } +out: + tsk_safe_free(stack); + return ret; } -static void -tsk_tree_insert_edge(tsk_tree_t *self, tsk_id_t p, tsk_id_t c, tsk_id_t edge_id) +static double +norm_kc_vectors(kc_vectors *self, kc_vectors *other, double lambda) { - tsk_id_t *restrict parent = self->parent; - tsk_size_t *restrict num_samples = self->num_samples; - tsk_size_t *restrict num_tracked_samples = self->num_tracked_samples; - tsk_id_t *restrict edge = self->edge; - const tsk_size_t root_threshold = self->root_threshold; - tsk_id_t u; - tsk_id_t path_end = TSK_NULL; - bool path_end_was_root = false; + double vT1, vT2, distance_sum; + tsk_id_t i; -#define POTENTIAL_ROOT(U) (num_samples[U] >= root_threshold) + distance_sum = 0; + for (i = 0; i < self->n + self->N; i++) { + vT1 = ((double) self->m[i] * (1 - lambda)) + (lambda * self->M[i]); + vT2 = ((double) other->m[i] * (1 - lambda)) + (lambda * other->M[i]); + distance_sum += (vT1 - vT2) * (vT1 - vT2); + } - if (!(self->options & TSK_NO_SAMPLE_COUNTS)) { - u = p; - while (u != TSK_NULL) { - path_end = u; - path_end_was_root = POTENTIAL_ROOT(u); - num_samples[u] += num_samples[c]; - num_tracked_samples[u] += num_tracked_samples[c]; - u = parent[u]; - } + return sqrt(distance_sum); +} - if (POTENTIAL_ROOT(c)) { - tsk_tree_remove_root(self, c, parent); - } - if (POTENTIAL_ROOT(path_end) && !path_end_was_root) { - tsk_tree_insert_root(self, path_end, parent); +static int +check_kc_distance_tree_inputs(const tsk_tree_t *self) +{ + tsk_id_t u, num_nodes, left_child; + int ret = 0; + + if (tsk_tree_get_num_roots(self) != 1) { + ret = tsk_trace_error(TSK_ERR_MULTIPLE_ROOTS); + goto out; + } + if (!tsk_tree_has_sample_lists(self)) { + ret = tsk_trace_error(TSK_ERR_NO_SAMPLE_LISTS); + goto out; + } + + num_nodes = (tsk_id_t) tsk_treeseq_get_num_nodes(self->tree_sequence); + for (u = 0; u < num_nodes; u++) { + left_child = self->left_child[u]; + if (left_child != TSK_NULL && left_child == self->right_child[u]) { + ret = tsk_trace_error(TSK_ERR_UNARY_NODES); + goto out; } } +out: + return ret; +} - tsk_tree_insert_branch(self, p, c, parent); - self->num_edges++; - edge[c] = edge_id; +static int +check_kc_distance_samples_inputs(const tsk_treeseq_t *self, const tsk_treeseq_t *other) +{ + const tsk_id_t *samples, *other_samples; + tsk_id_t i, n; + int ret = 0; - if (self->options & TSK_SAMPLE_LISTS) { - tsk_tree_update_sample_lists(self, p, parent); + if (self->num_samples != other->num_samples) { + ret = tsk_trace_error(TSK_ERR_SAMPLE_SIZE_MISMATCH); + goto out; + } + + samples = self->samples; + other_samples = other->samples; + n = (tsk_id_t) self->num_samples; + for (i = 0; i < n; i++) { + if (samples[i] != other_samples[i]) { + ret = tsk_trace_error(TSK_ERR_SAMPLES_NOT_EQUAL); + goto out; + } } +out: + return ret; } -static int -tsk_tree_advance(tsk_tree_t *self, int direction, const double *restrict out_breakpoints, - const tsk_id_t *restrict out_order, tsk_id_t *out_index, - const double *restrict in_breakpoints, const tsk_id_t *restrict in_order, - tsk_id_t *in_index) +int +tsk_tree_kc_distance( + const tsk_tree_t *self, const tsk_tree_t *other, double lambda, double *result) { + tsk_id_t n, i; + kc_vectors vecs[2]; + const tsk_tree_t *trees[2] = { self, other }; int ret = 0; - const int direction_change = direction * (direction != self->direction); - tsk_id_t in = *in_index + direction_change; - tsk_id_t out = *out_index + direction_change; - tsk_id_t k; - const tsk_table_collection_t *tables = self->tree_sequence->tables; - const double sequence_length = tables->sequence_length; - const tsk_id_t num_edges = (tsk_id_t) tables->edges.num_rows; - const tsk_id_t *restrict edge_parent = tables->edges.parent; - const tsk_id_t *restrict edge_child = tables->edges.child; - double x; - if (direction == TSK_DIR_FORWARD) { - x = self->interval.right; - } else { - x = self->interval.left; - } - while (out >= 0 && out < num_edges && out_breakpoints[out_order[out]] == x) { - tsk_bug_assert(out < num_edges); - k = out_order[out]; - out += direction; - tsk_tree_remove_edge(self, edge_parent[k], edge_child[k]); + for (i = 0; i < 2; i++) { + tsk_memset(&vecs[i], 0, sizeof(kc_vectors)); } - while (in >= 0 && in < num_edges && in_breakpoints[in_order[in]] == x) { - k = in_order[in]; - in += direction; - tsk_tree_insert_edge(self, edge_parent[k], edge_child[k], k); + ret = check_kc_distance_samples_inputs(self->tree_sequence, other->tree_sequence); + if (ret != 0) { + goto out; } - - self->direction = direction; - self->index = self->index + direction; - if (direction == TSK_DIR_FORWARD) { - self->interval.left = x; - self->interval.right = sequence_length; - if (out >= 0 && out < num_edges) { - self->interval.right - = TSK_MIN(self->interval.right, out_breakpoints[out_order[out]]); - } - if (in >= 0 && in < num_edges) { - self->interval.right - = TSK_MIN(self->interval.right, in_breakpoints[in_order[in]]); + for (i = 0; i < 2; i++) { + ret = check_kc_distance_tree_inputs(trees[i]); + if (ret != 0) { + goto out; } - } else { - self->interval.right = x; - self->interval.left = 0; - if (out >= 0 && out < num_edges) { - self->interval.left - = TSK_MAX(self->interval.left, out_breakpoints[out_order[out]]); + } + + n = (tsk_id_t) self->tree_sequence->num_samples; + for (i = 0; i < 2; i++) { + ret = kc_vectors_alloc(&vecs[i], n); + if (ret != 0) { + goto out; } - if (in >= 0 && in < num_edges) { - self->interval.left - = TSK_MAX(self->interval.left, in_breakpoints[in_order[in]]); + ret = fill_kc_vectors(trees[i], &vecs[i]); + if (ret != 0) { + goto out; } } - tsk_bug_assert(self->interval.left < self->interval.right); - *out_index = out; - *in_index = in; - if (tables->sites.num_rows > 0) { - self->sites = self->tree_sequence->tree_sites[self->index]; - self->sites_length = self->tree_sequence->tree_sites_length[self->index]; + + *result = norm_kc_vectors(&vecs[0], &vecs[1], lambda); +out: + for (i = 0; i < 2; i++) { + kc_vectors_free(&vecs[i]); } - ret = TSK_TREE_OK; return ret; } -int TSK_WARN_UNUSED -tsk_tree_first(tsk_tree_t *self) +static int +check_kc_distance_tree_sequence_inputs( + const tsk_treeseq_t *self, const tsk_treeseq_t *other) { - int ret = TSK_TREE_OK; - tsk_table_collection_t *tables = self->tree_sequence->tables; + int ret = 0; - self->interval.left = 0; - self->index = 0; - self->interval.right = tables->sequence_length; - self->sites = self->tree_sequence->tree_sites[0]; - self->sites_length = self->tree_sequence->tree_sites_length[0]; - - if (tables->edges.num_rows > 0) { - /* TODO this is redundant if this is the first usage of the tree. We - * should add a state machine here so we know what state the tree is - * in and can take the appropriate actions. - */ - ret = tsk_tree_clear(self); - if (ret != 0) { - goto out; - } - self->index = -1; - self->left_index = 0; - self->right_index = 0; - self->direction = TSK_DIR_FORWARD; - self->interval.right = 0; + if (self->tables->sequence_length != other->tables->sequence_length) { + ret = tsk_trace_error(TSK_ERR_SEQUENCE_LENGTH_MISMATCH); + goto out; + } - ret = tsk_tree_advance(self, TSK_DIR_FORWARD, tables->edges.right, - tables->indexes.edge_removal_order, &self->right_index, tables->edges.left, - tables->indexes.edge_insertion_order, &self->left_index); + ret = check_kc_distance_samples_inputs(self, other); + if (ret != 0) { + goto out; } + out: return ret; } -int TSK_WARN_UNUSED -tsk_tree_last(tsk_tree_t *self) +static void +update_kc_pair_with_sample(const tsk_tree_t *self, kc_vectors *kc, tsk_id_t sample, + tsk_size_t *depths, double root_time) { - int ret = TSK_TREE_OK; - const tsk_treeseq_t *ts = self->tree_sequence; - const tsk_table_collection_t *tables = ts->tables; + tsk_id_t c, p, sib; + double time; + tsk_size_t depth; + double *times = self->tree_sequence->tables->nodes.time; - self->interval.left = 0; - self->interval.right = tables->sequence_length; - self->index = 0; - self->sites = ts->tree_sites[0]; - self->sites_length = ts->tree_sites_length[0]; - - if (tables->edges.num_rows > 0) { - /* TODO this is redundant if this is the first usage of the tree. We - * should add a state machine here so we know what state the tree is - * in and can take the appropriate actions. - */ - ret = tsk_tree_clear(self); - if (ret != 0) { - goto out; + c = sample; + for (p = self->parent[sample]; p != TSK_NULL; p = self->parent[p]) { + time = root_time - times[p]; + depth = depths[p]; + for (sib = self->left_child[p]; sib != TSK_NULL; sib = self->right_sib[sib]) { + if (sib != c) { + update_kc_vectors_all_pairs(self, kc, sample, sib, depth, time); + } } - self->index = (tsk_id_t) tsk_treeseq_get_num_trees(ts); - self->left_index = (tsk_id_t) tables->edges.num_rows - 1; - self->right_index = (tsk_id_t) tables->edges.num_rows - 1; - self->direction = TSK_DIR_REVERSE; - self->interval.left = tables->sequence_length; - self->interval.right = 0; - - ret = tsk_tree_advance(self, TSK_DIR_REVERSE, tables->edges.left, - tables->indexes.edge_insertion_order, &self->left_index, tables->edges.right, - tables->indexes.edge_removal_order, &self->right_index); + c = p; } -out: - return ret; } -int TSK_WARN_UNUSED -tsk_tree_next(tsk_tree_t *self) +static int +update_kc_subtree_state( + tsk_tree_t *t, kc_vectors *kc, tsk_id_t u, tsk_size_t *depths, double root_time) { + int stack_top; + tsk_id_t v, c; + tsk_id_t *stack = NULL; int ret = 0; - const tsk_treeseq_t *ts = self->tree_sequence; - const tsk_table_collection_t *tables = ts->tables; - tsk_id_t num_trees = (tsk_id_t) tsk_treeseq_get_num_trees(ts); - if (self->index == -1) { - ret = tsk_tree_first(self); - } else if (self->index < num_trees - 1) { - ret = tsk_tree_advance(self, TSK_DIR_FORWARD, tables->edges.right, - tables->indexes.edge_removal_order, &self->right_index, tables->edges.left, - tables->indexes.edge_insertion_order, &self->left_index); - } else { - ret = tsk_tree_clear(self); + stack = tsk_malloc(tsk_tree_get_size_bound(t) * sizeof(*stack)); + if (stack == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; } - return ret; -} -int TSK_WARN_UNUSED -tsk_tree_prev(tsk_tree_t *self) -{ - int ret = 0; - const tsk_table_collection_t *tables = self->tree_sequence->tables; + stack_top = 0; + stack[stack_top] = u; + while (stack_top >= 0) { + v = stack[stack_top]; + stack_top--; - if (self->index == -1) { - ret = tsk_tree_last(self); - } else if (self->index > 0) { - ret = tsk_tree_advance(self, TSK_DIR_REVERSE, tables->edges.left, - tables->indexes.edge_insertion_order, &self->left_index, tables->edges.right, - tables->indexes.edge_removal_order, &self->right_index); - } else { - ret = tsk_tree_clear(self); + if (tsk_tree_is_sample(t, v)) { + update_kc_pair_with_sample(t, kc, v, depths, root_time); + } + for (c = t->left_child[v]; c != TSK_NULL; c = t->right_sib[c]) { + if (depths[c] != 0) { + depths[c] = depths[v] + 1; + stack_top++; + stack[stack_top] = c; + } + } } - return ret; -} -static inline bool -tsk_tree_position_in_interval(const tsk_tree_t *self, double x) -{ - return self->interval.left <= x && x < self->interval.right; +out: + tsk_safe_free(stack); + return ret; } -/* NOTE: - * - * Notes from Kevin Thornton: - * - * This method inserts the edges for an arbitrary tree - * in linear time and requires no additional memory. - * - * During design, the following alternatives were tested - * (in a combination of rust + C): - * 1. Indexing edge insertion/removal locations by tree. - * The indexing can be done in O(n) time, giving O(1) - * access to the first edge in a tree. We can then add - * edges to the tree in O(e) time, where e is the number - * of edges. This apparoach requires O(n) additional memory - * and is only marginally faster than the implementation below. - * 2. Building an interval tree mapping edge id -> span. - * This approach adds a lot of complexity and wasn't any faster - * than the indexing described above. - */ static int -tsk_tree_seek_from_null(tsk_tree_t *self, double x, tsk_flags_t TSK_UNUSED(options)) +update_kc_incremental(tsk_tree_t *tree, kc_vectors *kc, tsk_size_t *depths) { int ret = 0; - tsk_size_t edge; - tsk_id_t p, c, e, j, k, tree_index; - const double L = tsk_treeseq_get_sequence_length(self->tree_sequence); - const tsk_treeseq_t *treeseq = self->tree_sequence; - const tsk_table_collection_t *tables = treeseq->tables; - const tsk_id_t *restrict edge_parent = tables->edges.parent; - const tsk_id_t *restrict edge_child = tables->edges.child; - const tsk_size_t num_edges = tables->edges.num_rows; - const tsk_size_t num_trees = self->tree_sequence->num_trees; - const double *restrict edge_left = tables->edges.left; - const double *restrict edge_right = tables->edges.right; - const double *restrict breakpoints = treeseq->breakpoints; - const tsk_id_t *restrict insertion = tables->indexes.edge_insertion_order; - const tsk_id_t *restrict removal = tables->indexes.edge_removal_order; - - // NOTE: it may be better to get the - // index first and then ask if we are - // searching in the first or last 1/2 - // of trees. - j = -1; - if (x <= L / 2.0) { - for (edge = 0; edge < num_edges; edge++) { - e = insertion[edge]; - if (edge_left[e] > x) { - j = (tsk_id_t) edge; - break; - } - if (x >= edge_left[e] && x < edge_right[e]) { - p = edge_parent[e]; - c = edge_child[e]; - tsk_tree_insert_edge(self, p, c, e); - } - } - } else { - for (edge = 0; edge < num_edges; edge++) { - e = removal[num_edges - edge - 1]; - if (edge_right[e] < x) { - j = (tsk_id_t)(num_edges - edge - 1); - while (j < (tsk_id_t) num_edges && edge_left[insertion[j]] <= x) { - j++; - } - break; - } - if (x >= edge_left[e] && x < edge_right[e]) { - p = edge_parent[e]; - c = edge_child[e]; - tsk_tree_insert_edge(self, p, c, e); + tsk_id_t u, v, e, j; + double root_time, time; + const double *restrict times = tree->tree_sequence->tables->nodes.time; + const tsk_id_t *restrict edges_child = tree->tree_sequence->tables->edges.child; + const tsk_id_t *restrict edges_parent = tree->tree_sequence->tables->edges.parent; + tsk_tree_position_t tree_pos = tree->tree_pos; + + /* Update state of detached subtrees */ + for (j = tree_pos.out.stop - 1; j >= tree_pos.out.start; j--) { + e = tree_pos.out.order[j]; + u = edges_child[e]; + depths[u] = 0; + + if (tree->parent[u] == TSK_NULL) { + root_time = times[tsk_tree_node_root(tree, u)]; + ret = update_kc_subtree_state(tree, kc, u, depths, root_time); + if (ret != 0) { + goto out; } } } - if (j == -1) { - j = 0; - while (j < (tsk_id_t) num_edges && edge_left[insertion[j]] <= x) { - j++; + /* Propagate state change down into reattached subtrees. */ + for (j = tree_pos.in.stop - 1; j >= tree_pos.in.start; j--) { + e = tree_pos.in.order[j]; + u = edges_child[e]; + v = edges_parent[e]; + + tsk_bug_assert(depths[u] == 0); + depths[u] = depths[v] + 1; + + root_time = times[tsk_tree_node_root(tree, u)]; + ret = update_kc_subtree_state(tree, kc, u, depths, root_time); + if (ret != 0) { + goto out; } - } - k = 0; - while (k < (tsk_id_t) num_edges && edge_right[removal[k]] <= x) { - k++; - } - /* NOTE: tsk_search_sorted finds the first the first - * insertion locatiom >= the query point, which - * finds a RIGHT value for queries not at the left edge. - */ - tree_index = (tsk_id_t) tsk_search_sorted(breakpoints, num_trees + 1, x); - if (breakpoints[tree_index] > x) { - tree_index--; - } - self->index = tree_index; - self->interval.left = breakpoints[tree_index]; - self->interval.right = breakpoints[tree_index + 1]; - self->left_index = j; - self->right_index = k; - self->direction = TSK_DIR_FORWARD; - self->num_nodes = tables->nodes.num_rows; - if (tables->sites.num_rows > 0) { - self->sites = treeseq->tree_sites[self->index]; - self->sites_length = treeseq->tree_sites_length[self->index]; + if (tsk_tree_is_sample(tree, u)) { + time = tsk_tree_get_branch_length_unsafe(tree, u); + update_kc_vectors_single_sample(tree->tree_sequence, kc, u, time); + } } - +out: return ret; } -int TSK_WARN_UNUSED -tsk_tree_seek_index(tsk_tree_t *self, tsk_id_t tree, tsk_flags_t options) +int +tsk_treeseq_kc_distance(const tsk_treeseq_t *self, const tsk_treeseq_t *other, + double lambda_, double *result) { + int i; + tsk_id_t n; + tsk_size_t num_nodes; + double left, span, total; + const tsk_treeseq_t *treeseqs[2] = { self, other }; + tsk_tree_t trees[2]; + kc_vectors kcs[2]; + tsk_size_t *depths[2]; int ret = 0; - double x; - if (tree < 0 || tree >= (tsk_id_t) self->tree_sequence->num_trees) { - ret = TSK_ERR_SEEK_OUT_OF_BOUNDS; + for (i = 0; i < 2; i++) { + tsk_memset(&trees[i], 0, sizeof(trees[i])); + tsk_memset(&kcs[i], 0, sizeof(kcs[i])); + depths[i] = NULL; + } + + ret = check_kc_distance_tree_sequence_inputs(self, other); + if (ret != 0) { + goto out; + } + + n = (tsk_id_t) self->num_samples; + for (i = 0; i < 2; i++) { + ret = tsk_tree_init(&trees[i], treeseqs[i], TSK_SAMPLE_LISTS); + if (ret != 0) { + goto out; + } + ret = kc_vectors_alloc(&kcs[i], n); + if (ret != 0) { + goto out; + } + num_nodes = tsk_treeseq_get_num_nodes(treeseqs[i]); + depths[i] = tsk_calloc(num_nodes, sizeof(*depths[i])); + if (depths[i] == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + } + + total = 0; + left = 0; + + ret = tsk_tree_first(&trees[0]); + if (ret != TSK_TREE_OK) { + goto out; + } + ret = check_kc_distance_tree_inputs(&trees[0]); + if (ret != 0) { + goto out; + } + + ret = update_kc_incremental(&trees[0], &kcs[0], depths[0]); + if (ret != 0) { goto out; } - x = self->tree_sequence->breakpoints[tree]; - ret = tsk_tree_seek(self, x, options); -out: - return ret; -} + while ((ret = tsk_tree_next(&trees[1])) == TSK_TREE_OK) { + ret = check_kc_distance_tree_inputs(&trees[1]); + if (ret != 0) { + goto out; + } -static int TSK_WARN_UNUSED -tsk_tree_seek_linear(tsk_tree_t *self, double x, tsk_flags_t TSK_UNUSED(options)) -{ - const double L = tsk_treeseq_get_sequence_length(self->tree_sequence); - const double t_l = self->interval.left; - const double t_r = self->interval.right; - int ret = 0; - double distance_left, distance_right; + ret = update_kc_incremental(&trees[1], &kcs[1], depths[1]); + if (ret != 0) { + goto out; + } + while (trees[0].interval.right < trees[1].interval.right) { + span = trees[0].interval.right - left; + total += norm_kc_vectors(&kcs[0], &kcs[1], lambda_) * span; - if (x < t_l) { - /* |-----|-----|========|---------| */ - /* 0 x t_l t_r L */ - distance_left = t_l - x; - distance_right = L - t_r + x; - } else { - /* |------|========|------|-------| */ - /* 0 t_l t_r x L */ - distance_right = x - t_r; - distance_left = t_l + L - x; - } - if (distance_right <= distance_left) { - while (!tsk_tree_position_in_interval(self, x)) { - ret = tsk_tree_next(self); - if (ret < 0) { + left = trees[0].interval.right; + ret = tsk_tree_next(&trees[0]); + tsk_bug_assert(ret == TSK_TREE_OK); + ret = check_kc_distance_tree_inputs(&trees[0]); + if (ret != 0) { goto out; } - } - } else { - while (!tsk_tree_position_in_interval(self, x)) { - ret = tsk_tree_prev(self); - if (ret < 0) { + ret = update_kc_incremental(&trees[0], &kcs[0], depths[0]); + if (ret != 0) { goto out; } } + span = trees[1].interval.right - left; + left = trees[1].interval.right; + total += norm_kc_vectors(&kcs[0], &kcs[1], lambda_) * span; } - ret = 0; + if (ret != 0) { + goto out; + } + + *result = total / self->tables->sequence_length; out: + for (i = 0; i < 2; i++) { + tsk_tree_free(&trees[i]); + kc_vectors_free(&kcs[i]); + tsk_safe_free(depths[i]); + } return ret; } -int TSK_WARN_UNUSED -tsk_tree_seek(tsk_tree_t *self, double x, tsk_flags_t options) +/* + * Divergence matrix + */ + +typedef struct { + /* Note it's a waste storing the triply linked tree here, but the code + * is written on the assumption of 1-based trees and the algorithm is + * frighteningly subtle, so it doesn't seem worth messing with it + * unless we really need to save some memory */ + tsk_id_t *parent; + tsk_id_t *child; + tsk_id_t *sib; + tsk_id_t *lambda; + tsk_id_t *pi; + tsk_id_t *tau; + tsk_id_t *beta; + tsk_id_t *alpha; +} sv_tables_t; + +static int +sv_tables_init(sv_tables_t *self, tsk_size_t n) { int ret = 0; - const double L = tsk_treeseq_get_sequence_length(self->tree_sequence); - if (x < 0 || x >= L) { - ret = TSK_ERR_SEEK_OUT_OF_BOUNDS; + self->parent = tsk_malloc(n * sizeof(*self->parent)); + self->child = tsk_malloc(n * sizeof(*self->child)); + self->sib = tsk_malloc(n * sizeof(*self->sib)); + self->pi = tsk_malloc(n * sizeof(*self->pi)); + self->lambda = tsk_malloc(n * sizeof(*self->lambda)); + self->tau = tsk_malloc(n * sizeof(*self->tau)); + self->beta = tsk_malloc(n * sizeof(*self->beta)); + self->alpha = tsk_malloc(n * sizeof(*self->alpha)); + if (self->parent == NULL || self->child == NULL || self->sib == NULL + || self->lambda == NULL || self->tau == NULL || self->beta == NULL + || self->alpha == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } - - if (self->index == -1) { - ret = tsk_tree_seek_from_null(self, x, options); - } else { - ret = tsk_tree_seek_linear(self, x, options); - } - out: return ret; } -int TSK_WARN_UNUSED -tsk_tree_clear(tsk_tree_t *self) +static int +sv_tables_free(sv_tables_t *self) { - int ret = 0; + tsk_safe_free(self->parent); + tsk_safe_free(self->child); + tsk_safe_free(self->sib); + tsk_safe_free(self->lambda); + tsk_safe_free(self->pi); + tsk_safe_free(self->tau); + tsk_safe_free(self->beta); + tsk_safe_free(self->alpha); + return 0; +} +static void +sv_tables_reset(sv_tables_t *self, tsk_tree_t *tree) +{ + const tsk_size_t n = 1 + tree->num_nodes; + tsk_memset(self->parent, 0, n * sizeof(*self->parent)); + tsk_memset(self->child, 0, n * sizeof(*self->child)); + tsk_memset(self->sib, 0, n * sizeof(*self->sib)); + tsk_memset(self->pi, 0, n * sizeof(*self->pi)); + tsk_memset(self->lambda, 0, n * sizeof(*self->lambda)); + tsk_memset(self->tau, 0, n * sizeof(*self->tau)); + tsk_memset(self->beta, 0, n * sizeof(*self->beta)); + tsk_memset(self->alpha, 0, n * sizeof(*self->alpha)); +} + +static void +sv_tables_convert_tree(sv_tables_t *self, tsk_tree_t *tree) +{ + const tsk_size_t n = 1 + tree->num_nodes; + const tsk_id_t *restrict tsk_parent = tree->parent; + tsk_id_t *restrict child = self->child; + tsk_id_t *restrict parent = self->parent; + tsk_id_t *restrict sib = self->sib; tsk_size_t j; - tsk_id_t u; - const tsk_size_t N = self->num_nodes + 1; - const tsk_size_t num_samples = self->tree_sequence->num_samples; - const bool sample_counts = !(self->options & TSK_NO_SAMPLE_COUNTS); - const bool sample_lists = !!(self->options & TSK_SAMPLE_LISTS); - const tsk_flags_t *flags = self->tree_sequence->tables->nodes.flags; + tsk_id_t u, v; - self->interval.left = 0; - self->interval.right = 0; - self->num_edges = 0; - self->index = -1; - /* TODO we should profile this method to see if just doing a single loop over - * the nodes would be more efficient than multiple memsets. - */ - tsk_memset(self->parent, 0xff, N * sizeof(*self->parent)); - tsk_memset(self->left_child, 0xff, N * sizeof(*self->left_child)); - tsk_memset(self->right_child, 0xff, N * sizeof(*self->right_child)); - tsk_memset(self->left_sib, 0xff, N * sizeof(*self->left_sib)); - tsk_memset(self->right_sib, 0xff, N * sizeof(*self->right_sib)); - tsk_memset(self->num_children, 0, N * sizeof(*self->num_children)); - tsk_memset(self->edge, 0xff, N * sizeof(*self->edge)); + for (j = 0; j < n - 1; j++) { + u = (tsk_id_t) j + 1; + v = tsk_parent[j] + 1; + sib[u] = child[v]; + child[v] = u; + parent[u] = v; + } +} - if (sample_counts) { - tsk_memset(self->num_samples, 0, N * sizeof(*self->num_samples)); - /* We can't reset the tracked samples via memset because we don't - * know where the tracked samples are. - */ - for (j = 0; j < self->num_nodes; j++) { - if (!(flags[j] & TSK_NODE_IS_SAMPLE)) { - self->num_tracked_samples[j] = 0; +#define LAMBDA 0 + +static void +sv_tables_build_index(sv_tables_t *self) +{ + const tsk_id_t *restrict child = self->child; + const tsk_id_t *restrict parent = self->parent; + const tsk_id_t *restrict sib = self->sib; + tsk_id_t *restrict lambda = self->lambda; + tsk_id_t *restrict pi = self->pi; + tsk_id_t *restrict tau = self->tau; + tsk_id_t *restrict beta = self->beta; + tsk_id_t *restrict alpha = self->alpha; + tsk_id_t a, n, p, h; + + p = child[LAMBDA]; + n = 0; + lambda[0] = -1; + while (p != LAMBDA) { + while (true) { + n++; + pi[p] = n; + tau[n] = LAMBDA; + lambda[n] = 1 + lambda[n >> 1]; + if (child[p] != LAMBDA) { + p = child[p]; + } else { + break; } } - /* The total tracked_samples gets set in set_tracked_samples */ - self->num_samples[self->virtual_root] = num_samples; - } - if (sample_lists) { - tsk_memset(self->left_sample, 0xff, N * sizeof(tsk_id_t)); - tsk_memset(self->right_sample, 0xff, N * sizeof(tsk_id_t)); - tsk_memset(self->next_sample, 0xff, num_samples * sizeof(tsk_id_t)); - } - /* Set the sample attributes */ - for (j = 0; j < num_samples; j++) { - u = self->samples[j]; - if (sample_counts) { - self->num_samples[u] = 1; - } - if (sample_lists) { - /* We are mapping to *indexes* into the list of samples here */ - self->left_sample[u] = (tsk_id_t) j; - self->right_sample[u] = (tsk_id_t) j; + beta[p] = n; + while (true) { + tau[beta[p]] = parent[p]; + if (sib[p] != LAMBDA) { + p = sib[p]; + break; + } else { + p = parent[p]; + if (p != LAMBDA) { + h = lambda[n & -pi[p]]; + beta[p] = ((n >> h) | 1) << h; + } else { + break; + } + } } } - if (sample_counts && self->root_threshold == 1 && num_samples > 0) { - for (j = 0; j < num_samples; j++) { - /* Set initial roots */ - if (self->root_threshold == 1) { - tsk_tree_insert_root(self, self->samples[j], self->parent); + + /* Begin the second traversal */ + lambda[0] = lambda[n]; + pi[LAMBDA] = 0; + beta[LAMBDA] = 0; + alpha[LAMBDA] = 0; + p = child[LAMBDA]; + while (p != LAMBDA) { + while (true) { + a = alpha[parent[p]] | (beta[p] & -beta[p]); + alpha[p] = a; + if (child[p] != LAMBDA) { + p = child[p]; + } else { + break; + } + } + while (true) { + if (sib[p] != LAMBDA) { + p = sib[p]; + break; + } else { + p = parent[p]; + if (p == LAMBDA) { + break; + } } } } - return ret; } -tsk_size_t -tsk_tree_get_size_bound(const tsk_tree_t *self) +static void +sv_tables_build(sv_tables_t *self, tsk_tree_t *tree) { - tsk_size_t bound = 0; - - if (self->tree_sequence != NULL) { - /* This is a safe upper bound which can be computed cheaply. - * We have at most n roots and each edge adds at most one new - * node to the tree. We also allow space for the virtual root, - * to simplify client code. - * - * In the common case of a binary tree with a single root, we have - * 2n - 1 nodes in total, and 2n - 2 edges. Therefore, we return - * 3n - 1, which is an over-estimate of 1/2 and we allocate - * 1.5 times as much memory as we need. - * - * Since tracking the exact number of nodes in the tree would require - * storing the number of nodes beneath every node and complicate - * the tree transition method, this seems like a good compromise - * and will result in less memory usage overall in nearly all cases. - */ - bound = 1 + self->tree_sequence->num_samples + self->num_edges; - } - return bound; + sv_tables_reset(self, tree); + sv_tables_convert_tree(self, tree); + sv_tables_build_index(self); } -/* Traversal orders */ -static tsk_id_t * -tsk_tree_alloc_node_stack(const tsk_tree_t *self) +static tsk_id_t +sv_tables_mrca_one_based(const sv_tables_t *self, tsk_id_t x, tsk_id_t y) { - return tsk_malloc(tsk_tree_get_size_bound(self) * sizeof(tsk_id_t)); + const tsk_id_t *restrict lambda = self->lambda; + const tsk_id_t *restrict pi = self->pi; + const tsk_id_t *restrict tau = self->tau; + const tsk_id_t *restrict beta = self->beta; + const tsk_id_t *restrict alpha = self->alpha; + tsk_id_t h, k, xhat, yhat, ell, j, z; + + if (beta[x] <= beta[y]) { + h = lambda[beta[y] & -beta[x]]; + } else { + h = lambda[beta[x] & -beta[y]]; + } + k = alpha[x] & alpha[y] & -(1 << h); + h = lambda[k & -k]; + j = ((beta[x] >> h) | 1) << h; + if (j == beta[x]) { + xhat = x; + } else { + ell = lambda[alpha[x] & ((1 << h) - 1)]; + xhat = tau[((beta[x] >> ell) | 1) << ell]; + } + if (j == beta[y]) { + yhat = y; + } else { + ell = lambda[alpha[y] & ((1 << h) - 1)]; + yhat = tau[((beta[y] >> ell) | 1) << ell]; + } + if (pi[xhat] <= pi[yhat]) { + z = xhat; + } else { + z = yhat; + } + return z; } -int -tsk_tree_preorder(const tsk_tree_t *self, tsk_id_t *nodes, tsk_size_t *num_nodes_ret) +static tsk_id_t +sv_tables_mrca(const sv_tables_t *self, tsk_id_t x, tsk_id_t y) { - return tsk_tree_preorder_from(self, -1, nodes, num_nodes_ret); + /* Convert to 1-based indexes and back */ + return sv_tables_mrca_one_based(self, x + 1, y + 1) - 1; } -int -tsk_tree_preorder_from( - const tsk_tree_t *self, tsk_id_t root, tsk_id_t *nodes, tsk_size_t *num_nodes_ret) +static int +tsk_treeseq_divergence_matrix_branch(const tsk_treeseq_t *self, + tsk_size_t num_sample_sets, const tsk_size_t *restrict sample_set_sizes, + const tsk_id_t *restrict sample_sets, tsk_size_t num_windows, + const double *restrict windows, tsk_flags_t options, double *restrict result) { int ret = 0; - const tsk_id_t *restrict right_child = self->right_child; - const tsk_id_t *restrict left_sib = self->left_sib; - tsk_id_t *stack = tsk_tree_alloc_node_stack(self); - tsk_size_t num_nodes = 0; - tsk_id_t u, v; - int stack_top; - - if (stack == NULL) { - ret = TSK_ERR_NO_MEMORY; + tsk_tree_t tree; + const double *restrict nodes_time = self->tables->nodes.time; + const tsk_size_t N = num_sample_sets; + tsk_size_t i, j, k, offset, sj, sk; + tsk_id_t u, v, w, u_root, v_root; + double tu, tv, d, span, left, right, span_left, span_right; + double *restrict D; + sv_tables_t sv; + tsk_size_t *ss_offsets = tsk_malloc((num_sample_sets + 1) * sizeof(*ss_offsets)); + + memset(&sv, 0, sizeof(sv)); + ret = tsk_tree_init(&tree, self, 0); + if (ret != 0) { goto out; } - - if ((root == -1 || root == self->virtual_root) - && !tsk_tree_has_sample_counts(self)) { - ret = TSK_ERR_UNSUPPORTED_OPERATION; + ret = sv_tables_init(&sv, self->tables->nodes.num_rows + 1); + if (ret != 0) { goto out; } - if (root == -1) { - stack_top = -1; - for (u = right_child[self->virtual_root]; u != TSK_NULL; u = left_sib[u]) { - stack_top++; - stack[stack_top] = u; - } - } else { - ret = tsk_tree_check_node(self, root); + if (ss_offsets == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + if (self->time_uncalibrated && !(options & TSK_STAT_ALLOW_TIME_UNCALIBRATED)) { + ret = tsk_trace_error(TSK_ERR_TIME_UNCALIBRATED); + goto out; + } + + ss_offsets[0] = 0; + offset = 0; + for (j = 0; j < N; j++) { + offset += sample_set_sizes[j]; + ss_offsets[j + 1] = offset; + } + + for (i = 0; i < num_windows; i++) { + left = windows[i]; + right = windows[i + 1]; + D = result + i * N * N; + ret = tsk_tree_seek(&tree, left, 0); if (ret != 0) { goto out; } - stack_top = 0; - stack[stack_top] = root; - } - - while (stack_top >= 0) { - u = stack[stack_top]; - stack_top--; - nodes[num_nodes] = u; - num_nodes++; - for (v = right_child[u]; v != TSK_NULL; v = left_sib[v]) { - stack_top++; - stack[stack_top] = v; + while (tree.interval.left < right && tree.index != -1) { + span_left = TSK_MAX(tree.interval.left, left); + span_right = TSK_MIN(tree.interval.right, right); + span = span_right - span_left; + sv_tables_build(&sv, &tree); + for (sj = 0; sj < N; sj++) { + for (j = ss_offsets[sj]; j < ss_offsets[sj + 1]; j++) { + u = sample_sets[j]; + for (sk = sj; sk < N; sk++) { + for (k = ss_offsets[sk]; k < ss_offsets[sk + 1]; k++) { + v = sample_sets[k]; + if (u == v) { + /* This case contributes zero to divergence, so + * short-circuit to save time. + * TODO is there a better way to do this? */ + continue; + } + w = sv_tables_mrca(&sv, u, v); + if (w != TSK_NULL) { + u_root = w; + v_root = w; + } else { + /* Slow path - only happens for nodes in disconnected + * subtrees in a tree with multiple roots */ + u_root = tsk_tree_get_node_root(&tree, u); + v_root = tsk_tree_get_node_root(&tree, v); + } + tu = nodes_time[u_root] - nodes_time[u]; + tv = nodes_time[v_root] - nodes_time[v]; + d = (tu + tv) * span; + D[sj * N + sk] += d; + } + } + } + } + ret = tsk_tree_next(&tree); + if (ret < 0) { + goto out; + } } } - *num_nodes_ret = num_nodes; + ret = 0; out: - tsk_safe_free(stack); + tsk_tree_free(&tree); + sv_tables_free(&sv); + tsk_safe_free(ss_offsets); return ret; } -/* We could implement this using the preorder function, but since it's - * going to be performance critical we want to avoid the overhead - * of mallocing the intermediate node list (which will be bigger than - * the number of samples). */ -int -tsk_tree_preorder_samples_from( - const tsk_tree_t *self, tsk_id_t root, tsk_id_t *nodes, tsk_size_t *num_nodes_ret) -{ - int ret = 0; - const tsk_id_t *restrict right_child = self->right_child; - const tsk_id_t *restrict left_sib = self->left_sib; - const tsk_flags_t *restrict flags = self->tree_sequence->tables->nodes.flags; - tsk_id_t *stack = tsk_tree_alloc_node_stack(self); - tsk_size_t num_nodes = 0; - tsk_id_t u, v; - int stack_top; +// FIXME see #2817 +// Just including this here for now as it's the simplest option. Everything +// will probably move to stats.[c,h] in the near future though, and it +// can pull in ``genotypes.h`` without issues. +#include - if (stack == NULL) { - ret = TSK_ERR_NO_MEMORY; - goto out; - } +static void +update_site_divergence(const tsk_variant_t *var, const tsk_id_t *restrict A, + const tsk_size_t *restrict offsets, const tsk_size_t num_sample_sets, double *D) - /* We could push the virtual_root onto the stack directly to simplify - * the code a little, but then we'd have to check put an extra check - * when looking up the flags array (which isn't defined for virtual_root). - */ - if (root == -1 || root == self->virtual_root) { - if (!tsk_tree_has_sample_counts(self)) { - ret = TSK_ERR_UNSUPPORTED_OPERATION; - goto out; - } - stack_top = -1; - for (u = right_child[self->virtual_root]; u != TSK_NULL; u = left_sib[u]) { - stack_top++; - stack[stack_top] = u; - } - } else { - ret = tsk_tree_check_node(self, root); - if (ret != 0) { - goto out; +{ + const tsk_size_t num_alleles = var->num_alleles; + tsk_size_t a, b, j, k; + tsk_id_t u, v; + double increment; + + for (a = 0; a < num_alleles; a++) { + for (b = a + 1; b < num_alleles; b++) { + for (j = offsets[a]; j < offsets[a + 1]; j++) { + for (k = offsets[b]; k < offsets[b + 1]; k++) { + u = A[j]; + v = A[k]; + /* Only increment the upper triangle to (hopefully) improve memory + * access patterns */ + if (u > v) { + u = A[k]; + v = A[j]; + } + increment = 1; + if (u == v) { + increment = 2; + } + D[u * (tsk_id_t) num_sample_sets + v] += increment; + } + } } - stack_top = 0; - stack[stack_top] = root; } +} - while (stack_top >= 0) { - u = stack[stack_top]; - stack_top--; - if (flags[u] & TSK_NODE_IS_SAMPLE) { - nodes[num_nodes] = u; - num_nodes++; - } - for (v = right_child[u]; v != TSK_NULL; v = left_sib[v]) { - stack_top++; - stack[stack_top] = v; +static void +group_alleles(const tsk_variant_t *var, tsk_id_t *restrict A, tsk_size_t *offsets) +{ + const tsk_size_t n = var->num_samples; + const int32_t *restrict genotypes = var->genotypes; + tsk_id_t a; + tsk_size_t j, k; + + k = 0; + offsets[0] = 0; + for (a = 0; a < (tsk_id_t) var->num_alleles; a++) { + offsets[a + 1] = offsets[a]; + for (j = 0; j < n; j++) { + if (genotypes[j] == a) { + offsets[a + 1]++; + A[k] = (tsk_id_t) j; + k++; + } } } - *num_nodes_ret = num_nodes; -out: - tsk_safe_free(stack); - return ret; } -int -tsk_tree_postorder(const tsk_tree_t *self, tsk_id_t *nodes, tsk_size_t *num_nodes_ret) +static void +remap_to_sample_sets(const tsk_size_t num_samples, const tsk_id_t *restrict samples, + const tsk_id_t *restrict sample_set_index_map, tsk_id_t *restrict A) { - return tsk_tree_postorder_from(self, -1, nodes, num_nodes_ret); + tsk_size_t j; + tsk_id_t u; + for (j = 0; j < num_samples; j++) { + u = samples[A[j]]; + tsk_bug_assert(u >= 0); + tsk_bug_assert(sample_set_index_map[u] >= 0); + A[j] = sample_set_index_map[u]; + } } -int -tsk_tree_postorder_from( - const tsk_tree_t *self, tsk_id_t root, tsk_id_t *nodes, tsk_size_t *num_nodes_ret) + +static int +tsk_treeseq_divergence_matrix_site(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_id_t *restrict sample_set_index_map, const tsk_size_t num_samples, + const tsk_id_t *restrict samples, tsk_size_t num_windows, + const double *restrict windows, tsk_flags_t TSK_UNUSED(options), + double *restrict result) { int ret = 0; - const tsk_id_t *restrict right_child = self->right_child; - const tsk_id_t *restrict left_sib = self->left_sib; - const tsk_id_t *restrict parent = self->parent; - tsk_id_t *stack = tsk_tree_alloc_node_stack(self); - tsk_size_t num_nodes = 0; - tsk_id_t u, v, postorder_parent; - int stack_top; - bool is_virtual_root = root == self->virtual_root; - - if (stack == NULL) { - ret = TSK_ERR_NO_MEMORY; + tsk_size_t i; + tsk_id_t site_id; + double left, right; + double *restrict D; + const tsk_id_t num_sites = (tsk_id_t) self->tables->sites.num_rows; + const double *restrict sites_position = self->tables->sites.position; + tsk_id_t *A = tsk_malloc(num_samples * sizeof(*A)); + /* Allocate the allele offsets at the first variant */ + tsk_size_t max_alleles = 0; + tsk_size_t *allele_offsets = NULL; + tsk_variant_t variant; + + /* FIXME it's not clear that using TSK_ISOLATED_NOT_MISSING is + * correct here */ + ret = tsk_variant_init( + &variant, self, samples, num_samples, NULL, TSK_ISOLATED_NOT_MISSING); + if (ret != 0) { + goto out; + } + if (A == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } - if (root == -1 || is_virtual_root) { - if (!tsk_tree_has_sample_counts(self)) { - ret = TSK_ERR_UNSUPPORTED_OPERATION; - goto out; - } - stack_top = -1; - for (u = right_child[self->virtual_root]; u != TSK_NULL; u = left_sib[u]) { - stack_top++; - stack[stack_top] = u; - } - } else { - ret = tsk_tree_check_node(self, root); - if (ret != 0) { - goto out; - } - stack_top = 0; - stack[stack_top] = root; + site_id = 0; + while (site_id < num_sites && sites_position[site_id] < windows[0]) { + site_id++; } - postorder_parent = TSK_NULL; - while (stack_top >= 0) { - u = stack[stack_top]; - if (right_child[u] != TSK_NULL && u != postorder_parent) { - for (v = right_child[u]; v != TSK_NULL; v = left_sib[v]) { - stack_top++; - stack[stack_top] = v; + for (i = 0; i < num_windows; i++) { + left = windows[i]; + right = windows[i + 1]; + D = result + i * num_sample_sets * num_sample_sets; + + if (site_id < num_sites) { + tsk_bug_assert(sites_position[site_id] >= left); + } + while (site_id < num_sites && sites_position[site_id] < right) { + ret = tsk_variant_decode(&variant, site_id, 0); + if (ret != 0) { + goto out; } - } else { - stack_top--; - postorder_parent = parent[u]; - nodes[num_nodes] = u; - num_nodes++; + if (variant.num_alleles > max_alleles) { + /* could do some kind of doubling here, but there's no + * point - just keep it simple for testing. */ + max_alleles = variant.num_alleles; + tsk_safe_free(allele_offsets); + allele_offsets = tsk_malloc((max_alleles + 1) * sizeof(*allele_offsets)); + if (allele_offsets == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + } + group_alleles(&variant, A, allele_offsets); + remap_to_sample_sets(num_samples, samples, sample_set_index_map, A); + update_site_divergence(&variant, A, allele_offsets, num_sample_sets, D); + site_id++; } } - if (is_virtual_root) { - nodes[num_nodes] = root; - num_nodes++; - } - *num_nodes_ret = num_nodes; + ret = 0; out: - tsk_safe_free(stack); + tsk_variant_free(&variant); + tsk_safe_free(A); + tsk_safe_free(allele_offsets); return ret; } -/* Balance/imbalance metrics */ - -/* Result is a tsk_size_t value here because we could imagine the total - * depth overflowing a 32bit integer for a large tree. */ -int -tsk_tree_sackin_index(const tsk_tree_t *self, tsk_size_t *result) +/* Return the mapping from node IDs to the index of the sample set + * they belong to, or -1 of none. Error if a node is in more than one + * set. + */ +static int +get_sample_set_index_map(const tsk_treeseq_t *self, const tsk_size_t num_sample_sets, + const tsk_size_t *restrict sample_set_sizes, const tsk_id_t *restrict sample_sets, + tsk_size_t *ret_total_samples, tsk_id_t *restrict node_index_map) { - /* Keep the size of the stack elements to 8 bytes in total in the - * standard case. A tsk_id_t depth value is always safe, since - * depth counts the number of nodes encountered on a path. - */ - struct stack_elem { - tsk_id_t node; - tsk_id_t depth; - }; int ret = 0; - const tsk_id_t *restrict right_child = self->right_child; - const tsk_id_t *restrict left_sib = self->left_sib; - struct stack_elem *stack - = tsk_malloc(tsk_tree_get_size_bound(self) * sizeof(*stack)); - int stack_top; - tsk_size_t total_depth; + tsk_size_t i, j, k; tsk_id_t u; - struct stack_elem s = { .node = TSK_NULL, .depth = 0 }; - - if (stack == NULL) { - ret = TSK_ERR_NO_MEMORY; - goto out; - } + tsk_size_t total_samples = 0; + const tsk_size_t num_nodes = self->tables->nodes.num_rows; + const tsk_flags_t *restrict node_flags = self->tables->nodes.flags; - stack_top = -1; - for (u = right_child[self->virtual_root]; u != TSK_NULL; u = left_sib[u]) { - stack_top++; - s.node = u; - stack[stack_top] = s; + for (j = 0; j < num_nodes; j++) { + node_index_map[j] = TSK_NULL; } - total_depth = 0; - while (stack_top >= 0) { - s = stack[stack_top]; - stack_top--; - u = right_child[s.node]; - if (u == TSK_NULL) { - total_depth += (tsk_size_t) s.depth; - } else { - s.depth++; - while (u != TSK_NULL) { - stack_top++; - s.node = u; - stack[stack_top] = s; - u = left_sib[u]; + i = 0; + for (j = 0; j < num_sample_sets; j++) { + total_samples += sample_set_sizes[j]; + for (k = 0; k < sample_set_sizes[j]; k++) { + u = sample_sets[i]; + i++; + if (u < 0 || u >= (tsk_id_t) num_nodes) { + ret = tsk_trace_error(TSK_ERR_NODE_OUT_OF_BOUNDS); + goto out; + } + /* Note: we require nodes to be samples because we have to think + * about how to normalise by the length of genome that the node + * is 'in' the tree for each window otherwise. */ + if (!(node_flags[u] & TSK_NODE_IS_SAMPLE)) { + ret = tsk_trace_error(TSK_ERR_BAD_SAMPLES); + goto out; + } + if (node_index_map[u] != TSK_NULL) { + ret = tsk_trace_error(TSK_ERR_DUPLICATE_SAMPLE); + goto out; } + node_index_map[u] = (tsk_id_t) j; } } - *result = total_depth; + *ret_total_samples = total_samples; out: - tsk_safe_free(stack); return ret; } +static void +fill_lower_triangle_count_normalise(const tsk_size_t num_windows, const tsk_size_t n, + const tsk_size_t *set_sizes, double *restrict result) +{ + tsk_size_t i, j, k; + double denom; + double *restrict D; + + /* TODO there's probably a better striding pattern that could be used here */ + for (i = 0; i < num_windows; i++) { + D = result + i * n * n; + for (j = 0; j < n; j++) { + denom = (double) set_sizes[j] * (double) (set_sizes[j] - 1); + if (denom != 0) { + D[j * n + j] /= denom; + } + for (k = j + 1; k < n; k++) { + denom = (double) set_sizes[j] * (double) set_sizes[k]; + D[j * n + k] /= denom; + D[k * n + j] = D[j * n + k]; + } + } + } +} + int -tsk_tree_colless_index(const tsk_tree_t *self, tsk_size_t *result) +tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_sample_sets_in, + const tsk_size_t *sample_set_sizes_in, const tsk_id_t *sample_sets_in, + tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result) { int ret = 0; - const tsk_id_t *restrict right_child = self->right_child; - const tsk_id_t *restrict left_sib = self->left_sib; - tsk_id_t *nodes = tsk_malloc(tsk_tree_get_size_bound(self) * sizeof(*nodes)); - tsk_id_t *num_leaves = tsk_calloc(self->num_nodes, sizeof(*num_leaves)); - tsk_size_t j, num_nodes, total; - tsk_id_t num_children, u, v; + tsk_size_t N, total_samples; + const tsk_size_t *sample_set_sizes; + const tsk_id_t *sample_sets; + tsk_size_t *tmp_sample_set_sizes = NULL; + const double default_windows[] = { 0, self->tables->sequence_length }; + const tsk_size_t num_nodes = self->tables->nodes.num_rows; + bool stat_site = !!(options & TSK_STAT_SITE); + bool stat_branch = !!(options & TSK_STAT_BRANCH); + bool stat_node = !!(options & TSK_STAT_NODE); + tsk_id_t *sample_set_index_map + = tsk_malloc(num_nodes * sizeof(*sample_set_index_map)); + tsk_size_t j; - if (nodes == NULL || num_leaves == NULL) { - ret = TSK_ERR_NO_MEMORY; + if (stat_node) { + ret = tsk_trace_error(TSK_ERR_UNSUPPORTED_STAT_MODE); goto out; } - if (tsk_tree_get_num_roots(self) != 1) { - ret = TSK_ERR_UNDEFINED_MULTIROOT; + /* If no mode is specified, we default to site mode */ + if (!(stat_site || stat_branch)) { + stat_site = true; + } + /* It's an error to specify more than one mode */ + if (stat_site + stat_branch > 1) { + ret = tsk_trace_error(TSK_ERR_MULTIPLE_STAT_MODES); goto out; } - ret = tsk_tree_postorder(self, nodes, &num_nodes); - if (ret != 0) { + + if (options & TSK_STAT_POLARISED) { + ret = tsk_trace_error(TSK_ERR_STAT_POLARISED_UNSUPPORTED); goto out; } - total = 0; - for (j = 0; j < num_nodes; j++) { - u = nodes[j]; - /* Cheaper to compute this on the fly than to access the num_children array. - * since we're already iterating over the children. */ - num_children = 0; - for (v = right_child[u]; v != TSK_NULL; v = left_sib[v]) { - num_children++; - num_leaves[u] += num_leaves[v]; + if (windows == NULL) { + num_windows = 1; + windows = default_windows; + } else { + ret = tsk_treeseq_check_windows(self, num_windows, windows, 0); + if (ret != 0) { + goto out; + } + } + + /* If sample_sets is NULL, use self->samples and ignore input + * num_sample_sets */ + sample_sets = sample_sets_in; + N = num_sample_sets_in; + if (sample_sets_in == NULL) { + sample_sets = self->samples; + if (sample_set_sizes_in == NULL) { + N = self->num_samples; } - if (num_children == 0) { - num_leaves[u] = 1; - } else if (num_children == 2) { - v = right_child[u]; - total += (tsk_size_t) llabs(num_leaves[v] - num_leaves[left_sib[v]]); - } else { - ret = TSK_ERR_UNDEFINED_NONBINARY; + } + sample_set_sizes = sample_set_sizes_in; + /* If sample_set_sizes is NULL, assume its N 1S */ + if (sample_set_sizes_in == NULL) { + tmp_sample_set_sizes = tsk_malloc(N * sizeof(*tmp_sample_set_sizes)); + if (tmp_sample_set_sizes == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } + for (j = 0; j < N; j++) { + tmp_sample_set_sizes[j] = 1; + } + sample_set_sizes = tmp_sample_set_sizes; } - *result = total; -out: - tsk_safe_free(nodes); - tsk_safe_free(num_leaves); - return ret; -} - -int -tsk_tree_b1_index(const tsk_tree_t *self, double *result) -{ - int ret = 0; - const tsk_id_t *restrict parent = self->parent; - const tsk_id_t *restrict right_child = self->right_child; - const tsk_id_t *restrict left_sib = self->left_sib; - tsk_id_t *nodes = tsk_malloc(tsk_tree_get_size_bound(self) * sizeof(*nodes)); - tsk_size_t *max_path_length = tsk_calloc(self->num_nodes, sizeof(*max_path_length)); - tsk_size_t j, num_nodes, mpl; - double total = 0.0; - tsk_id_t u, v; - if (nodes == NULL || max_path_length == NULL) { - ret = TSK_ERR_NO_MEMORY; + ret = get_sample_set_index_map( + self, N, sample_set_sizes, sample_sets, &total_samples, sample_set_index_map); + if (ret != 0) { goto out; } - ret = tsk_tree_postorder(self, nodes, &num_nodes); + + tsk_memset(result, 0, num_windows * N * N * sizeof(*result)); + + if (stat_branch) { + ret = tsk_treeseq_divergence_matrix_branch(self, N, sample_set_sizes, + sample_sets, num_windows, windows, options, result); + } else { + tsk_bug_assert(stat_site); + ret = tsk_treeseq_divergence_matrix_site(self, N, sample_set_index_map, + total_samples, sample_sets, num_windows, windows, options, result); + } if (ret != 0) { goto out; } + fill_lower_triangle_count_normalise(num_windows, N, sample_set_sizes, result); - for (j = 0; j < num_nodes; j++) { - u = nodes[j]; - if (parent[u] != TSK_NULL && right_child[u] != TSK_NULL) { - mpl = 0; - for (v = right_child[u]; v != TSK_NULL; v = left_sib[v]) { - mpl = TSK_MAX(mpl, max_path_length[v]); - } - max_path_length[u] = mpl + 1; - total += 1 / (double) max_path_length[u]; - } + if (options & TSK_STAT_SPAN_NORMALISE) { + span_normalise(num_windows, windows, N * N, result); } - *result = total; out: - tsk_safe_free(nodes); - tsk_safe_free(max_path_length); + tsk_safe_free(sample_set_index_map); + tsk_safe_free(tmp_sample_set_sizes); return ret; } -static double -general_log(double x, double base) -{ - return log(x) / log(base); +/* ======================================================== * + * Extend haplotypes + * ======================================================== */ + +typedef struct _edge_list_t { + tsk_id_t edge; + // the `extended` flags records whether we have decided to extend + // this entry to the current tree? + int extended; + struct _edge_list_t *next; +} edge_list_t; + +static void +edge_list_print(edge_list_t **head, tsk_edge_table_t *edges, FILE *out) +{ + int n = 0; + edge_list_t *px; + fprintf(out, "Edge list:\n"); + for (px = *head; px != NULL; px = px->next) { + fprintf(out, " %d: %d (%d); ", n, (int) px->edge, px->extended); + if (px->edge >= 0 && edges != NULL) { + fprintf(out, "%d->%d on [%.1f, %.1f)", (int) edges->child[px->edge], + (int) edges->parent[px->edge], edges->left[px->edge], + edges->right[px->edge]); + } else { + fprintf(out, "(null)"); + } + fprintf(out, "\n"); + n += 1; + } + fprintf(out, "length = %d\n", n); } -int -tsk_tree_b2_index(const tsk_tree_t *self, double base, double *result) +static void +edge_list_append_entry( + edge_list_t **head, edge_list_t **tail, edge_list_t *x, tsk_id_t edge, int extended) { - struct stack_elem { - tsk_id_t node; - double path_product; - }; - int ret = 0; - const tsk_id_t *restrict right_child = self->right_child; - const tsk_id_t *restrict left_sib = self->left_sib; - struct stack_elem *stack - = tsk_malloc(tsk_tree_get_size_bound(self) * sizeof(*stack)); - int stack_top; - double total_proba = 0; - double num_children; - tsk_id_t u; - struct stack_elem s = { .node = TSK_NULL, .path_product = 1 }; + x->edge = edge; + x->extended = extended; + x->next = NULL; - if (stack == NULL) { - ret = TSK_ERR_NO_MEMORY; - goto out; - } - if (tsk_tree_get_num_roots(self) != 1) { - ret = TSK_ERR_UNDEFINED_MULTIROOT; - goto out; + if (*tail == NULL) { + *head = x; + } else { + (*tail)->next = x; } + *tail = x; +} - stack_top = 0; - s.node = tsk_tree_get_left_root(self); - stack[stack_top] = s; - - while (stack_top >= 0) { - s = stack[stack_top]; - stack_top--; - u = right_child[s.node]; - if (u == TSK_NULL) { - total_proba -= s.path_product * general_log(s.path_product, base); - } else { - num_children = 0; - for (; u != TSK_NULL; u = left_sib[u]) { - num_children++; - } - s.path_product *= 1 / num_children; - for (u = right_child[s.node]; u != TSK_NULL; u = left_sib[u]) { - stack_top++; - s.node = u; - stack[stack_top] = s; +static void +remove_unextended(edge_list_t **head, edge_list_t **tail) +{ + edge_list_t *px, *x; + + px = *head; + while (px != NULL && px->extended == 0) { + px = px->next; + } + *head = px; + if (px != NULL) { + px->extended = 0; + x = px->next; + while (x != NULL) { + if (x->extended > 0) { + x->extended = 0; + px->next = x; + px = x; } + x = x->next; } + px->next = NULL; } - *result = total_proba; -out: - tsk_safe_free(stack); - return ret; + *tail = px; } -int -tsk_tree_num_lineages(const tsk_tree_t *self, double t, tsk_size_t *result) +static void +edge_list_set_extended(edge_list_t **head, tsk_id_t edge_id) +{ + // finds the entry with edge 'edge_id' + // and sets its 'extended' flag to 1 + edge_list_t *px; + px = *head; + tsk_bug_assert(px != NULL); + while (px->edge != edge_id) { + px = px->next; + tsk_bug_assert(px != NULL); + } + tsk_bug_assert(px->edge == edge_id); + px->extended = 1; +} + +static int +tsk_treeseq_slide_mutation_nodes_up( + const tsk_treeseq_t *self, tsk_mutation_table_t *mutations) { int ret = 0; - const tsk_id_t *restrict right_child = self->right_child; - const tsk_id_t *restrict left_sib = self->left_sib; - const double *restrict time = self->tree_sequence->tables->nodes.time; - tsk_id_t *stack = tsk_tree_alloc_node_stack(self); - tsk_size_t num_lineages = 0; - int stack_top; - tsk_id_t u, v; - double child_time, parent_time; + double t; + tsk_id_t c, p, next_mut; + const tsk_size_t num_nodes = self->tables->nodes.num_rows; + const double *sites_position = self->tables->sites.position; + const double *nodes_time = self->tables->nodes.time; + tsk_tree_t tree; - if (stack == NULL) { - ret = TSK_ERR_NO_MEMORY; - goto out; - } - if (!tsk_isfinite(t)) { - ret = TSK_ERR_TIME_NONFINITE; + ret = tsk_tree_init(&tree, self, TSK_NO_SAMPLE_COUNTS); + if (ret != 0) { goto out; } - /* Push the roots onto the stack */ - stack_top = -1; - for (u = right_child[self->virtual_root]; u != TSK_NULL; u = left_sib[u]) { - stack_top++; - stack[stack_top] = u; - } - while (stack_top >= 0) { - u = stack[stack_top]; - parent_time = time[u]; - stack_top--; - for (v = right_child[u]; v != TSK_NULL; v = left_sib[v]) { - child_time = time[v]; - /* Only traverse down the tree as far as we need to */ - if (child_time > t) { - stack_top++; - stack[stack_top] = v; - } else if (t < parent_time) { - num_lineages++; + next_mut = 0; + for (ret = tsk_tree_first(&tree); ret == TSK_TREE_OK; ret = tsk_tree_next(&tree)) { + while (next_mut < (tsk_id_t) mutations->num_rows + && sites_position[mutations->site[next_mut]] < tree.interval.right) { + t = mutations->time[next_mut]; + if (tsk_is_unknown_time(t)) { + ret = tsk_trace_error(TSK_ERR_DISALLOWED_UNKNOWN_MUTATION_TIME); + goto out; + } + c = mutations->node[next_mut]; + tsk_bug_assert(c < (tsk_id_t) num_nodes); + p = tree.parent[c]; + while (p != TSK_NULL && nodes_time[p] <= t) { + c = p; + p = tree.parent[c]; } + tsk_bug_assert(nodes_time[c] <= t); + mutations->node[next_mut] = c; + next_mut++; } } - *result = num_lineages; + if (ret != 0) { + goto out; + } + out: - tsk_safe_free(stack); + tsk_tree_free(&tree); + return ret; } -/* Parsimony methods */ +typedef struct { + const tsk_treeseq_t *ts; + tsk_edge_table_t *edges; + int direction; + tsk_id_t *last_degree, *next_degree; + tsk_id_t *last_nodes_edge, *next_nodes_edge; + tsk_id_t *parent_out, *parent_in; + bool *not_sample; + double *near_side, *far_side; + edge_list_t *edges_out_head, *edges_out_tail; + edge_list_t *edges_in_head, *edges_in_tail; + tsk_blkalloc_t edge_list_heap; +} haplotype_extender_t; -static inline uint64_t -set_bit(uint64_t value, int32_t bit) +static int +haplotype_extender_init(haplotype_extender_t *self, const tsk_treeseq_t *ts, + int direction, tsk_edge_table_t *edges) { - return value | (1ULL << bit); -} + int ret = 0; + tsk_id_t tj; + tsk_size_t num_nodes = tsk_treeseq_get_num_nodes(ts); -static inline bool -bit_is_set(uint64_t value, int32_t bit) -{ - return (value & (1ULL << bit)) != 0; -} + tsk_memset(self, 0, sizeof(haplotype_extender_t)); -static inline int8_t -get_smallest_set_bit(uint64_t v) -{ - /* This is an inefficient implementation, there are several better - * approaches. On GCC we can use - * return (uint8_t) (__builtin_ffsll((long long) v) - 1); - */ - uint64_t t = 1; - int8_t r = 0; + self->ts = ts; + self->edges = edges; + ret = tsk_edge_table_copy(&ts->tables->edges, self->edges, TSK_NO_INIT); + if (ret != 0) { + goto out; + } - assert(v != 0); - while ((v & t) == 0) { - t <<= 1; - r++; + self->direction = direction; + if (direction == TSK_DIR_FORWARD) { + self->near_side = self->edges->left; + self->far_side = self->edges->right; + } else { + self->near_side = self->edges->right; + self->far_side = self->edges->left; } - return r; -} -#define HARTIGAN_MAX_ALLELES 64 + self->edges_in_head = NULL; + self->edges_in_tail = NULL; + self->edges_out_head = NULL; + self->edges_out_tail = NULL; -/* This interface is experimental. In the future, we should provide the option to - * use a general cost matrix, in which case we'll use the Sankoff algorithm. For - * now this is unused. - * - * We should also vectorise the function so that several sites can be processed - * at once. - * - * The algorithm used here is Hartigan parsimony, "Minimum Mutation Fits to a - * Given Tree", Biometrics 1973. - */ -int TSK_WARN_UNUSED -tsk_tree_map_mutations(tsk_tree_t *self, int32_t *genotypes, - double *TSK_UNUSED(cost_matrix), tsk_flags_t options, int32_t *r_ancestral_state, - tsk_size_t *r_num_transitions, tsk_state_transition_t **r_transitions) -{ - int ret = 0; - struct stack_elem { - tsk_id_t node; - tsk_id_t transition_parent; - int32_t state; - }; - const tsk_size_t num_samples = self->tree_sequence->num_samples; - const tsk_id_t *restrict left_child = self->left_child; - const tsk_id_t *restrict right_sib = self->right_sib; - const tsk_size_t N = tsk_treeseq_get_num_nodes(self->tree_sequence); - const tsk_flags_t *restrict node_flags = self->tree_sequence->tables->nodes.flags; - tsk_id_t *nodes = tsk_malloc(tsk_tree_get_size_bound(self) * sizeof(*nodes)); - /* Note: to use less memory here and to improve cache performance we should - * probably change to allocating exactly the number of nodes returned by - * a preorder traversal, and then lay the memory out in this order. So, we'd - * need a map from node ID to its index in the preorder traversal, but this - * is trivial to compute. Probably doesn't matter so much at the moment - * when we're doing a single site, but it would make a big difference if - * we were vectorising over lots of sites. */ - uint64_t *restrict optimal_set = tsk_calloc(N + 1, sizeof(*optimal_set)); - struct stack_elem *restrict preorder_stack - = tsk_malloc(tsk_tree_get_size_bound(self) * sizeof(*preorder_stack)); - tsk_id_t u, v; - /* The largest possible number of transitions is one over every sample */ - tsk_state_transition_t *transitions = tsk_malloc(num_samples * sizeof(*transitions)); - int32_t allele, ancestral_state; - int stack_top; - struct stack_elem s; - tsk_size_t j, num_transitions, max_allele_count, num_nodes; - tsk_size_t allele_count[HARTIGAN_MAX_ALLELES]; - tsk_size_t non_missing = 0; - int32_t num_alleles = 0; + ret = tsk_blkalloc_init(&self->edge_list_heap, 8192); + if (ret != 0) { + goto out; + } - if (optimal_set == NULL || preorder_stack == NULL || transitions == NULL - || nodes == NULL) { - ret = TSK_ERR_NO_MEMORY; + self->last_degree = tsk_calloc(num_nodes, sizeof(*self->last_degree)); + self->next_degree = tsk_calloc(num_nodes, sizeof(*self->next_degree)); + self->last_nodes_edge = tsk_malloc(num_nodes * sizeof(*self->last_nodes_edge)); + self->next_nodes_edge = tsk_malloc(num_nodes * sizeof(*self->next_nodes_edge)); + self->parent_out = tsk_malloc(num_nodes * sizeof(*self->parent_out)); + self->parent_in = tsk_malloc(num_nodes * sizeof(*self->parent_in)); + self->not_sample = tsk_malloc(num_nodes * sizeof(*self->not_sample)); + + if (self->last_degree == NULL || self->next_degree == NULL + || self->last_nodes_edge == NULL || self->next_nodes_edge == NULL + || self->parent_out == NULL || self->parent_in == NULL + || self->not_sample == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } - for (j = 0; j < num_samples; j++) { - if (genotypes[j] >= HARTIGAN_MAX_ALLELES || genotypes[j] < TSK_MISSING_DATA) { - ret = TSK_ERR_BAD_GENOTYPE; - goto out; + tsk_memset(self->last_nodes_edge, 0xff, num_nodes * sizeof(*self->last_nodes_edge)); + tsk_memset(self->next_nodes_edge, 0xff, num_nodes * sizeof(*self->next_nodes_edge)); + tsk_memset(self->parent_out, 0xff, num_nodes * sizeof(*self->parent_out)); + tsk_memset(self->parent_in, 0xff, num_nodes * sizeof(*self->parent_in)); + + for (tj = 0; tj < (tsk_id_t) num_nodes; tj++) { + self->not_sample[tj] = ((ts->tables->nodes.flags[tj] & TSK_NODE_IS_SAMPLE) == 0); + } + +out: + return ret; +} + +static void +haplotype_extender_print_state(haplotype_extender_t *self, FILE *out) +{ + fprintf(out, "\n======= haplotype extender ===========\n"); + fprintf(out, "parent in:\n"); + for (int j = 0; j < (int) self->ts->tables->nodes.num_rows; j++) { + fprintf(out, " %d: %d\n", j, (int) self->parent_in[j]); + } + fprintf(out, "parent out:\n"); + for (int j = 0; j < (int) self->ts->tables->nodes.num_rows; j++) { + fprintf(out, " %d: %d\n", j, (int) self->parent_out[j]); + } + fprintf(out, "last nodes edge:\n"); + for (int j = 0; j < (int) self->ts->tables->nodes.num_rows; j++) { + tsk_id_t ej = self->last_nodes_edge[j]; + fprintf(out, " %d: %d, ", j, (int) ej); + if (self->last_nodes_edge[j] != TSK_NULL) { + fprintf(out, "(%d->%d, %.1f-%.1f)", (int) self->edges->child[ej], + (int) self->edges->parent[ej], self->edges->left[ej], + self->edges->right[ej]); + } else { + fprintf(out, "(null);"); } - u = self->tree_sequence->samples[j]; - if (genotypes[j] == TSK_MISSING_DATA) { - /* All bits set */ - optimal_set[u] = UINT64_MAX; + fprintf(out, "\n"); + } + fprintf(out, "next nodes edge:\n"); + for (int j = 0; j < (int) self->ts->tables->nodes.num_rows; j++) { + tsk_id_t ej = self->next_nodes_edge[j]; + fprintf(out, " %d: %d, ", j, (int) ej); + if (self->next_nodes_edge[j] != TSK_NULL) { + fprintf(out, "(%d->%d, %.1f-%.1f)", (int) self->edges->child[ej], + (int) self->edges->parent[ej], self->edges->left[ej], + self->edges->right[ej]); } else { - optimal_set[u] = set_bit(optimal_set[u], genotypes[j]); - num_alleles = TSK_MAX(genotypes[j], num_alleles); - non_missing++; + fprintf(out, "(null);"); } + fprintf(out, "\n"); } + fprintf(out, "edges out:\n"); + edge_list_print(&self->edges_out_head, self->edges, out); + fprintf(out, "edges in:\n"); + edge_list_print(&self->edges_in_head, self->edges, out); +} - if (non_missing == 0) { - ret = TSK_ERR_GENOTYPES_ALL_MISSING; - goto out; +static int +haplotype_extender_free(haplotype_extender_t *self) +{ + tsk_blkalloc_free(&self->edge_list_heap); + tsk_safe_free(self->last_degree); + tsk_safe_free(self->next_degree); + tsk_safe_free(self->last_nodes_edge); + tsk_safe_free(self->next_nodes_edge); + tsk_safe_free(self->parent_out); + tsk_safe_free(self->parent_in); + tsk_safe_free(self->not_sample); + return 0; +} + +static int +haplotype_extender_next_tree(haplotype_extender_t *self, tsk_tree_position_t *tree_pos) +{ + int ret = 0; + tsk_id_t tj, e; + edge_list_t *ex_out, *ex_in; + edge_list_t *new_ex; + const tsk_id_t *edges_child = self->edges->child; + const tsk_id_t *edges_parent = self->edges->parent; + + for (ex_out = self->edges_out_head; ex_out != NULL; ex_out = ex_out->next) { + e = ex_out->edge; + self->parent_out[edges_child[e]] = TSK_NULL; + // note we only adjust near_side of edges_in, not edges_out, + // so no need to check for zero-length edges + if (ex_out->extended > 1) { + // this is needed to catch newly-created edges + self->last_nodes_edge[edges_child[e]] = e; + self->last_degree[edges_child[e]] += 1; + self->last_degree[edges_parent[e]] += 1; + } else if (ex_out->extended == 0) { + self->last_nodes_edge[edges_child[e]] = TSK_NULL; + self->last_degree[edges_child[e]] -= 1; + self->last_degree[edges_parent[e]] -= 1; + } + } + remove_unextended(&self->edges_out_head, &self->edges_out_tail); + for (ex_in = self->edges_in_head; ex_in != NULL; ex_in = ex_in->next) { + e = ex_in->edge; + self->parent_in[edges_child[e]] = TSK_NULL; + if (ex_in->extended == 0 && self->near_side[e] != self->far_side[e]) { + self->last_nodes_edge[edges_child[e]] = e; + self->last_degree[edges_child[e]] += 1; + self->last_degree[edges_parent[e]] += 1; + } + } + remove_unextended(&self->edges_in_head, &self->edges_in_tail); + + // done cleanup from last tree transition; + // now we set the state up for this tree transition + for (tj = tree_pos->out.start; tj != tree_pos->out.stop; tj += self->direction) { + e = tree_pos->out.order[tj]; + if (self->near_side[e] != self->far_side[e]) { + new_ex = tsk_blkalloc_get(&self->edge_list_heap, sizeof(*new_ex)); + if (new_ex == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + edge_list_append_entry( + &self->edges_out_head, &self->edges_out_tail, new_ex, e, 0); + } + } + for (ex_out = self->edges_out_head; ex_out != NULL; ex_out = ex_out->next) { + e = ex_out->edge; + self->parent_out[edges_child[e]] = edges_parent[e]; + self->next_nodes_edge[edges_child[e]] = TSK_NULL; + self->next_degree[edges_child[e]] -= 1; + self->next_degree[edges_parent[e]] -= 1; } - num_alleles++; - ancestral_state = 0; /* keep compiler happy */ - if (options & TSK_MM_FIXED_ANCESTRAL_STATE) { - ancestral_state = *r_ancestral_state; - if ((ancestral_state < 0) || (ancestral_state >= HARTIGAN_MAX_ALLELES)) { - ret = TSK_ERR_BAD_ANCESTRAL_STATE; + for (tj = tree_pos->in.start; tj != tree_pos->in.stop; tj += self->direction) { + e = tree_pos->in.order[tj]; + // add edge to pending_in + new_ex = tsk_blkalloc_get(&self->edge_list_heap, sizeof(*new_ex)); + if (new_ex == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; - } else if (ancestral_state >= num_alleles) { - num_alleles = (int32_t)(ancestral_state + 1); } + edge_list_append_entry(&self->edges_in_head, &self->edges_in_tail, new_ex, e, 0); + } + for (ex_in = self->edges_in_head; ex_in != NULL; ex_in = ex_in->next) { + e = ex_in->edge; + self->parent_in[edges_child[e]] = edges_parent[e]; + self->next_nodes_edge[edges_child[e]] = e; + self->next_degree[edges_child[e]] += 1; + self->next_degree[edges_parent[e]] += 1; } - ret = tsk_tree_postorder_from(self, self->virtual_root, nodes, &num_nodes); - if (ret != 0) { - goto out; +out: + return ret; +} + +static int +haplotype_extender_add_or_extend_edge(haplotype_extender_t *self, tsk_id_t new_parent, + tsk_id_t child, double left, double right) +{ + int ret = 0; + double there; + tsk_id_t old_edge, e_out, old_parent; + edge_list_t *ex_in; + edge_list_t *new_ex = NULL; + tsk_id_t e_in; + + there = (self->direction == TSK_DIR_FORWARD) ? right : left; + old_edge = self->next_nodes_edge[child]; + if (old_edge != TSK_NULL) { + old_parent = self->edges->parent[old_edge]; + } else { + old_parent = TSK_NULL; + } + if (new_parent != old_parent) { + if (self->parent_out[child] == new_parent) { + // if our new edge is in edges_out, it should be extended + e_out = self->last_nodes_edge[child]; + self->far_side[e_out] = there; + edge_list_set_extended(&self->edges_out_head, e_out); + } else { + e_out = tsk_edge_table_add_row( + self->edges, left, right, new_parent, child, NULL, 0); + if (e_out < 0) { + ret = (int) e_out; + goto out; + } + /* pointers to left/right might have changed! */ + if (self->direction == TSK_DIR_FORWARD) { + self->near_side = self->edges->left; + self->far_side = self->edges->right; + } else { + self->near_side = self->edges->right; + self->far_side = self->edges->left; + } + new_ex = tsk_blkalloc_get(&self->edge_list_heap, sizeof(*new_ex)); + if (new_ex == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + edge_list_append_entry( + &self->edges_out_head, &self->edges_out_tail, new_ex, e_out, 2); + } + self->next_nodes_edge[child] = e_out; + self->next_degree[child] += 1; + self->next_degree[new_parent] += 1; + self->parent_out[child] = TSK_NULL; + if (old_edge != TSK_NULL) { + for (ex_in = self->edges_in_head; ex_in != NULL; ex_in = ex_in->next) { + e_in = ex_in->edge; + if (e_in == old_edge) { + self->near_side[e_in] = there; + if (self->far_side[e_in] != there) { + ex_in->extended = 1; + } + self->next_degree[child] -= 1; + self->next_degree[self->parent_in[child]] -= 1; + self->parent_in[child] = TSK_NULL; + } + } + } } - for (j = 0; j < num_nodes; j++) { - u = nodes[j]; - tsk_memset(allele_count, 0, ((size_t) num_alleles) * sizeof(*allele_count)); - for (v = left_child[u]; v != TSK_NULL; v = right_sib[v]) { - for (allele = 0; allele < num_alleles; allele++) { - allele_count[allele] += bit_is_set(optimal_set[v], allele); +out: + return ret; +} + +static float +haplotype_extender_mergeable(haplotype_extender_t *self, tsk_id_t c) +{ + // returns the number of new edges needed + // if the paths in parent_in and parent_out + // up through nodes that aren't in the other tree + // end at the same place and don't have conflicting times; + // otherwise, return infinity + tsk_id_t p_in, p_out, child; + float num_new_edges; // needs to be float so we can have infinity + int num_extended; + double t_in, t_out; + bool climb_in, climb_out; + const double *nodes_time = self->ts->tables->nodes.time; + + p_out = self->parent_out[c]; + p_in = self->parent_in[c]; + t_out = (p_out == TSK_NULL) ? INFINITY : nodes_time[p_out]; + t_in = (p_in == TSK_NULL) ? INFINITY : nodes_time[p_in]; + child = c; + num_new_edges = 0; + num_extended = 0; + while (true) { + climb_in = (p_in != TSK_NULL && self->last_degree[p_in] == 0 + && self->not_sample[p_in] && t_in < t_out); + climb_out = (p_out != TSK_NULL && self->next_degree[p_out] == 0 + && self->not_sample[p_out] && t_out < t_in); + if (climb_in) { + if (self->parent_in[child] != p_in) { + num_new_edges += 1; } + child = p_in; + p_in = self->parent_in[p_in]; + t_in = (p_in == TSK_NULL) ? INFINITY : nodes_time[p_in]; + } else if (climb_out) { + if (self->parent_out[child] != p_out) { + num_new_edges += 1; + } + child = p_out; + p_out = self->parent_out[p_out]; + t_out = (p_out == TSK_NULL) ? INFINITY : nodes_time[p_out]; + num_extended += 1; + } else { + break; } - /* the virtual root has no flags defined */ - if (u == (tsk_id_t) N || !(node_flags[u] & TSK_NODE_IS_SAMPLE)) { - max_allele_count = 0; - for (allele = 0; allele < num_alleles; allele++) { - max_allele_count = TSK_MAX(max_allele_count, allele_count[allele]); + } + if ((num_extended == 0) || (p_in != p_out) || (p_in == TSK_NULL)) { + num_new_edges = INFINITY; + } + return num_new_edges; +} + +static int +haplotype_extender_merge_paths( + haplotype_extender_t *self, tsk_id_t c, double left, double right) +{ + int ret = 0; + tsk_id_t p_in, p_out, child; + double t_in, t_out; + bool climb_in, climb_out; + const double *nodes_time = self->ts->tables->nodes.time; + + p_out = self->parent_out[c]; + p_in = self->parent_in[c]; + t_out = nodes_time[p_out]; + t_in = nodes_time[p_in]; + child = c; + while (true) { + climb_in = (p_in != TSK_NULL && self->last_degree[p_in] == 0 + && self->not_sample[p_in] && t_in < t_out); + climb_out = (p_out != TSK_NULL && self->next_degree[p_out] == 0 + && self->not_sample[p_out] && t_out < t_in); + if (climb_in) { + ret = haplotype_extender_add_or_extend_edge(self, p_in, child, left, right); + if (ret != 0) { + goto out; } - for (allele = 0; allele < num_alleles; allele++) { - if (allele_count[allele] == max_allele_count) { - optimal_set[u] = set_bit(optimal_set[u], allele); - } + child = p_in; + p_in = self->parent_in[p_in]; + t_in = (p_in == TSK_NULL) ? INFINITY : nodes_time[p_in]; + } else if (climb_out) { + ret = haplotype_extender_add_or_extend_edge(self, p_out, child, left, right); + if (ret != 0) { + goto out; } + child = p_out; + p_out = self->parent_out[p_out]; + t_out = (p_out == TSK_NULL) ? INFINITY : nodes_time[p_out]; + } else { + break; } } - if (!(options & TSK_MM_FIXED_ANCESTRAL_STATE)) { - ancestral_state = get_smallest_set_bit(optimal_set[self->virtual_root]); - } else { - optimal_set[self->virtual_root] = UINT64_MAX; + tsk_bug_assert(p_out == p_in); + ret = haplotype_extender_add_or_extend_edge(self, p_out, child, left, right); + if (ret != 0) { + goto out; } +out: + return ret; +} - num_transitions = 0; +static int +haplotype_extender_extend_paths(haplotype_extender_t *self) +{ + int ret = 0; + bool valid; + double left, right; + float ne, max_new_edges, next_max_new_edges; + tsk_tree_position_t tree_pos; + edge_list_t *ex_in; + tsk_id_t e_in, c, e; + tsk_size_t num_edges; + tsk_bool_t *keep = NULL; + + tsk_memset(&tree_pos, 0, sizeof(tree_pos)); + ret = tsk_tree_position_init(&tree_pos, self->ts, 0); + if (ret != 0) { + goto out; + } - /* Do a preorder traversal */ - preorder_stack[0].node = self->virtual_root; - preorder_stack[0].state = ancestral_state; - preorder_stack[0].transition_parent = TSK_NULL; - stack_top = 0; - while (stack_top >= 0) { - s = preorder_stack[stack_top]; - stack_top--; + if (self->direction == TSK_DIR_FORWARD) { + valid = tsk_tree_position_next(&tree_pos); + } else { + valid = tsk_tree_position_prev(&tree_pos); + } - if (!bit_is_set(optimal_set[s.node], s.state)) { - s.state = get_smallest_set_bit(optimal_set[s.node]); - transitions[num_transitions].node = s.node; - transitions[num_transitions].parent = s.transition_parent; - transitions[num_transitions].state = s.state; - s.transition_parent = (tsk_id_t) num_transitions; - num_transitions++; + while (valid) { + left = tree_pos.interval.left; + right = tree_pos.interval.right; + ret = haplotype_extender_next_tree(self, &tree_pos); + if (ret != 0) { + goto out; } - for (v = left_child[s.node]; v != TSK_NULL; v = right_sib[v]) { - stack_top++; - s.node = v; - preorder_stack[stack_top] = s; + max_new_edges = 0; + next_max_new_edges = INFINITY; + while (max_new_edges < INFINITY) { + for (ex_in = self->edges_in_head; ex_in != NULL; ex_in = ex_in->next) { + e_in = ex_in->edge; + c = self->edges->child[e_in]; + if (self->last_degree[c] > 0) { + ne = haplotype_extender_mergeable(self, c); + if (ne <= max_new_edges) { + ret = haplotype_extender_merge_paths(self, c, left, right); + if (ret != 0) { + goto out; + } + } else { + next_max_new_edges = TSK_MIN(ne, next_max_new_edges); + } + } + } + max_new_edges = next_max_new_edges; + next_max_new_edges = INFINITY; + } + if (self->direction == TSK_DIR_FORWARD) { + valid = tsk_tree_position_next(&tree_pos); + } else { + valid = tsk_tree_position_prev(&tree_pos); } } - - *r_transitions = transitions; - *r_num_transitions = num_transitions; - *r_ancestral_state = ancestral_state; - transitions = NULL; -out: - tsk_safe_free(transitions); - /* Cannot safe_free because of 'restrict' */ - if (optimal_set != NULL) { - free(optimal_set); + /* Get rid of adjacent, identical edges */ + /* note: we need to calloc this here instead of at the start + * because we don't know how big it will need to be until now */ + num_edges = self->edges->num_rows; + keep = tsk_calloc(num_edges, sizeof(*keep)); + if (keep == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; } - if (preorder_stack != NULL) { - free(preorder_stack); + for (e = 0; e < (tsk_id_t) num_edges - 1; e++) { + if (self->edges->parent[e] == self->edges->parent[e + 1] + && self->edges->child[e] == self->edges->child[e + 1] + && self->edges->right[e] == self->edges->left[e + 1]) { + self->edges->right[e] = self->edges->right[e + 1]; + self->edges->left[e + 1] = self->edges->right[e + 1]; + } } - if (nodes != NULL) { - free(nodes); + for (e = 0; e < (tsk_id_t) num_edges; e++) { + keep[e] = self->edges->left[e] < self->edges->right[e]; } + ret = tsk_edge_table_keep_rows(self->edges, keep, 0, NULL); +out: + tsk_tree_position_free(&tree_pos); + tsk_safe_free(keep); return ret; } -/* Compatibility shim for initialising the diff iterator from a tree sequence. We are - * using this function in a small number of places internally, so simplest to keep it - * until a more satisfactory "diff" API comes along. - */ -int TSK_WARN_UNUSED -tsk_diff_iter_init_from_ts( - tsk_diff_iter_t *self, const tsk_treeseq_t *tree_sequence, tsk_flags_t options) +static int +extend_haplotypes_iter(const tsk_treeseq_t *self, int direction, tsk_edge_table_t *edges, + tsk_flags_t options) { - return tsk_diff_iter_init( - self, tree_sequence->tables, (tsk_id_t) tree_sequence->num_trees, options); -} + int ret = 0; + haplotype_extender_t haplotype_extender; + tsk_memset(&haplotype_extender, 0, sizeof(haplotype_extender)); + ret = haplotype_extender_init(&haplotype_extender, self, direction, edges); + if (ret != 0) { + goto out; + } -/* ======================================================== * - * KC Distance - * ======================================================== */ + ret = haplotype_extender_extend_paths(&haplotype_extender); + if (ret != 0) { + goto out; + } -typedef struct { - tsk_size_t *m; - double *M; - tsk_id_t n; - tsk_id_t N; -} kc_vectors; + if (!!(options & TSK_DEBUG)) { + haplotype_extender_print_state(&haplotype_extender, tsk_get_debug_stream()); + } -static int -kc_vectors_alloc(kc_vectors *self, tsk_id_t n) +out: + haplotype_extender_free(&haplotype_extender); + return ret; +} + +int TSK_WARN_UNUSED +tsk_treeseq_extend_haplotypes( + const tsk_treeseq_t *self, int max_iter, tsk_flags_t options, tsk_treeseq_t *output) { int ret = 0; + tsk_table_collection_t tables; + tsk_treeseq_t ts; + int iter, j; + tsk_size_t last_num_edges; + tsk_bookmark_t sort_start; + const int direction[] = { TSK_DIR_FORWARD, TSK_DIR_REVERSE }; - self->n = n; - self->N = (n * (n - 1)) / 2; - self->m = tsk_calloc((size_t)(self->N + self->n), sizeof(*self->m)); - self->M = tsk_calloc((size_t)(self->N + self->n), sizeof(*self->M)); - if (self->m == NULL || self->M == NULL) { - ret = TSK_ERR_NO_MEMORY; + tsk_memset(&tables, 0, sizeof(tables)); + tsk_memset(&ts, 0, sizeof(ts)); + tsk_memset(output, 0, sizeof(*output)); + + if (max_iter <= 0) { + ret = tsk_trace_error(TSK_ERR_EXTEND_EDGES_BAD_MAXITER); + goto out; + } + if (tsk_treeseq_get_num_migrations(self) != 0) { + ret = tsk_trace_error(TSK_ERR_MIGRATIONS_NOT_SUPPORTED); + goto out; + } + + /* Note: there is a fair bit of copying of table data in this implementation + * currently, as we create a new tree sequence for each iteration, which + * takes a full copy of the input tables. We could streamline this by + * adding a flag to treeseq_init which says "steal a reference to these + * tables and *don't* free them at the end". Then, we would only need + * one copy of the full tables, and could pass in a standalone edge + * table to use for in-place updating. + */ + ret = tsk_table_collection_copy(self->tables, &tables, 0); + if (ret != 0) { + goto out; + } + ret = tsk_mutation_table_clear(&tables.mutations); + if (ret != 0) { + goto out; + } + ret = tsk_treeseq_init(&ts, &tables, 0); + if (ret != 0) { + goto out; + } + + last_num_edges = tsk_treeseq_get_num_edges(&ts); + for (iter = 0; iter < max_iter; iter++) { + for (j = 0; j < 2; j++) { + ret = extend_haplotypes_iter(&ts, direction[j], &tables.edges, options); + if (ret != 0) { + goto out; + } + /* We're done with the current ts now */ + tsk_treeseq_free(&ts); + /* no need to sort sites and mutations */ + memset(&sort_start, 0, sizeof(sort_start)); + sort_start.sites = tables.sites.num_rows; + sort_start.mutations = tables.mutations.num_rows; + ret = tsk_table_collection_sort(&tables, &sort_start, 0); + if (ret != 0) { + goto out; + } + ret = tsk_treeseq_init(&ts, &tables, TSK_TS_INIT_BUILD_INDEXES); + if (ret != 0) { + goto out; + } + } + if (last_num_edges == tsk_treeseq_get_num_edges(&ts)) { + break; + } + last_num_edges = tsk_treeseq_get_num_edges(&ts); + } + + /* Remap mutation nodes */ + ret = tsk_mutation_table_copy( + &self->tables->mutations, &tables.mutations, TSK_NO_INIT); + if (ret != 0) { + goto out; + } + /* Note: to allow migrations we'd also have to do this same operation + * on the migration nodes; however it's a can of worms because the interval + * covering the migration might no longer make sense. */ + ret = tsk_treeseq_slide_mutation_nodes_up(&ts, &tables.mutations); + if (ret != 0) { + goto out; + } + tsk_treeseq_free(&ts); + ret = tsk_treeseq_init(&ts, &tables, TSK_TS_INIT_BUILD_INDEXES); + if (ret != 0) { goto out; } + /* Hand ownership of the tree sequence to the calling code */ + tsk_memcpy(output, &ts, sizeof(ts)); + tsk_memset(&ts, 0, sizeof(*output)); out: + tsk_treeseq_free(&ts); + tsk_table_collection_free(&tables); return ret; } -static void -kc_vectors_free(kc_vectors *self) +/* ======================================================== * + * Pair coalescence + * ======================================================== */ + +static int +check_node_bin_map( + const tsk_size_t num_nodes, const tsk_size_t num_bins, const tsk_id_t *node_bin_map) { - tsk_safe_free(self->m); - tsk_safe_free(self->M); + int ret = 0; + tsk_id_t max_index, index; + tsk_size_t i; + + max_index = TSK_NULL; + for (i = 0; i < num_nodes; i++) { + index = node_bin_map[i]; + if (index < TSK_NULL) { + ret = tsk_trace_error(TSK_ERR_BAD_NODE_BIN_MAP); + goto out; + } + if (index > max_index) { + max_index = index; + } + } + if (num_bins < 1 || (tsk_id_t) num_bins < max_index + 1) { + ret = tsk_trace_error(TSK_ERR_BAD_NODE_BIN_MAP_DIM); + goto out; + } +out: + return ret; } static inline void -update_kc_vectors_single_sample( - const tsk_treeseq_t *ts, kc_vectors *kc_vecs, tsk_id_t u, double time) +TRANSPOSE_2D(tsk_size_t rows, tsk_size_t cols, const double *source, double *dest) { - const tsk_id_t *sample_index_map = ts->sample_index_map; - tsk_id_t u_index = sample_index_map[u]; - - kc_vecs->m[kc_vecs->N + u_index] = 1; - kc_vecs->M[kc_vecs->N + u_index] = time; + tsk_size_t i, j; + for (i = 0; i < rows; ++i) { + for (j = 0; j < cols; ++j) { + dest[j * rows + i] = source[i * cols + j]; + } + } } static inline void -update_kc_vectors_all_pairs(const tsk_tree_t *tree, kc_vectors *kc_vecs, tsk_id_t u, - tsk_id_t v, tsk_size_t depth, double time) +pair_coalescence_count(tsk_size_t num_set_indexes, const tsk_id_t *set_indexes, + tsk_size_t num_sample_sets, const double *parent_count, const double *child_count, + const double *parent_state, const double *inside, double *outside, double *result) { - tsk_id_t sample1_index, sample2_index, n1, n2, tmp, pair_index; - const tsk_id_t *restrict left_sample = tree->left_sample; - const tsk_id_t *restrict right_sample = tree->right_sample; - const tsk_id_t *restrict next_sample = tree->next_sample; - tsk_size_t *restrict kc_m = kc_vecs->m; - double *restrict kc_M = kc_vecs->M; - - sample1_index = left_sample[u]; - while (sample1_index != TSK_NULL) { - sample2_index = left_sample[v]; - while (sample2_index != TSK_NULL) { - n1 = sample1_index; - n2 = sample2_index; - if (n1 > n2) { - tmp = n1; - n1 = n2; - n2 = tmp; - } - - /* We spend ~40% of our time here because these accesses - * are not in order and gets very poor cache behavior */ - pair_index = n2 - n1 - 1 + (-1 * n1 * (n1 - 2 * kc_vecs->n + 1)) / 2; - kc_m[pair_index] = depth; - kc_M[pair_index] = time; - - if (sample2_index == right_sample[v]) { - break; - } - sample2_index = next_sample[sample2_index]; - } - if (sample1_index == right_sample[u]) { - break; + tsk_size_t i; + tsk_id_t j, k; + for (i = 0; i < num_sample_sets; i++) { + outside[i] = parent_count[i] - child_count[i] - parent_state[i]; + } + for (i = 0; i < num_set_indexes; i++) { + j = set_indexes[2 * i]; + k = set_indexes[2 * i + 1]; + result[i] = outside[j] * inside[k]; + if (j != k) { + result[i] += outside[k] * inside[j]; } - sample1_index = next_sample[sample1_index]; } } -struct kc_stack_elmt { - tsk_id_t node; - tsk_size_t depth; -}; - -static int -fill_kc_vectors(const tsk_tree_t *t, kc_vectors *kc_vecs) +int +tsk_treeseq_pair_coalescence_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_set_indexes, const tsk_id_t *set_indexes, tsk_size_t num_windows, + const double *windows, tsk_size_t num_bins, const tsk_id_t *node_bin_map, + pair_coalescence_stat_func_t *summary_func, tsk_size_t summary_func_dim, + void *summary_func_args, tsk_flags_t options, double *result) { - int stack_top; - tsk_size_t depth; - double time; - const double *times; - struct kc_stack_elmt *stack; - tsk_id_t root, u, c1, c2; int ret = 0; - const tsk_treeseq_t *ts = t->tree_sequence; + double left, right, remaining_span, missing_span, window_span, denominator, x, t; + tsk_id_t e, p, c, u, v, w, i, j; + tsk_size_t num_samples, num_edges; + tsk_tree_position_t tree_pos; + const tsk_table_collection_t *tables = self->tables; + const tsk_size_t num_nodes = tables->nodes.num_rows; + const double *restrict nodes_time = self->tables->nodes.time; + const double sequence_length = tables->sequence_length; + const tsk_size_t num_outputs = summary_func_dim; + + /* buffers */ + bool *visited = NULL; + tsk_id_t *nodes_sample_set = NULL; + tsk_id_t *nodes_parent = NULL; + double *coalescing_pairs = NULL; + double *coalescence_time = NULL; + double *nodes_sample = NULL; + double *sample_count = NULL; + double *bin_weight = NULL; + double *bin_values = NULL; + double *pair_count = NULL; + double *total_pair = NULL; + double *outside = NULL; + + /* row pointers */ + double *inside = NULL; + double *weight = NULL; + double *values = NULL; + double *output = NULL; + double *above = NULL; + double *below = NULL; + double *state = NULL; + double *pairs = NULL; + double *times = NULL; + + tsk_memset(&tree_pos, 0, sizeof(tree_pos)); + + /* check inputs */ + ret = tsk_treeseq_check_windows(self, num_windows, windows, TSK_REQUIRE_FULL_SPAN); + if (ret != 0) { + goto out; + } + ret = check_set_indexes(num_sample_sets, 2 * num_set_indexes, set_indexes); + if (ret != 0) { + goto out; + } + ret = tsk_treeseq_check_sample_sets( + self, num_sample_sets, sample_set_sizes, sample_sets); + if (ret != 0) { + goto out; + } + ret = check_node_bin_map(num_nodes, num_bins, node_bin_map); + if (ret != 0) { + goto out; + } - stack = tsk_malloc(tsk_tree_get_size_bound(t) * sizeof(*stack)); - if (stack == NULL) { - ret = TSK_ERR_NO_MEMORY; + /* map nodes to sample sets */ + nodes_sample_set = tsk_malloc(num_nodes * sizeof(*nodes_sample_set)); + if (nodes_sample_set == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + ret = get_sample_set_index_map(self, num_sample_sets, sample_set_sizes, sample_sets, + &num_samples, nodes_sample_set); + if (ret != 0) { goto out; } - times = t->tree_sequence->tables->nodes.time; + visited = tsk_malloc(num_nodes * sizeof(*visited)); + outside = tsk_malloc(num_sample_sets * sizeof(*outside)); + nodes_parent = tsk_malloc(num_nodes * sizeof(*nodes_parent)); + nodes_sample = tsk_calloc(num_nodes * num_sample_sets, sizeof(*nodes_sample)); + sample_count = tsk_malloc(num_nodes * num_sample_sets * sizeof(*sample_count)); + coalescing_pairs = tsk_calloc(num_bins * num_set_indexes, sizeof(*coalescing_pairs)); + coalescence_time = tsk_calloc(num_bins * num_set_indexes, sizeof(*coalescence_time)); + bin_weight = tsk_malloc(num_bins * num_set_indexes * sizeof(*bin_weight)); + bin_values = tsk_malloc(num_bins * num_set_indexes * sizeof(*bin_values)); + pair_count = tsk_malloc(num_set_indexes * sizeof(*pair_count)); + total_pair = tsk_malloc(num_set_indexes * sizeof(*total_pair)); + if (nodes_parent == NULL || nodes_sample == NULL || sample_count == NULL + || coalescing_pairs == NULL || bin_weight == NULL || bin_values == NULL + || outside == NULL || pair_count == NULL || visited == NULL + || total_pair == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } - for (root = tsk_tree_get_left_root(t); root != TSK_NULL; root = t->right_sib[root]) { - stack_top = 0; - stack[stack_top].node = root; - stack[stack_top].depth = 0; - while (stack_top >= 0) { - u = stack[stack_top].node; - depth = stack[stack_top].depth; - stack_top--; + for (i = 0; i < (tsk_id_t) num_set_indexes; i++) { + u = set_indexes[2 * i]; + v = set_indexes[2 * i + 1]; + total_pair[i] = (double) sample_set_sizes[u] * (double) sample_set_sizes[v]; + if (u == v) { + total_pair[i] -= (double) sample_set_sizes[v]; + total_pair[i] /= 2; + } + } - if (tsk_tree_is_sample(t, u)) { - time = tsk_tree_get_branch_length_unsafe(t, u); - update_kc_vectors_single_sample(ts, kc_vecs, u, time); - } + /* initialize internal state */ + for (c = 0; c < (tsk_id_t) num_nodes; c++) { + i = nodes_sample_set[c]; + if (i != TSK_NULL) { + state = GET_2D_ROW(nodes_sample, num_sample_sets, c); + state[i] = 1.0; + } + nodes_parent[c] = TSK_NULL; + visited[c] = false; + } + tsk_memcpy( + sample_count, nodes_sample, num_nodes * num_sample_sets * sizeof(*sample_count)); - /* Don't bother going deeper if there are no samples under this node */ - if (t->left_sample[u] != TSK_NULL) { - for (c1 = t->left_child[u]; c1 != TSK_NULL; c1 = t->right_sib[c1]) { - stack_top++; - stack[stack_top].node = c1; - stack[stack_top].depth = depth + 1; + ret = tsk_tree_position_init(&tree_pos, self, 0); + if (ret != 0) { + goto out; + } - for (c2 = t->right_sib[c1]; c2 != TSK_NULL; c2 = t->right_sib[c2]) { - time = times[root] - times[u]; - update_kc_vectors_all_pairs(t, kc_vecs, c1, c2, depth, time); + num_edges = 0; + missing_span = 0.0; + w = 0; + while (true) { + tsk_tree_position_next(&tree_pos); + if (tree_pos.index == TSK_NULL) { + break; + } + + left = tree_pos.interval.left; + right = tree_pos.interval.right; + remaining_span = sequence_length - left; + + for (u = tree_pos.out.start; u != tree_pos.out.stop; u++) { + e = tree_pos.out.order[u]; + p = tables->edges.parent[e]; + c = tables->edges.child[e]; + nodes_parent[c] = TSK_NULL; + inside = GET_2D_ROW(sample_count, num_sample_sets, c); + while (p != TSK_NULL) { /* downdate statistic */ + v = node_bin_map[p]; + t = nodes_time[p]; + if (v != TSK_NULL) { + above = GET_2D_ROW(sample_count, num_sample_sets, p); + below = GET_2D_ROW(sample_count, num_sample_sets, c); + state = GET_2D_ROW(nodes_sample, num_sample_sets, p); + pairs = GET_2D_ROW(coalescing_pairs, num_set_indexes, v); + times = GET_2D_ROW(coalescence_time, num_set_indexes, v); + pair_coalescence_count(num_set_indexes, set_indexes, num_sample_sets, + above, below, state, inside, outside, pair_count); + for (i = 0; i < (tsk_id_t) num_set_indexes; i++) { + x = pair_count[i] * remaining_span; + pairs[i] -= x; + times[i] -= t * x; + } + } + c = p; + p = nodes_parent[c]; + } + p = tables->edges.parent[e]; + while (p != TSK_NULL) { /* downdate state */ + above = GET_2D_ROW(sample_count, num_sample_sets, p); + for (i = 0; i < (tsk_id_t) num_sample_sets; i++) { + above[i] -= inside[i]; + } + p = nodes_parent[p]; + } + num_edges -= 1; + } + + for (u = tree_pos.in.start; u != tree_pos.in.stop; u++) { + e = tree_pos.in.order[u]; + p = tables->edges.parent[e]; + c = tables->edges.child[e]; + nodes_parent[c] = p; + inside = GET_2D_ROW(sample_count, num_sample_sets, c); + while (p != TSK_NULL) { /* update state */ + above = GET_2D_ROW(sample_count, num_sample_sets, p); + for (i = 0; i < (tsk_id_t) num_sample_sets; i++) { + above[i] += inside[i]; + } + p = nodes_parent[p]; + } + p = tables->edges.parent[e]; + while (p != TSK_NULL) { /* update statistic */ + v = node_bin_map[p]; + t = nodes_time[p]; + if (v != TSK_NULL) { + above = GET_2D_ROW(sample_count, num_sample_sets, p); + below = GET_2D_ROW(sample_count, num_sample_sets, c); + state = GET_2D_ROW(nodes_sample, num_sample_sets, p); + pairs = GET_2D_ROW(coalescing_pairs, num_set_indexes, v); + times = GET_2D_ROW(coalescence_time, num_set_indexes, v); + pair_coalescence_count(num_set_indexes, set_indexes, num_sample_sets, + above, below, state, inside, outside, pair_count); + for (i = 0; i < (tsk_id_t) num_set_indexes; i++) { + x = pair_count[i] * remaining_span; + pairs[i] += x; + times[i] += t * x; + } + } + c = p; + p = nodes_parent[c]; + } + num_edges += 1; + } + + if (num_edges == 0) { + missing_span += right - left; + } + + /* flush windows */ + while (w < (tsk_id_t) num_windows && windows[w + 1] <= right) { + TRANSPOSE_2D(num_bins, num_set_indexes, coalescing_pairs, bin_weight); + TRANSPOSE_2D(num_bins, num_set_indexes, coalescence_time, bin_values); + tsk_memset(coalescing_pairs, 0, + num_bins * num_set_indexes * sizeof(*coalescing_pairs)); + tsk_memset(coalescence_time, 0, + num_bins * num_set_indexes * sizeof(*coalescence_time)); + remaining_span = sequence_length - windows[w + 1]; + for (j = 0; j < (tsk_id_t) num_samples; j++) { /* truncate at tree */ + c = sample_sets[j]; + p = nodes_parent[c]; + while (!visited[c] && p != TSK_NULL) { + v = node_bin_map[p]; + t = nodes_time[p]; + if (v != TSK_NULL) { + above = GET_2D_ROW(sample_count, num_sample_sets, p); + below = GET_2D_ROW(sample_count, num_sample_sets, c); + state = GET_2D_ROW(nodes_sample, num_sample_sets, p); + pairs = GET_2D_ROW(coalescing_pairs, num_set_indexes, v); + times = GET_2D_ROW(coalescence_time, num_set_indexes, v); + pair_coalescence_count(num_set_indexes, set_indexes, + num_sample_sets, above, below, state, below, outside, + pair_count); + for (i = 0; i < (tsk_id_t) num_set_indexes; i++) { + weight = GET_2D_ROW(bin_weight, num_bins, i); + values = GET_2D_ROW(bin_values, num_bins, i); + x = pair_count[i] * remaining_span / 2; + pairs[i] += x; + times[i] += t * x; + weight[v] -= x; + values[v] -= t * x; + } + } + visited[c] = true; + c = p; + p = nodes_parent[c]; + } + } + for (j = 0; j < (tsk_id_t) num_samples; j++) { /* reset tree */ + c = sample_sets[j]; + p = nodes_parent[c]; + while (visited[c] && p != TSK_NULL) { + visited[c] = false; + c = p; + p = nodes_parent[c]; + } + } + for (i = 0; i < (tsk_id_t) num_set_indexes; i++) { /* normalise values */ + weight = GET_2D_ROW(bin_weight, num_bins, i); + values = GET_2D_ROW(bin_values, num_bins, i); + for (v = 0; v < (tsk_id_t) num_bins; v++) { + values[v] /= weight[v]; + } + } + /* normalise weights */ + if (options & (TSK_STAT_SPAN_NORMALISE | TSK_STAT_PAIR_NORMALISE)) { + window_span = windows[w + 1] - windows[w] - missing_span; + missing_span = 0.0; + if (num_edges == 0) { + /* missing interval, so remove overcounted missing span */ + remaining_span = right - windows[w + 1]; + window_span += remaining_span; + missing_span += remaining_span; + } + for (i = 0; i < (tsk_id_t) num_set_indexes; i++) { + denominator = 1.0; + if (options & TSK_STAT_SPAN_NORMALISE) { + denominator *= window_span; + } + if (options & TSK_STAT_PAIR_NORMALISE) { + denominator *= total_pair[i]; + } + weight = GET_2D_ROW(bin_weight, num_bins, i); + for (v = 0; v < (tsk_id_t) num_bins; v++) { + weight[v] *= denominator == 0.0 ? 0.0 : 1 / denominator; } } } + for (i = 0; i < (tsk_id_t) num_set_indexes; i++) { /* summarise bins */ + weight = GET_2D_ROW(bin_weight, num_bins, i); + values = GET_2D_ROW(bin_values, num_bins, i); + output = GET_3D_ROW( + result, num_set_indexes, num_outputs, (tsk_size_t) w, i); + ret = summary_func( + num_bins, weight, values, num_outputs, output, summary_func_args); + if (ret != 0) { + goto out; + } + }; + w += 1; } } - out: - tsk_safe_free(stack); + tsk_tree_position_free(&tree_pos); + tsk_safe_free(nodes_sample_set); + tsk_safe_free(coalescing_pairs); + tsk_safe_free(coalescence_time); + tsk_safe_free(nodes_parent); + tsk_safe_free(nodes_sample); + tsk_safe_free(sample_count); + tsk_safe_free(bin_weight); + tsk_safe_free(bin_values); + tsk_safe_free(pair_count); + tsk_safe_free(total_pair); + tsk_safe_free(visited); + tsk_safe_free(outside); return ret; } -static double -norm_kc_vectors(kc_vectors *self, kc_vectors *other, double lambda) +static int +pair_coalescence_weights(tsk_size_t TSK_UNUSED(input_dim), const double *weight, + const double *TSK_UNUSED(values), tsk_size_t output_dim, double *output, + void *TSK_UNUSED(params)) { - double vT1, vT2, distance_sum; - tsk_id_t i; - - distance_sum = 0; - for (i = 0; i < self->n + self->N; i++) { - vT1 = ((double) self->m[i] * (1 - lambda)) + (lambda * self->M[i]); - vT2 = ((double) other->m[i] * (1 - lambda)) + (lambda * other->M[i]); - distance_sum += (vT1 - vT2) * (vT1 - vT2); - } + int ret = 0; + tsk_memcpy(output, weight, output_dim * sizeof(*output)); + return ret; +} - return sqrt(distance_sum); +int +tsk_treeseq_pair_coalescence_counts(const tsk_treeseq_t *self, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t num_set_indexes, const tsk_id_t *set_indexes, + tsk_size_t num_windows, const double *windows, tsk_size_t num_bins, + const tsk_id_t *node_bin_map, tsk_flags_t options, double *result) +{ + return tsk_treeseq_pair_coalescence_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_set_indexes, set_indexes, num_windows, windows, num_bins, + node_bin_map, pair_coalescence_weights, num_bins, NULL, options, result); } static int -check_kc_distance_tree_inputs(const tsk_tree_t *self) +pair_coalescence_quantiles(tsk_size_t input_dim, const double *weight, + const double *values, tsk_size_t output_dim, double *output, void *params) { - tsk_id_t u, num_nodes, left_child; int ret = 0; - - if (tsk_tree_get_num_roots(self) != 1) { - ret = TSK_ERR_MULTIPLE_ROOTS; - goto out; + double coalesced, timepoint; + double *quantiles = (double *) params; + tsk_size_t i, j; + j = 0; + coalesced = 0.0; + timepoint = TSK_UNKNOWN_TIME; + for (i = 0; i < output_dim; i++) { + output[i] = NAN; + } + for (i = 0; i < input_dim; i++) { + if (weight[i] > 0) { + coalesced += weight[i]; + timepoint = values[i]; + while (j < output_dim && quantiles[j] <= coalesced) { + output[j] = timepoint; + j += 1; + } + } } - if (!tsk_tree_has_sample_lists(self)) { - ret = TSK_ERR_NO_SAMPLE_LISTS; - goto out; + if (quantiles[output_dim - 1] == 1.0) { + output[output_dim - 1] = timepoint; } + return ret; +} - num_nodes = (tsk_id_t) tsk_treeseq_get_num_nodes(self->tree_sequence); - for (u = 0; u < num_nodes; u++) { - left_child = self->left_child[u]; - if (left_child != TSK_NULL && left_child == self->right_child[u]) { - ret = TSK_ERR_UNARY_NODES; +static int +check_quantiles(const tsk_size_t num_quantiles, const double *quantiles) +{ + int ret = 0; + tsk_size_t i; + double last = -INFINITY; + for (i = 0; i < num_quantiles; i++) { + if (quantiles[i] <= last || quantiles[i] < 0.0 || quantiles[i] > 1.0) { + ret = tsk_trace_error(TSK_ERR_BAD_QUANTILES); goto out; } + last = quantiles[i]; } out: return ret; } static int -check_kc_distance_samples_inputs(const tsk_treeseq_t *self, const tsk_treeseq_t *other) +check_sorted_node_bin_map( + const tsk_treeseq_t *self, tsk_size_t num_bins, const tsk_id_t *node_bin_map) { - const tsk_id_t *samples, *other_samples; - tsk_id_t i, n; int ret = 0; - - if (self->num_samples != other->num_samples) { - ret = TSK_ERR_SAMPLE_SIZE_MISMATCH; + tsk_size_t num_nodes = self->tables->nodes.num_rows; + const double *nodes_time = self->tables->nodes.time; + double last; + tsk_id_t i, j; + double *min_time = tsk_malloc(num_bins * sizeof(*min_time)); + double *max_time = tsk_malloc(num_bins * sizeof(*max_time)); + if (min_time == NULL || max_time == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } - - samples = self->samples; - other_samples = other->samples; - n = (tsk_id_t) self->num_samples; - for (i = 0; i < n; i++) { - if (samples[i] != other_samples[i]) { - ret = TSK_ERR_SAMPLES_NOT_EQUAL; + for (j = 0; j < (tsk_id_t) num_bins; j++) { + min_time[j] = TSK_UNKNOWN_TIME; + max_time[j] = TSK_UNKNOWN_TIME; + } + for (i = 0; i < (tsk_id_t) num_nodes; i++) { + j = node_bin_map[i]; + if (j < 0 || j >= (tsk_id_t) num_bins) { + continue; + } + if (tsk_is_unknown_time(max_time[j]) || nodes_time[i] > max_time[j]) { + max_time[j] = nodes_time[i]; + } + if (tsk_is_unknown_time(min_time[j]) || nodes_time[i] < min_time[j]) { + min_time[j] = nodes_time[i]; + } + } + last = -INFINITY; + for (j = 0; j < (tsk_id_t) num_bins; j++) { + if (tsk_is_unknown_time(min_time[j])) { + continue; + } + if (min_time[j] < last) { + ret = tsk_trace_error(TSK_ERR_UNSORTED_TIMES); goto out; + } else { + last = max_time[j]; } } out: + tsk_safe_free(min_time); + tsk_safe_free(max_time); return ret; } int -tsk_tree_kc_distance( - const tsk_tree_t *self, const tsk_tree_t *other, double lambda, double *result) +tsk_treeseq_pair_coalescence_quantiles(const tsk_treeseq_t *self, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t num_set_indexes, const tsk_id_t *set_indexes, + tsk_size_t num_windows, const double *windows, tsk_size_t num_bins, + const tsk_id_t *node_bin_map, tsk_size_t num_quantiles, double *quantiles, + tsk_flags_t options, double *result) { - tsk_id_t n, i; - kc_vectors vecs[2]; - const tsk_tree_t *trees[2] = { self, other }; int ret = 0; + void *params = (void *) quantiles; + ret = check_quantiles(num_quantiles, quantiles); + if (ret != 0) { + goto out; + } + ret = check_sorted_node_bin_map(self, num_bins, node_bin_map); + if (ret != 0) { + goto out; + } + options |= TSK_STAT_SPAN_NORMALISE | TSK_STAT_PAIR_NORMALISE; + ret = tsk_treeseq_pair_coalescence_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_set_indexes, set_indexes, num_windows, windows, num_bins, + node_bin_map, pair_coalescence_quantiles, num_quantiles, params, options, + result); + if (ret != 0) { + goto out; + } +out: + return ret; +} - for (i = 0; i < 2; i++) { - tsk_memset(&vecs[i], 0, sizeof(kc_vectors)); +static int +pair_coalescence_rates(tsk_size_t input_dim, const double *weight, const double *values, + tsk_size_t output_dim, double *output, void *params) +{ + int ret = 0; + double coalesced, rate, waiting_time, a, b; + double *time_windows = (double *) params; + tsk_id_t i, j; + tsk_bug_assert(input_dim == output_dim); + for (j = (tsk_id_t) output_dim; j > 0; j--) { /* find last window with data */ + if (weight[j - 1] == 0) { + output[j - 1] = NAN; /* TODO: should fill value be zero instead? */ + } else { + break; + } + } + coalesced = 0.0; + for (i = 0; i < j; i++) { + a = time_windows[i]; + b = time_windows[i + 1]; + if (i + 1 == j) { + waiting_time = values[i] < a ? 0.0 : values[i] - a; + rate = 1 / waiting_time; + } else { + rate = log(1 - weight[i] / (1 - coalesced)) / (a - b); + } + // avoid tiny negative values from fp error + output[i] = rate > 0 ? rate : 0; + coalesced += weight[i]; } + return ret; +} - ret = check_kc_distance_samples_inputs(self->tree_sequence, other->tree_sequence); - if (ret != 0) { +static int +check_coalescence_rate_time_windows(const tsk_treeseq_t *self, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t num_time_windows, + const tsk_id_t *node_time_window, const double *time_windows) +{ + int ret = 0; + double timepoint; + const double *nodes_time = self->tables->nodes.time; + tsk_size_t num_nodes = self->tables->nodes.num_rows; + tsk_id_t i, j, k; + tsk_id_t n; + if (num_time_windows == 0) { + ret = tsk_trace_error(TSK_ERR_BAD_TIME_WINDOWS_DIM); + goto out; + } + /* time windows are sorted */ + timepoint = time_windows[0]; + for (i = 0; i < (tsk_id_t) num_time_windows; i++) { + if (time_windows[i + 1] <= timepoint) { + ret = tsk_trace_error(TSK_ERR_BAD_TIME_WINDOWS); + goto out; + } + timepoint = time_windows[i + 1]; + } + if (timepoint != INFINITY) { + ret = tsk_trace_error(TSK_ERR_BAD_TIME_WINDOWS_END); goto out; } - for (i = 0; i < 2; i++) { - ret = check_kc_distance_tree_inputs(trees[i]); - if (ret != 0) { - goto out; + /* all sample times align with start of first time window */ + k = 0; + for (i = 0; i < (tsk_id_t) num_sample_sets; i++) { + for (j = 0; j < (tsk_id_t) sample_set_sizes[i]; j++) { + n = sample_sets[k++]; + if (nodes_time[n] != time_windows[0]) { + ret = tsk_trace_error(TSK_ERR_BAD_SAMPLE_PAIR_TIMES); + goto out; + } } } - - n = (tsk_id_t) self->tree_sequence->num_samples; - for (i = 0; i < 2; i++) { - ret = kc_vectors_alloc(&vecs[i], n); - if (ret != 0) { + /* nodes are correctly assigned to time windows */ + for (i = 0; i < (tsk_id_t) num_nodes; i++) { + j = node_time_window[i]; + if (j < 0) { + continue; + } + if (j >= (tsk_id_t) num_time_windows) { + ret = tsk_trace_error(TSK_ERR_BAD_NODE_BIN_MAP_DIM); goto out; } - ret = fill_kc_vectors(trees[i], &vecs[i]); - if (ret != 0) { + if (nodes_time[i] < time_windows[j] || nodes_time[i] >= time_windows[j + 1]) { + ret = tsk_trace_error(TSK_ERR_BAD_NODE_TIME_WINDOW); goto out; } } - - *result = norm_kc_vectors(&vecs[0], &vecs[1], lambda); out: - for (i = 0; i < 2; i++) { - kc_vectors_free(&vecs[i]); - } return ret; } -static int -check_kc_distance_tree_sequence_inputs( - const tsk_treeseq_t *self, const tsk_treeseq_t *other) +int +tsk_treeseq_pair_coalescence_rates(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_set_indexes, const tsk_id_t *set_indexes, tsk_size_t num_windows, + const double *windows, tsk_size_t num_time_windows, const tsk_id_t *node_time_window, + double *time_windows, tsk_flags_t options, double *result) { int ret = 0; - - if (self->tables->sequence_length != other->tables->sequence_length) { - ret = TSK_ERR_SEQUENCE_LENGTH_MISMATCH; + void *params = (void *) time_windows; + ret = check_coalescence_rate_time_windows(self, num_sample_sets, sample_set_sizes, + sample_sets, num_time_windows, node_time_window, time_windows); + if (ret != 0) { goto out; } - - ret = check_kc_distance_samples_inputs(self, other); + options |= TSK_STAT_SPAN_NORMALISE | TSK_STAT_PAIR_NORMALISE; + ret = tsk_treeseq_pair_coalescence_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_set_indexes, set_indexes, num_windows, windows, + num_time_windows, node_time_window, pair_coalescence_rates, num_time_windows, + params, options, result); if (ret != 0) { goto out; } - out: return ret; } +/* ======================================================== * + * Relatedness matrix-vector product + * ======================================================== */ + +typedef struct { + const tsk_treeseq_t *ts; + tsk_size_t num_weights; + const double *weights; + tsk_size_t num_windows; + const double *windows; + tsk_size_t num_focal_nodes; + const tsk_id_t *focal_nodes; + tsk_flags_t options; + double *result; + tsk_tree_position_t tree_pos; + double position; + tsk_size_t num_nodes; + tsk_id_t *parent; + double *x; + double *w; + double *v; +} tsk_matvec_calculator_t; + static void -update_kc_pair_with_sample(const tsk_tree_t *self, kc_vectors *kc, tsk_id_t sample, - tsk_size_t *depths, double root_time) +tsk_matvec_calculator_print_state(const tsk_matvec_calculator_t *self, FILE *out) { - tsk_id_t c, p, sib; - double time; - tsk_size_t depth; - double *times = self->tree_sequence->tables->nodes.time; + tsk_id_t j; + tsk_size_t num_samples = tsk_treeseq_get_num_samples(self->ts); - c = sample; - for (p = self->parent[sample]; p != TSK_NULL; p = self->parent[p]) { - time = root_time - times[p]; - depth = depths[p]; - for (sib = self->left_child[p]; sib != TSK_NULL; sib = self->right_sib[sib]) { - if (sib != c) { - update_kc_vectors_all_pairs(self, kc, sample, sib, depth, time); - } - } - c = p; + fprintf(out, "Matvec state:\n"); + fprintf(out, "options = %d\n", self->options); + fprintf(out, "position = %f\n", self->position); + fprintf(out, "focal nodes = %lld: [", (long long) self->num_focal_nodes); + fprintf(out, "tree_pos:\n"); + tsk_tree_position_print_state(&self->tree_pos, out); + fprintf(out, "samples = %lld: [", (long long) num_samples); + fprintf(out, "]\n"); + fprintf(out, "node\tparent\tx\tv\tw"); + fprintf(out, "\n"); + + for (j = 0; j < (tsk_id_t) self->num_nodes; j++) { + fprintf(out, "%lld\t", (long long) j); + fprintf(out, "%lld\t%g\t%g\t%g\n", (long long) self->parent[j], self->x[j], + self->v[j], self->w[j]); } } static int -update_kc_subtree_state( - tsk_tree_t *t, kc_vectors *kc, tsk_id_t u, tsk_size_t *depths, double root_time) +tsk_matvec_calculator_init(tsk_matvec_calculator_t *self, const tsk_treeseq_t *ts, + tsk_size_t num_weights, const double *weights, tsk_size_t num_windows, + const double *windows, tsk_size_t num_focal_nodes, const tsk_id_t *focal_nodes, + tsk_flags_t options, double *result) { - int stack_top; - tsk_id_t v, c; - tsk_id_t *stack = NULL; int ret = 0; + tsk_size_t num_samples = tsk_treeseq_get_num_samples(ts); + const tsk_size_t num_nodes = ts->tables->nodes.num_rows; + const double *row; + double *new_row; + tsk_size_t k; + tsk_id_t index, u, j; + double *weight_means = tsk_malloc(num_weights * sizeof(*weight_means)); + const tsk_size_t num_trees = ts->num_trees; + const double *restrict breakpoints = ts->breakpoints; + + self->ts = ts; + self->num_weights = num_weights; + self->weights = weights; + self->num_windows = num_windows; + self->windows = windows; + self->num_focal_nodes = num_focal_nodes; + self->focal_nodes = focal_nodes; + self->options = options; + self->result = result; + self->num_nodes = num_nodes; + self->position = windows[0]; - stack = tsk_malloc(tsk_tree_get_size_bound(t) * sizeof(*stack)); - if (stack == NULL) { - ret = TSK_ERR_NO_MEMORY; + self->parent = tsk_malloc(num_nodes * sizeof(*self->parent)); + self->x = tsk_calloc(num_nodes, sizeof(*self->x)); + self->v = tsk_calloc(num_nodes * num_weights, sizeof(*self->v)); + self->w = tsk_calloc(num_nodes * num_weights, sizeof(*self->w)); + + if (self->parent == NULL || self->x == NULL || self->w == NULL || self->v == NULL + || weight_means == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } - stack_top = 0; - stack[stack_top] = u; - while (stack_top >= 0) { - v = stack[stack_top]; - stack_top--; + tsk_memset(result, 0, num_windows * num_focal_nodes * num_weights * sizeof(*result)); + tsk_memset(self->parent, TSK_NULL, num_nodes * sizeof(*self->parent)); - if (tsk_tree_is_sample(t, v)) { - update_kc_pair_with_sample(t, kc, v, depths, root_time); + for (j = 0; j < (tsk_id_t) num_focal_nodes; j++) { + if (focal_nodes[j] < 0 || (tsk_size_t) focal_nodes[j] >= num_nodes) { + ret = tsk_trace_error(TSK_ERR_NODE_OUT_OF_BOUNDS); + goto out; } - for (c = t->left_child[v]; c != TSK_NULL; c = t->right_sib[c]) { - if (depths[c] != 0) { - depths[c] = depths[v] + 1; - stack_top++; - stack[stack_top] = c; + } + + ret = tsk_tree_position_init(&self->tree_pos, ts, 0); + if (ret != 0) { + goto out; + } + /* seek to the first window */ + index = (tsk_id_t) tsk_search_sorted(breakpoints, num_trees + 1, windows[0]); + if (breakpoints[index] > windows[0]) { + index--; + } + ret = tsk_tree_position_seek_forward(&self->tree_pos, index); + if (ret != 0) { + goto out; + } + + for (k = 0; k < num_weights; k++) { + weight_means[k] = 0.0; + } + /* centre the input */ + if (!(options & TSK_STAT_NONCENTRED)) { + for (j = 0; j < (tsk_id_t) num_samples; j++) { + row = GET_2D_ROW(weights, num_weights, j); + for (k = 0; k < num_weights; k++) { + weight_means[k] += row[k]; } } + for (k = 0; k < num_weights; k++) { + weight_means[k] /= (double) num_samples; + } } + /* set the initial state */ + for (j = 0; j < (tsk_id_t) num_samples; j++) { + u = ts->samples[j]; + row = GET_2D_ROW(weights, num_weights, j); + new_row = GET_2D_ROW(self->w, num_weights, u); + for (k = 0; k < num_weights; k++) { + new_row[k] = row[k] - weight_means[k]; + } + } out: - tsk_safe_free(stack); + tsk_safe_free(weight_means); return ret; } static int -update_kc_incremental(tsk_tree_t *self, kc_vectors *kc, tsk_edge_list_t *edges_out, - tsk_edge_list_t *edges_in, tsk_size_t *depths) +tsk_matvec_calculator_free(tsk_matvec_calculator_t *self) { - int ret = 0; - tsk_edge_list_node_t *record; - tsk_edge_t *e; - tsk_id_t u; - double root_time, time; - const double *times = self->tree_sequence->tables->nodes.time; + tsk_safe_free(self->parent); + tsk_safe_free(self->x); + tsk_safe_free(self->w); + tsk_safe_free(self->v); + tsk_tree_position_free(&self->tree_pos); - /* Update state of detached subtrees */ - for (record = edges_out->tail; record != NULL; record = record->prev) { - e = &record->edge; - u = e->child; - depths[u] = 0; + /* Make this safe for multiple free calls */ + memset(self, 0, sizeof(*self)); + return 0; +} - if (self->parent[u] == TSK_NULL) { - root_time = times[tsk_tree_node_root(self, u)]; - ret = update_kc_subtree_state(self, kc, u, depths, root_time); - if (ret != 0) { - goto out; - } +static inline void +tsk_matvec_calculator_add_z(tsk_id_t u, tsk_id_t p, const double position, + double *restrict x, const tsk_size_t num_weights, double *restrict w, + double *restrict v, const double *restrict nodes_time) +{ + double t, span; + tsk_size_t j; + double *restrict v_row, *restrict w_row; + + if (p != TSK_NULL) { + t = nodes_time[p] - nodes_time[u]; + span = position - x[u]; + // do this: self->v[u] += t * span * self->w[u]; + w_row = GET_2D_ROW(w, num_weights, u); + v_row = GET_2D_ROW(v, num_weights, u); + for (j = 0; j < num_weights; j++) { + v_row[j] += t * span * w_row[j]; } } + x[u] = position; +} - /* Propagate state change down into reattached subtrees. */ - for (record = edges_in->tail; record != NULL; record = record->prev) { - e = &record->edge; - u = e->child; - - tsk_bug_assert(depths[e->child] == 0); - depths[u] = depths[e->parent] + 1; +static void +tsk_matvec_calculator_adjust_path_up( + tsk_matvec_calculator_t *self, tsk_id_t p, tsk_id_t c, double sign) +{ + tsk_size_t j; + double *p_row, *c_row; + const tsk_id_t *restrict parent = self->parent; + const double position = self->position; + double *restrict x = self->x; + const tsk_size_t num_weights = self->num_weights; + double *restrict w = self->w; + double *restrict v = self->v; + const double *restrict nodes_time = self->ts->tables->nodes.time; - root_time = times[tsk_tree_node_root(self, u)]; - ret = update_kc_subtree_state(self, kc, u, depths, root_time); - if (ret != 0) { - goto out; + // sign = -1 for removing edges, +1 for adding + while (p != TSK_NULL) { + tsk_matvec_calculator_add_z( + p, parent[p], position, x, num_weights, w, v, nodes_time); + // do this: self->v[c] -= sign * self->v[p]; + p_row = GET_2D_ROW(v, num_weights, p); + c_row = GET_2D_ROW(v, num_weights, c); + for (j = 0; j < num_weights; j++) { + c_row[j] -= sign * p_row[j]; } - - if (tsk_tree_is_sample(self, u)) { - time = tsk_tree_get_branch_length_unsafe(self, u); - update_kc_vectors_single_sample(self->tree_sequence, kc, u, time); + // do this: self->w[p] += sign * self->w[c]; + p_row = GET_2D_ROW(w, num_weights, p); + c_row = GET_2D_ROW(w, num_weights, c); + for (j = 0; j < num_weights; j++) { + p_row[j] += sign * c_row[j]; } + p = parent[p]; } +} -out: - return ret; +static void +tsk_matvec_calculator_remove_edge(tsk_matvec_calculator_t *self, tsk_id_t p, tsk_id_t c) +{ + tsk_id_t *parent = self->parent; + const double position = self->position; + double *restrict x = self->x; + const tsk_size_t num_weights = self->num_weights; + double *restrict w = self->w; + double *restrict v = self->v; + const double *restrict nodes_time = self->ts->tables->nodes.time; + + tsk_matvec_calculator_add_z( + c, parent[c], position, x, num_weights, w, v, nodes_time); + parent[c] = TSK_NULL; + tsk_matvec_calculator_adjust_path_up(self, p, c, -1); } -int -tsk_treeseq_kc_distance(const tsk_treeseq_t *self, const tsk_treeseq_t *other, - double lambda_, double *result) +static void +tsk_matvec_calculator_insert_edge(tsk_matvec_calculator_t *self, tsk_id_t p, tsk_id_t c) +{ + tsk_id_t *parent = self->parent; + + tsk_matvec_calculator_adjust_path_up(self, p, c, +1); + self->x[c] = self->position; + parent[c] = p; +} + +static int +tsk_matvec_calculator_write_output(tsk_matvec_calculator_t *self, double *restrict y) { - int i; - tsk_id_t n; - tsk_size_t num_nodes; - double left, span, total; - const tsk_treeseq_t *treeseqs[2] = { self, other }; - tsk_tree_t trees[2]; - kc_vectors kcs[2]; - tsk_diff_iter_t diff_iters[2]; - tsk_edge_list_t edges_out[2]; - tsk_edge_list_t edges_in[2]; - tsk_size_t *depths[2]; - double t0_left, t0_right, t1_left, t1_right; int ret = 0; + tsk_id_t u; + tsk_size_t j, k; + const tsk_size_t n = self->num_focal_nodes; + const tsk_size_t num_weights = self->num_weights; + const double position = self->position; + double *u_row, *out_row; + double *out_means = tsk_malloc(num_weights * sizeof(*out_means)); + const tsk_id_t *restrict parent = self->parent; + const double *restrict nodes_time = self->ts->tables->nodes.time; + double *restrict x = self->x; + double *restrict w = self->w; + double *restrict v = self->v; + const tsk_id_t *restrict focal_nodes = self->focal_nodes; - for (i = 0; i < 2; i++) { - tsk_memset(&trees[i], 0, sizeof(trees[i])); - tsk_memset(&diff_iters[i], 0, sizeof(diff_iters[i])); - tsk_memset(&kcs[i], 0, sizeof(kcs[i])); - tsk_memset(&edges_out[i], 0, sizeof(edges_out[i])); - tsk_memset(&edges_in[i], 0, sizeof(edges_in[i])); - depths[i] = NULL; + if (out_means == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; } - ret = check_kc_distance_tree_sequence_inputs(self, other); - if (ret != 0) { - goto out; + for (j = 0; j < n; j++) { + out_row = GET_2D_ROW(y, num_weights, j); + u = focal_nodes[j]; + while (u != TSK_NULL) { + if (x[u] != position) { + tsk_matvec_calculator_add_z( + u, parent[u], position, x, num_weights, w, v, nodes_time); + } + u_row = GET_2D_ROW(v, num_weights, u); + for (k = 0; k < num_weights; k++) { + out_row[k] += u_row[k]; + } + u = parent[u]; + } } - n = (tsk_id_t) self->num_samples; - for (i = 0; i < 2; i++) { - ret = tsk_tree_init(&trees[i], treeseqs[i], TSK_SAMPLE_LISTS); - if (ret != 0) { - goto out; + if (!(self->options & TSK_STAT_NONCENTRED)) { + for (k = 0; k < num_weights; k++) { + out_means[k] = 0.0; } - ret = tsk_diff_iter_init_from_ts(&diff_iters[i], treeseqs[i], false); - if (ret != 0) { - goto out; + for (j = 0; j < n; j++) { + out_row = GET_2D_ROW(y, num_weights, j); + for (k = 0; k < num_weights; k++) { + out_means[k] += out_row[k]; + } } - ret = kc_vectors_alloc(&kcs[i], n); - if (ret != 0) { - goto out; + for (k = 0; k < num_weights; k++) { + out_means[k] /= (double) n; } - num_nodes = tsk_treeseq_get_num_nodes(treeseqs[i]); - depths[i] = tsk_calloc(num_nodes, sizeof(*depths[i])); - if (depths[i] == NULL) { - ret = TSK_ERR_NO_MEMORY; - goto out; + for (j = 0; j < n; j++) { + out_row = GET_2D_ROW(y, num_weights, j); + for (k = 0; k < num_weights; k++) { + out_row[k] -= out_means[k]; + } } } + /* zero out v */ + tsk_memset(self->v, 0, self->num_nodes * num_weights * sizeof(*self->v)); +out: + tsk_safe_free(out_means); + return ret; +} - total = 0; - left = 0; - - ret = tsk_tree_first(&trees[0]); - if (ret != TSK_TREE_OK) { - goto out; +static int +tsk_matvec_calculator_run(tsk_matvec_calculator_t *self) +{ + int ret = 0; + tsk_size_t j, k, m; + tsk_id_t e, p, c; + const tsk_size_t out_size = self->num_weights * self->num_focal_nodes; + const tsk_size_t num_edges = self->ts->tables->edges.num_rows; + const double *restrict edge_right = self->ts->tables->edges.right; + const double *restrict edge_left = self->ts->tables->edges.left; + const tsk_id_t *restrict edge_child = self->ts->tables->edges.child; + const tsk_id_t *restrict edge_parent = self->ts->tables->edges.parent; + const double *restrict windows = self->windows; + double *restrict out; + tsk_tree_position_t tree_pos = self->tree_pos; + const tsk_id_t *restrict in_order = tree_pos.in.order; + const tsk_id_t *restrict out_order = tree_pos.out.order; + bool valid; + double next_position; + + m = 0; + self->position = windows[0]; + + for (j = (tsk_size_t) tree_pos.in.start; j != (tsk_size_t) tree_pos.in.stop; j++) { + e = in_order[j]; + tsk_bug_assert(edge_left[e] <= self->position); + if (self->position < edge_right[e]) { + p = edge_parent[e]; + c = edge_child[e]; + tsk_matvec_calculator_insert_edge(self, p, c); + } + } + + valid = tsk_tree_position_next(&tree_pos); + j = (tsk_size_t) tree_pos.in.start; + k = (tsk_size_t) tree_pos.out.start; + while (m < self->num_windows) { + if (valid && self->position == tree_pos.interval.left) { + for (k = (tsk_size_t) tree_pos.out.start; + k != (tsk_size_t) tree_pos.out.stop; k++) { + e = out_order[k]; + p = edge_parent[e]; + c = edge_child[e]; + tsk_matvec_calculator_remove_edge(self, p, c); + } + for (j = (tsk_size_t) tree_pos.in.start; j != (tsk_size_t) tree_pos.in.stop; + j++) { + e = in_order[j]; + p = edge_parent[e]; + c = edge_child[e]; + tsk_matvec_calculator_insert_edge(self, p, c); + } + valid = tsk_tree_position_next(&tree_pos); + } + next_position = windows[m + 1]; + if (j < num_edges) { + next_position = TSK_MIN(next_position, edge_left[in_order[j]]); + } + if (k < num_edges) { + next_position = TSK_MIN(next_position, edge_right[out_order[k]]); + } + tsk_bug_assert(self->position < next_position); + self->position = next_position; + if (self->position == windows[m + 1]) { + out = GET_2D_ROW(self->result, out_size, m); + tsk_matvec_calculator_write_output(self, out); + m += 1; + } + if (self->options & TSK_DEBUG) { + tsk_matvec_calculator_print_state(self, tsk_get_debug_stream()); + } } - ret = check_kc_distance_tree_inputs(&trees[0]); - if (ret != 0) { + + /* out: */ + return ret; +} + +int +tsk_treeseq_genetic_relatedness_vector(const tsk_treeseq_t *self, tsk_size_t num_weights, + const double *weights, tsk_size_t num_windows, const double *windows, + tsk_size_t num_focal_nodes, const tsk_id_t *focal_nodes, double *result, + tsk_flags_t options) +{ + int ret = 0; + bool stat_site = !!(options & TSK_STAT_SITE); + bool stat_node = !!(options & TSK_STAT_NODE); + tsk_matvec_calculator_t calc; + + memset(&calc, 0, sizeof(calc)); + + if (stat_node || stat_site) { + ret = tsk_trace_error(TSK_ERR_UNSUPPORTED_STAT_MODE); goto out; } - ret = tsk_diff_iter_next( - &diff_iters[0], &t0_left, &t0_right, &edges_out[0], &edges_in[0]); - tsk_bug_assert(ret == TSK_TREE_OK); - ret = update_kc_incremental( - &trees[0], &kcs[0], &edges_out[0], &edges_in[0], depths[0]); + ret = tsk_treeseq_check_windows(self, num_windows, windows, 0); if (ret != 0) { goto out; } - while ((ret = tsk_tree_next(&trees[1])) == TSK_TREE_OK) { - ret = check_kc_distance_tree_inputs(&trees[1]); - if (ret != 0) { - goto out; - } - ret = tsk_diff_iter_next( - &diff_iters[1], &t1_left, &t1_right, &edges_out[1], &edges_in[1]); - tsk_bug_assert(ret == TSK_TREE_OK); - - ret = update_kc_incremental( - &trees[1], &kcs[1], &edges_out[1], &edges_in[1], depths[1]); - if (ret != 0) { - goto out; - } - while (t0_right < t1_right) { - span = t0_right - left; - total += norm_kc_vectors(&kcs[0], &kcs[1], lambda_) * span; - left = t0_right; - ret = tsk_tree_next(&trees[0]); - tsk_bug_assert(ret == TSK_TREE_OK); - ret = check_kc_distance_tree_inputs(&trees[0]); - if (ret != 0) { - goto out; - } - ret = tsk_diff_iter_next( - &diff_iters[0], &t0_left, &t0_right, &edges_out[0], &edges_in[0]); - tsk_bug_assert(ret == TSK_TREE_OK); - ret = update_kc_incremental( - &trees[0], &kcs[0], &edges_out[0], &edges_in[0], depths[0]); - if (ret != 0) { - goto out; - } - } - span = t1_right - left; - left = t1_right; - total += norm_kc_vectors(&kcs[0], &kcs[1], lambda_) * span; - } + ret = tsk_matvec_calculator_init(&calc, self, num_weights, weights, num_windows, + windows, num_focal_nodes, focal_nodes, options, result); if (ret != 0) { goto out; } - - *result = total / self->tables->sequence_length; -out: - for (i = 0; i < 2; i++) { - tsk_tree_free(&trees[i]); - tsk_diff_iter_free(&diff_iters[i]); - kc_vectors_free(&kcs[i]); - tsk_safe_free(depths[i]); + if (options & TSK_DEBUG) { + tsk_matvec_calculator_print_state(&calc, tsk_get_debug_stream()); } + ret = tsk_matvec_calculator_run(&calc); +out: + tsk_matvec_calculator_free(&calc); return ret; } diff --git a/subprojects/tskit/tskit/trees.h b/subprojects/tskit/tskit/trees.h index efe998007..21495edbf 100644 --- a/subprojects/tskit/tskit/trees.h +++ b/subprojects/tskit/tskit/trees.h @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2023 Tskit Developers + * Copyright (c) 2019-2024 Tskit Developers * Copyright (c) 2015-2018 University of Oxford * * Permission is hereby granted, free of charge, to any person obtaining a copy @@ -52,6 +52,8 @@ extern "C" { #define TSK_STAT_POLARISED (1 << 10) #define TSK_STAT_SPAN_NORMALISE (1 << 11) #define TSK_STAT_ALLOW_TIME_UNCALIBRATED (1 << 12) +#define TSK_STAT_PAIR_NORMALISE (1 << 13) +#define TSK_STAT_NONCENTRED (1 << 14) /* Options for map_mutations */ #define TSK_MM_FIXED_ANCESTRAL_STATE (1 << 0) @@ -69,6 +71,11 @@ when the tree sequence is initialised. Indexes are required for a valid tree sequence, and are not built by default for performance reasons. */ #define TSK_TS_INIT_BUILD_INDEXES (1 << 0) +/** +If specified, mutation parents in the table collection will be overwritten +with those computed from the topology when the tree sequence is initialised. +*/ +#define TSK_TS_INIT_COMPUTE_MUTATION_PARENTS (1 << 1) /** @} */ // clang-format on @@ -111,6 +118,28 @@ typedef struct { tsk_table_collection_t *tables; } tsk_treeseq_t; +typedef struct { + tsk_id_t index; + struct { + double left; + double right; + } interval; + struct { + tsk_id_t start; + tsk_id_t stop; + const tsk_id_t *order; + } in; + struct { + tsk_id_t start; + tsk_id_t stop; + const tsk_id_t *order; + } out; + tsk_id_t left_current_index; + tsk_id_t right_current_index; + int direction; + const tsk_treeseq_t *tree_sequence; +} tsk_tree_position_t; + /** @brief A single tree in a tree sequence. @@ -256,6 +285,7 @@ typedef struct { int direction; tsk_id_t left_index; tsk_id_t right_index; + tsk_tree_position_t tree_pos; } tsk_tree_t; /****************************************************************************/ @@ -891,6 +921,43 @@ int tsk_treeseq_simplify(const tsk_treeseq_t *self, const tsk_id_t *samples, tsk_size_t num_samples, tsk_flags_t options, tsk_treeseq_t *output, tsk_id_t *node_map); +/** +@brief Extends haplotypes + +Returns a new tree sequence in which the span covered by ancestral nodes +is "extended" to regions of the genome according to the following rule: +If an ancestral segment corresponding to node `n` has ancestor `p` and +descendant `c` on some portion of the genome, and on an adjacent segment of +genome `p` is still an ancestor of `c`, then `n` is inserted into the +path from `p` to `c`. For instance, if `p` is the parent of `n` and `n` +is the parent of `c`, then the span of the edges from `p` to `n` and +`n` to `c` are extended, and the span of the edge from `p` to `c` is +reduced. However, any edges whose child node is a sample are not +modified. The `node` of certain mutations may also be remapped; to do this +unambiguously we need to know mutation times. If mutations times are unknown, +use `tsk_table_collection_compute_mutation_times` first. + +The method will not affect any tables except the edge table, or the node +column in the mutation table. + +The method works by iterating over the genome to look for edges that can +be extended in this way; the maximum number of such iterations is +controlled by ``max_iter``. + +@rst + +**Options**: None currently defined. +@endrst + +@param self A pointer to a tsk_treeseq_t object. +@param max_iter The maximum number of iterations over the tree sequence. +@param options Bitwise option flags. (UNUSED) +@param output A pointer to an uninitialised tsk_treeseq_t object. +@return Return 0 on success or a negative value on failure. +*/ +int tsk_treeseq_extend_haplotypes( + const tsk_treeseq_t *self, int max_iter, tsk_flags_t options, tsk_treeseq_t *output); + /** @} */ int tsk_treeseq_split_edges(const tsk_treeseq_t *self, double time, tsk_flags_t flags, @@ -920,6 +987,17 @@ int tsk_treeseq_general_stat(const tsk_treeseq_t *self, tsk_size_t K, const doub tsk_size_t M, general_stat_func_t *f, void *f_params, tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result); +typedef int norm_func_t(tsk_size_t result_dim, const double *hap_weights, tsk_size_t n_a, + tsk_size_t n_b, double *result, void *params); + +int tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t result_dim, const tsk_id_t *set_indexes, + general_stat_func_t *f, norm_func_t *norm_f, tsk_size_t out_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t out_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result); + /* One way weighted stats */ typedef int one_way_weighted_method(const tsk_treeseq_t *self, tsk_size_t num_weights, @@ -943,6 +1021,29 @@ int tsk_treeseq_trait_linear_model(const tsk_treeseq_t *self, tsk_size_t num_wei const double *weights, tsk_size_t num_covariates, const double *covariates, tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result); +/* Two way weighted stats with covariates */ + +typedef int two_way_weighted_method(const tsk_treeseq_t *self, tsk_size_t num_weights, + const double *weights, tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, + tsk_size_t num_windows, const double *windows, double *result, tsk_flags_t options); + +int tsk_treeseq_genetic_relatedness_weighted(const tsk_treeseq_t *self, + tsk_size_t num_weights, const double *weights, tsk_size_t num_index_tuples, + const tsk_id_t *index_tuples, tsk_size_t num_windows, const double *windows, + double *result, tsk_flags_t options); + +/* One way weighted stats with vector output */ + +typedef int weighted_vector_method(const tsk_treeseq_t *self, tsk_size_t num_weights, + const double *weights, tsk_size_t num_windows, const double *windows, + tsk_size_t num_focal_nodes, const tsk_id_t *focal_nodes, double *result, + tsk_flags_t options); + +int tsk_treeseq_genetic_relatedness_vector(const tsk_treeseq_t *self, + tsk_size_t num_weights, const double *weights, tsk_size_t num_windows, + const double *windows, tsk_size_t num_focal_nodes, const tsk_id_t *focal_nodes, + double *result, tsk_flags_t options); + /* One way sample set stats */ typedef int one_way_sample_stat_method(const tsk_treeseq_t *self, @@ -962,13 +1063,80 @@ int tsk_treeseq_Y1(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, int tsk_treeseq_allele_frequency_spectrum(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_windows, const double *windows, - tsk_flags_t options, double *result); + tsk_size_t num_time_windows, const double *time_windows, tsk_flags_t options, + double *result); typedef int general_sample_stat_method(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_indexes, const tsk_id_t *indexes, tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result); +typedef int two_locus_count_stat_method(const tsk_treeseq_t *self, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t num_rows, const tsk_id_t *row_sites, + const double *row_positions, tsk_size_t num_cols, const tsk_id_t *col_sites, + const double *col_positions, tsk_flags_t options, double *result); + +int tsk_treeseq_D(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result); +int tsk_treeseq_D2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result); +int tsk_treeseq_r2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result); +int tsk_treeseq_D_prime(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result); +int tsk_treeseq_r(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result); +int tsk_treeseq_Dz(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result); +int tsk_treeseq_pi2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result); +int tsk_treeseq_D2_unbiased(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result); +int tsk_treeseq_Dz_unbiased(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result); +int tsk_treeseq_pi2_unbiased(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result); + +typedef int k_way_two_locus_count_stat_method(const tsk_treeseq_t *self, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t num_index_tuples, + const tsk_id_t *index_tuples, tsk_size_t num_rows, const tsk_id_t *row_sites, + const double *row_positions, tsk_size_t num_cols, const tsk_id_t *col_sites, + const double *col_positions, tsk_flags_t options, double *result); + +/* Two way sample set stats */ + int tsk_treeseq_divergence(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_windows, @@ -986,6 +1154,24 @@ int tsk_treeseq_genetic_relatedness(const tsk_treeseq_t *self, const tsk_id_t *sample_sets, tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result); +int tsk_treeseq_D2_ij(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result); +int tsk_treeseq_D2_ij_unbiased(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result); +int tsk_treeseq_r2_ij(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result); /* Three way sample set stats */ int tsk_treeseq_Y3(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, @@ -1003,6 +1189,38 @@ int tsk_treeseq_f4(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result); +int tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result); + +/* Coalescence rates */ +typedef int pair_coalescence_stat_func_t(tsk_size_t input_dim, const double *atoms, + const double *weights, tsk_size_t result_dim, double *result, void *params); +int tsk_treeseq_pair_coalescence_stat(const tsk_treeseq_t *self, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t num_set_indexes, const tsk_id_t *set_indexes, + tsk_size_t num_windows, const double *windows, tsk_size_t num_bins, + const tsk_id_t *node_bin_map, pair_coalescence_stat_func_t *summary_func, + tsk_size_t summary_func_dim, void *summary_func_args, tsk_flags_t options, + double *result); +int tsk_treeseq_pair_coalescence_counts(const tsk_treeseq_t *self, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t num_set_indexes, const tsk_id_t *set_indexes, + tsk_size_t num_windows, const double *windows, tsk_size_t num_bins, + const tsk_id_t *node_bin_map, tsk_flags_t options, double *result); +int tsk_treeseq_pair_coalescence_quantiles(const tsk_treeseq_t *self, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t num_set_indexes, const tsk_id_t *set_indexes, + tsk_size_t num_windows, const double *windows, tsk_size_t num_bins, + const tsk_id_t *node_bin_map, tsk_size_t num_quantiles, double *quantiles, + tsk_flags_t options, double *result); +int tsk_treeseq_pair_coalescence_rates(const tsk_treeseq_t *self, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t num_set_indexes, const tsk_id_t *set_indexes, + tsk_size_t num_windows, const double *windows, tsk_size_t num_time_windows, + const tsk_id_t *node_time_window, double *time_windows, tsk_flags_t options, + double *result); + /****************************************************************************/ /* Tree */ /****************************************************************************/ @@ -1087,6 +1305,13 @@ int tsk_tree_copy(const tsk_tree_t *self, tsk_tree_t *dest, tsk_flags_t options) @{ */ +/** @brief Option to seek by skipping to the target tree, adding and removing as few + edges as possible. If not specified, a linear time algorithm is used instead. + + @ingroup TREE_API_SEEKING_GROUP +*/ +#define TSK_SEEK_SKIP (1 << 0) + /** @brief Seek to the first tree in the sequence. @@ -1192,12 +1417,22 @@ we will have ``position < tree.interval.right``. Seeking to a position currently covered by the tree is a constant time operation. + +Seeking to a position from a non-null tree uses a linear time +algorithm by default, unless the option :c:macro:`TSK_SEEK_SKIP` +is specified. In this case, a faster algorithm is employed which skips +to the target tree by removing and adding the minimal number of edges +possible. However, this approach does not guarantee that edges are +inserted and removed in time-sorted order. + +.. warning:: Using the :c:macro:`TSK_SEEK_SKIP` option + may lead to edges not being inserted or removed in time-sorted order. + @endrst @param self A pointer to an initialised tsk_tree_t object. @param position The position in genome coordinates -@param options Seek options. Currently unused. Set to 0 for compatibility - with future versions of tskit. +@param options Seek options. See the notes above for details. @return Return 0 on success or a negative value on failure. */ int tsk_tree_seek(tsk_tree_t *self, double position, tsk_flags_t options); @@ -1721,8 +1956,14 @@ bool tsk_tree_is_sample(const tsk_tree_t *self, tsk_id_t u); */ bool tsk_tree_equals(const tsk_tree_t *self, const tsk_tree_t *other); -int tsk_diff_iter_init_from_ts( - tsk_diff_iter_t *self, const tsk_treeseq_t *tree_sequence, tsk_flags_t options); +int tsk_tree_position_init( + tsk_tree_position_t *self, const tsk_treeseq_t *tree_sequence, tsk_flags_t options); +int tsk_tree_position_free(tsk_tree_position_t *self); +int tsk_tree_position_print_state(const tsk_tree_position_t *self, FILE *out); +bool tsk_tree_position_next(tsk_tree_position_t *self); +bool tsk_tree_position_prev(tsk_tree_position_t *self); +int tsk_tree_position_seek_forward(tsk_tree_position_t *self, tsk_id_t index); +int tsk_tree_position_seek_backward(tsk_tree_position_t *self, tsk_id_t index); #ifdef __cplusplus }