# Imports and definitions

In [1]:
import jax
import jax.numpy as jnp
from jax import grad, jit, random
from functools import partial
import optax
import matplotlib.pyplot as plt

from numerical_integrators import leap_frog_harmonic_oscillator
from jx_pot import sliced_wasserstein_distance_CDiag

import import_ipynb
import FINAL_Data_Generation as data_gen

In [2]:
LAMBDA_REG = 0.1

In [25]:
ALPHA_TRUE = jnp.array([data_gen.MU_TARGET, jnp.log(data_gen.TAU_MAP)])
ALPHA_PHI = jnp.concatenate([data_gen.MU_PHI, jnp.log(data_gen.TAU_PHI_MAP)])

[ 0.6931472  1.7917595  0.        -1.3217558 -0.9162907 -1.0296193]


In [5]:
data_gen.target_observations_from_sampled_params

Array([[1.9902518 , 1.9248587 , 1.9184332 , 1.9450139 , 1.9086748 ,
        1.8405888 , 1.8068295 , 1.8218379 , 1.7758108 , 1.8066919 ,
        1.7111119 , 1.654546  , 1.6842153 , 1.6615462 , 1.5382918 ,
        1.4691916 , 1.5370142 , 1.5917947 , 1.5123758 , 1.4880695 ,
        1.4265585 , 1.4334753 , 1.4557508 , 1.3556144 , 1.258726  ,
        1.4008838 , 1.2735006 , 1.2836605 , 1.1904333 , 1.1640927 ,
        1.24054   , 1.213093  , 1.1714345 , 1.1459193 , 1.1441957 ,
        1.1173106 , 1.0870682 , 1.1290278 , 1.1379956 , 1.0519263 ,
        0.9684204 , 0.9921023 , 1.0164814 , 1.0086647 , 0.89739203,
        0.91374016, 0.84245306, 0.89235   , 0.83788896, 0.8467267 ],
       [1.961783  , 2.069918  , 2.0100641 , 1.9775637 , 2.0223405 ,
        1.9965476 , 2.0366576 , 2.058195  , 1.9210925 , 1.8751935 ,
        1.9799551 , 1.9204594 , 1.8449118 , 1.877304  , 1.8960203 ,
        1.910673  , 1.8432211 , 1.8735038 , 1.8524454 , 1.6835542 ,
        1.7402086 , 1.7689518 , 1.6844764 , 1.6

# Objective function

In [6]:
def loss_fn(alpha):
    """
    Compute objective using Sliced-Wasserstein distance with transformed parameters and a regularisation KL-divergence term
    """
    try:
        iteration += 1
    except:
        iteration = 0
        
    key = random.key(iteration)
    param_sample_key, estimate_key, sw_key = random.split(key, 3)
    
    params = sample_params(param_sample_key, alpha)
    
    # Generate trajectories
    y_estimate = data_gen.noisy_observations_from_parameters(iteration, params, data_gen.OBSERVATION_NOISE)
    y_estimates = jnp.tile(y_estimate, (data_gen.number_systems, 1))
    
    # Compute distance
    C = (data_gen.OBSERVATION_NOISE ** 2) * jnp.ones(50)
    sw_distance = sliced_wasserstein_distance_CDiag(random.key(iteration+1), y_estimates, data_gen.target_observations_from_sampled_params, C)
    return sw_distance
    
    # Compute regularisation term
    kl_divergence = kl_multivariate_gaussians(params)
    
    return sw_distance + LAMBDA_REG * kl_divergence

## Sample parameters

In [7]:
def sample_params(sample_key, alpha):
    '''
    p(z^(n) | alpha) with parameter transformation
    '''
    m_alpha, C_alpha = alpha[:3], jnp.exp(alpha[3:])
    return random.multivariate_normal(
        sample_key,
        mean=m_alpha,
        cov=jnp.diag(C_alpha),
    )

In [8]:
test_key = random.key(31)
test_alpa = jnp.array([5., 7., 10., 1., 2., 3.])
test_params = sample_params(test_key, test_alpa)

In [9]:
y_test = data_gen.noisy_observations_from_parameters(3, test_params, data_gen.OBSERVATION_NOISE)
y_test.shape

(50,)

In [10]:
y_tests = jnp.tile(y_test, (4, 1))
y_tests.shape

(4, 50)

In [11]:
data_gen.target_observations_from_sampled_params.shape

(4, 50)

In [12]:
C = (data_gen.OBSERVATION_NOISE ** 2) * jnp.ones(50)
sw_distance = sliced_wasserstein_distance_CDiag(random.key(21), y_tests, data_gen.target_observations_from_sampled_params, C)
sw_distance

Array(18.35927, dtype=float32)

# Optimise objective function

In [26]:
solver = optax.adam(learning_rate=.1)
alpha = ALPHA_PHI
opt_state = solver.init(alpha)
losses = []
for iteration in range(1001):
    grad = jax.grad(loss_fn)(alpha)
    updates, opt_state = solver.update(grad, opt_state, alpha)
    alpha = optax.apply_updates(alpha, updates)
    losses.append(loss_fn(alpha))
    
    if iteration % 100 == 0:
        print(f'Loss function at step {iteration}: {loss_fn(alpha)}, with real parameters {alpha}')

Loss function at step 0: 4.813715934753418, with real parameters [ 0.48999932  1.9999993   0.40000066 -1.6000007  -1.0999993  -0.6000006 ]
Loss function at step 100: 4.404922008514404, with real parameters [ 0.7397204   2.0601144   0.32151186 -1.4579549  -1.251004   -0.6198238 ]
Loss function at step 200: 4.404911994934082, with real parameters [ 0.7399579  2.0596259  0.3219291 -1.4578112 -1.2505832 -0.620281 ]
Loss function at step 300: 4.404911994934082, with real parameters [ 0.7399546   2.0596285   0.32192793 -1.4578143  -1.250585   -0.62027967]
Loss function at step 400: 4.404911994934082, with real parameters [ 0.73995453  2.0596285   0.32192776 -1.4578148  -1.2505846  -0.6202795 ]


In [27]:
ALPHA_TRUE

Array([[ 0.6931472,  1.7917595,  0.       ],
       [-1.3217558, -0.9162907, -1.0296193]], dtype=float32)

In [28]:
alpha

Array([ 0.73995453,  2.0596285 ,  0.32192776, -1.4578148 , -1.2505846 ,
       -0.6202795 ], dtype=float32)

In [29]:
mu_alpha, tau_alpha = alpha[:3], jnp.exp(alpha[3:])
final_params = jnp.array([mu_alpha, tau_alpha])
final_params

Array([[0.73995453, 2.0596285 , 0.32192776],
       [0.2327443 , 0.28633735, 0.5377941 ]], dtype=float32)