Skip to content

Commit

Permalink
Merge pull request #31 from theosanderson/horseshoe
Browse files Browse the repository at this point in the history
Horse shoe like model working branch
  • Loading branch information
theosanderson committed Nov 1, 2023
2 parents 19d66ce + 6f3ec31 commit 3461244
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 3 deletions.
5 changes: 3 additions & 2 deletions src/chronumental/__main__.py
Expand Up @@ -96,7 +96,7 @@ def get_parser():
help="Number of steps to use for the SVI. Increasing this number will make runtime increase, but yield more accurate results.")

parser.add_argument('--lr',
default=0.05,
default=0.01,
type=float,
help="Adam learning rate")

Expand Down Expand Up @@ -327,7 +327,8 @@ def main():
terminal_target_dates_array=terminal_target_dates_array,
terminal_target_errors_array=terminal_target_errors_array,
ref_point_distance=ref_point_distance,
model_configuration=model_configuration)
model_configuration=model_configuration,
terminal_names=terminal_names)

print("Performing SVI:")
optimiser = optim.ClippedAdam(
Expand Down
133 changes: 132 additions & 1 deletion src/chronumental/models.py
Expand Up @@ -19,6 +19,7 @@ def __init__(self, **kwargs):
self.ref_point_distance = kwargs['ref_point_distance']

self.set_initial_time()
self.terminal_names = kwargs['terminal_names']

def get_logging_results(self, params):
results = collections.OrderedDict()
Expand Down Expand Up @@ -149,5 +150,135 @@ def get_mutation_rate(self, params):
return self.clock_rate
return params['mutation_rate_mu']

class HorseShoeLike(ChronumentalModelBase):

models = {"DeltaGuideWithStrictLearntClock": DeltaGuideWithStrictLearntClock}
def __init__(self, **kwargs):

self.clock_rate = kwargs['model_configuration']["clock_rate"]

self.variance_dates = kwargs['model_configuration']['variance_dates']
self.enforce_exact_clock = kwargs['model_configuration'][
'enforce_exact_clock']
self.variance_on_clock_rate = kwargs['model_configuration'][
'variance_on_clock_rate']
self.expected_min_between_transmissions = kwargs[
'model_configuration']['expected_min_between_transmissions']

super().__init__(**kwargs)

def get_logging_results(self, params):
results = super().get_logging_results(params)
results['mutation_rate'] = self.get_mutation_rate(params)
results['argmax_variances_param'] = self.terminal_names[onp.argmax(params['variances_param'])]
results['median_variances_param'] = onp.median(params['variances_param'])
results['max_variances_param'] = onp.max(params['variances_param'])
results['tau'] = params['tau_param']
return results

def set_initial_time(self):
self.initial_time = jnp.maximum(
365 * (self.branch_distances_array) / self.clock_rate,
self.expected_min_between_transmissions)

def calc_dates(self, branch_lengths_array, root_date):

calc_dates = helpers.do_branch_matmul(
self.rows,
self.cols,
branch_lengths_array=branch_lengths_array,
final_size=self.terminal_target_dates_array.shape[0])
return calc_dates + root_date

def model(self):
root_date = numpyro.sample("root_date",
dist.Normal(loc=0.0, scale=1000.0))

branch_times = numpyro.sample(
"latent_time_length",
dist.Uniform(
low=onp.ones(self.branch_distances_array.shape[0]) * 0,
high=onp.ones(self.branch_distances_array.shape[0]) * 365 *
10000))

if self.enforce_exact_clock:
mutation_rate = self.clock_rate
else:
mutation_rate = numpyro.sample(
f"latent_mutation_rate",
dist.Uniform(
low=0.0,
high=self.clock_rate * 1000.0))

branch_distances = numpyro.sample("branch_distances",
dist.Poisson(mutation_rate *
branch_times / 365),
obs=self.branch_distances_array)

calced_dates = self.calc_dates(branch_times, root_date)

hs_scale = 1

tau = numpyro.sample("tau",
dist.HalfCauchy(hs_scale))

lambda_l = numpyro.sample("lambda",
dist.HalfCauchy(hs_scale),
sample_shape=self.terminal_target_dates_array.shape)


final_dates = numpyro.sample(f"final_dates",
dist.Normal(
calced_dates, lambda_l**2 * tau**2),
obs=self.terminal_target_dates_array)

def guide(self):
root_date_mu = numpyro.param(
"root_date_mu", -365 * self.ref_point_distance / self.clock_rate)

root_date = numpyro.sample("root_date", dist.Delta(root_date_mu))

time_length_mu = numpyro.param("time_length_mu",
self.initial_time,
constraint=dist.constraints.positive)

mutation_rate_mu = numpyro.param("mutation_rate_mu",
self.clock_rate,
constraint=dist.constraints.positive)
mutation_rate_sigma = numpyro.param(
"mutation_rate_sigma",
self.clock_rate,
constraint=dist.constraints.positive)

variances = numpyro.param("variances_param",
onp.ones(self.terminal_target_dates_array.shape) *0.2,
constraint=dist.constraints.positive)

tau_param = numpyro.param("tau_param",
1,
constraint=dist.constraints.positive)
tau = numpyro.sample("tau",
dist.Delta(tau_param))

sample_variances = numpyro.sample("lambda",
dist.Delta(variances))


branch_times = numpyro.sample("latent_time_length",
dist.Delta(time_length_mu))

if not self.variance_on_clock_rate:
mutation_rate = numpyro.sample("latent_mutation_rate",
dist.Delta(mutation_rate_mu))
else:
mutation_rate = numpyro.sample(
f"latent_mutation_rate",
dist.TruncatedNormal(0, mutation_rate_mu, mutation_rate_sigma))

def get_branch_times(self, params):
return params['time_length_mu']

def get_mutation_rate(self, params):
if self.enforce_exact_clock:
return self.clock_rate
return params['mutation_rate_mu']
models = {"DeltaGuideWithStrictLearntClock": DeltaGuideWithStrictLearntClock, "HorseShoeLike": HorseShoeLike}

0 comments on commit 3461244

Please sign in to comment.