In [None]:
import numpy as np
import os
import tensorflow as tf
import torch
import maxent
from maxent.sbi_gravitation import GravitySimulator, sim_wrapper, get_observation_points
from torch.distributions.multivariate_normal import MultivariateNormal
from sbi.inference import infer

In [None]:
# set up true parameters
m1 = 100. # solar masses
m2 = 50. # solar masses
m3 = 75 # solar masses
G = 1.90809e5 # solar radius / solar mass * (km/s)^2
v0 = np.array([15.,-40.]) # km/s

true_params = [m1, m2, m3, v0[0], v0[1]]

# set prior means
prior_means = [85., 40., 70., 12., -30.]
prior_cov = np.eye(5) * 50

In [None]:
# generate true trajectory and apply some noise to it
if os.path.exists('true_trajectory.txt'):
    true_traj = np.genfromtxt('true_trajectory.txt')
else:
    sim = GravitySimulator(m1, m2, m3, v0, G, random_noise=False)
    true_traj = sim.run()
    np.savetxt('true_trajectory.txt', true_traj)

if os.path.exists('noisy_trajectory.txt'):
    traj=np.genfromtxt('noisy_trajectory.txt')
else:
    sim = GravitySimulator(m1, m2, m3, v0, G, random_noise=True)
    traj = sim.run()
    np.savetxt('noisy_trajectory.txt', traj)

observed_points = get_observation_points(traj)
observation_summary_stats = observed_points.flatten()
sim.plot_traj()

In [None]:
# perform SNL inference
prior = MultivariateNormal(loc=torch.as_tensor(prior_means),
                            covariance_matrix=torch.as_tensor(torch.eye(5)*50))

posterior = infer(sim_wrapper, prior, method='SNLE', num_simulations=2048, num_workers=16)

In [None]:
# sample from SNL posterior
samples = posterior.sample((2000,), x=observation_summary_stats)

np.savetxt('wide_prior_samples.txt', np.array(samples))

In [None]:
# set up restraints for maxent
# restraint structure: [value, uncertainty, indices... ]
restraints = []
for i, point in enumerate(observed_points):
    value1 = point[0]
    value2 = point[1]
    uncertainty = 25
    index = 20 * i + 19 # based on how we slice in get_observation_points()
    restraints.append([value1, uncertainty, index, 0])
    restraints.append([value2, uncertainty, index, 1])

In [None]:
# set up laplace restraints
laplace_restraints = []

for i in range(len(restraints)):
    traj_index = tuple(restraints[i][2:])
    value = restraints[i][0]
    uncertainty = restraints[i][1]
    #p = maxentep.Laplace(uncertainty)
    p = maxent.EmptyPrior()
    r = maxent.Restraint(lambda traj, i=traj_index: traj[i], value, p)
    laplace_restraints.append(r)

In [None]:
# sample from prior for maxent
np.random.seed(12656)
prior_dist = np.random.multivariate_normal(prior_means, prior_cov, size=2048)
np.save('maxent_prior_samples.npy', prior_dist)

In [None]:
# generate trajectories for maxent from prior samples
trajs = np.zeros([prior_dist.shape[0], 100, 2])

for i, sample in enumerate(tqdm(prior_dist)):
    m1, m2, m3, v0 = sample[0], sample[1], sample[2], sample[3:]
    sim = GravitySimulator(m1, m2, m3, v0, random_noise=False)
    traj = sim.run()
    trajs[i] = traj
    
np.save('maxent_raw_trajectories.npy', trajs)

In [None]:
# run maxent on trajectories
batch_size = prior_dist.shape[0]

model = maxentep.MaxentModel(laplace_restraints)
model.compile(tf.keras.optimizers.Adam(1e-4), 'mean_squared_error')
# short burn-in
h = model.fit(trajs, batch_size=batch_size, epochs=5000, verbose=1)
# restart to reset learning rate
h = model.fit(trajs, batch_size=batch_size, epochs=25000, verbose=1)

np.savetxt('maxent_loss.txt', h.history['loss'])

weights = model.traj_weights
np.savetxt('maxent_traj_weights.txt', weights)

avg_traj = np.sum(trajs * model.traj_weights[:, np.newaxis, np.newaxis], axis=0)
np.savetxt('maxent_avg_traj.txt', avg_traj)

TODO: Plotting code