In [None]:
import sys

sys.path.append("../")
from hamilton_neural_network import (
    TrainTestData,
    LatentHamiltonianNeuralNetwork,
)
from hamilton_system import HamiltonianSystem
from pdf_models import NegLogIndepedentGaussians, NegLogNealFunnel
import tensorflow as tf
import matplotlib.pyplot as plt
from no_u_turn.nuts import NoUTurnSampling

tf.random.set_seed(0)

In [None]:
U = NegLogNealFunnel()
K = NegLogIndepedentGaussians(tf.constant([0.0, 0.0]), tf.constant([1.0, 1.0]))
q0 = tf.constant([[0.0, 0.0]])
p0 = tf.random.normal(q0.shape)
T = 120.0
leap_frog_per_unit = 40
num_samples = 40
num_train = int(0.9 * num_samples * leap_frog_per_unit * T)

In [None]:
train_test_data = TrainTestData(num_samples, T, leap_frog_per_unit, q0, p0, U=U, K=K)
samples = train_test_data()
tf.io.write_file("../exps/demo3_train_test_data.txt", tf.io.serialize_tensor(samples))

In [None]:
file = tf.io.read_file("../exps/demo3_train_test_data.txt")
train_test_data = tf.io.parse_tensor(file, out_type=tf.float32)
train_test_data = tf.random.shuffle(train_test_data)
train_data = train_test_data[:num_train, :]
test_data = train_test_data[num_train:, :]
print(train_data.shape, test_data.shape)
lhnn = LatentHamiltonianNeuralNetwork(3, 100, 2)
lhnn.build(input_shape=(1, 4))
train_hist, test_hist = lhnn.train(
    1000, 1000, 5e-4, train_data, test_data, save_dir="../exps/demo3_lhnn.weights.h5"
)

In [None]:
lhnn = LatentHamiltonianNeuralNetwork(3, 100, 2)
lhnn.build(input_shape=(1, 4))
lhnn.load_weights("../exps/demo3_lhnn.weights.h5")
U = NegLogNealFunnel()
K = NegLogIndepedentGaussians(tf.constant([0.0, 0.0]), tf.constant([1.0, 1.0]))
q0 = tf.constant([[2.0, -10.0]])
p0 = tf.constant([[0.0, 0.0]])
leap_frog_per_unit = 40
n_steps = 1024
original_hamiltonian = HamiltonianSystem(U=U, K=K)

hist_original = original_hamiltonian.symplectic_integrate(
    q0, p0, 1 / leap_frog_per_unit, n_steps
)
hist_lhnn = lhnn.symplectic_integrate(q0, p0, 1 / leap_frog_per_unit, n_steps)

fig, ax = plt.subplots(1, 2)
ax[0].plot(hist_original[:, 0])
ax[0].plot(hist_lhnn[:, 0])
ax[1].plot(hist_original[:, 1])
ax[1].plot(hist_lhnn[:, 1])
plt.show()

In [None]:
q0 = tf.constant([[0.0, 0.0]])
nuts = NoUTurnSampling(
    num_samples=25000,
    q0=q0,
    dt=0.025,
    lhnn=lhnn,
    Hamiltonian=HamiltonianSystem(U=U, K=K),
    Delta_lf=1000.0,
    Delta_lhnn=10.0,
    num_lf_steps=20,
)
nuts(print_every=2500)
q_hist = tf.concat(nuts.q_hist, axis=0)
tf.io.write_file("../exps/demo3_q_hist.txt", tf.io.serialize_tensor(q_hist))
# plt.hist(q_hist.numpy()[10000:, 0].flatten(), bins=30, color="blue")
# plt.show()