Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AssertionError: kl-min: bad initial condition #289

Closed
hyanwong opened this issue Jul 11, 2023 · 3 comments
Closed

AssertionError: kl-min: bad initial condition #289

hyanwong opened this issue Jul 11, 2023 · 3 comments

Comments

@hyanwong
Copy link
Member

hyanwong commented Jul 11, 2023

For the attached (inferred) tree sequence, which has unary regions of coalescent nodes, I get an error when trying to use the variational_gamma method, e.g. by running the following code:

import tskit
import tsdate
mu = 1e-8
Ne=1e4
ts = tskit.load("debug.trees")
prior = tsdate.prior.parameter_grid(
    ts, population_size=Ne, allow_unary=True, progress=True)
dts = tsdate.date(ts, priors=prior, mutation_rate=mu, progress=True, method="variational_gamma")

Here's the error:

AssertionError                            Traceback (most recent call last)
Cell In[100], line 1
----> 1 dts = tsdate.date(semi_unary_ts, priors=prior, mutation_rate=mu, progress=True, method="variational_gamma")

File ~/Library/jupyterlab-desktop/jlab_server/lib/python3.10/site-packages/tsdate/core.py:1237, in date(tree_sequence, mutation_rate, population_size, recombination_rate, time_units, priors, method, Ne, return_posteriors, progress, **kwargs)
   1234     population_size = demography.PopulationSizeHistory(**population_size)
   1236 if method == "variational_gamma":
-> 1237     tree_sequence, dates, posteriors, timepoints, eps, nds = variational_dates(
   1238         tree_sequence,
   1239         population_size=population_size,
   1240         mutation_rate=mutation_rate,
   1241         recombination_rate=recombination_rate,
   1242         priors=priors,
   1243         progress=progress,
   1244         **kwargs,
   1245     )
   1246 else:
   1247     tree_sequence, dates, posteriors, timepoints, eps, nds = get_dates(
   1248         tree_sequence,
   1249         population_size=population_size,
   (...)
   1255         **kwargs,
   1256     )

File ~/Library/jupyterlab-desktop/jlab_server/lib/python3.10/site-packages/tsdate/core.py:1546, in variational_dates(tree_sequence, mutation_rate, population_size, recombination_rate, priors, max_iterations, global_prior, eps, progress, num_threads, probability_space, ignore_oldest_root)
   1544 dynamic_prog = ExpectationPropagation(priors, liklhd, progress=progress)
   1545 for it in range(max_iterations):
-> 1546     dynamic_prog.iterate(iter_num=it)
   1547 posterior = dynamic_prog.posterior
   1548 tree_sequence, mn_post, _ = variational_mean_var(
   1549     tree_sequence, posterior, fixed_node_set=fixed_nodes
   1550 )

File ~/Library/jupyterlab-desktop/jlab_server/lib/python3.10/site-packages/tsdate/core.py:1045, in ExpectationPropagation.iterate(self, iter_num, progress)
   1043 if iter_num:  # Show iteration number if not first iteration
   1044     desc = f"EP (iter {iter_num + 1:>2}, rootwards)"
-> 1045 self.propagate(
   1046     edges=self.edges_by_parent_asc(grouped=False), desc=desc, progress=progress
   1047 )
   1048 if iter_num:
   1049     desc = f"EP (iter {iter_num + 1:>2}, leafwards)"

File ~/Library/jupyterlab-desktop/jlab_server/lib/python3.10/site-packages/tsdate/core.py:1022, in ExpectationPropagation.propagate(self, edges, desc, progress)
   1012 child_cavity = self.lik.ratio(
   1013     self.posterior[edge.child], self.child_message[edge.id]
   1014 )
   1015 # Get the target posterior: that is, the cavity multiplied by the
   1016 # edge likelihood, and projected onto a gamma distribution via
   1017 # moment matching.
   1018 (
   1019     norm_const,
   1020     self.posterior[edge.parent],
   1021     self.posterior[edge.child],
-> 1022 ) = approx.gamma_projection(*parent_cavity, *child_cavity, *edge_lik)
   1023 # Get the messages: that is, the multiplicative difference between
   1024 # the target and cavity posteriors. This only involves updating the
   1025 # variational parameters for the parent and child on the edge.
   1026 self.parent_message[edge.id] = self.lik.ratio(
   1027     self.posterior[edge.parent], parent_cavity
   1028 )

File ~/Library/jupyterlab-desktop/jlab_server/lib/python3.10/site-packages/tsdate/approx.py:62, in approximate_gamma_kl()
     60 assert np.isfinite(x) and np.isfinite(logx)
     61 alpha = 0.5 / (np.log(x) - logx)  # lower bound on alpha
---> 62 assert alpha > 0, "kl-min: bad initial condition"
     63 last = np.inf
     64 itt = 0

AssertionError: kl-min: bad initial condition
@nspope
Copy link
Contributor

nspope commented Jul 12, 2023

Looks like the attachment didn't come through here @hyanwong -- do you still have it handy?

@hyanwong
Copy link
Member Author

Oh, sorry. Will have a look.

@hyanwong
Copy link
Member Author

It's this one. It takes a while to create the priors though, sorry!

debug.trees.zip

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants