In [None]:
# %load init.py
import os
import pickle
import sys
# Enable module import from the parent directory from notebooks
sys.path.append(os.path.abspath('..'))
import time

import matplotlib as mpl
# Select plotting backend
mpl.use('nbAgg')

import matplotlib.pyplot as plt
# Customize plotting
plt.style.use('seaborn-paper')
plt.rcParams['axes.labelsize'] = 11.0
plt.rcParams['axes.titlesize'] = 12.0
plt.rcParams['errorbar.capsize'] = 3.0
plt.rcParams['figure.dpi'] = 72.0
plt.rcParams['figure.titlesize'] = 12.0
plt.rcParams['legend.fontsize'] = 10.
plt.rcParams['lines.linewidth'] = 1.
plt.rcParams['xtick.labelsize'] = 11.0
plt.rcParams['ytick.labelsize'] = 11.0

import numpy as np
import sympy as sp
sp.init_printing(euler=True, use_latex=True)

from IPython import display
from scipy import io, optimize
from sklearn import metrics

import core
import dynamicals
import kernels
import numericals
import utils

In [None]:
# Example setup to run the inference
dynamical = dynamicals.Lorenz63()
config = core.Config()
config.create_time(0, 20, 100, 0, 20, 5, 0, 20, 5)
config.X_0 = np.random.random(dynamical.num_x) * 10.
config.theta = np.array([10., 28., 8. / 3])
config.rho_2 = np.full(dynamical.num_x, 10.)

config.phi = [
    #  (Kernal name, Kernal parameters)
    ('rbf', np.array([3.6, 0.15])),
    ('rbf', np.array([3.6, 0.15])),
    ('rbf', np.array([3.6, 0.15]))
]
config.sigma_2 = np.full(dynamical.num_x, 2.) 
config.delta = np.full(dynamical.num_x, True)
config.gamma = np.full(dynamical.num_x, 1e-4) 

config.opt_method = 'Newton-CG'
config.opt_tol = 1e-6
config.max_init_iter = 5
config.max_iter = 2000

config.plotting_enabled = True
config.plotting_freq = 50

config.spl_X = dynamical.generate_sample_path(config.theta, config.rho_2, config.X_0, config.spl_tps)
config.obs_Y = utils.collect_observations(config.spl_X, config.obs_t_indices, config.sigma_2)

In [None]:
gp = core.GaussianProcessRegression(dynamical, config)
gp.run()

lpmf = core.LaplaceMeanFieldSDE(dynamical, config, gp)
lpmf.run()

In [None]:
dynamical = dynamicals.Lorenz63()

full_directory = '../data/tars/sde-lorenz-63/full/{}/'
partial_directory = '../data/tars/sde-lorenz-63/partial/{}'

config_filename = utils.CONFIG_FILENAME
data_filename = utils.DATA_FILENAME

vgpa_directory = '/Users/ruifengxu/Development/ruiixu23/VGPA/results/parameter-lorenz-63/{}/'
vgpa_filename = 'result.pickle'
vgpa_config_filename = 'sample.pickle'

partial_delta = np.array([True, False, True])

num_repetitions = 10
num_rodes = 100

In [None]:
# Plotting of the state estimation for one SDE sample path and the 3D trajectory
repetition = 10
x_index = 1

(spl_t_0, spl_t_T, spl_tps, spl_freq,
 obs_t_0, obs_t_T, obs_freq, obs_tps, obs_t_indices,
 est_t_0, est_t_T, est_freq, est_tps, est_t_indices,
 X_0, theta, rho_2, phi, sigma_2, delta, gamma,
 opt_method, opt_tol, max_init_iter, max_iter, plotting_enabled, plotting_freq,
 spl_X, obs_Y) = utils.load_sde_config(full_directory.format(repetition), config_filename)

# Load LPMF full observation
data = []
for i in range(1, num_rodes + 1):
    tmp = utils.load_data(full_directory.format(repetition), data_filename.format(i))
    if not np.alltrue(tmp['eta_theta'] > 0):
        raise RuntimeError('Negative theta value encountered for rode {}'.format(i))
    data.append(tmp)
    
full_X_mean = utils.get_X_mean(data)
full_X_var = utils.get_X_var(data)
full_theta_mean = utils.get_theta_mean(data)
full_theta_var = utils.get_theta_var(data)

# Load LPMF partial observation 
data = []
for i in range(1, num_rodes + 1):
    tmp = utils.load_data(partial_directory.format(repetition), data_filename.format(i))
    if not np.alltrue(tmp['eta_theta'] > 0):
        raise RuntimeError('Negative theta value encountered for rode {}'.format(i))
    data.append(tmp)
    
partial_X_mean = utils.get_X_mean(data)
partial_X_var = utils.get_X_var(data)    
partial_theta_mean = utils.get_theta_mean(data)
partial_theta_var = utils.get_theta_var(data)

# Load VGPA full observation
data = utils.load_data(vgpa_directory.format(repetition), vgpa_filename)
vgpa_X_mean = np.array(data['mt']).T[:, est_t_indices]
tmp = data['St']
vgpa_X_var = np.array([
    [
        tmp[j, :, :][i, i]
        for j in range(spl_tps.size)
    ]
    for i in range(3)
])[:, est_t_indices]
vgpa_theta_mean = data['theta_Drift']
vgpa_theta_var = np.zeros(3)

# Plotting state estimation result
figure = plt.figure(figsize=[10, 12])

ax = figure.add_subplot(3, 1, 1)
ax.plot(spl_tps, spl_X[x_index], color='C0', linestyle='-', linewidth=1.5, label='Sample path')
if delta[x_index]:
    ax.scatter(obs_tps, obs_Y[x_index], color='C1', marker='x', label='Observation')
ax.errorbar(est_tps, full_X_mean[x_index], color='C2', linestyle='--', linewidth=1.5, label='Estimation', 
            yerr=np.sqrt(full_X_var[x_index]), ecolor='0', elinewidth=1., capsize=3., capthick=.5)
ax.set_xlabel('Time')
ax.set_ylabel('State y')
ax.set_title('LPMF-SDE-F')
ax.set_xlim([spl_t_0, spl_t_T])
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles=handles, labels=labels, loc=1)

ax = figure.add_subplot(3, 1, 2)
ax.plot(spl_tps, spl_X[x_index], color='C0', linestyle='-', linewidth=1.5, label='Sample path')
if partial_delta[x_index]:
    ax.scatter(obs_tps, obs_Y[x_index], color='C1', marker='x', label='Observation')
ax.errorbar(est_tps, partial_X_mean[x_index], color='C2', linestyle='--', linewidth=1.5, label='Estimation', 
            yerr=np.sqrt(partial_X_var[x_index]), ecolor='0', elinewidth=1., capsize=3., capthick=.5)
ax.set_xlabel('Time')
ax.set_ylabel('State y')
ax.set_title('LPMF-SDE-P')
ax.set_xlim([spl_t_0, spl_t_T])
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles=handles, labels=labels, loc=1)

ax = figure.add_subplot(3, 1, 3)
ax.plot(spl_tps, spl_X[x_index], color='C0', linestyle='-', linewidth=1.5, label='Sample path')
if delta[x_index]:
    ax.scatter(obs_tps, obs_Y[x_index], color='C1', marker='x', label='Observation')
ax.errorbar(est_tps, vgpa_X_mean[x_index], color='C2', linestyle='--', linewidth=1.5, 
            label='Estimation', yerr=np.sqrt(vgpa_X_var[x_index]), ecolor='0', elinewidth=1.,
            capsize=3., capthick=.5)
ax.set_xlabel('Time')
ax.set_ylabel('State y')
ax.set_title('VGPA-MAP')
ax.set_xlim([spl_t_0, spl_t_T])
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles=handles, labels=labels, loc=1) 
figure.tight_layout()
plt.show()
figure.savefig('lorenz-63-states.eps', format='eps', dpi=1000, bbox_inches='tight')

# Plotting parameter estimation result
figure = plt.figure(figsize=plt.figaspect(1))
ax = plt.gca()
bar_indices = np.arange(3)
bar_width = 0.15
ax.bar(bar_indices + 0.2, theta, bar_width, color='C0', label='Truth', edgecolor='black')
ax.bar(bar_indices + 0.2 + bar_width, full_theta_mean, bar_width, yerr=full_theta_var, label='LPMF-SDE-F',
       color='C1', edgecolor='black', error_kw=dict( ecolor='0', elinewidth=1., capsize=3., capthick=.5))
ax.bar(bar_indices + 0.2 + bar_width * 2, partial_theta_mean, bar_width, label='LPMF-SDE-P',
       yerr=partial_theta_var, color='C2', edgecolor='black', 
       error_kw=dict( ecolor='0', elinewidth=1., capsize=3., capthick=.5))
ax.bar(bar_indices + 0.2 + bar_width * 3, vgpa_theta_mean, bar_width, yerr=vgpa_theta_var, label='VGPA-MAP',
       color='C3', edgecolor='black', 
       error_kw=dict( ecolor='0', elinewidth=1., capsize=3., capthick=.5))
ax.set_xlabel('Parameter')
ax.set_ylabel('Value')
ax.set_xticks(bar_indices + bar_width / 2 + 0.35)
ax.set_xticklabels([r'${}$'.format(label) for label in dynamical.theta_labels])
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles=handles, labels=labels, loc=0)
figure.tight_layout()
plt.show()
figure.savefig('lorenz-63-parameters.eps', format='eps', dpi=1000, bbox_inches='tight')

# Plot the sample path
from mpl_toolkits.mplot3d import Axes3D

state_labels = dynamical.x_labels
figure = plt.figure(figsize=plt.figaspect(.75 / 1.) * 1.5)
ax = figure.gca(projection='3d')
ax.view_init(elev=15)
ax.plot(spl_X[0], spl_X[1], spl_X[2], color='C0', linestyle='-', linewidth=1.5, label='Sample path')

ax.set_xlabel(r'State ${}$'.format(state_labels[0]))
ax.set_ylabel(r'State ${}$'.format(state_labels[1]))
ax.set_zlabel(r'State ${}$'.format(state_labels[2]))
plt.tight_layout()
plt.show()

figure.savefig('lorenz-63-sample-path.eps', format='eps', dpi=1000, bbox_inches='tight')

In [None]:
# Helper to calculate RMSE of state estimation, gather parameter estimaiton and runtime result
# Generate the box plots 
full_rmse = []
partial_rmse = []
vgpa_rmse = []

full_theta_mean = []
partial_theta_mean = []
vgpa_theta_mean = []

full_runtime_mean = []
partial_runtime_mean = []
vgpa_runtime = []

for repetition in range(1, num_repetitions + 1):        
    (spl_t_0, spl_t_T, spl_tps, spl_freq,
     obs_t_0, obs_t_T, obs_freq, obs_tps, obs_t_indices,
     est_t_0, est_t_T, est_freq, est_tps, est_t_indices,
     X_0, theta, rho_2, phi, sigma_2, delta, gamma,
     opt_method, opt_tol, max_init_iter, max_iter, plotting_enabled, plotting_freq,
     spl_X, obs_Y) = utils.load_sde_config(full_directory.format(repetition), config_filename)
    
    # Process LPMF full observation
    data = []
    for i in range(1, num_rodes + 1):
        tmp = utils.load_data(full_directory.format(repetition), data_filename.format(i))
        if not np.alltrue(tmp['eta_theta'] > 0):
            raise RuntimeError('Negative theta value encountered for rode {}'.format(i))
        data.append(tmp)
        
    full_X_mean = utils.get_X_mean(data)
    full_rmse.append([
        metrics.mean_squared_error(full_X_mean[i], spl_X[i, est_t_indices]) 
        for i in range(spl_X.shape[0])
    ])
    full_theta_mean.append(list(utils.get_theta_mean(data)))
    full_runtime_mean.append(np.mean([
        item['runtime'] for item in data
    ]))

    # Process LPMF partial observation
    data = []
    for i in range(1, num_rodes + 1):
        tmp = utils.load_data(partial_directory.format(repetition), data_filename.format(i))
        if not np.alltrue(tmp['eta_theta'] > 0):
            raise RuntimeError('Negative theta value encountered for rode {}'.format(i))
        data.append(tmp)
        
    partial_X_mean = utils.get_X_mean(data)
    partial_rmse.append([
        metrics.mean_squared_error(partial_X_mean[i], spl_X[i, est_t_indices]) 
        for i in range(spl_X.shape[0])
    ])
    partial_theta_mean.append(list(utils.get_theta_mean(data)))
    partial_runtime_mean.append(np.mean([
        item['runtime'] for item in data
    ]))
    
    # Process VGPA full observation
    data = utils.load_data(vgpa_directory.format(repetition), vgpa_filename)    
    vgpa_X_mean = np.array(data['mt']).T[:, est_t_indices]
    vgpa_rmse.append([
        metrics.mean_squared_error(vgpa_X_mean[i], spl_X[i, est_t_indices])
        for i in range(spl_X.shape[0])
    ])
    vgpa_theta_mean.append(list(data['theta_Drift']))
    vgpa_runtime.append(data['time'])
    
full_rmse = np.sqrt(full_rmse).T
full_theta_mean = np.array(full_theta_mean).T    
full_runtime_mean = np.array(full_runtime_mean)

partial_rmse = np.sqrt(partial_rmse).T
partial_theta_mean = np.array(partial_theta_mean).T    
parital_runtime_mean = np.array(partial_runtime_mean)
    
vgpa_rmse = np.sqrt(vgpa_rmse).T
vgpa_theta_mean = np.array(vgpa_theta_mean).T    
vgpa_runtime = np.array(vgpa_runtime)

boxprops = dict(linestyle='-', linewidth=1., color='0')
medianprops = dict(linestyle='-', linewidth=1.2, color='red')
meanpointprops = dict(marker='D', markersize=6., markeredgecolor='green', markerfacecolor='green')

# Box plot for state estimation RMSE
figure = plt.figure(figsize=plt.figaspect(0.6))
ax = plt.gca()
rmse_data = [
    np.mean(full_rmse, axis=0),
    np.mean(partial_rmse, axis=0),
    np.mean(vgpa_rmse, axis=0),
]
labels = ['LPMF-SDE-F', 'LPMF-SDE-P', 'VGPA-MAP']
ax.boxplot(rmse_data, labels=labels, notch=False, showfliers=False, showmeans=True, 
           boxprops=boxprops, medianprops=medianprops, meanprops=meanpointprops, whis=[5, 95])
ax.set_xlabel('Method')
ax.set_ylabel('RMSE')
figure.tight_layout()
plt.show()
figure.savefig('lorenz-63-states-boxplot.eps', format='eps', dpi=1000, bbox_inches='tight')

# Box plot for parameter estimation
labels = ['LPMF-SDE-F', 'LPMF-SDE-P', 'VGPA-MAP']
y_lims = [
    [0, 15],
    [None, 29],
    [0, 3]
]
for i in range(3):
    figure = plt.figure(figsize=plt.figaspect(1))
    ax = plt.gca()
    theta_data = [
        full_theta_mean[i],
        partial_theta_mean[i],
        vgpa_theta_mean[i]
    ]
    ax.boxplot(theta_data, labels=labels, notch=False, showfliers=False, showmeans=True, 
               boxprops=boxprops, medianprops=medianprops, meanprops=meanpointprops, whis=[5, 95])
    ax.plot(np.arange(5), np.full(5, theta[i]), linestyle='--', label='Truth')
    ax.set_xlabel('Method')
    ax.set_ylabel(r'${}$'.format(dynamical.theta_labels[i]))
    ax.set_ylim(y_lims[i])
    handles, _ = ax.get_legend_handles_labels()
    ax.legend(handles=handles, labels=['Truth'], loc=0)    
    figure.tight_layout()
    plt.show()
    figure.savefig('lorenz-63-parameters-{}-boxplot.eps'.format(dynamical.theta_labels[i][1:]), 
                   format='eps', dpi=1000, bbox_inches='tight')
    
# Box plot for runtime
figure = plt.figure(figsize=plt.figaspect(1))
ax = plt.gca()
runtime_data = [
    full_runtime_mean,
    parital_runtime_mean,
    vgpa_runtime
]
labels = ['LPMF-SDE-F', 'LPMF-SDE-P', 'VGPA-MAP']
ax.boxplot(runtime_data, labels=labels, notch=False, showfliers=False, showmeans=True, 
           boxprops=boxprops, medianprops=medianprops, meanprops=meanpointprops, whis=[5, 95])
ax.set_xlabel('Method')
ax.set_ylabel('Runtime (s)')
figure.tight_layout()
plt.show()
figure.savefig('lorenz-63-runtime-boxplot.eps', format='eps', dpi=1000, bbox_inches='tight')

In [None]:
# Helper to tranform VPGA sample path into our settings
for repetition in range(1, num_repetitions + 1):
    sample = utils.load_data(full_directory.format(repetition), 'sample.pickle')
    
    spl_t_0, spl_t_T, spl_freq = 0, 20, 100
    obs_t_0, obs_t_T, obs_freq = 0, 20, 5
    est_t_0, est_t_T, est_freq = 0, 20, 5
    spl_tps, obs_tps, obs_t_indices, est_tps, est_t_indices = utils.create_time(
        spl_t_0, spl_t_T, spl_freq, obs_t_0, obs_t_T, obs_freq, est_t_0, est_t_T, est_freq)
    
    spl_tps = sample['Tw']
    obs_t_indices = sample['obsX']
    obs_tps = sample['Tw'][sample['obsX']]
    est_t_indices = obs_t_indices.copy()
    est_tps = obs_tps.copy()
    
    X_0 = sample['xt_true'][0]
    num_x = X_0.shape[0]
    theta = np.array([10., 28., 8. / 3])
    rho_2 = np.full(num_x, 10.)
    sigma_2 = np.full(num_x, 2.)
    phi = [
        # (Kernal name, Kernal parameters)
        ('rbf', np.array([3.6, 0.15])),
        ('rbf', np.array([3.6, 0.15])),
        ('rbf', np.array([3.6, 0.15]))
    ]
    delta = np.full(num_x, True)
    gamma = np.full(num_x, 1e-4) 

    opt_method = 'Newton-CG'
    opt_tol = 1e-6
    max_init_iter = 5
    max_iter = 2000

    plotting_enabled = False
    plotting_freq = 50
    
    spl_X = sample['xt_true'].T
    obs_Y = sample['obsY'].T

    utils.save_sde_config(full_directory.format(repetition), config_filename,
                          spl_t_0, spl_t_T, spl_freq, spl_tps,
                          obs_t_0, obs_t_T, obs_freq, obs_tps, obs_t_indices,
                          est_t_0, est_t_T, est_freq, est_tps, est_t_indices,
                          X_0, theta, rho_2, phi, sigma_2, delta, gamma,
                          opt_method, opt_tol, max_init_iter, max_iter,
                          plotting_enabled, plotting_freq, spl_X, obs_Y)
    
    delta[1] = False
    utils.save_sde_config(partial_directory.format(repetition), config_filename,
                          spl_t_0, spl_t_T, spl_freq, spl_tps,
                          obs_t_0, obs_t_T, obs_freq, obs_tps, obs_t_indices,
                          est_t_0, est_t_T, est_freq, est_tps, est_t_indices,
                          X_0, theta, rho_2, phi, sigma_2, delta, gamma,
                          opt_method, opt_tol, max_init_iter, max_iter,
                          plotting_enabled, plotting_freq, spl_X, obs_Y)  

In [None]:
# Helper to check that the result are matching
for repetition in range(1, num_repetitions + 1):
    full_sample = utils.load_data(full_directory.format(repetition), 'sample.pickle')
    partial_sample = utils.load_data(partial_directory.format(repetition), 'sample.pickle')
    
    full_config = utils.load_data(full_directory.format(repetition), config_filename)
    partial_config = utils.load_data(partial_directory.format(repetition), config_filename)
    
    assert np.alltrue(full_config['delta'])
    assert not np.alltrue(partial_config['delta']) and partial_config['delta'][1] == False
    
    assert np.alltrue(full_sample['xt_true'].T == full_config['spl_X'])
    assert np.alltrue(partial_sample['xt_true'].T == partial_config['spl_X'])
    
    original_sample = utils.load_data(vgpa_directory.format(repetition), vgpa_config_filename)
        
    assert np.alltrue(full_sample['xt_true'] == original_sample['xt_true'])
    assert np.alltrue(partial_sample['xt_true'] == original_sample['xt_true'])
    
    for j in range(repetition, num_repetitions + 1):
        config_j = utils.load_data(full_directory.format(j), config_filename)
        if repetition == j:
            assert np.all(full_config['spl_X'] == config_j['spl_X'])
        else:
            assert np.any(full_config['spl_X'] != config_j['spl_X'])
            
        config_j = utils.load_data(partial_directory.format(j), config_filename)
        if repetition == j:
            assert np.all(partial_config['spl_X'] == config_j['spl_X'])
        else:
            assert np.any(partial_config['spl_X'] != config_j['spl_X'])