In [None]:
# %load init.py
import os
import pickle
import sys
# Enable module import from the parent directory from notebooks
sys.path.append(os.path.abspath('..'))
import time

import matplotlib as mpl
# Select plotting backend
mpl.use('nbAgg')

import matplotlib.pyplot as plt
# Customize plotting
plt.style.use('seaborn-paper')
plt.rcParams['axes.labelsize'] = 11.0
plt.rcParams['axes.titlesize'] = 12.0
plt.rcParams['errorbar.capsize'] = 3.0
plt.rcParams['figure.dpi'] = 72.0
plt.rcParams['figure.titlesize'] = 12.0
plt.rcParams['legend.fontsize'] = 10.
plt.rcParams['lines.linewidth'] = 1.
plt.rcParams['xtick.labelsize'] = 11.0
plt.rcParams['ytick.labelsize'] = 11.0

import numpy as np
import sympy as sp
sp.init_printing(euler=True, use_latex=True)

from IPython import display
from scipy import io, optimize
from sklearn import metrics

import core
import dynamicals
import kernels
import numericals
import utils

In [None]:
# Example setup to run the inference
dynamical = dynamicals.Lorenz96(10) 

spl_t_0, spl_t_T, spl_freq = 0, 4, 80
obs_t_0, obs_t_T, obs_freq = 0, 4, 8
est_t_0, est_t_T, est_freq = 0, 4, 8
spl_tps, obs_tps, obs_t_indices, est_tps, est_t_indices = utils.create_time(
    spl_t_0, spl_t_T, spl_freq, obs_t_0, obs_t_T, obs_freq, est_t_0, est_t_T, est_freq)
X_0 = np.random.random(dynamical.num_x) * 8.
theta = np.array([8.]) 
rho_2 = np.full(dynamical.num_x, 4.) 
phi = [
    # (Kernal name, Kernal parameters)
    ('rbf', np.array([4.2, 0.1]))
] * dynamical.num_x
sigma_2 = np.full(dynamical.num_x, 1.) 
delta = np.full(dynamical.num_x, True)
delta[np.random.permutation(dynamical.num_x)[:int(0.35 * dynamical.num_x)]] = False
gamma = np.full(dynamical.num_x, 5e-2) 
gamma[delta] = 1e-1

opt_method = 'Newton-CG'
opt_tol = 1e-6
max_init_iter = 10
max_iter = 1000

plotting_enabled = True
plotting_freq = 50

spl_X = dynamical.generate_sample_path(theta, rho_2, X_0, spl_tps)
obs_Y = utils.collect_observations(spl_X, obs_t_indices, sigma_2)

data = core.laplace_mean_field(dynamical, 
                               spl_X, spl_tps, 
                               obs_Y, obs_tps, obs_t_indices, 
                               est_tps, est_t_indices,
                               theta, rho_2, phi, sigma_2, delta, gamma,
                               opt_method, opt_tol, max_init_iter, max_iter,
                               plotting_enabled, plotting_freq)

In [None]:
dynamical = dynamicals.Lorenz96(500) 

directory = '../data/tars/sde-lorenz-96-vgpamf/{}/'
config_filename = utils.CONFIG_FILENAME
data_filename = utils.DATA_FILENAME

vgpamf_directory = '/Users/ruifengxu/Development/ruiixu23/VGPA_MF/Results/'
vgpamf_filename = '{}-VGPA.mat'
vgpamf_config_filename = '{}-config.mat'

num_repetitions = 10
num_rodes = 100

In [None]:
# Helper to plot Lorenz 96 ODE trajactories
dynamical = dynamicals.Lorenz96(10) 

spl_t_0, spl_t_T, spl_freq = 0, 15, 100
obs_t_0, obs_t_T, obs_freq = 0, 15, 10
est_t_0, est_t_T, est_freq = 0, 15, 10
spl_tps, obs_tps, obs_t_indices, est_tps, est_t_indices = utils.create_time(
    spl_t_0, spl_t_T, spl_freq, obs_t_0, obs_t_T, obs_freq, est_t_0, est_t_T, est_freq)
X_0 = np.random.random(dynamical.num_x)
thetas = [
    np.array([0.5]),
    np.array([2.]),
    np.array([8.])
]
rho_2 = None
titles = [
    'F = 0.5',
    'F = 2',
    'F = 8'
]

figure = plt.figure(figsize=plt.figaspect(3 / 9))
for i in range(3):
    theta = thetas[i]
    spl_X = dynamical.generate_sample_path(theta, rho_2, X_0, spl_tps)
    ax = figure.add_subplot(1, 3, i + 1)
    ax.plot(spl_tps, spl_X[0], label='State $1$', linewidth=1.5)
    ax.plot(spl_tps, spl_X[1], label='State $2$', linewidth=1.5)
    ax.plot(spl_tps, spl_X[2], label='State $3$', linewidth=1.5)
    ax.set_xlabel('Time', fontsize=13.)    
    if i == 0:
        ax.set_ylabel('State', fontsize=13.)
    ax.set_title(titles[i], fontsize=15.)
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles=handles, labels=labels, loc=0)
    ax.set_xlim([0, 15])

plt.tight_layout()
plt.show()
figure.savefig('lorenz-96-trajectories.eps', format='eps', dpi=1000, bbox_inches='tight')

In [None]:
# Helper to print which states are observed
repetition = 6

(spl_t_0, spl_t_T, spl_tps, spl_freq,
 obs_t_0, obs_t_T, obs_freq, obs_tps, obs_t_indices,
 est_t_0, est_t_T, est_freq, est_tps, est_t_indices,
 X_0, theta, rho_2, phi, sigma_2, delta, gamma,
 opt_method, opt_tol, max_init_iter, max_iter, plotting_enabled, plotting_freq,
 spl_X, obs_Y) = utils.load_sde_config(directory.format(repetition), config_filename)

tmp = 0
for i in np.where(delta == True)[0]:
    if i == 0:
        pass
    elif int(i / 10) != tmp:
        tmp = int(i / 10)
        print()
    print(i, end=' ')

In [None]:
# Plotting of the example state and parameter estimation
repetition = 6

# Load data
(spl_t_0, spl_t_T, spl_tps, spl_freq,
 obs_t_0, obs_t_T, obs_freq, obs_tps, obs_t_indices,
 est_t_0, est_t_T, est_freq, est_tps, est_t_indices,
 X_0, theta, rho_2, phi, sigma_2, delta, gamma,
 opt_method, opt_tol, max_init_iter, max_iter, plotting_enabled, plotting_freq,
 spl_X, obs_Y) = utils.load_sde_config(directory.format(repetition), config_filename)

data = []
for i in range(1, num_rodes + 1):
    tmp = utils.load_data(directory.format(repetition), data_filename.format(i))
    if not np.alltrue(tmp['eta_theta'] > 0):
        raise RuntimeError('Negative theta value encountered for rode {}'.format(i))
    data.append(tmp)

X_mean = utils.get_X_mean(data)
X_var = utils.get_X_var(data)
theta_mean = utils.get_theta_mean(data)
theta_var = utils.get_theta_var(data)

with open(os.path.join(vgpamf_directory, vgpamf_filename.format(repetition)), 'rb') as infile:
    data = io.loadmat(infile)
vgpamf_X_mean = data['mt'][:, est_t_indices]
vgpamf_X_var = data['st'][:, est_t_indices]

# Plotting state estimation result
figure = plt.figure(figsize=[10, 12])
for idx, i in enumerate([95, 357, 119, 212]):
    # Plotting LPMF-SDE result
    ax = figure.add_subplot(4, 2, idx * 2 + 1)
    ax.plot(spl_tps, spl_X[i], color='C0', linestyle='-', linewidth=1.5, label='Sample path')
    if delta[i]:
        ax.scatter(obs_tps, obs_Y[i], color='C1', marker='x', label='Observation')
    ax.errorbar(est_tps, X_mean[i], color='C2', linestyle='--', linewidth=1.5, label='Estimation', 
                yerr=np.sqrt(X_var[i]), ecolor='0', elinewidth=1., capsize=3., capthick=.5)
    ax.set_xlabel('Time')
    ax.set_ylabel('State {}'.format(i + 1))
    ax.set_xlim([spl_t_0, spl_t_T])
    ax.set_ylim([-12, 16])
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles=handles, labels=labels, loc=0)
    if idx == 0:
        ax.set_title('LPMF-SDE')
    
    # Plotting VGPA-MF result
    ax = figure.add_subplot(4, 2, idx * 2 + 2)
    ax.plot(spl_tps, spl_X[i], color='C0', linestyle='-', linewidth=1.5, label='Sample path')
    if delta[i]:
        ax.scatter(obs_tps, obs_Y[i], color='C1', marker='x', label='Observation')
    ax.errorbar(est_tps, vgpamf_X_mean[i], color='C2', linestyle='--', linewidth=1.5, label='Estimation', 
                yerr=np.sqrt(vgpamf_X_var[i]), ecolor='0', elinewidth=1., capsize=3., capthick=.5)
    ax.set_xlabel('Time')
    ax.set_xlim([spl_t_0, spl_t_T])
    ax.set_ylim([-12, 16])
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles=handles, labels=labels, loc=0)
    if idx == 0:
        ax.set_title('VGPA-MF')
figure.tight_layout()   
plt.show()
figure.savefig('lorenz-96-states.eps', format='eps', dpi=1000, bbox_inches='tight')

# Plotting parameter estimation result
figure = plt.figure(figsize=plt.figaspect(1))
bar_width = 0.15
bar_indices = np.arange(theta.size)
ax = plt.gca()
ax.bar(bar_indices, theta, bar_width, color='C0', edgecolor='black', label='Truth')
ax.bar(bar_indices + bar_width, theta_mean, bar_width, yerr=np.sqrt(theta_var),
       color='C2', edgecolor='black', label='LPMF-SDE', 
       error_kw=dict( ecolor='0', elinewidth=1., capsize=3., capthick=.5))
ax.set_ylabel('Value')
ax.set_xlabel('Parameter')
ax.set_xlim([-0.35, 0.55])
ax.set_ylim([0, 10])
ax.set_xticks(bar_indices + bar_width / 2)
ax.set_xticklabels([r'${}$'.format(label) for label in dynamical.theta_labels])
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles=handles, labels=labels, loc=0)
figure.tight_layout()
plt.show()
figure.savefig('lorenz-96-parameters.eps', format='eps', dpi=1000, bbox_inches='tight')

In [None]:
# Helper to calculate RMSE of state estimation, gather parameter estimaiton and runtime result
# Generate the box plots 
rmse = []
vgpamf_rmse = []

theta_mean = []

runtime_mean = []
vgpamf_runtime = []

for repetition in range(1, num_repetitions + 1):
    # Load data
    (spl_t_0, spl_t_T, spl_tps, spl_freq,
     obs_t_0, obs_t_T, obs_freq, obs_tps, obs_t_indices,
     est_t_0, est_t_T, est_freq, est_tps, est_t_indices,
     X_0, theta, rho_2, phi, sigma_2, delta, gamma,
     opt_method, opt_tol, max_init_iter, max_iter, plotting_enabled, plotting_freq,
     spl_X, obs_Y) = utils.load_sde_config(directory.format(repetition), config_filename)
    
    # Load data from LPMF-SDE
    data = []
    for i in range(1, num_rodes + 1):
        tmp = utils.load_data(directory.format(repetition), data_filename.format(i))
        if not np.alltrue(tmp['eta_theta'] > 0):
            raise RuntimeError('Negative theta value encountered for rode {}'.format(i))
        data.append(tmp)
    
    X_mean = utils.get_X_mean(data)
    rmse.append([
        metrics.mean_squared_error(X_mean[i], spl_X[i, est_t_indices]) 
        for i in range(spl_X.shape[0])
    ])    
    theta_mean.append(list(utils.get_theta_mean(data)))
    runtime_mean.append(utils.get_runtime_mean(data))
    
    # Load data from VGPA-MF
    with open(os.path.join(vgpamf_directory, vgpamf_filename.format(repetition)), 'rb') as infile:
        data = io.loadmat(infile)
    vgpamf_X_mean = data['mt'][:, est_t_indices]
    vgpamf_rmse.append([
        metrics.mean_squared_error(vgpamf_X_mean[i], spl_X[i, est_t_indices])
        for i in range(spl_X.shape[0])
    ])
    vgpamf_runtime.append(data['runtime'].ravel()[0])
    
rmse = np.sqrt(rmse).T
theta_mean = np.array(theta_mean).T
runtime_mean = np.array(runtime_mean)

vgpamf_rmse = np.sqrt(vgpamf_rmse).T
vgpamf_runtime = np.array(vgpamf_runtime)

boxprops = dict(linestyle='-', linewidth=1., color='0')
medianprops = dict(linestyle='-', linewidth=1.2, color='red')
meanpointprops = dict(marker='D', markersize=6., markeredgecolor='green', markerfacecolor='green')

# Box plot for state estimation RMSE
figure = plt.figure(figsize=plt.figaspect(0.6))
ax = plt.gca()
rmse_data = [
    np.mean(rmse[delta == True, :], axis=0),
    np.mean(vgpamf_rmse[delta == True, :], axis=0),
    np.mean(rmse[delta == False, :], axis=0),
    np.mean(vgpamf_rmse[delta == False, :], axis=0)
]
labels = [
    'LPMF-SDE\nRMSE$_{obs}$', 
    'VGPA-MF\nRMSE$_{obs}}$', 
    'LPMF-SDE\nRMSE$_{unobs}$', 
    'VGPA-MF\nRMSE$_{unobs}$'
]
ax.boxplot(rmse_data, labels=labels, notch=False, showfliers=False, showmeans=True, 
           boxprops=boxprops, medianprops=medianprops, meanprops=meanpointprops, whis=[5, 95])
ax.set_ylabel('RMSE')
ax.set_xlabel('Method')
figure.tight_layout()
plt.show()
figure.savefig('lorenz-96-states-boxplot.eps', format='eps', dpi=1000, bbox_inches='tight')

# Box plot for parameter estimation
figure = plt.figure(figsize=plt.figaspect(1))
ax = plt.gca()
theta_data = theta_mean
labels = ['$F$']
ax.boxplot(theta_data, labels=labels, notch=False, showfliers=False, showmeans=True, 
           boxprops=boxprops, medianprops=medianprops, meanprops=meanpointprops, whis=[5, 95])
ax.plot(np.arange(3), np.full(3, 8), linestyle='--', label='Truth')
ax.set_ylim([7.5, 8.2])
ax.set_xlabel('Parameter')
# ax.set_ylabel('Value') The y_label is shared with another plot
handles, _ = ax.get_legend_handles_labels()
ax.legend(handles=handles, labels=['Truth'], loc=0)    
figure.tight_layout()
plt.show()
figure.savefig('lorenz-96-parameters-boxplot.eps', format='eps', dpi=1000, bbox_inches='tight')

# Box plot for runtime
figure = plt.figure(figsize=plt.figaspect(1))
ax = plt.gca()
runtime_data = np.array([
    runtime_mean,
    vgpamf_runtime
]) 
labels = ['LPMF-SDE\n', 'VGPA-MF\n']
ax.boxplot(runtime_data.T, labels=labels, notch=False, showfliers=False, showmeans=True, 
           boxprops=boxprops, medianprops=medianprops, meanprops=meanpointprops, whis=[5, 95])
ax.set_xlabel('Method')
ax.set_ylabel('Runtime (s)')
ax.set_ylim([2000, 3400])
figure.tight_layout()
plt.show()
figure.savefig('lorenz-96-runtime-boxplot.eps', format='eps', dpi=1000, bbox_inches='tight')

In [None]:
# Helper to tranform VPGA sample path into our settings
for repetition in range(1, num_repetitions + 1):
    with open(os.path.join(directory.format(repetition), 'config.mat'), 'rb') as infile:
        config_mat = io.loadmat(infile)
        
    spl_t_0, spl_t_T, spl_freq = 0, 4, 100
    obs_t_0, obs_t_T, obs_freq = 0, 4, 1
    est_t_0, est_t_T, est_freq = 0, 4, 1
    spl_tps, obs_tps, obs_t_indices, est_tps, est_t_indices = utils.create_time(
        spl_t_0, spl_t_T, spl_freq, obs_t_0, obs_t_T, obs_freq, est_t_0, est_t_T, est_freq)

    obs_freq = 8    
    obs_t_indices = np.array(config_mat['obsX'].ravel(), dtype=np.int)
    obs_tps = spl_tps[obs_t_indices]
    
    est_freq = 8
    est_t_indices = obs_t_indices.copy()    
    est_tps = spl_tps[est_t_indices]
        
    X_0 = np.array(list(config_mat['Xt'][:, 0]), dtype=np.float)
    num_x = X_0.shape[0]
    theta = np.array([8.]) 
    rho_2 = np.full(num_x, 4.) 
    phi = [
        # (Kernal name, Kernal parameters)
        ('rbf', np.array([4.2, 0.1]))
    ] * num_x
    sigma_2 = np.full(num_x, 1.) 
    delta = np.full(num_x, False)
    delta[config_mat['dMask'] - 1] = True
    gamma = np.full(num_x, 5e-2) 
    gamma[delta] = 1e-1

    opt_method = 'Newton-CG'
    opt_tol = 1e-6
    max_init_iter = 10
    max_iter = 1000

    plotting_enabled = False
    plotting_freq = 50
    
    spl_X = np.array(list(config_mat['Xt']))
    obs_Y = np.zeros((num_x, obs_tps.size))
    obs_Y[delta, :] = np.array(list(config_mat['obsY']))    
    
    utils.save_sde_config(directory.format(repetition), config_filename,
                          spl_t_0, spl_t_T, spl_freq, spl_tps,
                          obs_t_0, obs_t_T, obs_freq, obs_tps, obs_t_indices,
                          est_t_0, est_t_T, est_freq, est_tps, est_t_indices,
                          X_0, theta, rho_2, phi, sigma_2, delta, gamma,
                          opt_method, opt_tol, max_init_iter, max_iter,
                          plotting_enabled, plotting_freq, spl_X, obs_Y)    

In [None]:
# Helper to check that the result are matching
for repetition in range(1, num_repetitions + 1):
    with open(os.path.join(directory.format(repetition), 'config.mat'), 'rb') as infile:
        config_mat = io.loadmat(infile)    

    config_i = utils.load_data(directory.format(repetition), config_filename)
    
    assert np.alltrue(config_mat['Xt'] == config_i['spl_X'])
    
    with open(os.path.join(vgpamf_directory, vgpamf_config_filename).format(repetition), 
              'rb') as infile:
        config_mat_original = io.loadmat(infile)
        
    assert np.alltrue(config_mat['Xt'] == config_mat_original['Xt'])
    
    for j in range(repetition, num_repetitions + 1):
        config_j = utils.load_data(directory.format(j), config_filename)
        if repetition == j:
            assert np.all(config_i['spl_X'] == config_j['spl_X'])
        else:
            assert np.any(config_i['spl_X'] != config_j['spl_X'])        