From a7025028c69cf18d85d1b577b2df73830047a277 Mon Sep 17 00:00:00 2001 From: Theo Sanderson Date: Wed, 26 Jul 2023 16:12:17 +0100 Subject: [PATCH 1/4] Horse shoe like model working branch --- src/chronumental/models.py | 132 ++++++++++++++++++++++++++++++++++++- 1 file changed, 131 insertions(+), 1 deletion(-) diff --git a/src/chronumental/models.py b/src/chronumental/models.py index 7433090..97eaec7 100644 --- a/src/chronumental/models.py +++ b/src/chronumental/models.py @@ -149,5 +149,135 @@ def get_mutation_rate(self, params): return self.clock_rate return params['mutation_rate_mu'] +class HorseShoeLike(ChronumentalModelBase): -models = {"DeltaGuideWithStrictLearntClock": DeltaGuideWithStrictLearntClock} + def __init__(self, **kwargs): + + self.clock_rate = kwargs['model_configuration']["clock_rate"] + + self.variance_dates = kwargs['model_configuration']['variance_dates'] + self.enforce_exact_clock = kwargs['model_configuration'][ + 'enforce_exact_clock'] + self.variance_on_clock_rate = kwargs['model_configuration'][ + 'variance_on_clock_rate'] + self.expected_min_between_transmissions = kwargs[ + 'model_configuration']['expected_min_between_transmissions'] + + super().__init__(**kwargs) + + def get_logging_results(self, params): + results = super().get_logging_results(params) + results['mutation_rate'] = self.get_mutation_rate(params) + results['argmax_variances_param'] = onp.argmax(params['variances_param']) + results['median_variances_param'] = onp.median(params['variances_param']) + results['max_variances_param'] = onp.max(params['variances_param']) + results['tau'] = params['tau_param'] + return results + + def set_initial_time(self): + self.initial_time = jnp.maximum( + 365 * (self.branch_distances_array) / self.clock_rate, + self.expected_min_between_transmissions) + + def calc_dates(self, branch_lengths_array, root_date): + + calc_dates = helpers.do_branch_matmul( + self.rows, + self.cols, + branch_lengths_array=branch_lengths_array, + final_size=self.terminal_target_dates_array.shape[0]) + return calc_dates + root_date + + def model(self): + root_date = numpyro.sample("root_date", + dist.Normal(loc=0.0, scale=1000.0)) + + branch_times = numpyro.sample( + "latent_time_length", + dist.Uniform( + low=onp.ones(self.branch_distances_array.shape[0]) * 0, + high=onp.ones(self.branch_distances_array.shape[0]) * 365 * + 10000)) + + if self.enforce_exact_clock: + mutation_rate = self.clock_rate + else: + mutation_rate = numpyro.sample( + f"latent_mutation_rate", + dist.Uniform( + low=0.0, + high=self.clock_rate * 1000.0)) + + branch_distances = numpyro.sample("branch_distances", + dist.Poisson(mutation_rate * + branch_times / 365), + obs=self.branch_distances_array) + + calced_dates = self.calc_dates(branch_times, root_date) + + hs_scale = 1 + + tau = numpyro.sample("tau", + dist.HalfCauchy(hs_scale)) + + lambda_l = numpyro.sample("lambda", + dist.HalfCauchy(hs_scale), + sample_shape=self.terminal_target_dates_array.shape) + + + final_dates = numpyro.sample(f"final_dates", + dist.Normal( + calced_dates, lambda_l**2 * tau**2), + obs=self.terminal_target_dates_array) + + def guide(self): + root_date_mu = numpyro.param( + "root_date_mu", -365 * self.ref_point_distance / self.clock_rate) + + root_date = numpyro.sample("root_date", dist.Delta(root_date_mu)) + + time_length_mu = numpyro.param("time_length_mu", + self.initial_time, + constraint=dist.constraints.positive) + + mutation_rate_mu = numpyro.param("mutation_rate_mu", + self.clock_rate, + constraint=dist.constraints.positive) + mutation_rate_sigma = numpyro.param( + "mutation_rate_sigma", + self.clock_rate, + constraint=dist.constraints.positive) + + variances = numpyro.param("variances_param", + onp.ones(self.terminal_target_dates_array.shape) *0.2, + constraint=dist.constraints.positive) + + tau_param = numpyro.param("tau_param", + 1000, + constraint=dist.constraints.positive) + tau = numpyro.sample("tau", + dist.Delta(tau_param)) + + sample_variances = numpyro.sample("lambda", + dist.Delta(variances)) + + + branch_times = numpyro.sample("latent_time_length", + dist.Delta(time_length_mu)) + + if not self.variance_on_clock_rate: + mutation_rate = numpyro.sample("latent_mutation_rate", + dist.Delta(mutation_rate_mu)) + else: + mutation_rate = numpyro.sample( + f"latent_mutation_rate", + dist.TruncatedNormal(0, mutation_rate_mu, mutation_rate_sigma)) + + def get_branch_times(self, params): + return params['time_length_mu'] + + def get_mutation_rate(self, params): + if self.enforce_exact_clock: + return self.clock_rate + return params['mutation_rate_mu'] +models = {"DeltaGuideWithStrictLearntClock": DeltaGuideWithStrictLearntClock, "HorseShoeLike": HorseShoeLike} From 6826b207593832577b3941526699afa5c8bac6a2 Mon Sep 17 00:00:00 2001 From: Theo Sanderson Date: Wed, 26 Jul 2023 20:12:27 +0100 Subject: [PATCH 2/4] Update models.py --- src/chronumental/models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/chronumental/models.py b/src/chronumental/models.py index 97eaec7..ec1ab7b 100644 --- a/src/chronumental/models.py +++ b/src/chronumental/models.py @@ -19,6 +19,7 @@ def __init__(self, **kwargs): self.ref_point_distance = kwargs['ref_point_distance'] self.set_initial_time() + self.terminal_names = kwargs['terminal_names'] def get_logging_results(self, params): results = collections.OrderedDict() @@ -168,7 +169,7 @@ def __init__(self, **kwargs): def get_logging_results(self, params): results = super().get_logging_results(params) results['mutation_rate'] = self.get_mutation_rate(params) - results['argmax_variances_param'] = onp.argmax(params['variances_param']) + results['argmax_variances_param'] = self.terminal_names[onp.argmax(params['variances_param'])] results['median_variances_param'] = onp.median(params['variances_param']) results['max_variances_param'] = onp.max(params['variances_param']) results['tau'] = params['tau_param'] From c5fd2a5a302139f8d5e564eaf3f3918e40af8b79 Mon Sep 17 00:00:00 2001 From: Theo Sanderson Date: Wed, 26 Jul 2023 20:12:44 +0100 Subject: [PATCH 3/4] Update __main__.py --- src/chronumental/__main__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/chronumental/__main__.py b/src/chronumental/__main__.py index ab89a94..6d071d3 100644 --- a/src/chronumental/__main__.py +++ b/src/chronumental/__main__.py @@ -96,7 +96,7 @@ def get_parser(): help="Number of steps to use for the SVI. Increasing this number will make runtime increase, but yield more accurate results.") parser.add_argument('--lr', - default=0.05, + default=0.01, type=float, help="Adam learning rate") @@ -327,7 +327,8 @@ def main(): terminal_target_dates_array=terminal_target_dates_array, terminal_target_errors_array=terminal_target_errors_array, ref_point_distance=ref_point_distance, - model_configuration=model_configuration) + model_configuration=model_configuration, + terminal_names=terminal_names) print("Performing SVI:") optimiser = optim.ClippedAdam( From 6f3ec315e814e578022f1c15b1338b72771f40dd Mon Sep 17 00:00:00 2001 From: Theo Sanderson Date: Fri, 28 Jul 2023 14:19:42 +0100 Subject: [PATCH 4/4] update --- src/chronumental/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/chronumental/models.py b/src/chronumental/models.py index ec1ab7b..8915fba 100644 --- a/src/chronumental/models.py +++ b/src/chronumental/models.py @@ -254,7 +254,7 @@ def guide(self): constraint=dist.constraints.positive) tau_param = numpyro.param("tau_param", - 1000, + 1, constraint=dist.constraints.positive) tau = numpyro.sample("tau", dist.Delta(tau_param))