In [None]:
import numpy as np
import jax.numpy as jnp
import jax
import spec_model
from generate_data import generate
import tinygp
import matplotlib.pyplot as plt

In [None]:
f = lambda x, amp, mu, sig: amp * np.exp(- 0.5 * (x - mu)**2 / sig**2)
t = np.linspace(0, 10, 100)

nbands = 4

offsets = np.array([0, 1, 3, 6]) * 100
amps = np.array([0.9, 0.8, 0.7])
diags = np.array([0.1, 0.2, 0.3, 0.4]) * 10

mean_params = np.tile(np.array([20, 5, 0.1]), (nbands , 1))
#mean_params[:, 0] = np.array([20, 50, 100, 200])

terms = [tinygp.kernels.quasisep.Matern32]
gp_params = np.array([[0.3, 2.0]])

data = generate(t, f, terms, nbands, offsets, amps, diags, mean_params, gp_params, seed=42)
plt.figure(figsize=(12, 10))
[plt.plot(t, data[i::nbands], '.') for i in range(nbands)];

In [None]:
f = lambda x, amp, mu, sig, off: amp * jnp.exp(- 0.5 * (x - mu)**2 / sig**2) + off
terms = [tinygp.kernels.quasisep.Matern32]

model = spec_model.Model(t, f, terms, 4, hold_params=['mu', 'sig', 'amp'])
model.labels

In [None]:
init_position = np.array([20, 5, 0.1, 0, 100, 300, 600, 0.3, 2.0, 1, 2, 3, 4, 0.9, 0.8, 0.7])
model.train(data, init_position, n_loops=50)

In [None]:
utils.plot_nf_samples(plt.gca(), model)

In [None]:
import utils

fig, axs = plt.subplots(1, 3, figsize=(15, 5))
utils.plot_nf_convergence(axs, model)

In [None]:
model.run_production(data, init_position, n_loops=500, step_size=5e-4)

In [None]:
import corner

chains, log_prob, local_accs, global_accs = model.production_sampler.get_sampler_state().values()
samples = chains.reshape(-1, len(model.labels))[::10, :]
figure = corner.corner(samples, truths=init_position, labels=model.labels)

In [None]:
import corner

out_train = model.trained_sampler.get_sampler_state(training=True)
print('Logged during tuning:', out_train.keys())

chains = np.array(out_train['chains'])
global_accs = np.array(out_train['global_accs'])
local_accs = np.array(out_train['local_accs'])
loss_vals = np.array(out_train['loss_vals'])
nf_samples = np.array(model.trained_sampler.sample_flow(1000)[1])


# Plot 2 chains in the plane of 2 coordinates for first visual check 
plt.figure(figsize=(6, 6))
axs = [plt.subplot(2, 2, i + 1) for i in range(4)]
plt.sca(axs[0])
plt.title("2d proj of 2 chains")

plt.plot(chains[0, :, 0], chains[0, :, 1], 'o-', alpha=0.5, ms=2)
plt.plot(chains[1, :, 0], chains[1, :, 1], 'o-', alpha=0.5, ms=2)
plt.xlabel(model.labels[0])
plt.ylabel(model.labels[1])

plt.sca(axs[1])
plt.title("NF loss")
plt.plot(loss_vals.reshape(-1))
plt.xlabel("iteration")

plt.sca(axs[2])
plt.title("Local Acceptance")
plt.plot(local_accs.mean(0))
plt.xlabel("iteration")

plt.sca(axs[3])
plt.title("Global Acceptance")
plt.plot(global_accs.mean(0))
plt.xlabel("iteration")
plt.tight_layout()
plt.show(block=False)

# Plot all chains
figure = corner.corner(
    chains.reshape(-1, len(model.labels)), labels=model.labels
)
figure.set_size_inches(20, 20)
figure.suptitle("Visualize samples")
plt.show(block=False)

# Plot Nf samples
figure = corner.corner(nf_samples, labels=model.labels)
figure.set_size_inches(20, 20)
figure.suptitle("Visualize NF samples")
plt.show()