In [None]:
import numpy as np
import os
import tensorflow as tf
import maxent
import matplotlib.pyplot as plt
from sbi_gravitation import GravitySimulator, sim_wrapper, get_observation_points, prior_means

In [None]:
# set up true parameters
m1 = 100. # solar masses
m2 = 50. # solar masses
m3 = 75 # solar masses
G = 1.90809e5 # solar radius / solar mass * (km/s)^2
v0 = np.array([15.,-40.]) # km/s

In [None]:
if os.path.exists('true_trajectory.txt'):
    traj = np.genfromtxt('true_trajectory.txt')
else:
# make "true" path
    sim = GravitySimulator(m1, m2, m3, v0, random_noise=False)
    traj = sim.run()
    np.savetxt('true_trajectory.txt', traj)
    sim.plot_traj()

if os.path.exists('noisy_trajectory.txt'):
    traj=np.genfromtxt('noisy_trajectory.txt')
    
sim = GravitySimulator(m1, m2, m3, v0, random_noise=True)
traj = sim.run()

np.savetxt('noisy_trajectory.txt', traj)
observation_summary_stats = get_observation_points(traj).flatten()
sim.plot_traj()

In [None]:
prior = MultivariateNormal(loc=torch.as_tensor(prior_means),
                            covariance_matrix=torch.as_tensor(torch.eye(5)*50))

posterior = infer(sim_wrapper, prior, method='SNLE', num_simulations=2048, num_workers=16)

In [None]:
samples = posterior.sample((2000,), x=observation_summary_stats)

np.savetxt('wide_prior_samples.txt', np.array(samples))