Skip to content

Commit

Permalink
Fix unit test and posteriors
Browse files Browse the repository at this point in the history
Previously we claimed to return posteriors that were probabilities but the testing was broken, so we didn't notice that we didn't scale the posteriors to add up to one (so they weren't actually probabilities)
  • Loading branch information
hyanwong committed Jan 16, 2023
1 parent e9bd098 commit 5cc7fad
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 51 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
individuals, populations, or sites, aiming to change the tree sequence tables as
little as possible.

**Bugfixes**

- The returned posteriors when ``return_posteriors=True`` now return actual
probabilities (scaled so that they sum to one) rather than normalised
probabilites whose maximum value is one.

--------------------
[0.1.5] - 2022-06-07
--------------------
Expand Down
61 changes: 45 additions & 16 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import utility_functions

import tsdate
from tsdate.base import NodeGridValues
from tsdate import base
from tsdate.core import constrain_ages_topo
from tsdate.core import date
from tsdate.core import get_dates
Expand Down Expand Up @@ -797,14 +797,14 @@ def test_init(self):
num_nodes = 5
ids = np.array([3, 4])
timepoints = np.array(range(10))
store = NodeGridValues(num_nodes, ids, timepoints, fill_value=6)
store = base.NodeGridValues(num_nodes, ids, timepoints, fill_value=6)
assert store.grid_data.shape == (len(ids), len(timepoints))
assert len(store.fixed_data) == (num_nodes - len(ids))
assert np.all(store.grid_data == 6)
assert np.all(store.fixed_data == 6)

ids = np.array([3, 4], dtype=np.int32)
store = NodeGridValues(num_nodes, ids, timepoints, fill_value=5)
store = base.NodeGridValues(num_nodes, ids, timepoints, fill_value=5)
assert store.grid_data.shape == (len(ids), len(timepoints))
assert len(store.fixed_data) == num_nodes - len(ids)
assert np.all(store.fixed_data == 5)
Expand All @@ -815,7 +815,7 @@ def test_set_and_get(self):
fill = {}
for ids in ([3, 4], []):
np.random.seed(1)
store = NodeGridValues(
store = base.NodeGridValues(
num_nodes, np.array(ids, dtype=np.int32), np.array(range(grid_size))
)
for i in range(num_nodes):
Expand All @@ -829,48 +829,52 @@ def test_set_and_get(self):
def test_bad_init(self):
ids = [3, 4]
with pytest.raises(ValueError):
NodeGridValues(3, np.array(ids), np.array([0, 1.2, 2]))
base.NodeGridValues(3, np.array(ids), np.array([0, 1.2, 2]))
with pytest.raises(AttributeError):
NodeGridValues(5, np.array(ids), -1)
base.NodeGridValues(5, np.array(ids), -1)
with pytest.raises(ValueError):
NodeGridValues(5, np.array([-1]), np.array([0, 1.2, 2]))
base.NodeGridValues(5, np.array([-1]), np.array([0, 1.2, 2]))

def test_clone(self):
num_nodes = 10
grid_size = 2
ids = [3, 4]
orig = NodeGridValues(num_nodes, np.array(ids), np.array(range(grid_size)))
orig = base.NodeGridValues(num_nodes, np.array(ids), np.array(range(grid_size)))
orig[3] = np.array([1, 2])
orig[4] = np.array([4, 3])
orig[0] = 1.5
orig[9] = 2.5
# test with np.zeros
clone = NodeGridValues.clone_with_new_data(orig, 0)
clone = base.NodeGridValues.clone_with_new_data(orig, 0)
assert clone.grid_data.shape == orig.grid_data.shape
assert clone.fixed_data.shape == orig.fixed_data.shape
assert np.all(clone.grid_data == 0)
assert np.all(clone.fixed_data == 0)
# test with something else
clone = NodeGridValues.clone_with_new_data(orig, 5)
clone = base.NodeGridValues.clone_with_new_data(orig, 5)
assert clone.grid_data.shape == orig.grid_data.shape
assert clone.fixed_data.shape == orig.fixed_data.shape
assert np.all(clone.grid_data == 5)
assert np.all(clone.fixed_data == 5)
# test with different
scalars = np.arange(num_nodes - len(ids))
clone = NodeGridValues.clone_with_new_data(orig, 0, scalars)
clone = base.NodeGridValues.clone_with_new_data(orig, 0, scalars)
assert clone.grid_data.shape == orig.grid_data.shape
assert clone.fixed_data.shape == orig.fixed_data.shape
assert np.all(clone.grid_data == 0)
assert np.all(clone.fixed_data == scalars)

clone = NodeGridValues.clone_with_new_data(orig, np.array([[1, 2], [4, 3]]))
clone = base.NodeGridValues.clone_with_new_data(
orig, np.array([[1, 2], [4, 3]])
)
for i in range(num_nodes):
if i in ids:
assert np.all(clone[i] == orig[i])
else:
assert np.isnan(clone[i])
clone = NodeGridValues.clone_with_new_data(orig, np.array([[1, 2], [4, 3]]), 0)
clone = base.NodeGridValues.clone_with_new_data(
orig, np.array([[1, 2], [4, 3]]), 0
)
for i in range(num_nodes):
if i in ids:
assert np.all(clone[i] == orig[i])
Expand All @@ -880,19 +884,44 @@ def test_clone(self):
def test_bad_clone(self):
num_nodes = 10
ids = [3, 4]
orig = NodeGridValues(num_nodes, np.array(ids), np.array([0, 1.2]))
orig = base.NodeGridValues(num_nodes, np.array(ids), np.array([0, 1.2]))
with pytest.raises(ValueError):
NodeGridValues.clone_with_new_data(
base.NodeGridValues.clone_with_new_data(
orig,
np.array([[1, 2, 3], [4, 5, 6]]),
)
with pytest.raises(ValueError):
NodeGridValues.clone_with_new_data(
base.NodeGridValues.clone_with_new_data(
orig,
0,
np.array([[1, 2], [4, 5]]),
)

def test_convert_to_probs(self):
num_nodes = 10
ids = [3, 4]
make_nan_row = 4
orig = base.NodeGridValues(num_nodes, np.array(ids), np.array([0, 1.2]), 1)
orig[make_nan_row][0] = np.nan
assert np.all(np.isnan(orig[make_nan_row]) == [True, False])
orig.force_probability_space(base.LIN)
orig.to_probabilities()
for n in orig.nonfixed_nodes:
if n == make_nan_row:
assert np.all(np.isnan(orig[n]))
else:
assert np.allclose(np.sum(orig[n]), 1)
assert np.all(orig[n] >= 0)

def test_cannot_convert_to_probs(self):
# No class implemention of logsumexp to convert to probabilities in log space
num_nodes = 10
ids = [3, 4]
orig = base.NodeGridValues(num_nodes, np.array(ids), np.array([0, 1.2]))
orig.force_probability_space(base.LOG)
with pytest.raises(NotImplementedError, match="linear space"):
orig.to_probabilities()


class TestAlgorithmClass:
def test_nonmatching_prior_vs_lik_timepoints(self):
Expand Down
5 changes: 3 additions & 2 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# MIT License
#
# Copyright (c) 2021-23 Tskit Developers
# Copyright (c) 2020 University of Oxford
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
Expand Down Expand Up @@ -109,7 +110,7 @@ def test_no_posteriors(self):
assert len(posteriors["start_time"]) == len(posteriors["end_time"])
assert len(posteriors["start_time"]) > 0
for node in ts.nodes():
if not node.is_sample:
if not node.is_sample():
assert node.id in posteriors
assert posteriors[node.id] is None

Expand All @@ -122,7 +123,7 @@ def test_posteriors(self):
assert len(posteriors["start_time"]) == len(posteriors["end_time"])
assert len(posteriors["start_time"]) > 0
for node in ts.nodes():
if not node.is_sample:
if not node.is_sample():
assert node.id in posteriors
assert len(posteriors[node.id]) == len(posteriors["start_time"])
assert np.isclose(np.sum(posteriors[node.id]), 1)
Expand Down
24 changes: 19 additions & 5 deletions tsdate/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,9 @@

class NodeGridValues:
"""
A class to store grid values for node ids. For some nodes (fixed ones), only a single
value needs to be stored. For non-fixed nodes, an array of grid_size variables
is required, e.g. in order to store all the possible values for each of the hidden
states in the grid
A class to store times or discretised distributions of times for node ids. For nodes
with fixed times, only a single time value needs to be stored. For non-fixed nodes,
an array of len(timepoints) probabilies is required.
:ivar num_nodes: The number of nodes that will be stored in this object
:vartype num_nodes: int
Expand Down Expand Up @@ -130,7 +129,10 @@ def force_probability_space(self, probability_space):

def normalize(self):
"""
normalize grid and fixed data so the max is one
normalize grid data so the max is one (in linear space) or zero
(in logarithmic space)
TODO - is it clear why we omit the first element of the
"""
rowmax = self.grid_data[:, 1:].max(axis=1)
if self.probability_space == LIN:
Expand All @@ -140,6 +142,18 @@ def normalize(self):
else:
raise RuntimeError("Probability space is not", LIN, "or", LOG)

def to_probabilities(self):
"""
Change grid data into probabilities (i.e. each row sums to one in linear or zero
in logarithmic space)
"""
if self.probability_space != LIN:
raise NotImplementedError(
"Can only convert to probabilities in linear space"
)
assert not np.any(self.grid_data < 0)
self.grid_data = self.grid_data / self.grid_data.sum(axis=1)[:, np.newaxis]

def __getitem__(self, node_id):
index = self.row_lookup[node_id]
if index < 0:
Expand Down
68 changes: 40 additions & 28 deletions tsdate/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
eps=0,
fixed_node_set=None,
normalize=True,
progress=False
progress=False,
):
self.ts = ts
self.timepoints = timepoints
Expand Down Expand Up @@ -694,13 +694,14 @@ def outside_pass(
normalize=False,
ignore_oldest_root=False,
progress=None,
probability_space_returned=base.LIN
):
"""
Computes the full posterior distribution on nodes.
Computes the full posterior distribution on nodes, returning the
posterior values. These are *not* probabilities, as they do not sum to one:
to convert to probabilities, call posterior.to_probabilities()
Normalising may be necessary if there is overflow, but means that we cannot
check the total functional value at each node
Normalising *during* the outside process may be necessary if there is overflow,
but means that we cannot check the total functional value at each node
Ignoring the oldest root may also be necessary when the oldest root node
causes numerical stability issues.
Expand Down Expand Up @@ -769,13 +770,11 @@ def outside_pass(
outside[child] = self.lik.reduce(val, self.norm[child])
if normalize:
outside[child] = self.lik.reduce(val, np.max(val))
self.outside = outside
posterior = outside.clone_with_new_data(
grid_data=self.lik.combine(self.inside.grid_data, outside.grid_data),
fixed_data=np.nan,
) # We should never use the posterior for a fixed node
posterior.normalize()
posterior.force_probability_space(probability_space_returned)
self.outside = outside
return posterior

def outside_maximization(self, *, eps, progress=None):
Expand Down Expand Up @@ -857,12 +856,12 @@ def outside_maximization(self, *, eps, progress=None):

def posterior_mean_var(ts, posterior, *, fixed_node_set=None):
"""
Mean and variance of node age in unscaled time. Fixed nodes will be given a mean
Mean and variance of node age. Fixed nodes will be given a mean
of their exact time in the tree sequence, and zero variance (as long as they are
identified by the fixed_node_set
identified by the fixed_node_set).
If fixed_node_set is None, we attempt to date all the non-sample nodes
Also assigns the estimated mean and variance of the age of each node, in unscaled
time, as metadata in the tree sequence.
Also assigns the estimated mean and variance of the age of each node
as metadata in the tree sequence.
"""
mn_post = np.full(ts.num_nodes, np.nan) # Fill with NaNs so we detect when there's
vr_post = np.full(ts.num_nodes, np.nan) # been an error
Expand Down Expand Up @@ -936,7 +935,7 @@ def date(
*,
return_posteriors=None,
progress=False,
**kwargs
**kwargs,
):
"""
Take a tree sequence (which could have
Expand All @@ -948,6 +947,19 @@ def date(
mutations and non-sample nodes in the input tree sequence are not used in inference
and will be removed.
.. note::
If posteriors are returned via the ``return_posteriors`` option, the output will
be a tuple ``(ts, posteriors)``, where ``posteriors`` is a dictionary suitable
for reading as a pandas ``DataFrame`` object, using ``pd.DataFrame(posteriors)``.
Each node whose time was inferred corresponds to an item in this dictionary,
with the key being the node ID and the value a 1D array of probabilities of the
node being in a given time slice (or ``None`` if the "inside_outside" method
was not used). The start and end times of each time slice are given as 1D
arrays in the dictionary, under keys named ``"start_time"`` and ``end_time"``.
As timeslices may not be not of uniform width, it is important to divide the
posterior probabilities by ``end_time - start_time`` when assessing the shape
of the probability density function over time.
:param TreeSequence tree_sequence: The input :class:`tskit.TreeSequence`, treated as
one whose non-sample nodes are undated.
:param float Ne: The estimated (diploid) effective population size used to construct
Expand All @@ -974,9 +986,7 @@ def date(
conditional coalescent prior with a standard set of time points as given by
:func:`build_prior_grid`.
:param bool return_posteriors: If ``True``, instead of returning just a dated tree
sequence, return a tuple of ``(dated_ts, posteriors)``. Note that the dictionary
returned in ``posteriors`` (described below) is suitable for reading as a pandas
``DataFrame`` object, using ``pd.DataFrame(posteriors)``.
sequence, return a tuple of ``(dated_ts, posteriors)`` (see note above).
:param float eps: Specify minimum distance separating time points. Also specifies
the error factor in time difference calculations. Default: 1e-6
:param int num_threads: The number of threads to use. A simpler unthreaded algorithm
Expand All @@ -996,11 +1006,6 @@ def date(
:return: A copy of the input tree sequence but with altered node times, or (if
``return_posteriors`` is True) a tuple of that tree sequence plus a dictionary
of posterior probabilities from the "inside_outside" estimation ``method``.
Each node whose time was inferred corresponds to an item in this dictionary,
with the key being the node ID and the value a 1D array of probabilities of the
node being in a given time slice (or ``None`` if the "inside_outside" method
was not used). The start and end times of each time slice are given as 1D
arrays in the dictionary, under keys named ``"start_time"`` and ``end_time"``.
:rtype: tskit.TreeSequence or (tskit.TreeSequence, dict)
"""
if time_units is None:
Expand All @@ -1012,7 +1017,7 @@ def date(
recombination_rate=recombination_rate,
priors=priors,
progress=progress,
**kwargs
**kwargs,
)
constrained = constrain_ages_topo(tree_sequence, dates, eps, nds, progress)
tables = tree_sequence.dump_tables()
Expand All @@ -1028,12 +1033,12 @@ def date(
Ne=Ne,
recombination_rate=recombination_rate,
progress=progress,
**kwargs
**kwargs,
)
if return_posteriors:
pst = {"start_time": timepoints, "end_time": np.append(timepoints[1:], np.inf)}
for i, n in enumerate(nds):
pst[n] = None if posteriors is None else posteriors.grid_data[i, :]
for n in nds:
pst[n] = None if posteriors is None else posteriors[n]
return tables.tree_sequence(), pst
else:
return tables.tree_sequence()
Expand All @@ -1053,15 +1058,18 @@ def get_dates(
ignore_oldest_root=False,
progress=False,
cache_inside=False,
probability_space=base.LOG
probability_space=base.LOG,
):
"""
Infer dates for the nodes in a tree sequence, returning an array of inferred dates
for nodes, plus other variables such as the distribution of posterior probabilities
for nodes, plus other variables such as the posteriors object
etc. Parameters are identical to the date() method, which calls this method, then
injects the resulting date estimates into the tree sequence
:return: tuple(mn_post, posterior, timepoints, eps, nodes_to_date)
:return: a tuple of ``(mn_post, posteriors, timepoints, eps, nodes_to_date)``.
If the "inside_outside" method is used, ``posteriors`` will contain the
posterior probabilities for each node in each time slice, else the returned
variable will be ``None``.
"""
# Stuff yet to be implemented. These can be deleted once fixed
for sample in tree_sequence.samples():
Expand Down Expand Up @@ -1128,6 +1136,10 @@ def get_dates(
posterior = dynamic_prog.outside_pass(
normalize=outside_normalize, ignore_oldest_root=ignore_oldest_root
)
# Turn the posterior into probabilities
posterior.normalize() # Just to make sure there are no floating point issues
posterior.force_probability_space(base.LIN)
posterior.to_probabilities()
tree_sequence, mn_post, _ = posterior_mean_var(
tree_sequence, posterior, fixed_node_set=fixed_nodes
)
Expand Down

0 comments on commit 5cc7fad

Please sign in to comment.