In [None]:
# %load init.py
import os
import pickle
import sys
# Enable module import from the parent directory from notebooks
sys.path.append(os.path.abspath('..'))
import time

import matplotlib as mpl
# Select plotting backend
mpl.use('nbAgg')

import matplotlib.pyplot as plt
# Customize plotting
plt.style.use('seaborn-paper')
plt.rcParams['axes.labelsize'] = 11.0
plt.rcParams['axes.titlesize'] = 12.0
plt.rcParams['errorbar.capsize'] = 3.0
plt.rcParams['figure.dpi'] = 72.0
plt.rcParams['figure.titlesize'] = 12.0
plt.rcParams['legend.fontsize'] = 10.
plt.rcParams['lines.linewidth'] = 1.
plt.rcParams['xtick.labelsize'] = 11.0
plt.rcParams['ytick.labelsize'] = 11.0

import numpy as np
import sympy as sp
sp.init_printing(euler=True, use_latex=True)

from IPython import display
from scipy import io, optimize
from sklearn import metrics

import core
import dynamicals
import kernels
import numericals
import utils

In [None]:
# Example setup to run the inference
dynamical = dynamicals.Lorenz96(10) 

spl_t_0, spl_t_T, spl_freq = 0, 4, 80
obs_t_0, obs_t_T, obs_freq = 0, 4, 8
est_t_0, est_t_T, est_freq = 0, 4, 8
spl_tps, obs_tps, obs_t_indices, est_tps, est_t_indices = utils.create_time(
    spl_t_0, spl_t_T, spl_freq, obs_t_0, obs_t_T, obs_freq, est_t_0, est_t_T, est_freq)
X_0 = np.random.random(dynamical.num_x) * 8.
theta = np.array([8.]) 
rho_2 = np.full(dynamical.num_x, 4.) 
phi = [
    # (Kernal name, Kernal parameters)
    ('rbf', np.array([4.2, 0.1]))
] * dynamical.num_x
sigma_2 = np.full(dynamical.num_x, 1.) 
delta = np.full(dynamical.num_x, True)
delta[np.random.permutation(dynamical.num_x)[:int(0.35 * dynamical.num_x)]] = False
gamma = np.full(dynamical.num_x, 5e-2) 
gamma[delta] = 1e-1

opt_method = 'Newton-CG'
opt_tol = 1e-6
max_init_iter = 10
max_iter = 1000

plotting_enabled = True
plotting_freq = 50

spl_X = dynamical.generate_sample_path(theta, rho_2, X_0, spl_tps)
obs_Y = utils.collect_observations(spl_X, obs_t_indices, sigma_2)

data = core.laplace_mean_field(dynamical, 
                               spl_X, spl_tps, 
                               obs_Y, obs_tps, obs_t_indices, 
                               est_tps, est_t_indices,
                               theta, rho_2, phi, sigma_2, delta, gamma,
                               opt_method, opt_tol, max_init_iter, max_iter,
                               plotting_enabled, plotting_freq)

In [None]:
sde_directory = '../data/tars/sde-lorenz-96-scalability/sde-x-{}/'
ode_directory = '../data/tars/sde-lorenz-96-scalability/ode-x-{}/'
config_filename = utils.CONFIG_FILENAME
data_filename = utils.DATA_FILENAME

num_states = [25, 50, 100, 200, 400, 800]
num_repetitions = 10

In [None]:
sde_runtime_mean = []
sde_runtime_var = []
ode_runtime_mean = []
ode_runtime_var = []

for num_state in num_states:
    # Load data
    (spl_t_0, spl_t_T, spl_tps, spl_freq,
     obs_t_0, obs_t_T, obs_freq, obs_tps, obs_t_indices,
     est_t_0, est_t_T, est_freq, est_tps, est_t_indices,
     X_0, theta, rho_2, phi, sigma_2, delta, gamma,
     opt_method, opt_tol, max_init_iter, max_iter, plotting_enabled, plotting_freq,
     spl_X, obs_Y) = utils.load_sde_config(sde_directory.format(num_state), config_filename)
    
    # Load data from LPMF-SDE
    data = []
    for i in range(1, num_repetitions + 1):
        tmp = utils.load_data(sde_directory.format(num_state), data_filename.format(i))
        if not np.alltrue(tmp['eta_theta'] > 0):
            raise RuntimeError('Negative theta value encountered for rode {}'.format(i))
        data.append(tmp)
    
    sde_runtime_mean.append(utils.get_runtime_mean(data))
    sde_runtime_var.append(utils.get_runtime_var(data))

for num_state in num_states:
    # Load data
    (spl_t_0, spl_t_T, spl_tps, spl_freq,
     obs_t_0, obs_t_T, obs_freq, obs_tps, obs_t_indices,
     est_t_0, est_t_T, est_freq, est_tps, est_t_indices,
     X_0, theta, rho_2, phi, sigma_2, delta, gamma,
     opt_method, opt_tol, max_init_iter, max_iter, plotting_enabled, plotting_freq,
     spl_X) = utils.load_ode_config(ode_directory.format(num_state), config_filename)
    
    # Load data from LPMF-SDE
    data = []
    for i in range(1, num_repetitions + 1):
        tmp = utils.load_data(ode_directory.format(num_state), data_filename.format(i))
        if not np.alltrue(tmp['eta_theta'] > 0):
            raise RuntimeError('Negative theta value encountered for rode {}'.format(i))
        data.append(tmp)
    
    ode_runtime_mean.append(utils.get_runtime_mean(data))
    ode_runtime_var.append(utils.get_runtime_var(data))

figure = plt.figure(figsize=plt.figaspect(0.4))
ax = plt.gca()
ax.errorbar(num_states, sde_runtime_mean, color='C0', linestyle='-', linewidth=1.5, label='LPMF-SDE', 
            yerr=np.sqrt(sde_runtime_var), ecolor='0', elinewidth=1., capsize=3., capthick=.5)
ax.errorbar(num_states, ode_runtime_mean, color='C1', linestyle='-.', linewidth=1.5, label='LPMF', 
            yerr=np.sqrt(ode_runtime_var), ecolor='0', elinewidth=1., capsize=3., capthick=.5)
ax.set_ylabel('Runtime (s)', fontsize=14)
ax.set_xlabel('Number of States', fontsize=14)
ax.set_xlim([0, 810])
ax.set_ylim([0, None])
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles=handles, labels=labels, loc=0)
figure.tight_layout()
plt.show()
figure.savefig('lorenz-96-scalability.eps', format='eps', dpi=1000, bbox_inches='tight')

In [None]:
# Helper to generate config files
sde_directory = '../data/sde-lorenz-96-scalability/sde-x-{}/'
ode_directory = '../data/sde-lorenz-96-scalability/ode-x-{}/'

for num_state in num_states:
    dynamical = dynamicals.Lorenz96(num_state)

    spl_t_0, spl_t_T, spl_freq = 0, 4, 80
    obs_t_0, obs_t_T, obs_freq = 0, 4, 8
    est_t_0, est_t_T, est_freq = 0, 4, 8
    spl_tps, obs_tps, obs_t_indices, est_tps, est_t_indices = utils.create_time(
        spl_t_0, spl_t_T, spl_freq, obs_t_0, obs_t_T, obs_freq, est_t_0, est_t_T, est_freq)
    X_0 = np.random.random(dynamical.num_x) * 8.
    theta = np.array([8.]) 
    rho_2 = np.full(dynamical.num_x, 4.) 
    phi = [
        # (Kernal name, Kernal parameters)
        ('rbf', np.array([4.2, 0.1]))
    ] * dynamical.num_x
    sigma_2 = np.full(dynamical.num_x, 1.) 
    delta = np.full(dynamical.num_x, True)
    delta[np.random.permutation(dynamical.num_x)[:int(0.35 * dynamical.num_x)]] = False
    gamma = np.full(dynamical.num_x, 5e-2) 
    gamma[delta] = 1e-1

    opt_method = 'Newton-CG'
    opt_tol = 1e-6
    max_init_iter = 10
    max_iter = 1000

    plotting_enabled = True
    plotting_freq = 50

    spl_X = dynamical.generate_sample_path(theta, rho_2, X_0, spl_tps)
    obs_Y = utils.collect_observations(spl_X, obs_t_indices, sigma_2)
    
    utils.save_sde_config(sde_directory.format(num_state), config_filename,
                          spl_t_0, spl_t_T, spl_freq, spl_tps,
                          obs_t_0, obs_t_T, obs_freq, obs_tps, obs_t_indices,
                          est_t_0, est_t_T, est_freq, est_tps, est_t_indices,
                          X_0, theta, rho_2, phi, sigma_2, delta, gamma,
                          opt_method, opt_tol, max_init_iter, max_iter,
                          plotting_enabled, plotting_freq, spl_X, obs_Y)    
    
    rho_2 = None
    utils.save_ode_config(ode_directory.format(num_state), config_filename,
                          spl_t_0, spl_t_T, spl_freq, spl_tps,
                          obs_t_0, obs_t_T, obs_freq, obs_tps, obs_t_indices,
                          est_t_0, est_t_T, est_freq, est_tps, est_t_indices,
                          X_0, theta, rho_2, phi, sigma_2, delta, gamma,
                          opt_method, opt_tol, max_init_iter, max_iter,
                          plotting_enabled, plotting_freq, spl_X)
    break