In [1]:
%pip install git+https://github.com/mariogeiger/hessian.git

Defaulting to user installation because normal site-packages is not writeable
Collecting git+https://github.com/mariogeiger/hessian.git
  Cloning https://github.com/mariogeiger/hessian.git to /tmp/pip-req-build-qqyygz3d
  Running command git clone -q https://github.com/mariogeiger/hessian.git /tmp/pip-req-build-qqyygz3d


In [2]:
%matplotlib inline

In [3]:
from functools import partial
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from moviepy.editor import ImageSequenceClip
import sys
import pendulum

In [4]:
x0 = np.array([3*np.pi/7, 3*np.pi/4, 0, 0], dtype=np.float32)
t_demo = np.linspace(0, 100, num=1000, dtype=np.float32)
demo_trajectory = pendulum.solve_analytical(x0, t_demo, g=9.8)




# Prepare a noisy observed trajectory

In [5]:
t = np.linspace(0, 10, num=5, dtype=np.float32)
def sample_trajectories_by_gs(key, gs, noise_std=0.5):
    # odeint is bugged with vmap, probably becouse of the control flow to maintain tolerance
    true_trajectories = jnp.array([pendulum.solve_analytical(x0, t, g=g) for g in gs])
    return true_trajectories + jax.random.normal(key, true_trajectories.shape, true_trajectories.dtype)*noise_std

In [6]:
rng = jax.random.PRNGKey(41231)
rng, subkey = jax.random.split(rng)
observed_trajectory = sample_trajectories_by_gs(subkey, [9.8])

In [7]:
observed_trajectory = np.array(observed_trajectory)

# What if we didn't know the likelihood?

# L-GSO

https://arxiv.org/abs/2002.04632

In [8]:
import torch

In [9]:
class Simulator:
    def __init__(self, rng_key):
        self.rng_key = rng_key
    
    def simulate(self, gs):
        self.rng_key, subkey = jax.random.split(self.rng_key)
        trajectory = sample_trajectories_by_gs(subkey, jax.numpy.array(gs.ravel(), dtype=jax.numpy.float32))
        return torch.tensor(np.array(trajectory).reshape((gs.shape[0], -1)))

In [10]:
rng, subkey = jax.random.split(rng)
simulator = Simulator(subkey)

In [11]:
from gan_model import GANModel
from experience_replay import ExperienceReplay
from optimizer import TorchOptimizer

In [12]:
class DoublePendulumLossModel:
    """
    Just a class that implements loss function between GAN surrogate output and real output
    """
    def __init__(self, ground_truth_data):
        self.ground_truth_data = ground_truth_data.view(1, -1)
        
    def loss(self, y, conditions):
        return ((self.ground_truth_data - y)**2).mean()
    
y_model = DoublePendulumLossModel(torch.tensor(observed_trajectory).float())

In [13]:
simulator.simulate(np.array([5.]))

tensor([[ 0.9438,  1.9554, -0.5243,  0.1454, -1.8628, -1.4289,  0.7781,  1.7277,
          1.0496,  4.2183, -1.6455,  4.2909,  0.2393, 14.4571,  1.1370, -1.4946,
         -1.3830,  9.8013,  0.7803, -0.7391]])

In [14]:
(simulator.simulate(np.array([5.])) - observed_trajectory.reshape(-1)).pow(2).mean()

tensor(19.8121)

In [15]:
(simulator.simulate(np.array([3.])) - observed_trajectory.reshape(-1)).pow(2).mean()

tensor(15.8944)

In [16]:
(simulator.simulate(np.array([9.8])) - observed_trajectory.reshape(-1)).pow(2).mean()

tensor(0.3326)

In [17]:
(simulator.simulate(np.array([20.8])) - observed_trajectory.reshape(-1)).pow(2).mean()

tensor(30.7193)

In [18]:
device = "cpu"

model_config = {
    'task': "CRAMER", #"REVERSED_KL", #"CRAMER", #"CRAMER", #"REVERSED_KL",  # 'WASSERSTEIN', # WASSERSTEIN, REVERSED_KL
    'y_dim': 20,
    'x_dim': 0,
    'psi_dim': 1,
    'noise_dim': 150,
    'lr': 1e-4 * 8,
    'batch_size': 512,
    'epochs': 15,
    'iters_discriminator': 1,
    'iters_generator': 1,
    'instance_noise_std': 0.01,
    'burn_in_period': None,
    'averaging_coeff': 0.,
    'dis_output_dim': 256,
    'grad_penalty': True,
    'attention_net_size': None,
    'gp_reg_coeff': 10,
    'device': device
    # 'predict_risk': False
}

In [19]:
exp_replay = ExperienceReplay(
    psi_dim=model_config['psi_dim'],
    y_dim=model_config['y_dim'],
    x_dim=model_config['x_dim'],
    device=device
)

In [20]:
def generate_local_data(simulator, current_psi, step=0.1, n_samples=5):
    """
    Sampler of new points around current parameters
    """
    psis = np.random.uniform(low=-1., high=1., size=(n_samples, current_psi.shape[0]))
    psis = current_psi[np.newaxis] + step * psis
    data = simulator.simulate(psis.reshape(-1))
    return data, psis.float()

In [21]:
# area for training surrogate
step_data_gen = 1.

# points to sample at each epoch
n_samples = 10

# initial psi
current_psi=torch.tensor([14.])

In [22]:
# optimizer that we are going to re-use
optimizer_config = {
    'lr': 0.1,
    'num_repetitions': 10000,
    'max_iters': 1,
    'torch_model': 'Adam',
}

In [23]:
# sample new data
output, condition = generate_local_data(simulator, current_psi, step=step_data_gen, n_samples=n_samples)

In [24]:
# look up in experience replay
output_exp_replay, condition_exp_replay = exp_replay.extract(psi=current_psi, step=step_data_gen)
exp_replay.add(y=output, condition=condition)
output = torch.cat([output, output_exp_replay], dim=0)
condition = torch.cat([condition, condition_exp_replay], dim=0)

In [25]:
# if not reusing model
# then at each epoch re-initialize and re-fit GAN
model = GANModel(**model_config, y_model=y_model).to(device)
model.fit(output, condition=condition)
model.eval()

GANModel(
  (_generator): Generator(
    (fc1): Linear(in_features=151, out_features=100, bias=True)
    (fc2): Linear(in_features=100, out_features=100, bias=True)
    (fc3): Linear(in_features=100, out_features=20, bias=True)
    (fc4): Linear(in_features=100, out_features=100, bias=True)
  )
  (_discriminator): Discriminator(
    (fc1): Linear(in_features=21, out_features=100, bias=True)
    (fc2): Linear(in_features=100, out_features=100, bias=True)
    (fc3): Linear(in_features=100, out_features=100, bias=True)
    (fc4): Linear(in_features=100, out_features=256, bias=True)
  )
)

In [26]:
previous_psi = current_psi.clone()
optimizer = TorchOptimizer(oracle=model, x=current_psi, **optimizer_config)

optimizer.update(oracle=model, x=current_psi)
current_psi, status, history = optimizer.optimize()

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.1
    weight_decay: 0
)


In [27]:
previous_psi, current_psi

(tensor([14.]), tensor([13.9000]))

## Stack everything together

In [28]:
from tqdm.notebook import tqdm

for epochs in tqdm(range(200)):    
    # sample new data
    output, condition = generate_local_data(simulator, current_psi, step=step_data_gen, n_samples=n_samples)

    # look up in experience replay
    output_exp_replay, condition_exp_replay = exp_replay.extract(psi=current_psi, step=step_data_gen)
    exp_replay.add(y=output, condition=condition)
    output = torch.cat([output, output_exp_replay], dim=0)
    condition = torch.cat([condition, condition_exp_replay], dim=0)

    # if not reusing model
    # then at each epoch re-initialize and re-fit GAN
    model = GANModel(**model_config, y_model=y_model).to(device)
    model.fit(output, condition=condition)
    model.eval()

    # optimize psi with surrogate
    previous_psi = current_psi.clone()
    optimizer.update(oracle=model, x=current_psi)
    current_psi, status, history = optimizer.optimize()
    if epochs % 10 == 0:
        print(
            f"Current optima: {current_psi.item()}, Current MSE: {history['func'][0]}, Dataset size: {condition.shape[0]}"
        )
        print()

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=200.0), HTML(value='')))

Current optima: 13.845582962036133, Current MSE: 17.962360382080078, Dataset size: 20

Current optima: 12.949688911437988, Current MSE: 13.462657928466797, Dataset size: 90

Current optima: 12.247157096862793, Current MSE: 11.016587257385254, Dataset size: 139

Current optima: 11.650197982788086, Current MSE: 9.92439079284668, Dataset size: 162

Current optima: 11.067352294921875, Current MSE: 9.12927532196045, Dataset size: 161

Current optima: 10.327651977539062, Current MSE: 6.482333183288574, Dataset size: 144

Current optima: 9.252219200134277, Current MSE: 0.6492361426353455, Dataset size: 100

Current optima: 8.979599952697754, Current MSE: 2.280441999435425, Dataset size: 152

Current optima: 9.635954856872559, Current MSE: 0.935620129108429, Dataset size: 284

Current optima: 9.891376495361328, Current MSE: 1.3760687112808228, Dataset size: 388

Current optima: 9.575937271118164, Current MSE: 1.0935181379318237, Dataset size: 476

Current optima: 9.629622459411621, Current MSE

In [29]:
current_psi

tensor([9.7366])