In [None]:
## Mount google drive. Uncomment these 2 lines if you want to train new models and save them in your drive. Saving in Colab is volatile.
# from google.colab import drive
# drive.mount('/content/drive')

## Clone the data and the files from github
%cd '/content'
GIT_USERNAME = "zeyangding96"
GIT_REPOSITORY = "ral20-pue"
GIT_PATH = "https://github.com/" + GIT_USERNAME + "/" + GIT_REPOSITORY + ".git"
!rm -rf "{GIT_REPOSITORY}"
!git clone "{GIT_PATH}"
%cd "{GIT_REPOSITORY}"

import matplotlib.pyplot as plt
import numpy as np
import os
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from myData import *
from myUtils import *
from myModel import *

device = check_gpu(); print('Using', device)

In [2]:
train_dir = 'data/train'; test_dir = 'data/test'
input_dim = 2; output_dim = 18
norm_method = 'scale'
shuffle = True; lr = 7e-3; wd = 1e-4
epochs = 2; seq_len = 100; stride = seq_len
criterion = 'nll'; logvar_dim = 18; epistemic = 'gmm'; num_ensemble = 2; num_MC = 50
core = 'lstm'; dropout = 0.5
# savepath = '/content/drive/My Drive/???' ## Specify the savepath if you wish to save the trained models to your google drive. 

if num_ensemble == 1:
    bs = [32]; hidden_dim = [128]
else:
    np.random.seed(0)
    bs = np.random.permutation(range(25,36))[:num_ensemble].tolist()
    hidden_dim = np.random.permutation(range(120,141))[:num_ensemble].tolist()
    if len(bs) != num_ensemble: bs += np.random.permutation(range(25,36))[:num_ensemble-len(bs)].tolist()
    if len(hidden_dim) != num_ensemble: hidden_dim += np.random.permutation(range(120,141))[:num_ensemble-len(hidden_dim)].tolist()

X, Y, norm_param = prepare_data(train_dir, input_dim, output_dim, seq_len, stride)
X_norm, Y_norm = normalize_data(X, Y, norm_param, norm_method)
data_train = myDataset(X_norm, Y_norm)

X_Test, Y_Test = load_data_to_timseries(test_dir, input_dim, output_dim, 1, 1)
x_test, y_test = normalize_data(X_Test, Y_Test, norm_param, norm_method)

## Simulated drift for OOD analysis
# interval = range(0,900); N = len(interval)
# x_test[300:600,:,0] += -0.003 * np.arange(300)[:,None]
# x_test[300:600,:,1] += 0.007 * np.arange(300)[:,None]
# x_test[interval,:,0] = np.clip(x_test[interval,:,0], -1.5, 1.5); x_test[interval,:,1] = np.clip(x_test[interval,:,1], -1.5, 1.5)
# x_test = x_test[:1000]; y_test = y_test[:1000]

data_test = myDataset(x_test.transpose(1,0,2), y_test.transpose(1,0,2))
loader_test = DataLoader(data_test)
target = torch.tensor(y_test.reshape(-1, output_dim), dtype=torch.float, device='cpu')

In [None]:
d = {'rmse': [], 'nll': []}
for ii in range(1): ## Average across seeds
    print('Seed: ', ii)
    set_rng_seed(ii)

    ## Training
    Mu_test_ensemble = torch.empty(num_ensemble, 1, x_test.shape[0], output_dim)
    Var_test_ensemble = torch.empty(num_ensemble, 1, x_test.shape[0], logvar_dim)
    models = []
    for j in range(num_ensemble):
        print('Ensemble: ', j)
        model = myRnnPue(input_dim, hidden_dim[j], output_dim, logvar_dim, core=core, dropout=dropout, device=device)
        optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
        loader_train = DataLoader(data_train, batch_size=bs[j], shuffle=shuffle)
        model.init_train(optimizer, loader_train, loader_val=loader_test)
        model.Train(epochs=epochs, criterion=criterion, check_train=False, check_val=False, verbose=True)
        models += [model]
        Mu_test_ensemble[j], Var_test_ensemble[j], _ = model.validate(loader_test, loss=False)
    # torch.save(models, savepath+'-seed'+str(ii))

    ## Testing/Inference
    Mu_test = Mu_test_ensemble.mean(dim=0).squeeze()
    if epistemic == 'gmm': # gaussian mixture model
        Var_test = ( (Var_test_ensemble + Mu_test_ensemble**2).mean(dim=0) - Mu_test**2 ).squeeze(dim=0)
    elif epistemic == 'gal' or epistemic == 'kendall': # MC dropout
        Mu_test_mc = torch.empty(num_MC, 1, x_test.shape[0], output_dim)
        Var_test_mc = torch.empty(num_MC, 1, x_test.shape[0], logvar_dim)
        for j in range(num_MC):
            Mu_test_mc[j], Var_test_mc[j], _ = model.validate(loader_test, loss=False, mcdropout=True)
        Mu_test = Mu_test_mc.mean(dim=0).squeeze(dim=0)
        term1 = (Mu_test_mc**2).mean(dim=0).squeeze(dim=0)
        term3 = Var_test_mc.mean(dim=0).squeeze(dim=0) if epistemic == 'kendall' else 0
        Var_test = term1 - Mu_test**2 + term3
    else:
        Var_test = Mu_test_ensemble.var(dim=0).squeeze(dim=0)
    rmse = F.mse_loss(Mu_test, target).sqrt().item()
    nll = model.nll_loss(Mu_test, torch.log(Var_test), target).item()
    d['rmse'].append(rmse)
    d['nll'].append(nll)

In [None]:
## Here are the codes to generate some of the figures in the paper. I have saved my trained models in "model_saves" folder and load it here to plot the figures.
## New models can also be trained to generate the figures, but they will be slightly different from those in the paper because every network trained is different
## depending on the random seed and the machines.

## Figure 4: Performance vs number of ensemble. CPU takes a long time to run.
names = ['lstm-15-nll-gmm-seed0','lstm-15-nll-gmm-seed1','lstm-15-nll-gmm-seed2','lstm-15-nll-gmm-seed3','lstm-15-nll-gmm-seed4']
rmse_plot_mean = []; rmse_plot_std = []
nll_plot_mean = []; nll_plot_std = []
for M in range(1,16):
    print(M)
    d = {'rmse': [], 'nll': []}
    for name in names:
        models = torch.load('model_saves/'+name, map_location=device)
        Mu_test_ensemble = torch.empty(M, 1, x_test.shape[0], output_dim)
        Var_test_ensemble = torch.empty(M, 1, x_test.shape[0], output_dim)
        for j, model in enumerate(models):
            model.device = device
            Mu_test_ensemble[j], Var_test_ensemble[j], _ = model.validate(loader_test, loss=False)
            if j == M-1: break
        Mu_test = Mu_test_ensemble.mean(dim=0).squeeze()
        Var_test = ( (Var_test_ensemble + Mu_test_ensemble**2).mean(dim=0) - Mu_test**2 ).squeeze()
        rmse = F.mse_loss(Mu_test, target).sqrt().item()
        nll = model.nll_loss(Mu_test, torch.log(Var_test), target).item()
        d['rmse'].append(rmse)
        d['nll'].append(nll)
    rmse_plot_mean.append( np.mean(d['rmse']) )
    rmse_plot_std.append( np.std(d['rmse']) )
    nll_plot_mean.append( np.mean(d['nll']) )
    nll_plot_std.append( np.std(d['nll']) )

## Plot
plt.style.use('seaborn-darkgrid')
fig = plt.figure(figsize=(7,3))
ax1 = fig.add_subplot()
plot1 = ax1.plot(rmse_plot_mean, marker='o', markeredgecolor='C0', color='C0', label='RMSE')
ax1.tick_params(axis='y', labelcolor='C0')
ax1.set_xlabel('Number of models in the ensemble, $M$', fontsize=14)
ax1.set_xticks(list(range(0,15,1)))
ax1.set_xticklabels(['','2','','4','','6','','8','','10','','12','','14'])
ax2 = ax1.twinx()
plot2 = ax2.plot(nll_plot_mean, marker='^', markeredgecolor='C1', color='C1', label='NLL')
ax2.tick_params(axis='y', labelcolor='C1')
plots = plot1+plot2
labels = [p.get_label() for p in plots]
ax1.legend(plots, labels, fontsize=13, loc='best')
ax1.set_yticks(np.round(np.linspace(ax1.get_ybound()[0], ax1.get_ybound()[1], 5), decimals=4))
ax2.set_yticks(np.round(np.linspace(ax2.get_ybound()[0], ax2.get_ybound()[1], 5), decimals=4))
ax1.tick_params(axis='both', labelsize=13)
ax2.tick_params(axis='y', labelsize=13)

fig.show()

In [None]:
num_ensemble = 10
Mu_test_ensemble = torch.empty(num_ensemble, 1, x_test.shape[0], output_dim)
Var_test_ensemble = torch.empty(num_ensemble, 1, x_test.shape[0], logvar_dim)
for j in range(num_ensemble):
    model = torch.load('model_saves/lstm-15-nll-gmm-seed0')[j]
    Mu_test_ensemble[j], Var_test_ensemble[j], _ = model.validate(loader_test, loss=False)
Mu_test = Mu_test_ensemble.mean(dim=0).squeeze()
Var_test = ( (Var_test_ensemble + Mu_test_ensemble**2).mean(dim=0) - Mu_test**2 ).squeeze(dim=0)
Std_test = Var_test.sqrt()
Mu_denorm, Std_denorm = denormalize_data(Mu_test.numpy(), norm_param, norm_method, Std_test.numpy())
colorTrue = 'red'; colorHat = 'blue'; colorFill = 'cornflowerblue'; alpha = 0.5
lower = (Mu_denorm - 2*Std_denorm).squeeze()
upper = (Mu_denorm + 2*Std_denorm).squeeze()
N = range(9000,11000)#range(8000,12000)

plt.style.use('seaborn-darkgrid')
fig = plt.figure(figsize=(10,10))
gs = fig.add_gridspec(4,1)
plt.subplots_adjust(hspace=0.05)

ax1 = fig.add_subplot(gs[0, :]); idx = 9; ## X8. idx 2 to 8 correspond to X1 to X7.
ax1.plot(Y_Test.squeeze()[N, idx], color=colorTrue)
ax1.plot(Mu_denorm[N, idx], color=colorHat)
ax1.fill_between(range(len(N)), lower[N, idx], upper[N, idx], alpha=alpha, color=colorFill)
# ax1.legend(['Actual', 'Predict', r'$\pm 2\sigma$'])
ax1.set_xticks(list(range(0,4500,500)))
ax1.set_xticklabels(list(range(0,450,50)))
plt.setp(ax1.get_xticklabels(), visible=False)
ax1.set_ylabel('$x_8$ (m)', fontsize=17)

ax2 = fig.add_subplot(gs[1, :], sharex=ax1); idx = -1; ## Y8. idx 10-17 correspond to Y1 to Y7.
ax2.plot(Y_Test.squeeze()[N, idx], color=colorTrue)
ax2.plot(Mu_denorm[N, idx], color=colorHat)
ax2.fill_between(range(len(N)), lower[N, idx], upper[N, idx], alpha=alpha, color=colorFill)
ax2.legend(['Actual', 'Predict', r'$\pm 2\sigma$'], fontsize=14)
plt.setp(ax2.get_xticklabels(), visible=False)
ax2.set_ylabel('$y_8$ (m)', fontsize=17)

ax3 = fig.add_subplot(gs[2, :], sharex=ax1); idx = 0; # Force mag
ax3.plot(Y_Test.squeeze()[N, idx], color=colorTrue)
ax3.plot(Mu_denorm[N, idx], color=colorHat)
ax3.fill_between(range(len(N)), lower[N, idx], upper[N, idx], alpha=alpha, color=colorFill)
plt.setp(ax3.get_xticklabels(), visible=False)
ax3.set_ylabel('Force Mag. (g)', fontsize=16)

ax4 = fig.add_subplot(gs[3, :], sharex=ax1); idx = 1; # Force loc
ax4.plot(Y_Test.squeeze()[N, idx], color=colorTrue)
ax4.plot(Mu_denorm[N, idx], color=colorHat)
ax4.fill_between(range(len(N)), lower[N, idx], upper[N, idx], alpha=alpha, color=colorFill)
ax4.set_ylabel('Force Loc. (m)', fontsize=17)
ax4.set_xlabel('Time (s)', fontsize=17)

ax2.set_yticks(np.round(np.linspace(ax2.get_ybound()[0], ax2.get_ybound()[1], 4), decimals=3))
ax1.tick_params(axis='y', labelsize=16)
ax2.tick_params(axis='y', labelsize=16)
ax3.tick_params(axis='y', labelsize=16)
ax4.tick_params(axis='both', labelsize=16)

fig.align_ylabels()
plt.show()