In [48]:
import os
import argparse
import logging
import time
import numpy as np
import numpy.random as npr
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchdiffeq import odeint_adjoint as odeint
from torch.utils.data import DataLoader
os.environ['KMP_DUPLICATE_LIB_OK']='True'

In [49]:
samp_trajs_TE = torch.load('samp_trajs_TE_tau4k6_25.pt')
samp_trajs_val_TE = torch.load('samp_trajs_val_TE_tau4k6_25.pt')
tau = 4
k = 6
mesured_dim = 12

trial_num = 16

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

batch = 1000 #for lstm256

ts_num = 0.33
tot_num = 25

samp_ts = np.linspace(0, ts_num, num=tot_num)
samp_ts = torch.from_numpy(samp_ts).float().to(device)

orig_trajs_TE = np.load('orig_trajs_TE_Stereo_Stim_tau4k6.npy')
samp_trajs_TE_test = orig_trajs_TE[:, :tot_num, :]
samp_trajs_TE_test = torch.from_numpy(samp_trajs_TE_test).float().to(device).reshape(trial_num, tot_num, mesured_dim*(k+1))

#Load to Dataloader
train_loader = DataLoader(dataset = samp_trajs_TE, batch_size = batch, shuffle = True, drop_last = True)
val_loader = DataLoader(dataset = samp_trajs_val_TE, batch_size = batch, shuffle = True, drop_last = True)

In [50]:
if not os.path.exists('model'):
           os.makedirs('model')
        
class LatentODEfunc(nn.Module):

    def __init__(self, latent_dim=8, nhidden=50):
        super(LatentODEfunc, self).__init__()
        #self.tanh = nn.ELU(inplace= True)
        self.tanh = nn.Tanh()
        self.fc1 = nn.Linear(latent_dim, nhidden)
        self.fc2 = nn.Linear(nhidden, nhidden)
        self.fc3 = nn.Linear(nhidden, latent_dim)
        self.nfe = 0

    def forward(self, t, x):
        self.nfe += 1
        out = self.fc1(x)
        out = self.tanh(out)
        out = self.fc2(out)
        out = self.tanh(out)
        out = self.fc3(out)
        return out

class RecognitionRNN(nn.Module):

    def __init__(self, latent_dim=8, obs_dim=46, nhidden=50, nbatch=1):
        super(RecognitionRNN, self).__init__()
        self.nhidden = nhidden
        self.nbatch = nbatch
        #self.h1o = nn.Linear(obs_dim, 8)
        self.h1o = nn.Linear(obs_dim, 36)
        self.h3o = nn.Linear(36, latent_dim*2)
        self.lstm = nn.LSTMCell(latent_dim*2, nhidden)
        self.tanh = nn.Tanh()
        self.h2o = nn.Linear(nhidden, latent_dim*2)

    def forward(self, x, h, c):
        xo = self.h1o(x)
        xo = self.tanh(xo)
        xxo = self.h3o(xo)
        hn, cn = self.lstm(xxo, (h,c))
        hn = self.tanh(hn)
        out = self.h2o(hn)
        return out, hn, cn
    

    def initHidden(self):
        return torch.zeros(1, self.nbatch, self.nhidden)


class Decoder(nn.Module):

    def __init__(self, latent_dim=8, obs_dim=46, nhidden=50):
        super(Decoder, self).__init__()
        self.relu = nn.ReLU(inplace=True)
        self.tanh = nn.Tanh()
        self.fc1 = nn.Linear(latent_dim, nhidden)
        self.fc2 = nn.Linear(nhidden, nhidden*2)
        self.fc3 = nn.Linear(nhidden*2, obs_dim)

    def forward(self, z):
        out = self.fc1(z)
        out = self.tanh(out)
        out = self.fc2(out)
        out = self.tanh(out)
        out = self.fc3(out)
        return out


class RunningAverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, momentum=0.99):
        self.momentum = momentum
        self.reset()

    def reset(self):
        self.val = None
        self.avg = 0

    def update(self, val):
        if self.val is None:
            self.avg = val
        else:
            self.avg = self.avg * self.momentum + val * (1 - self.momentum)
        self.val = val


def log_normal_pdf(x, mean, logvar):
    const = torch.from_numpy(np.array([2. * np.pi])).float().to(x.device)
    const = torch.log(const)
    return -.5 * (const + logvar + (x - mean) ** 2. / torch.exp(logvar))

def mseloss(x, mean):
    loss = nn.MSELoss()
    return loss(x, mean)

def normal_kl(mu1, lv1, mu2, lv2):
    v1 = torch.exp(lv1)
    v2 = torch.exp(lv2)
    lstd1 = lv1 / 2.
    lstd2 = lv2 / 2.

    kl = lstd2 - lstd1 + ((v1 + (mu1 - mu2) ** 2.) / (2. * v2)) - .5
    return kl

def MSELoss(yhat, y):
    assert type(yhat) == torch.Tensor
    assert type(y) == torch.Tensor
    return torch.mean((yhat - y) ** 2)


def get_args():
    return {'latent_dim': latent_dim,
            'obs_dim': obs_dim,
            'nhidden': nhidden,
            'dec_nhidden' : dec_nhidden,
            'rnn_nhidden': rnn_nhidden,
            'device': device,
            'learning_rate': learning_rate,
            'tau': tau,
            'k': k}

def get_state_dicts():
    return {'odefunc_state_dict': func.state_dict(),
            'encoder_state_dict': rec.state_dict(),
            'decoder_state_dict': dec.state_dict()}

def data_get_dict():
    return {
        'samp_trajs_TE': samp_trajs_TE,
        'samp_trajs_val_TE': samp_trajs_val_TE,
        'samp_ts': samp_ts,
    }

def get_losses():
    return {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'val_losses_k1': val_losses_k1,
        'val_losses_k2': val_losses_k2,
        'val_losses_k3': val_losses_k3,
        'val_losses_k4': val_losses_k4,
        'val_losses_k5': val_losses_k5,
        'val_losses_k6': val_losses_k6,
        'val_losses_k7': val_losses_k7,
        'val_losses_k8': val_losses_k8,
        'val_losses_k9': val_losses_k9,
    }

def save_model(Training_Trial, rnn_nhidden, tau, k, lr, latent_dim, itr):
    if not os.path.exists('/model'):
        os.makedirs('/model')
    save_dict = {
        'model_args': get_args(),
        'optimizer_state_dict': optimizer.state_dict(),
        #'data': data_get_dict(),
        'train_loss': get_losses()
    }
    
    save_dict.update(get_state_dicts())
    
    torch.save(save_dict, 'model/ODE_Xcoord_Trial{}_TakenEmbedding_rnn2_lstm{}_tau{}k{}_LSTM_lr{}_latent{}_LSTMautoencoder_Dataloader_epoch{}.pth'.format(Training_Trial, rnn_nhidden, tau, k, lr, latent_dim, itr))

    
def data_for_plot_graph(gen_index):
    with torch.no_grad():
        # sample from trajectorys' approx. posterior

        ts_pos = np.linspace(0, ts_num*gen_index, num=tot_num*gen_index)
        ts_pos = torch.from_numpy(ts_pos).float().to(device)
    
        h = torch.zeros(1, samp_trajs_TE_test.shape[0], rnn_nhidden).to(device)
        c = torch.zeros(1, samp_trajs_TE_test.shape[0], rnn_nhidden).to(device)
    
        hn = h[0, :, :]
        cn = c[0, :, :]
    
        for t in reversed(range(samp_trajs_TE_test.size(1))):
            obs = samp_trajs_TE_test[:, t, :]
            out, hn, cn = rec.forward(obs, hn, cn)
        qz0_mean, qz0_logvar = out[:, :latent_dim], out[:, latent_dim:]
        epsilon = torch.randn(qz0_mean.size()).to(device)
        z0 = epsilon * torch.exp(.5 * qz0_logvar) + qz0_mean

        # forward in time and solve ode for reconstructions
        pred_z = odeint(func, z0, ts_pos).permute(1, 0, 2) #change time and batch with permute
        pred_x = dec(pred_z)
        
        return pred_x, pred_z
    
def plot_graph(gen_index, times_index, dataset_value, deriv_index, pred_x_forgraph, orig_trajs, itr, path):
    with torch.no_grad():
        orig_trajs_forgraph = orig_trajs
        ts_pos_combined = np.linspace(0, ts_num*gen_index, num=tot_num*gen_index) 
        
        fig, axes = plt.subplots(nrows=4, ncols=3, figsize=(15, 9)) #####MAKE SURE ROW COL MATCHES THE NUM OF FEATURES
        axes = axes.flatten()
        
        for i, ax in enumerate(axes):
            ax.scatter(ts_pos_combined[times_index:times_index+tot_num*gen_index], orig_trajs_forgraph[dataset_value,times_index:tot_num*gen_index, i*(k+1)+deriv_index], label='sampled data', s = 5)
            ax.plot(ts_pos_combined[times_index:times_index+tot_num*gen_index], pred_x_forgraph[dataset_value, times_index:times_index+tot_num*gen_index, i*(k+1)+deriv_index], 'r',
                 label='learned trajectory (t>0)')
            ax.set_ylim(-2.5, 2.5)

        plt.legend()
        plot_name = 'lstm_datasetnum{}_latent{}_gen{}_deriv{}_epoch{}.png'.format(dataset_value, latent_dim, gen_index, deriv_index, itr)
        save_path = os.path.join(path, plot_name)
        plt.savefig(save_path, dpi=500)
        plt.close()
    

def plot_z_graph(gen_index, times_index, dataset_value, deriv_index, pred_z_forgraph, orig_trajs, itr, path):
    with torch.no_grad():
        orig_trajs_forgraph = orig_trajs
        out, hn, cn = rec.forward(orig_trajs)
        qz_mean, qz_logvar = out[:, :latent_dim], out[:, latent_dim:]
        epsilon = torch.randn(qz_mean.size()).to(device)
        z = epsilon * torch.exp(.5 * qz_logvar) + qz_mean
        
        z_forgraph = z.detach().cpu().numpy()
        ts_pos_combined = np.linspace(0, ts_num*gen_index, num=tot_num*gen_index) 
        
        fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(15, 9))
        axes = axes.flatten()
        
        for i, ax in enumerate(axes):
            ax.scatter(ts_pos_combined[times_index:50*gen_index], z_forgraph[dataset_value, 0:50*gen_index, i], label='sampled data', s = 5)
            ax.plot(ts_pos_combined[times_index:+50*gen_index], pred_z_forgraph[dataset_value, times_index:+50*gen_index, i], 'r',
                 label='learned trajectory (t>0)')
            ax.set_ylim(-2.5, 2.5)

        plt.legend()
        plot_name = 'Zgraph_lstm_datasetnum{}_latent{}_gen{}_deriv{}_epoch{}.png'.format(dataset_value, latent_dim, gen_index, deriv_index, itr)
        save_path = os.path.join(path, plot_name)
        plt.savefig(save_path, dpi=500)
        plt.close()

In [51]:
Training_Trial = 1
latent_dim = 8
nhidden = 64 ##Trial1 = 64, Trial2 = 128, Trial3 = 128, Trial4 = 64, Trial5 = 64
dec_nhidden = 32
obs_dim = 12*(k+1)
rnn_nhidden = 256
nitrs = 1000
noise_std = 0.2
learning_rate = 0.008

func = LatentODEfunc(latent_dim, nhidden).to(device)
rec = RecognitionRNN(latent_dim, obs_dim, rnn_nhidden, batch).to(device)
dec = Decoder(latent_dim, obs_dim, dec_nhidden).to(device)
params = (list(func.parameters()) + list(dec.parameters()) + list(rec.parameters()))
optimizer = optim.Adam(params, lr=learning_rate)
loss_meter = RunningAverageMeter()

train_losses = []
val_losses = []
val_losses_k1 = []
val_losses_k2 = []
val_losses_k3 = []
val_losses_k4 = []
val_losses_k5 = []
val_losses_k6 = []
val_losses_k7 = []
val_losses_k8 = []
val_losses_k9 = []
torch.cuda.empty_cache()

In [52]:
for itr in range(1, nitrs + 1):
    for data in train_loader:
        optimizer.zero_grad()
        h = rec.initHidden().to(device)
        c = rec.initHidden().to(device)
        hn = h[0, :, :]
        cn = c[0, :, :]
        for t in reversed(range(data.size(1))):
            obs = data[:, t, :]
            out, hn, cn = rec.forward(obs, hn, cn)
        qz0_mean, qz0_logvar = out[:, :latent_dim], out[:, latent_dim:]
        epsilon = torch.randn(qz0_mean.size()).to(device)
        z0 = epsilon * torch.exp(.5 * qz0_logvar) + qz0_mean   
        
        # forward in time and solve ode for reconstructions
        pred_z = odeint(func, z0, samp_ts).permute(1, 0, 2)
        pred_x = dec(pred_z)

        # compute loss
        loss = MSELoss(pred_x, data)
        loss.backward()
        optimizer.step()
        train_losses.append(loss.item())
        
    with torch.no_grad():
        for data in val_loader:
            h = torch.zeros(1, batch, rnn_nhidden).to(device)
            c = torch.zeros(1, batch, rnn_nhidden).to(device)
            hn = h[0, :, :]
            cn = c[0, :, :]
            
            for t in reversed(range(data.size(1))):
                obs = data[:, t, :]
                out, hn, cn = rec.forward(obs, hn, cn)
            qz0_mean, qz0_logvar = out[:, :latent_dim], out[:, latent_dim:]
            epsilon = torch.randn(qz0_mean.size()).to(device)
            z0 = epsilon * torch.exp(.5 * qz0_logvar) + qz0_mean

            #forward in time and solve ode for reconstructions
            pred_z = odeint(func, z0, samp_ts).permute(1, 0, 2)
            pred_x = dec(pred_z)
        
            #val_loss = MSELoss(pred_x, samp_trajs_val_TE) + torch.mean(-0.5 * torch.sum(1 + qz0_logvar - qz0_mean**2 - torch.exp(qz0_logvar), dim = -1)/(31*(k+1)+2))
            val_loss = MSELoss(pred_x[:,:,::(k+1)], data[:,:,::(k+1)])
            val_loss_k1 = MSELoss(pred_x[:,:,1::(k+1)], data[:,:,1::(k+1)])
            val_loss_k2 = MSELoss(pred_x[:,:,2::(k+1)], data[:,:,2::(k+1)])
            val_loss_k3 = MSELoss(pred_x[:,:,3::(k+1)], data[:,:,3::(k+1)])
            val_loss_k4 = MSELoss(pred_x[:,:,4::(k+1)], data[:,:,4::(k+1)])
            val_loss_k5 = MSELoss(pred_x[:,:,5::(k+1)], data[:,:,5::(k+1)])
            val_loss_k6 = MSELoss(pred_x[:,:,6::(k+1)], data[:,:,6::(k+1)])

            val_losses.append(val_loss)
            val_losses_k1.append(val_loss_k1)
            val_losses_k2.append(val_loss_k2)
            val_losses_k3.append(val_loss_k3)
            val_losses_k4.append(val_loss_k4)
            val_losses_k5.append(val_loss_k5)
            val_losses_k6.append(val_loss_k6)

            V = [val_loss, val_loss_k1, val_loss_k2, val_loss_k3, val_loss_k4, val_loss_k5, val_loss_k6]
            lowest_val_loss = torch.asarray(V).min(0)[0]
            deriv_index = torch.asarray(V).min(0)[1]

    if ((itr > 100) and (itr % 10 == 0)):
        save_model(Training_Trial, rnn_nhidden, tau, k, learning_rate, latent_dim, itr)
        tot_index = 40
        times_index = 0
        deriv_index = deriv_index.numpy()
        
        orig_trajs = orig_trajs_TE[:, 0:0+tot_num*tot_index, :]

        pred_x, pred_z = data_for_plot_graph(tot_index)
        pred_x = pred_x.reshape(trial_num, tot_num*tot_index, mesured_dim*(k+1))
        pred_z = pred_z.reshape(trial_num, tot_num*tot_index, latent_dim)
        pred_x_forgraph = pred_x.detach().cpu().numpy()
        pred_z_forgraph = pred_z.detach().cpu().numpy()

        path = "Results_pic/tau{}k{}/latent{}/data_loader_rnn2layer_lstm{}_lr{}_Trial{}/epoch{}".format(tau, k, latent_dim, rnn_nhidden, learning_rate, Training_Trial, itr)

        if not os.path.exists(path):
           os.makedirs(path)

        plotgraph_index = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]

        gen_index = 40
        
        for i in range(len(plotgraph_index)):
            plot_graph(gen_index, times_index, plotgraph_index[i], deriv_index, pred_x_forgraph, orig_trajs, itr, path)
        
    print('Iter: {}, running avg mse: {:.4f} lowest val mse: {:.4f} at k {}'.format(itr, loss, lowest_val_loss, deriv_index))

Iter: 1, running avg mse: 1.0043 lowest val mse: 0.9891 at k 5
Iter: 2, running avg mse: 0.9792 lowest val mse: 0.9586 at k 5
Iter: 3, running avg mse: 0.9648 lowest val mse: 0.9479 at k 4
Iter: 4, running avg mse: 0.9475 lowest val mse: 0.9408 at k 4
Iter: 5, running avg mse: 0.9398 lowest val mse: 0.9364 at k 4
Iter: 6, running avg mse: 0.9351 lowest val mse: 0.9368 at k 4
Iter: 7, running avg mse: 0.9410 lowest val mse: 0.9347 at k 4
Iter: 8, running avg mse: 0.9300 lowest val mse: 0.9320 at k 4
Iter: 9, running avg mse: 0.9285 lowest val mse: 0.9242 at k 4
Iter: 10, running avg mse: 0.9073 lowest val mse: 0.9071 at k 4
Iter: 11, running avg mse: 0.9107 lowest val mse: 0.9014 at k 4
Iter: 12, running avg mse: 0.9036 lowest val mse: 0.8997 at k 4
Iter: 13, running avg mse: 0.9131 lowest val mse: 0.8960 at k 4
Iter: 14, running avg mse: 0.8930 lowest val mse: 0.8900 at k 4
Iter: 15, running avg mse: 0.8900 lowest val mse: 0.8846 at k 4
Iter: 16, running avg mse: 0.8909 lowest val mse:

Iter: 129, running avg mse: 0.3968 lowest val mse: 0.3603 at k 3
Iter: 130, running avg mse: 0.3941 lowest val mse: 0.3597 at k 3
Iter: 131, running avg mse: 0.3884 lowest val mse: 0.3574 at k 3
Iter: 132, running avg mse: 0.3923 lowest val mse: 0.3569 at k 3
Iter: 133, running avg mse: 0.3889 lowest val mse: 0.3560 at k 3
Iter: 134, running avg mse: 0.3905 lowest val mse: 0.3635 at k 3
Iter: 135, running avg mse: 0.4073 lowest val mse: 0.3719 at k 3
Iter: 136, running avg mse: 0.3931 lowest val mse: 0.3696 at k 3
Iter: 137, running avg mse: 0.4015 lowest val mse: 0.3615 at k 3
Iter: 138, running avg mse: 0.3921 lowest val mse: 0.3599 at k 3
Iter: 139, running avg mse: 0.3898 lowest val mse: 0.3567 at k 3
Iter: 140, running avg mse: 0.3908 lowest val mse: 0.3554 at k 3
Iter: 141, running avg mse: 0.3824 lowest val mse: 0.3555 at k 3
Iter: 142, running avg mse: 0.3872 lowest val mse: 0.3540 at k 3
Iter: 143, running avg mse: 0.3888 lowest val mse: 0.3525 at k 3
Iter: 144, running avg ms

Iter: 256, running avg mse: 0.3608 lowest val mse: 0.3383 at k 3
Iter: 257, running avg mse: 0.3561 lowest val mse: 0.3391 at k 3
Iter: 258, running avg mse: 0.3602 lowest val mse: 0.3368 at k 4
Iter: 259, running avg mse: 0.3626 lowest val mse: 0.3349 at k 3
Iter: 260, running avg mse: 0.3612 lowest val mse: 0.3340 at k 3
Iter: 261, running avg mse: 0.3569 lowest val mse: 0.3316 at k 3
Iter: 262, running avg mse: 0.3591 lowest val mse: 0.3304 at k 3
Iter: 263, running avg mse: 0.3619 lowest val mse: 0.3307 at k 3
Iter: 264, running avg mse: 0.3560 lowest val mse: 0.3320 at k 3
Iter: 265, running avg mse: 0.3540 lowest val mse: 0.3307 at k 3
Iter: 266, running avg mse: 0.3545 lowest val mse: 0.3311 at k 3
Iter: 267, running avg mse: 0.3537 lowest val mse: 0.3298 at k 3
Iter: 268, running avg mse: 0.3510 lowest val mse: 0.3304 at k 3
Iter: 269, running avg mse: 0.3552 lowest val mse: 0.3376 at k 3
Iter: 270, running avg mse: 0.3653 lowest val mse: 0.3392 at k 3
Iter: 271, running avg ms

Iter: 383, running avg mse: 0.3393 lowest val mse: 0.3277 at k 3
Iter: 384, running avg mse: 0.3507 lowest val mse: 0.3387 at k 3
Iter: 385, running avg mse: 0.3369 lowest val mse: 0.3346 at k 3
Iter: 386, running avg mse: 0.3461 lowest val mse: 0.3276 at k 3
Iter: 387, running avg mse: 0.3424 lowest val mse: 0.3311 at k 3
Iter: 388, running avg mse: 0.3448 lowest val mse: 0.3256 at k 3
Iter: 389, running avg mse: 0.3370 lowest val mse: 0.3265 at k 3
Iter: 390, running avg mse: 0.3503 lowest val mse: 0.3245 at k 3
Iter: 391, running avg mse: 0.3429 lowest val mse: 0.3249 at k 3
Iter: 392, running avg mse: 0.3319 lowest val mse: 0.3234 at k 3
Iter: 393, running avg mse: 0.3402 lowest val mse: 0.3241 at k 3
Iter: 394, running avg mse: 0.3374 lowest val mse: 0.3248 at k 3
Iter: 395, running avg mse: 0.3304 lowest val mse: 0.3230 at k 3
Iter: 396, running avg mse: 0.3413 lowest val mse: 0.3258 at k 3
Iter: 397, running avg mse: 0.3387 lowest val mse: 0.3263 at k 3
Iter: 398, running avg ms

Iter: 510, running avg mse: 0.3313 lowest val mse: 0.3211 at k 3
Iter: 511, running avg mse: 0.3362 lowest val mse: 0.3221 at k 3
Iter: 512, running avg mse: 0.3293 lowest val mse: 0.3247 at k 3
Iter: 513, running avg mse: 0.3493 lowest val mse: 0.3309 at k 3
Iter: 514, running avg mse: 0.3344 lowest val mse: 0.3247 at k 3
Iter: 515, running avg mse: 0.3291 lowest val mse: 0.3233 at k 3
Iter: 516, running avg mse: 0.3371 lowest val mse: 0.3219 at k 3
Iter: 517, running avg mse: 0.3352 lowest val mse: 0.3210 at k 3
Iter: 518, running avg mse: 0.3328 lowest val mse: 0.3245 at k 3
Iter: 519, running avg mse: 0.3422 lowest val mse: 0.3246 at k 3
Iter: 520, running avg mse: 0.3341 lowest val mse: 0.3208 at k 3
Iter: 521, running avg mse: 0.3272 lowest val mse: 0.3203 at k 3
Iter: 522, running avg mse: 0.3272 lowest val mse: 0.3198 at k 3
Iter: 523, running avg mse: 0.3300 lowest val mse: 0.3199 at k 3
Iter: 524, running avg mse: 0.3294 lowest val mse: 0.3211 at k 3
Iter: 525, running avg ms

Iter: 637, running avg mse: 0.3424 lowest val mse: 0.3221 at k 3
Iter: 638, running avg mse: 0.3384 lowest val mse: 0.3231 at k 3
Iter: 639, running avg mse: 0.3318 lowest val mse: 0.3239 at k 3
Iter: 640, running avg mse: 0.3410 lowest val mse: 0.3210 at k 3
Iter: 641, running avg mse: 0.3357 lowest val mse: 0.3220 at k 3
Iter: 642, running avg mse: 0.3299 lowest val mse: 0.3207 at k 3
Iter: 643, running avg mse: 0.3293 lowest val mse: 0.3220 at k 3
Iter: 644, running avg mse: 0.3342 lowest val mse: 0.3204 at k 3
Iter: 645, running avg mse: 0.3342 lowest val mse: 0.3204 at k 3
Iter: 646, running avg mse: 0.3284 lowest val mse: 0.3203 at k 3
Iter: 647, running avg mse: 0.3343 lowest val mse: 0.3195 at k 3
Iter: 648, running avg mse: 0.3360 lowest val mse: 0.3188 at k 3
Iter: 649, running avg mse: 0.3361 lowest val mse: 0.3195 at k 3
Iter: 650, running avg mse: 0.3310 lowest val mse: 0.3191 at k 3
Iter: 651, running avg mse: 0.3287 lowest val mse: 0.3198 at k 3
Iter: 652, running avg ms

Iter: 764, running avg mse: 0.3307 lowest val mse: 0.3221 at k 3
Iter: 765, running avg mse: 0.3370 lowest val mse: 0.3297 at k 3
Iter: 766, running avg mse: 0.3425 lowest val mse: 0.3292 at k 3
Iter: 767, running avg mse: 0.3426 lowest val mse: 0.3265 at k 3
Iter: 768, running avg mse: 0.3400 lowest val mse: 0.3256 at k 3
Iter: 769, running avg mse: 0.3407 lowest val mse: 0.3246 at k 3
Iter: 770, running avg mse: 0.3296 lowest val mse: 0.3267 at k 3
Iter: 771, running avg mse: 0.3353 lowest val mse: 0.3220 at k 3
Iter: 772, running avg mse: 0.3361 lowest val mse: 0.3274 at k 3
Iter: 773, running avg mse: 0.3340 lowest val mse: 0.3245 at k 3
Iter: 774, running avg mse: 0.3338 lowest val mse: 0.3237 at k 3
Iter: 775, running avg mse: 0.3314 lowest val mse: 0.3247 at k 3
Iter: 776, running avg mse: 0.3369 lowest val mse: 0.3270 at k 3
Iter: 777, running avg mse: 0.3468 lowest val mse: 0.3275 at k 3
Iter: 778, running avg mse: 0.3475 lowest val mse: 0.3298 at k 4
Iter: 779, running avg ms

Iter: 891, running avg mse: 0.3275 lowest val mse: 0.3166 at k 3
Iter: 892, running avg mse: 0.3274 lowest val mse: 0.3165 at k 3
Iter: 893, running avg mse: 0.3222 lowest val mse: 0.3162 at k 3
Iter: 894, running avg mse: 0.3236 lowest val mse: 0.3175 at k 3
Iter: 895, running avg mse: 0.3253 lowest val mse: 0.3163 at k 3
Iter: 896, running avg mse: 0.3282 lowest val mse: 0.3159 at k 3
Iter: 897, running avg mse: 0.3287 lowest val mse: 0.3174 at k 3
Iter: 898, running avg mse: 0.3285 lowest val mse: 0.3166 at k 3
Iter: 899, running avg mse: 0.3202 lowest val mse: 0.3164 at k 3
Iter: 900, running avg mse: 0.3189 lowest val mse: 0.3169 at k 3
Iter: 901, running avg mse: 0.3215 lowest val mse: 0.3166 at k 3
Iter: 902, running avg mse: 0.3250 lowest val mse: 0.3172 at k 3
Iter: 903, running avg mse: 0.3232 lowest val mse: 0.3206 at k 3
Iter: 904, running avg mse: 0.3260 lowest val mse: 0.3191 at k 3
Iter: 905, running avg mse: 0.3263 lowest val mse: 0.3173 at k 3
Iter: 906, running avg ms

In [53]:
Trial_tot_num = 10

In [None]:
for p in range(3, Trial_tot_num):
    Training_Trial = p
    
    latent_dim = 8
    nhidden = 64 ##Trial1 = 64, Trial2 = 128, Trial3 = 128, Trial4 = 64, Trial5 = 64
    dec_nhidden = 32
    obs_dim = 12*(k+1)
    rnn_nhidden = 256
    nitrs = 600
    noise_std = 0.2
    learning_rate = 0.008

    func = LatentODEfunc(latent_dim, nhidden).to(device)
    rec = RecognitionRNN(latent_dim, obs_dim, rnn_nhidden, batch).to(device)
    dec = Decoder(latent_dim, obs_dim, dec_nhidden).to(device)
    params = (list(func.parameters()) + list(dec.parameters()) + list(rec.parameters()))
    optimizer = optim.Adam(params, lr=learning_rate)
    loss_meter = RunningAverageMeter()

    train_losses = []
    val_losses = []
    val_losses_k1 = []
    val_losses_k2 = []
    val_losses_k3 = []
    val_losses_k4 = []
    val_losses_k5 = []
    val_losses_k6 = []
    val_losses_k7 = []
    val_losses_k8 = []
    val_losses_k9 = []
    torch.cuda.empty_cache()
    for itr in range(1, nitrs + 1):
        for data in train_loader:
            optimizer.zero_grad()
            h = rec.initHidden().to(device)
            c = rec.initHidden().to(device)
            hn = h[0, :, :]
            cn = c[0, :, :]
            for t in reversed(range(data.size(1))):
                obs = data[:, t, :]
                out, hn, cn = rec.forward(obs, hn, cn)
            qz0_mean, qz0_logvar = out[:, :latent_dim], out[:, latent_dim:]
            epsilon = torch.randn(qz0_mean.size()).to(device)
            z0 = epsilon * torch.exp(.5 * qz0_logvar) + qz0_mean   

            # forward in time and solve ode for reconstructions
            pred_z = odeint(func, z0, samp_ts).permute(1, 0, 2)
            pred_x = dec(pred_z)

            # compute loss
            loss = MSELoss(pred_x, data)
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())

        with torch.no_grad():
            for data in val_loader:
                h = torch.zeros(1, batch, rnn_nhidden).to(device)
                c = torch.zeros(1, batch, rnn_nhidden).to(device)
                hn = h[0, :, :]
                cn = c[0, :, :]

                for t in reversed(range(data.size(1))):
                    obs = data[:, t, :]
                    out, hn, cn = rec.forward(obs, hn, cn)
                qz0_mean, qz0_logvar = out[:, :latent_dim], out[:, latent_dim:]
                epsilon = torch.randn(qz0_mean.size()).to(device)
                z0 = epsilon * torch.exp(.5 * qz0_logvar) + qz0_mean

                #forward in time and solve ode for reconstructions
                pred_z = odeint(func, z0, samp_ts).permute(1, 0, 2)
                pred_x = dec(pred_z)

                #val_loss = MSELoss(pred_x, samp_trajs_val_TE) + torch.mean(-0.5 * torch.sum(1 + qz0_logvar - qz0_mean**2 - torch.exp(qz0_logvar), dim = -1)/(31*(k+1)+2))
                val_loss = MSELoss(pred_x[:,:,::(k+1)], data[:,:,::(k+1)])
                val_loss_k1 = MSELoss(pred_x[:,:,1::(k+1)], data[:,:,1::(k+1)])
                val_loss_k2 = MSELoss(pred_x[:,:,2::(k+1)], data[:,:,2::(k+1)])
                val_loss_k3 = MSELoss(pred_x[:,:,3::(k+1)], data[:,:,3::(k+1)])
                val_loss_k4 = MSELoss(pred_x[:,:,4::(k+1)], data[:,:,4::(k+1)])
                val_loss_k5 = MSELoss(pred_x[:,:,5::(k+1)], data[:,:,5::(k+1)])
                val_loss_k6 = MSELoss(pred_x[:,:,6::(k+1)], data[:,:,6::(k+1)])

                val_losses.append(val_loss)
                val_losses_k1.append(val_loss_k1)
                val_losses_k2.append(val_loss_k2)
                val_losses_k3.append(val_loss_k3)
                val_losses_k4.append(val_loss_k4)
                val_losses_k5.append(val_loss_k5)
                val_losses_k6.append(val_loss_k6)

                V = [val_loss, val_loss_k1, val_loss_k2, val_loss_k3, val_loss_k4, val_loss_k5, val_loss_k6]
                lowest_val_loss = torch.asarray(V).min(0)[0]
                deriv_index = torch.asarray(V).min(0)[1]

        if ((itr > 100) and (itr % 10 == 0)):
            save_model(Training_Trial, rnn_nhidden, tau, k, learning_rate, latent_dim, itr)
            tot_index = 40
            times_index = 0
            deriv_index = deriv_index.numpy()

            orig_trajs = orig_trajs_TE[:, 0:0+tot_num*tot_index, :]

            pred_x, pred_z = data_for_plot_graph(tot_index)
            pred_x = pred_x.reshape(trial_num, tot_num*tot_index, mesured_dim*(k+1))
            pred_z = pred_z.reshape(trial_num, tot_num*tot_index, latent_dim)
            pred_x_forgraph = pred_x.detach().cpu().numpy()
            pred_z_forgraph = pred_z.detach().cpu().numpy()

            path = "Results_pic/tau{}k{}/latent{}/data_loader_rnn2layer_lstm{}_lr{}_Trial{}/epoch{}".format(tau, k, latent_dim, rnn_nhidden, learning_rate, Training_Trial, itr)

            if not os.path.exists(path):
               os.makedirs(path)

            plotgraph_index = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]

            gen_index = 40

            for i in range(len(plotgraph_index)):
                plot_graph(gen_index, times_index, plotgraph_index[i], deriv_index, pred_x_forgraph, orig_trajs, itr, path)

        print('Iter: {}, running avg mse: {:.4f} lowest val mse: {:.4f} at k {}'.format(itr, loss, lowest_val_loss, deriv_index))

Iter: 1, running avg mse: 1.0067 lowest val mse: 0.9919 at k 6
Iter: 2, running avg mse: 0.9742 lowest val mse: 0.9663 at k 6
Iter: 3, running avg mse: 0.9682 lowest val mse: 0.9504 at k 5
Iter: 4, running avg mse: 0.9501 lowest val mse: 0.9428 at k 5
Iter: 5, running avg mse: 0.9524 lowest val mse: 0.9405 at k 5
Iter: 6, running avg mse: 0.9459 lowest val mse: 0.9377 at k 4
Iter: 7, running avg mse: 0.9277 lowest val mse: 0.9300 at k 4
Iter: 8, running avg mse: 0.9268 lowest val mse: 0.9212 at k 3
Iter: 9, running avg mse: 0.9160 lowest val mse: 0.9136 at k 4
Iter: 10, running avg mse: 0.9000 lowest val mse: 0.9082 at k 4
Iter: 11, running avg mse: 0.9081 lowest val mse: 0.9047 at k 4
Iter: 12, running avg mse: 0.8953 lowest val mse: 0.9024 at k 4
Iter: 13, running avg mse: 0.8979 lowest val mse: 0.8992 at k 4
Iter: 14, running avg mse: 0.8978 lowest val mse: 0.8967 at k 4
Iter: 15, running avg mse: 0.8981 lowest val mse: 0.8950 at k 4
Iter: 16, running avg mse: 0.9004 lowest val mse:

Iter: 129, running avg mse: 0.3968 lowest val mse: 0.3605 at k 3
Iter: 130, running avg mse: 0.3910 lowest val mse: 0.3607 at k 3
Iter: 131, running avg mse: 0.4014 lowest val mse: 0.3615 at k 3
Iter: 132, running avg mse: 0.3983 lowest val mse: 0.3627 at k 3
Iter: 133, running avg mse: 0.3934 lowest val mse: 0.3605 at k 3
Iter: 134, running avg mse: 0.3866 lowest val mse: 0.3608 at k 3
Iter: 135, running avg mse: 0.3925 lowest val mse: 0.3604 at k 3
Iter: 136, running avg mse: 0.3880 lowest val mse: 0.3594 at k 3
Iter: 137, running avg mse: 0.3776 lowest val mse: 0.3572 at k 3
Iter: 138, running avg mse: 0.3834 lowest val mse: 0.3581 at k 3
Iter: 139, running avg mse: 0.3893 lowest val mse: 0.3582 at k 3
Iter: 140, running avg mse: 0.3840 lowest val mse: 0.3580 at k 3
Iter: 141, running avg mse: 0.3845 lowest val mse: 0.3566 at k 3
Iter: 142, running avg mse: 0.3850 lowest val mse: 0.3553 at k 3
Iter: 143, running avg mse: 0.3810 lowest val mse: 0.3566 at k 3
Iter: 144, running avg ms

Iter: 256, running avg mse: 0.3632 lowest val mse: 0.3401 at k 3
Iter: 257, running avg mse: 0.3611 lowest val mse: 0.3370 at k 3
Iter: 258, running avg mse: 0.3644 lowest val mse: 0.3377 at k 3
Iter: 259, running avg mse: 0.3558 lowest val mse: 0.3345 at k 3
Iter: 260, running avg mse: 0.3590 lowest val mse: 0.3349 at k 3
Iter: 261, running avg mse: 0.3640 lowest val mse: 0.3351 at k 3
Iter: 262, running avg mse: 0.3575 lowest val mse: 0.3338 at k 3
Iter: 263, running avg mse: 0.3597 lowest val mse: 0.3347 at k 3
Iter: 264, running avg mse: 0.3543 lowest val mse: 0.3347 at k 3
Iter: 265, running avg mse: 0.3697 lowest val mse: 0.3392 at k 3
Iter: 266, running avg mse: 0.3565 lowest val mse: 0.3385 at k 3
Iter: 267, running avg mse: 0.3547 lowest val mse: 0.3352 at k 3
Iter: 268, running avg mse: 0.3590 lowest val mse: 0.3380 at k 3
Iter: 269, running avg mse: 0.3512 lowest val mse: 0.3342 at k 3
Iter: 270, running avg mse: 0.3572 lowest val mse: 0.3350 at k 3
Iter: 271, running avg ms

Iter: 383, running avg mse: 0.3480 lowest val mse: 0.3354 at k 3
Iter: 384, running avg mse: 0.3468 lowest val mse: 0.3317 at k 3
Iter: 385, running avg mse: 0.3387 lowest val mse: 0.3298 at k 3
Iter: 386, running avg mse: 0.3366 lowest val mse: 0.3286 at k 3
Iter: 387, running avg mse: 0.3364 lowest val mse: 0.3293 at k 3
Iter: 388, running avg mse: 0.3557 lowest val mse: 0.3305 at k 3
Iter: 389, running avg mse: 0.3441 lowest val mse: 0.3300 at k 3
Iter: 390, running avg mse: 0.3456 lowest val mse: 0.3324 at k 4
Iter: 391, running avg mse: 0.3388 lowest val mse: 0.3360 at k 4
Iter: 392, running avg mse: 0.3710 lowest val mse: 0.3459 at k 4
Iter: 393, running avg mse: 0.3568 lowest val mse: 0.3422 at k 3
Iter: 394, running avg mse: 0.3563 lowest val mse: 0.3336 at k 3
Iter: 395, running avg mse: 0.3617 lowest val mse: 0.3353 at k 3
Iter: 396, running avg mse: 0.3524 lowest val mse: 0.3376 at k 3
Iter: 397, running avg mse: 0.3426 lowest val mse: 0.3329 at k 4
Iter: 398, running avg ms

Iter: 510, running avg mse: 0.3417 lowest val mse: 0.3240 at k 3
Iter: 511, running avg mse: 0.3439 lowest val mse: 0.3238 at k 3
Iter: 512, running avg mse: 0.3355 lowest val mse: 0.3249 at k 3
Iter: 513, running avg mse: 0.3310 lowest val mse: 0.3254 at k 3
Iter: 514, running avg mse: 0.3428 lowest val mse: 0.3247 at k 3
Iter: 515, running avg mse: 0.3330 lowest val mse: 0.3259 at k 3
Iter: 516, running avg mse: 0.3427 lowest val mse: 0.3241 at k 3
Iter: 517, running avg mse: 0.3322 lowest val mse: 0.3246 at k 3
Iter: 518, running avg mse: 0.3347 lowest val mse: 0.3244 at k 3
Iter: 519, running avg mse: 0.3370 lowest val mse: 0.3254 at k 3
Iter: 520, running avg mse: 0.3374 lowest val mse: 0.3239 at k 3
Iter: 521, running avg mse: 0.3370 lowest val mse: 0.3242 at k 3
Iter: 522, running avg mse: 0.3418 lowest val mse: 0.3245 at k 3
Iter: 523, running avg mse: 0.3356 lowest val mse: 0.3238 at k 3
Iter: 524, running avg mse: 0.3383 lowest val mse: 0.3230 at k 3
Iter: 525, running avg ms

Iter: 37, running avg mse: 0.7397 lowest val mse: 0.6521 at k 4
Iter: 38, running avg mse: 0.7639 lowest val mse: 0.6214 at k 4
Iter: 39, running avg mse: 0.8267 lowest val mse: 0.7610 at k 4
Iter: 40, running avg mse: 0.7677 lowest val mse: 0.6610 at k 4
Iter: 41, running avg mse: 0.7025 lowest val mse: 0.6007 at k 4
Iter: 42, running avg mse: 0.6637 lowest val mse: 0.5645 at k 4
Iter: 43, running avg mse: 0.6512 lowest val mse: 0.6161 at k 4
Iter: 44, running avg mse: 0.7176 lowest val mse: 0.6111 at k 4
Iter: 45, running avg mse: 0.6580 lowest val mse: 0.5937 at k 4
Iter: 46, running avg mse: 0.6546 lowest val mse: 0.5618 at k 4
Iter: 47, running avg mse: 0.6524 lowest val mse: 0.5681 at k 4
Iter: 48, running avg mse: 0.6218 lowest val mse: 0.5324 at k 4
Iter: 49, running avg mse: 0.6229 lowest val mse: 0.5179 at k 4
Iter: 50, running avg mse: 0.5905 lowest val mse: 0.5103 at k 4
Iter: 51, running avg mse: 0.5909 lowest val mse: 0.4986 at k 4
Iter: 52, running avg mse: 0.5881 lowest

Iter: 165, running avg mse: 0.3724 lowest val mse: 0.3612 at k 3
Iter: 166, running avg mse: 0.3917 lowest val mse: 0.3600 at k 4
Iter: 167, running avg mse: 0.3852 lowest val mse: 0.3593 at k 3
Iter: 168, running avg mse: 0.3803 lowest val mse: 0.3582 at k 3
Iter: 169, running avg mse: 0.3852 lowest val mse: 0.3588 at k 3
Iter: 170, running avg mse: 0.3910 lowest val mse: 0.3687 at k 4
Iter: 171, running avg mse: 0.3896 lowest val mse: 0.3609 at k 4
Iter: 172, running avg mse: 0.3845 lowest val mse: 0.3597 at k 3
Iter: 173, running avg mse: 0.3806 lowest val mse: 0.3607 at k 4
Iter: 174, running avg mse: 0.3840 lowest val mse: 0.3595 at k 4
Iter: 175, running avg mse: 0.3777 lowest val mse: 0.3578 at k 3
Iter: 176, running avg mse: 0.3776 lowest val mse: 0.3559 at k 4
Iter: 177, running avg mse: 0.3775 lowest val mse: 0.3573 at k 4
Iter: 178, running avg mse: 0.3732 lowest val mse: 0.3555 at k 3
Iter: 179, running avg mse: 0.3775 lowest val mse: 0.3545 at k 3
Iter: 180, running avg ms

Iter: 292, running avg mse: 0.3569 lowest val mse: 0.3357 at k 3
Iter: 293, running avg mse: 0.3602 lowest val mse: 0.3365 at k 3
Iter: 294, running avg mse: 0.3575 lowest val mse: 0.3367 at k 3
Iter: 295, running avg mse: 0.3557 lowest val mse: 0.3356 at k 3
Iter: 296, running avg mse: 0.3567 lowest val mse: 0.3351 at k 3
Iter: 297, running avg mse: 0.3539 lowest val mse: 0.3352 at k 3
Iter: 298, running avg mse: 0.3563 lowest val mse: 0.3368 at k 3
Iter: 299, running avg mse: 0.3604 lowest val mse: 0.3406 at k 3
Iter: 300, running avg mse: 0.3693 lowest val mse: 0.3474 at k 3
Iter: 301, running avg mse: 0.3568 lowest val mse: 0.3402 at k 3
Iter: 302, running avg mse: 0.3518 lowest val mse: 0.3371 at k 3
Iter: 303, running avg mse: 0.3622 lowest val mse: 0.3391 at k 3
Iter: 304, running avg mse: 0.3578 lowest val mse: 0.3391 at k 3
Iter: 305, running avg mse: 0.3539 lowest val mse: 0.3351 at k 3
Iter: 306, running avg mse: 0.3551 lowest val mse: 0.3365 at k 3
Iter: 307, running avg ms

Iter: 419, running avg mse: 0.3599 lowest val mse: 0.3389 at k 3
Iter: 420, running avg mse: 0.3548 lowest val mse: 0.3382 at k 3
Iter: 421, running avg mse: 0.3620 lowest val mse: 0.3371 at k 3
Iter: 422, running avg mse: 0.3575 lowest val mse: 0.3396 at k 4
Iter: 423, running avg mse: 0.3510 lowest val mse: 0.3374 at k 3
Iter: 424, running avg mse: 0.3464 lowest val mse: 0.3366 at k 3
Iter: 425, running avg mse: 0.3495 lowest val mse: 0.3350 at k 3
Iter: 426, running avg mse: 0.3497 lowest val mse: 0.3336 at k 3
Iter: 427, running avg mse: 0.3526 lowest val mse: 0.3343 at k 3
Iter: 428, running avg mse: 0.3470 lowest val mse: 0.3338 at k 3
Iter: 429, running avg mse: 0.3445 lowest val mse: 0.3331 at k 3
Iter: 430, running avg mse: 0.3515 lowest val mse: 0.3330 at k 3
Iter: 431, running avg mse: 0.3501 lowest val mse: 0.3326 at k 3
Iter: 432, running avg mse: 0.3472 lowest val mse: 0.3325 at k 3
Iter: 433, running avg mse: 0.3404 lowest val mse: 0.3320 at k 3
Iter: 434, running avg ms

Iter: 546, running avg mse: 0.3648 lowest val mse: 0.3477 at k 4
Iter: 547, running avg mse: 0.3592 lowest val mse: 0.3482 at k 4
Iter: 548, running avg mse: 0.3719 lowest val mse: 0.3460 at k 4
Iter: 549, running avg mse: 0.3711 lowest val mse: 0.3481 at k 3
Iter: 550, running avg mse: 0.3529 lowest val mse: 0.3445 at k 3
Iter: 551, running avg mse: 0.3521 lowest val mse: 0.3426 at k 3
Iter: 552, running avg mse: 0.3558 lowest val mse: 0.3391 at k 3
Iter: 553, running avg mse: 0.3595 lowest val mse: 0.3372 at k 4
Iter: 554, running avg mse: 0.3591 lowest val mse: 0.3369 at k 3
Iter: 555, running avg mse: 0.3529 lowest val mse: 0.3360 at k 3
Iter: 556, running avg mse: 0.3538 lowest val mse: 0.3342 at k 3
Iter: 557, running avg mse: 0.3467 lowest val mse: 0.3336 at k 3
Iter: 558, running avg mse: 0.3445 lowest val mse: 0.3324 at k 3
Iter: 559, running avg mse: 0.3456 lowest val mse: 0.3324 at k 3
Iter: 560, running avg mse: 0.3399 lowest val mse: 0.3316 at k 3
Iter: 561, running avg ms

Iter: 74, running avg mse: 0.4520 lowest val mse: 0.4062 at k 3
Iter: 75, running avg mse: 0.4470 lowest val mse: 0.4050 at k 3
Iter: 76, running avg mse: 0.4459 lowest val mse: 0.4024 at k 3
Iter: 77, running avg mse: 0.4480 lowest val mse: 0.4014 at k 3
Iter: 78, running avg mse: 0.4352 lowest val mse: 0.3968 at k 3
Iter: 79, running avg mse: 0.4381 lowest val mse: 0.3944 at k 3
Iter: 80, running avg mse: 0.4471 lowest val mse: 0.3920 at k 3
Iter: 81, running avg mse: 0.4405 lowest val mse: 0.3923 at k 3
Iter: 82, running avg mse: 0.4343 lowest val mse: 0.3885 at k 3
Iter: 83, running avg mse: 0.4353 lowest val mse: 0.3868 at k 3
Iter: 84, running avg mse: 0.4199 lowest val mse: 0.3857 at k 3
Iter: 85, running avg mse: 0.4442 lowest val mse: 0.4015 at k 3
Iter: 86, running avg mse: 0.4451 lowest val mse: 0.3859 at k 3
Iter: 87, running avg mse: 0.4305 lowest val mse: 0.4046 at k 3
Iter: 88, running avg mse: 0.4420 lowest val mse: 0.4058 at k 3
Iter: 89, running avg mse: 0.4450 lowest

Iter: 201, running avg mse: 0.3809 lowest val mse: 0.3592 at k 3
Iter: 202, running avg mse: 0.3837 lowest val mse: 0.3668 at k 3
Iter: 203, running avg mse: 0.3920 lowest val mse: 0.3651 at k 3
Iter: 204, running avg mse: 0.3791 lowest val mse: 0.3646 at k 3
Iter: 205, running avg mse: 0.3807 lowest val mse: 0.3630 at k 3
Iter: 206, running avg mse: 0.3761 lowest val mse: 0.3595 at k 3
Iter: 207, running avg mse: 0.3798 lowest val mse: 0.3575 at k 3
Iter: 208, running avg mse: 0.3763 lowest val mse: 0.3610 at k 3
Iter: 209, running avg mse: 0.3835 lowest val mse: 0.3608 at k 3
Iter: 210, running avg mse: 0.3727 lowest val mse: 0.3598 at k 3
Iter: 211, running avg mse: 0.3910 lowest val mse: 0.3598 at k 3
Iter: 212, running avg mse: 0.3748 lowest val mse: 0.3593 at k 3
Iter: 213, running avg mse: 0.3836 lowest val mse: 0.3619 at k 3
Iter: 214, running avg mse: 0.3807 lowest val mse: 0.3626 at k 3
Iter: 215, running avg mse: 0.4017 lowest val mse: 0.3712 at k 3
Iter: 216, running avg ms

Iter: 328, running avg mse: 0.3502 lowest val mse: 0.3383 at k 3
Iter: 329, running avg mse: 0.3522 lowest val mse: 0.3375 at k 3
Iter: 330, running avg mse: 0.3524 lowest val mse: 0.3368 at k 3
Iter: 331, running avg mse: 0.3479 lowest val mse: 0.3356 at k 3
Iter: 332, running avg mse: 0.3498 lowest val mse: 0.3349 at k 3
Iter: 333, running avg mse: 0.3497 lowest val mse: 0.3345 at k 3
Iter: 334, running avg mse: 0.3516 lowest val mse: 0.3344 at k 3
Iter: 335, running avg mse: 0.3500 lowest val mse: 0.3351 at k 3
Iter: 336, running avg mse: 0.3560 lowest val mse: 0.3399 at k 3
Iter: 337, running avg mse: 0.3561 lowest val mse: 0.3379 at k 3
Iter: 338, running avg mse: 0.3597 lowest val mse: 0.3346 at k 3
Iter: 339, running avg mse: 0.3535 lowest val mse: 0.3367 at k 3
Iter: 340, running avg mse: 0.3529 lowest val mse: 0.3342 at k 3
Iter: 341, running avg mse: 0.3573 lowest val mse: 0.3342 at k 3
Iter: 342, running avg mse: 0.3491 lowest val mse: 0.3339 at k 3
Iter: 343, running avg ms

Iter: 455, running avg mse: 0.3420 lowest val mse: 0.3296 at k 3
Iter: 456, running avg mse: 0.3438 lowest val mse: 0.3288 at k 3
Iter: 457, running avg mse: 0.3432 lowest val mse: 0.3281 at k 3
Iter: 458, running avg mse: 0.3406 lowest val mse: 0.3284 at k 3
Iter: 459, running avg mse: 0.3424 lowest val mse: 0.3284 at k 3
Iter: 460, running avg mse: 0.3483 lowest val mse: 0.3286 at k 3
Iter: 461, running avg mse: 0.3379 lowest val mse: 0.3299 at k 3
Iter: 462, running avg mse: 0.3356 lowest val mse: 0.3337 at k 3
Iter: 463, running avg mse: 0.3434 lowest val mse: 0.3338 at k 3
Iter: 464, running avg mse: 0.3418 lowest val mse: 0.3302 at k 3
Iter: 465, running avg mse: 0.3455 lowest val mse: 0.3352 at k 3
Iter: 466, running avg mse: 0.3397 lowest val mse: 0.3293 at k 3
Iter: 467, running avg mse: 0.3386 lowest val mse: 0.3307 at k 3
Iter: 468, running avg mse: 0.3398 lowest val mse: 0.3283 at k 3
Iter: 469, running avg mse: 0.3439 lowest val mse: 0.3290 at k 3
Iter: 470, running avg ms

Iter: 582, running avg mse: 0.3324 lowest val mse: 0.3264 at k 3
Iter: 583, running avg mse: 0.3375 lowest val mse: 0.3266 at k 3
Iter: 584, running avg mse: 0.3294 lowest val mse: 0.3271 at k 3
Iter: 585, running avg mse: 0.3340 lowest val mse: 0.3256 at k 3
Iter: 586, running avg mse: 0.3314 lowest val mse: 0.3265 at k 3
Iter: 587, running avg mse: 0.3313 lowest val mse: 0.3262 at k 3
Iter: 588, running avg mse: 0.3415 lowest val mse: 0.3268 at k 3
Iter: 589, running avg mse: 0.3403 lowest val mse: 0.3268 at k 3
Iter: 590, running avg mse: 0.3313 lowest val mse: 0.3261 at k 3
Iter: 591, running avg mse: 0.3318 lowest val mse: 0.3257 at k 3
Iter: 592, running avg mse: 0.3387 lowest val mse: 0.3281 at k 3
Iter: 593, running avg mse: 0.3411 lowest val mse: 0.3357 at k 3
Iter: 594, running avg mse: 0.3422 lowest val mse: 0.3386 at k 3
Iter: 595, running avg mse: 0.3578 lowest val mse: 0.3604 at k 4
Iter: 596, running avg mse: 0.4237 lowest val mse: 0.3657 at k 3
Iter: 597, running avg ms

Iter: 110, running avg mse: 0.4183 lowest val mse: 0.3751 at k 3
Iter: 111, running avg mse: 0.4113 lowest val mse: 0.3732 at k 3
Iter: 112, running avg mse: 0.4062 lowest val mse: 0.3732 at k 3
Iter: 113, running avg mse: 0.4135 lowest val mse: 0.3697 at k 3
Iter: 114, running avg mse: 0.4155 lowest val mse: 0.3716 at k 3
Iter: 115, running avg mse: 0.4170 lowest val mse: 0.3690 at k 3
Iter: 116, running avg mse: 0.4041 lowest val mse: 0.3681 at k 3
Iter: 117, running avg mse: 0.4097 lowest val mse: 0.3656 at k 3
Iter: 118, running avg mse: 0.4149 lowest val mse: 0.3680 at k 3
Iter: 119, running avg mse: 0.4043 lowest val mse: 0.3661 at k 3
Iter: 120, running avg mse: 0.3947 lowest val mse: 0.3645 at k 3
Iter: 121, running avg mse: 0.3940 lowest val mse: 0.3642 at k 3
Iter: 122, running avg mse: 0.4063 lowest val mse: 0.3674 at k 3
Iter: 123, running avg mse: 0.4043 lowest val mse: 0.3682 at k 3
Iter: 124, running avg mse: 0.3994 lowest val mse: 0.3634 at k 3
Iter: 125, running avg ms

Iter: 237, running avg mse: 0.3792 lowest val mse: 0.3525 at k 3
Iter: 238, running avg mse: 0.3687 lowest val mse: 0.3522 at k 3
Iter: 239, running avg mse: 0.3738 lowest val mse: 0.3511 at k 3
Iter: 240, running avg mse: 0.3686 lowest val mse: 0.3506 at k 3
Iter: 241, running avg mse: 0.3734 lowest val mse: 0.3583 at k 3
Iter: 242, running avg mse: 0.3859 lowest val mse: 0.3675 at k 4
Iter: 243, running avg mse: 0.3841 lowest val mse: 0.3660 at k 3
Iter: 244, running avg mse: 0.3790 lowest val mse: 0.3530 at k 3
Iter: 245, running avg mse: 0.3721 lowest val mse: 0.3520 at k 3
Iter: 246, running avg mse: 0.3773 lowest val mse: 0.3520 at k 3
Iter: 247, running avg mse: 0.3702 lowest val mse: 0.3502 at k 3
Iter: 248, running avg mse: 0.3681 lowest val mse: 0.3488 at k 3
Iter: 249, running avg mse: 0.3662 lowest val mse: 0.3494 at k 3
Iter: 250, running avg mse: 0.3643 lowest val mse: 0.3476 at k 3
Iter: 251, running avg mse: 0.3637 lowest val mse: 0.3473 at k 3
Iter: 252, running avg ms

Iter: 364, running avg mse: 0.3494 lowest val mse: 0.3416 at k 3
Iter: 365, running avg mse: 0.3654 lowest val mse: 0.3489 at k 4
Iter: 366, running avg mse: 0.3652 lowest val mse: 0.3443 at k 3
Iter: 367, running avg mse: 0.3577 lowest val mse: 0.3455 at k 3
Iter: 368, running avg mse: 0.3572 lowest val mse: 0.3418 at k 3
Iter: 369, running avg mse: 0.3608 lowest val mse: 0.3417 at k 3
Iter: 370, running avg mse: 0.3522 lowest val mse: 0.3407 at k 3
Iter: 371, running avg mse: 0.3540 lowest val mse: 0.3405 at k 3
Iter: 372, running avg mse: 0.3490 lowest val mse: 0.3413 at k 3
Iter: 373, running avg mse: 0.3516 lowest val mse: 0.3402 at k 3
Iter: 374, running avg mse: 0.3526 lowest val mse: 0.3403 at k 3
Iter: 375, running avg mse: 0.3549 lowest val mse: 0.3394 at k 3
Iter: 376, running avg mse: 0.3494 lowest val mse: 0.3395 at k 3
Iter: 377, running avg mse: 0.3467 lowest val mse: 0.3392 at k 3
Iter: 378, running avg mse: 0.3561 lowest val mse: 0.3398 at k 3
Iter: 379, running avg ms

Iter: 491, running avg mse: 0.3538 lowest val mse: 0.3343 at k 3
Iter: 492, running avg mse: 0.3437 lowest val mse: 0.3337 at k 3
Iter: 493, running avg mse: 0.3414 lowest val mse: 0.3322 at k 3
Iter: 494, running avg mse: 0.3465 lowest val mse: 0.3327 at k 3
Iter: 495, running avg mse: 0.3478 lowest val mse: 0.3322 at k 3
Iter: 496, running avg mse: 0.3468 lowest val mse: 0.3321 at k 3
Iter: 497, running avg mse: 0.3405 lowest val mse: 0.3321 at k 3
Iter: 498, running avg mse: 0.3411 lowest val mse: 0.3311 at k 3
Iter: 499, running avg mse: 0.3372 lowest val mse: 0.3312 at k 3
Iter: 500, running avg mse: 0.3437 lowest val mse: 0.3320 at k 3
Iter: 501, running avg mse: 0.3467 lowest val mse: 0.3327 at k 3
Iter: 502, running avg mse: 0.3445 lowest val mse: 0.3340 at k 3
Iter: 503, running avg mse: 0.3406 lowest val mse: 0.3328 at k 3
Iter: 504, running avg mse: 0.3396 lowest val mse: 0.3316 at k 3
Iter: 505, running avg mse: 0.3391 lowest val mse: 0.3322 at k 3
Iter: 506, running avg ms

Iter: 18, running avg mse: 0.8800 lowest val mse: 0.8749 at k 4
Iter: 19, running avg mse: 0.8832 lowest val mse: 0.8737 at k 4
Iter: 20, running avg mse: 0.8823 lowest val mse: 0.8689 at k 4
Iter: 21, running avg mse: 0.8765 lowest val mse: 0.8679 at k 4
Iter: 22, running avg mse: 0.8667 lowest val mse: 0.8665 at k 4
Iter: 23, running avg mse: 0.8697 lowest val mse: 0.8639 at k 4
Iter: 24, running avg mse: 0.8588 lowest val mse: 0.8608 at k 4
Iter: 25, running avg mse: 0.8675 lowest val mse: 0.8586 at k 4
Iter: 26, running avg mse: 0.8665 lowest val mse: 0.8585 at k 4
Iter: 27, running avg mse: 0.8621 lowest val mse: 0.8539 at k 4
Iter: 28, running avg mse: 0.8496 lowest val mse: 0.8534 at k 4
Iter: 29, running avg mse: 0.8661 lowest val mse: 0.8512 at k 4
Iter: 30, running avg mse: 0.8574 lowest val mse: 0.8534 at k 4
Iter: 31, running avg mse: 0.8573 lowest val mse: 0.8464 at k 4
Iter: 32, running avg mse: 0.8562 lowest val mse: 0.8317 at k 4
Iter: 33, running avg mse: 0.8324 lowest

Iter: 146, running avg mse: 0.4322 lowest val mse: 0.3944 at k 3
Iter: 147, running avg mse: 0.4188 lowest val mse: 0.3891 at k 3
Iter: 148, running avg mse: 0.4139 lowest val mse: 0.3834 at k 3
Iter: 149, running avg mse: 0.4077 lowest val mse: 0.3788 at k 3
Iter: 150, running avg mse: 0.4086 lowest val mse: 0.3738 at k 3
Iter: 151, running avg mse: 0.3997 lowest val mse: 0.3715 at k 3
Iter: 152, running avg mse: 0.3998 lowest val mse: 0.3691 at k 3
Iter: 153, running avg mse: 0.3935 lowest val mse: 0.3687 at k 3
Iter: 154, running avg mse: 0.4048 lowest val mse: 0.3686 at k 3
Iter: 155, running avg mse: 0.4029 lowest val mse: 0.3712 at k 3
Iter: 156, running avg mse: 0.4058 lowest val mse: 0.3703 at k 3
Iter: 157, running avg mse: 0.4107 lowest val mse: 0.3712 at k 3
Iter: 158, running avg mse: 0.4059 lowest val mse: 0.3701 at k 3
Iter: 159, running avg mse: 0.4005 lowest val mse: 0.3706 at k 3
Iter: 160, running avg mse: 0.4084 lowest val mse: 0.3694 at k 3
Iter: 161, running avg ms

Iter: 273, running avg mse: 0.3653 lowest val mse: 0.3441 at k 3
Iter: 274, running avg mse: 0.3602 lowest val mse: 0.3416 at k 3
Iter: 275, running avg mse: 0.3610 lowest val mse: 0.3433 at k 3
Iter: 276, running avg mse: 0.3693 lowest val mse: 0.3524 at k 4
Iter: 277, running avg mse: 0.3644 lowest val mse: 0.3430 at k 3
Iter: 278, running avg mse: 0.3685 lowest val mse: 0.3463 at k 3
Iter: 279, running avg mse: 0.3520 lowest val mse: 0.3450 at k 3
Iter: 280, running avg mse: 0.3543 lowest val mse: 0.3433 at k 3
Iter: 281, running avg mse: 0.3616 lowest val mse: 0.3420 at k 3
Iter: 282, running avg mse: 0.3519 lowest val mse: 0.3416 at k 3
Iter: 283, running avg mse: 0.3649 lowest val mse: 0.3408 at k 3
Iter: 284, running avg mse: 0.3561 lowest val mse: 0.3406 at k 3
Iter: 285, running avg mse: 0.3541 lowest val mse: 0.3399 at k 3
Iter: 286, running avg mse: 0.3516 lowest val mse: 0.3413 at k 3
Iter: 287, running avg mse: 0.3519 lowest val mse: 0.3426 at k 3
Iter: 288, running avg ms

Iter: 400, running avg mse: 0.3544 lowest val mse: 0.3397 at k 3
Iter: 401, running avg mse: 0.3605 lowest val mse: 0.3383 at k 3
Iter: 402, running avg mse: 0.3453 lowest val mse: 0.3369 at k 3
Iter: 403, running avg mse: 0.3539 lowest val mse: 0.3366 at k 3
Iter: 404, running avg mse: 0.3490 lowest val mse: 0.3360 at k 3
Iter: 405, running avg mse: 0.3515 lowest val mse: 0.3366 at k 3
Iter: 406, running avg mse: 0.3456 lowest val mse: 0.3374 at k 3
Iter: 407, running avg mse: 0.3516 lowest val mse: 0.3357 at k 3
Iter: 408, running avg mse: 0.3542 lowest val mse: 0.3353 at k 3
Iter: 409, running avg mse: 0.3560 lowest val mse: 0.3354 at k 3
Iter: 410, running avg mse: 0.3478 lowest val mse: 0.3355 at k 3
Iter: 411, running avg mse: 0.3454 lowest val mse: 0.3362 at k 3
Iter: 412, running avg mse: 0.3499 lowest val mse: 0.3354 at k 3
Iter: 413, running avg mse: 0.3450 lowest val mse: 0.3347 at k 3
Iter: 414, running avg mse: 0.3547 lowest val mse: 0.3351 at k 3
Iter: 415, running avg ms

Iter: 527, running avg mse: 0.3361 lowest val mse: 0.3285 at k 3
Iter: 528, running avg mse: 0.3382 lowest val mse: 0.3283 at k 3
Iter: 529, running avg mse: 0.3344 lowest val mse: 0.3290 at k 3
Iter: 530, running avg mse: 0.3365 lowest val mse: 0.3316 at k 3
Iter: 531, running avg mse: 0.3441 lowest val mse: 0.3327 at k 3
Iter: 532, running avg mse: 0.3360 lowest val mse: 0.3332 at k 3
Iter: 533, running avg mse: 0.3441 lowest val mse: 0.3326 at k 3
Iter: 534, running avg mse: 0.3484 lowest val mse: 0.3317 at k 3
Iter: 535, running avg mse: 0.3383 lowest val mse: 0.3303 at k 3
Iter: 536, running avg mse: 0.3422 lowest val mse: 0.3309 at k 3
Iter: 537, running avg mse: 0.3424 lowest val mse: 0.3311 at k 3
Iter: 538, running avg mse: 0.3452 lowest val mse: 0.3306 at k 3
Iter: 539, running avg mse: 0.3425 lowest val mse: 0.3335 at k 3
Iter: 540, running avg mse: 0.3426 lowest val mse: 0.3327 at k 3
Iter: 541, running avg mse: 0.3487 lowest val mse: 0.3319 at k 3
Iter: 542, running avg ms

Iter: 54, running avg mse: 0.4633 lowest val mse: 0.4240 at k 3
Iter: 55, running avg mse: 0.4627 lowest val mse: 0.4262 at k 3
Iter: 56, running avg mse: 0.4604 lowest val mse: 0.4312 at k 3
Iter: 57, running avg mse: 0.4661 lowest val mse: 0.4192 at k 3
Iter: 58, running avg mse: 0.4599 lowest val mse: 0.4212 at k 3
Iter: 59, running avg mse: 0.4607 lowest val mse: 0.4089 at k 3
Iter: 60, running avg mse: 0.4583 lowest val mse: 0.4082 at k 3
Iter: 61, running avg mse: 0.4425 lowest val mse: 0.4017 at k 3
Iter: 62, running avg mse: 0.4463 lowest val mse: 0.3998 at k 3
Iter: 63, running avg mse: 0.4409 lowest val mse: 0.3963 at k 3
Iter: 64, running avg mse: 0.4380 lowest val mse: 0.3938 at k 3
Iter: 65, running avg mse: 0.4354 lowest val mse: 0.3917 at k 3
Iter: 66, running avg mse: 0.4291 lowest val mse: 0.3894 at k 3
Iter: 67, running avg mse: 0.4276 lowest val mse: 0.3879 at k 3
Iter: 68, running avg mse: 0.4164 lowest val mse: 0.4010 at k 3
Iter: 69, running avg mse: 0.4400 lowest

Iter: 181, running avg mse: 0.3741 lowest val mse: 0.3504 at k 3
Iter: 182, running avg mse: 0.3752 lowest val mse: 0.3490 at k 3
Iter: 183, running avg mse: 0.3610 lowest val mse: 0.3493 at k 3
Iter: 184, running avg mse: 0.3659 lowest val mse: 0.3503 at k 3
Iter: 185, running avg mse: 0.3777 lowest val mse: 0.3615 at k 3
Iter: 186, running avg mse: 0.3721 lowest val mse: 0.3536 at k 3
Iter: 187, running avg mse: 0.3765 lowest val mse: 0.3543 at k 3
Iter: 188, running avg mse: 0.3662 lowest val mse: 0.3531 at k 3
Iter: 189, running avg mse: 0.3656 lowest val mse: 0.3504 at k 3
Iter: 190, running avg mse: 0.3687 lowest val mse: 0.3509 at k 3
Iter: 191, running avg mse: 0.3555 lowest val mse: 0.3499 at k 3
Iter: 192, running avg mse: 0.3671 lowest val mse: 0.3513 at k 3
Iter: 193, running avg mse: 0.3604 lowest val mse: 0.3494 at k 3
Iter: 194, running avg mse: 0.3696 lowest val mse: 0.3506 at k 3
Iter: 195, running avg mse: 0.3667 lowest val mse: 0.3561 at k 3
Iter: 196, running avg ms

Iter: 308, running avg mse: 0.3510 lowest val mse: 0.3368 at k 3
Iter: 309, running avg mse: 0.3591 lowest val mse: 0.3373 at k 3
Iter: 310, running avg mse: 0.3560 lowest val mse: 0.3367 at k 3
Iter: 311, running avg mse: 0.3496 lowest val mse: 0.3367 at k 3
Iter: 312, running avg mse: 0.3565 lowest val mse: 0.3369 at k 3
Iter: 313, running avg mse: 0.3502 lowest val mse: 0.3363 at k 3
Iter: 314, running avg mse: 0.3472 lowest val mse: 0.3368 at k 3
Iter: 315, running avg mse: 0.3530 lowest val mse: 0.3361 at k 3
Iter: 316, running avg mse: 0.3447 lowest val mse: 0.3360 at k 3
Iter: 317, running avg mse: 0.3549 lowest val mse: 0.3368 at k 3
Iter: 318, running avg mse: 0.3489 lowest val mse: 0.3360 at k 3
Iter: 319, running avg mse: 0.3524 lowest val mse: 0.3388 at k 3
Iter: 320, running avg mse: 0.3659 lowest val mse: 0.3431 at k 3
Iter: 321, running avg mse: 0.3459 lowest val mse: 0.3372 at k 3
Iter: 322, running avg mse: 0.3508 lowest val mse: 0.3389 at k 3
Iter: 323, running avg ms

Iter: 435, running avg mse: 0.3434 lowest val mse: 0.3307 at k 3
Iter: 436, running avg mse: 0.3470 lowest val mse: 0.3306 at k 3
Iter: 437, running avg mse: 0.3373 lowest val mse: 0.3310 at k 3
Iter: 438, running avg mse: 0.3429 lowest val mse: 0.3310 at k 3
Iter: 439, running avg mse: 0.3414 lowest val mse: 0.3307 at k 3
Iter: 440, running avg mse: 0.3439 lowest val mse: 0.3306 at k 3
Iter: 441, running avg mse: 0.3407 lowest val mse: 0.3304 at k 3
Iter: 442, running avg mse: 0.3427 lowest val mse: 0.3300 at k 3
Iter: 443, running avg mse: 0.3453 lowest val mse: 0.3303 at k 3
Iter: 444, running avg mse: 0.3474 lowest val mse: 0.3341 at k 3
Iter: 445, running avg mse: 0.3624 lowest val mse: 0.3508 at k 4
Iter: 446, running avg mse: 0.3467 lowest val mse: 0.3380 at k 3
Iter: 447, running avg mse: 0.3613 lowest val mse: 0.3346 at k 3
Iter: 448, running avg mse: 0.3478 lowest val mse: 0.3351 at k 3
Iter: 449, running avg mse: 0.3544 lowest val mse: 0.3365 at k 3
Iter: 450, running avg ms

Iter: 562, running avg mse: 0.3547 lowest val mse: 0.3376 at k 3
Iter: 563, running avg mse: 0.3479 lowest val mse: 0.3368 at k 3
Iter: 564, running avg mse: 0.3445 lowest val mse: 0.3390 at k 3
Iter: 565, running avg mse: 0.3527 lowest val mse: 0.3357 at k 3
Iter: 566, running avg mse: 0.3531 lowest val mse: 0.3365 at k 3
Iter: 567, running avg mse: 0.3385 lowest val mse: 0.3340 at k 3
Iter: 568, running avg mse: 0.3449 lowest val mse: 0.3346 at k 3
Iter: 569, running avg mse: 0.3416 lowest val mse: 0.3331 at k 3
Iter: 570, running avg mse: 0.3442 lowest val mse: 0.3329 at k 3
Iter: 571, running avg mse: 0.3403 lowest val mse: 0.3333 at k 3
Iter: 572, running avg mse: 0.3404 lowest val mse: 0.3320 at k 3
Iter: 573, running avg mse: 0.3358 lowest val mse: 0.3324 at k 3
Iter: 574, running avg mse: 0.3463 lowest val mse: 0.3319 at k 3
Iter: 575, running avg mse: 0.3448 lowest val mse: 0.3328 at k 3
Iter: 576, running avg mse: 0.3557 lowest val mse: 0.3385 at k 3
Iter: 577, running avg ms

Iter: 90, running avg mse: 0.4360 lowest val mse: 0.3866 at k 3
Iter: 91, running avg mse: 0.4240 lowest val mse: 0.3870 at k 3
Iter: 92, running avg mse: 0.4146 lowest val mse: 0.3839 at k 3
Iter: 93, running avg mse: 0.4161 lowest val mse: 0.3822 at k 3
Iter: 94, running avg mse: 0.4174 lowest val mse: 0.3829 at k 3
Iter: 95, running avg mse: 0.4111 lowest val mse: 0.3803 at k 3
Iter: 96, running avg mse: 0.4198 lowest val mse: 0.3815 at k 3
Iter: 97, running avg mse: 0.4132 lowest val mse: 0.3785 at k 3
Iter: 98, running avg mse: 0.4152 lowest val mse: 0.3773 at k 3
Iter: 99, running avg mse: 0.4113 lowest val mse: 0.3766 at k 3
Iter: 100, running avg mse: 0.4079 lowest val mse: 0.3753 at k 3
Iter: 101, running avg mse: 0.4052 lowest val mse: 0.3758 at k 3
Iter: 102, running avg mse: 0.4081 lowest val mse: 0.3765 at k 3
Iter: 103, running avg mse: 0.4079 lowest val mse: 0.3855 at k 3
Iter: 104, running avg mse: 0.4145 lowest val mse: 0.3737 at k 3
Iter: 105, running avg mse: 0.4072 

Iter: 217, running avg mse: 0.3641 lowest val mse: 0.3444 at k 3
Iter: 218, running avg mse: 0.3683 lowest val mse: 0.3447 at k 3
Iter: 219, running avg mse: 0.3659 lowest val mse: 0.3446 at k 3
Iter: 220, running avg mse: 0.3652 lowest val mse: 0.3433 at k 3
Iter: 221, running avg mse: 0.3704 lowest val mse: 0.3430 at k 3
Iter: 222, running avg mse: 0.3621 lowest val mse: 0.3407 at k 3
Iter: 223, running avg mse: 0.3663 lowest val mse: 0.3408 at k 3
Iter: 224, running avg mse: 0.3728 lowest val mse: 0.3416 at k 3
Iter: 225, running avg mse: 0.3694 lowest val mse: 0.3406 at k 3
Iter: 226, running avg mse: 0.3626 lowest val mse: 0.3409 at k 3
Iter: 227, running avg mse: 0.3559 lowest val mse: 0.3408 at k 3
Iter: 228, running avg mse: 0.3618 lowest val mse: 0.3399 at k 3
Iter: 229, running avg mse: 0.3630 lowest val mse: 0.3424 at k 3
Iter: 230, running avg mse: 0.3800 lowest val mse: 0.3674 at k 3
Iter: 231, running avg mse: 0.3661 lowest val mse: 0.3491 at k 3
Iter: 232, running avg ms

Iter: 344, running avg mse: 0.3548 lowest val mse: 0.3367 at k 3
Iter: 345, running avg mse: 0.3602 lowest val mse: 0.3365 at k 3
Iter: 346, running avg mse: 0.3517 lowest val mse: 0.3353 at k 3


In [None]:
for i in range(len(plotgraph_index)):
    plot_graph(gen_index, times_index, plotgraph_index[i], deriv_index, pred_x_forgraph, orig_trajs, itr, path)

In [None]:
orig_trajs.shape

In [None]:
train_loss = np.array(train_losses)

plt.plot(train_loss, 'r')
plt.savefig('C:/Users/shiny/Documents/NeuralODE_RatTreadMill/Results_pic/TrainingLossGraph/tau6k3/trainloss_0.005_nonoise_latent8_lstm256.png')
np.save('C:/Users/shiny/Documents/NeuralODE_RatTreadMill/Results_pic/TrainingLossGraph/tau6k3/trainloss_0.005_nonoise_latent8_lstm256.npy', train_loss)

In [None]:
folder = 'runs/model_'

def save_model():
    folder = 'runs/model'
    folder = os.path.join(folder, 'ckpt')

    ckpt_path = os.path.join(folder, f'ODE_normalized_4_128_2tanh.pth')

    save_dict = {
        'model_args': get_args(),
        'optimizer_state_dict': optimizer.state_dict(),
        'data': data_get_dict(),
        'train_loss': get_losses()
    }
    
    save_dict.update(get_state_dicts())
    
    torch.save(save_dict, 'C:/Users/shiny/Documents/NeuralODE_RatTreadMill/model/All_rodent_ODE_TakenEmbedding_tau6k3_LSTM_lr0.008_latent8_LSTMautoencoder_epoch500.pth')

save_model()

In [None]:
checkpoint = torch.load('model/All_rodent_ODE_TakenEmbedding_tau6k3_LSTM_lr0.008_latent8_LSTMautoencoder_epoch380.pth')
rec.load_state_dict(checkpoint['encoder_state_dict'])
func.load_state_dict(checkpoint['odefunc_state_dict'])
dec.load_state_dict(checkpoint['decoder_state_dict'])

## Long time series generation

In [None]:
gen_index = 20
times_index = 0
deriv_index = 0
itr= 380
orig_trajs_TE = np.load('orig_trajs_TE_tau6k3.npy')
orig_trajs_TE = orig_trajs_TE.reshape(203, 200*34, 31*(k+1)+2)
samp_trajs_TE_test = orig_trajs_TE[:, :50, :]

samp_trajs_TE_test = torch.from_numpy(samp_trajs_TE_test).float().to(device).reshape(203, 50, 31*(k+1)+2)
orig_trajs = orig_trajs_TE[:, 0:0+50*gen_index, :]

pred_x = data_for_plot_graph(gen_index)

path = "Results_pic/tau6k3/longtimeseries/epoch{}".format(itr)

if not os.path.exists(path):
   os.makedirs(path)

plot_graph(gen_index, times_index, 0, deriv_index, pred_x, orig_trajs, itr, path)
plot_graph(gen_index, times_index, 4, deriv_index, pred_x, orig_trajs, itr, path)
plot_graph(gen_index, times_index, 20, deriv_index, pred_x, orig_trajs, itr, path)

In [None]:
with torch.no_grad():
    # sample from trajectorys' approx. posterior

    ts_pos = np.linspace(0, 0.25*gen_index, num=50*gen_index)
    ts_pos = torch.from_numpy(ts_pos).float().to(device)
    
    h = torch.zeros(1, samp_trajs_TE.shape[0], rnn_nhidden).to(device)
    c = torch.zeros(1, samp_trajs_TE.shape[0], rnn_nhidden).to(device)
    
    hn = h[0, :, :]
    cn = c[0, :, :]
    
    for t in reversed(range(samp_trajs_TE.size(1))):
        obs = samp_trajs_TE[:, t, :]
        out, hn, cn = rec.forward(obs, hn, cn)
    qz0_mean, qz0_logvar = out[:, :latent_dim], out[:, latent_dim:]
    epsilon = torch.randn(qz0_mean.size()).to(device)
    z0 = epsilon * torch.exp(.5 * qz0_logvar) + qz0_mean

    # forward in time and solve ode for reconstructions
    pred_z = odeint(func, z0, ts_pos).permute(1, 0, 2) #change time and batch with permute
    pred_x = dec(pred_z)

In [None]:
with torch.no_grad():
    samp_ts_forgraph = samp_ts.detach().cpu().numpy()
    pred_x = pred_x.reshape(203, 50*gen_index, 28*(k+1)+2)
    pred_x_forgraph = pred_x.detach().cpu().numpy()
    orig_trajs_forgraph = orig_trajs
    ts_pos_forgraph = ts_pos.detach().cpu().numpy()
    ts_pos_combined = np.linspace(0, 0.25*gen_index, num=50*gen_index) 
    
    times_index = 0
    dataset_value = 0
    deriv_index = 0
    fig, axes = plt.subplots(nrows=6, ncols=5, figsize=(15, 9))
    axes = axes.flatten()
    for i, ax in enumerate(axes):
        ax.scatter(ts_pos_combined[times_index:50*gen_index], orig_trajs_forgraph[dataset_value,0:50*gen_index, i], label='sampled data', s = 5)
        ax.plot(ts_pos_combined[times_index:+50*gen_index], pred_x_forgraph[dataset_value, times_index:+50*gen_index, i*(k+1)+deriv_index], 'r',
                 label='learned trajectory (t>0)')

    plt.legend()
    plt.savefig('./Results_pic/tau6k3/longtimeseries/lstm_Tied_latent8_gen10_deriv0_50.png', dpi=500)
    #plt.savefig('./minibatchfps200_take300_predict900_positionalvalue3.png', dpi=500)
    print('Saved visualization figure at {}'.format('./test.png'))

In [None]:
with torch.no_grad():
    samp_ts_forgraph = samp_ts.detach().cpu().numpy()
    pred_x = pred_x.reshape(203, 50*gen_index, 28*(k+1)+2)
    pred_x_forgraph = pred_x.detach().cpu().numpy()
    orig_trajs_forgraph = orig_trajs
    ts_pos_forgraph = ts_pos.detach().cpu().numpy()
    ts_pos_combined = np.linspace(0, 0.25*gen_index, num=50*gen_index) 
    
    times_index = 0
    dataset_value = 4
    deriv_index = 0
    fig, axes = plt.subplots(nrows=6, ncols=5, figsize=(15, 9))
    axes = axes.flatten()
    for i, ax in enumerate(axes):
        ax.scatter(ts_pos_combined[times_index:50*gen_index], orig_trajs_forgraph[dataset_value,0:50*gen_index, i], label='sampled data', s = 5)
        ax.plot(ts_pos_combined[times_index:+50*gen_index], pred_x_forgraph[dataset_value, times_index:+50*gen_index, i*(k+1)+deriv_index], 'r',
                 label='learned trajectory (t>0)')

    plt.legend()
    plt.savefig('./Results_pic/tau6k3/longtimeseries/lstm_split_latent8_gen10_deriv0_50.png', dpi=500)
    #plt.savefig('./minibatchfps200_take300_predict900_positionalvalue3.png', dpi=500)
    print('Saved visualization figure at {}'.format('./test.png'))

In [None]:
with torch.no_grad():
    samp_ts_forgraph = samp_ts.detach().cpu().numpy()
    pred_x = pred_x.reshape(203, 50*gen_index, 28*(k+1)+2)
    pred_x_forgraph = pred_x.detach().cpu().numpy()
    orig_trajs_forgraph = orig_trajs
    ts_pos_forgraph = ts_pos.detach().cpu().numpy()
    ts_pos_combined = np.linspace(0, 0.25*gen_index, num=50*gen_index) 
    
    times_index = 0
    dataset_value = 20
    deriv_index = 0
    fig, axes = plt.subplots(nrows=6, ncols=5, figsize=(15, 9))
    axes = axes.flatten()
    for i, ax in enumerate(axes):
        ax.scatter(ts_pos_combined[times_index:50*gen_index], orig_trajs_forgraph[dataset_value,0:50*gen_index, i], label='sampled data', s = 5)
        ax.plot(ts_pos_combined[times_index:+50*gen_index], pred_x_forgraph[dataset_value, times_index:+50*gen_index, i*(k+1)+deriv_index], 'r',
                 label='learned trajectory (t>0)')

    plt.legend()
    plt.savefig('./Results_pic/tau6k3/longtimeseries/lstm_Tied_latent8_gen10_deriv0_50_wash.png', dpi=500)
    #plt.savefig('./minibatchfps200_take300_predict900_positionalvalue3.png', dpi=500)
    print('Saved visualization figure at {}'.format('./test.png'))

In [None]:
with torch.no_grad():
    samp_ts_forgraph = samp_ts.detach().cpu().numpy()
    pred_x = pred_x.reshape(203, 50*gen_index, 31*(k+1)+2)
    pred_x_forgraph = pred_x.detach().cpu().numpy()
    orig_trajs_forgraph = orig_trajs
    ts_pos_forgraph = ts_pos.detach().cpu().numpy()
    ts_pos_combined = np.linspace(0, 0.25*gen_index, num=2500) 
    
    times_index = 0
    positional_value = 0
    fig, axes = plt.subplots(nrows=4, ncols=1, figsize=(15, 9))
    axes = axes.flatten()
    for i, ax in enumerate(axes):
        ax.scatter(ts_pos_combined[0:50*gen_index], orig_trajs_forgraph[i,0:50*gen_index, positional_value], label='sampled data', s = 5)
        ax.plot(ts_pos_combined[times_index:+50*gen_index], pred_x_forgraph[i, times_index:+50*gen_index, positional_value*(k+1)], 'r',
                 label='learned trajectory (t>0)')

    plt.legend()
    plt.savefig('./Results_pic/tau6k3/longtimeseries/Allrodent_Tau6k3_takenembedding_longepochgeneration_position0.png', dpi=500)
    #plt.savefig('./minibatchfps200_take300_predict900_positionalvalue3.png', dpi=500)
    print('Saved visualization figure at {}'.format('./test.png'))

## Predicting longer timescales with combining minibatches

In [None]:
with torch.no_grad():
    # sample from trajectorys' approx. posterior

    ts_pos = np.linspace(0, np.pi*2, num=50)
    ts_pos = torch.from_numpy(ts_pos).float().to(device)
    #ts_neg = np.linspace(-np.pi*20, 0., num=400)[::-1].copy()
    #ts_neg = torch.from_numpy(ts_neg).float().to(device)
    
    h = rec.initHidden().to(device)
    for t in reversed(range(samp_trajs.size(1))):
        obs = samp_trajs[:, t, :]
        out, h = rec.forward(obs, h)
    qz0_mean, qz0_logvar = out[:, :latent_dim], out[:, latent_dim:]
    epsilon = torch.randn(qz0_mean.size()).to(device)
    z0 = epsilon * torch.exp(.5 * qz0_logvar) + qz0_mean

    # forward in time and solve ode for reconstructions
    pred_z = odeint(func, z0, ts_pos).permute(1, 0, 2)
    pred_x = dec(pred_z)

In [None]:
with torch.no_grad():
    samp_ts_forgraph = samp_ts.detach().cpu().numpy()
    pred_x = pred_x.reshape(64, 1500, 46)
    pred_x_forgraph = pred_x.detach().cpu().numpy()
    samp_trajs = samp_trajs.reshape(64, 300, 46)
    samp_trajs_forgraph = samp_trajs.detach().cpu().numpy()
    ts_pos_forgraph = ts_pos.detach().cpu().numpy()
    samp_ts_combined = np.linspace(0, 299, num=300)
    ts_pos_combined = np.linspace(0, 1499, num=1500) 
    
    times_index = 0
    positional_value = 3
    fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(15, 9))
    axes = axes.flatten()
    for i, ax in enumerate(axes):
        ax.scatter(samp_ts_combined[times_index:+300], samp_trajs_forgraph[i,times_index:+300, positional_value], label='sampled data', s = 5)
        ax.plot(ts_pos_combined[times_index:+1200], pred_x_forgraph[i, times_index:+1200, positional_value], 'r',
                 label='learned trajectory (t>0)')

    plt.legend()
    plt.savefig('./1200.png', dpi=500)
    #plt.savefig('./minibatchfps200_take300_predict900_positionalvalue3.png', dpi=500)
    print('Saved visualization figure at {}'.format('./test.png'))

## Predicting longer timescales: looking at each one

In [None]:
with torch.no_grad():
    samp_ts_forgraph = samp_ts.detach().cpu().numpy()
    pred_x_forgraph = pred_x.detach().cpu().numpy()
    samp_trajs_forgraph = samp_trajs.detach().cpu().numpy()
    ts_pos_forgraph = ts_pos.detach().cpu().numpy()
    
    times_index = 0
    positional_value = 0
    fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(15, 9))
    axes = axes.flatten()
    for i, ax in enumerate(axes):
        ax.scatter(samp_ts_combined[times_index:+50], samp_trajs_forgraph[i,times_index:+50, positional_value], label='sampled data', s = 5)
        ax.plot(ts_pos_combined[times_index:+250], pred_x_forgraph[i, times_index:+250, positional_value], 'r',
                 label='learned trajectory (t>0)')

    plt.legend()
    plt.savefig('./test.png', dpi=500)
    print('Saved visualization figure at {}'.format('./test.png'))

In [None]:
z0.shape

## PCA for z0

In [None]:
from sklearn.decomposition import PCA
import pandas as pd

In [None]:
with torch.no_grad():
    ts_pos = np.linspace(0, np.pi*2*5, num=250)
    ts_pos = torch.from_numpy(ts_pos).float().to(device)
    
    h = torch.zeros(samp_trajs.shape[0], rnn_nhidden).to(device)
    
    for t in reversed(range(samp_trajs.size(1))):
        obs = samp_trajs[:, t, :]
        out, h = rec.forward(obs, h)
    qz0_mean, qz0_logvar = out[:, :latent_dim], out[:, latent_dim:]
    epsilon = torch.randn(qz0_mean.size()).to(device)
    z0 = epsilon * torch.exp(.5 * qz0_logvar) + qz0_mean
    z0 = z0.cpu()
    
    pca = PCA(n_components=2)
    pca.fit(z0.cpu())
    print("Explained variance:", pca.explained_variance_ratio_)
    z0_red = pca.fit_transform(z0)
    
    print(z0_red[:, 0].shape)

In [None]:
z0 = z0.cpu()

In [None]:
    pca = PCA(n_components=2)
    pca.fit(z0.cpu())
    print("Explained variance:", pca.explained_variance_ratio_)

In [None]:
pca_z = PCA(n_components=2)
pca.fit(z0)

z0_red = pca.fit_transform(z0)

d = {'PC1': z0_red[:, 0], 'PC2': z0_red[:, 1]}
df = pd.DataFrame(d)

plt.figure()
plt.plot(z0_red[:, 0], z0_red[:, 1], 'o', label='z0 samples in 2D', linewidth=2, zorder=1)
plt.legend()
plt.savefig('./PCAgraph.png', dpi=250)

In [None]:
from sklearn.manifold import TSNE

In [None]:
time_start = time.time()
tsne = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=300)
tsne_results = tsne.fit_transform(data_subset)