In [1]:
import os
import argparse
import logging
import time
import numpy as np
import numpy.random as npr
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchdiffeq import odeint_adjoint as odeint
from torch.utils.data import DataLoader
os.environ['KMP_DUPLICATE_LIB_OK']='True'

In [2]:
samp_trajs_TE = torch.load('samp_trajs_TE_tau18k5_timestep500.pt')
samp_trajs_val_TE = torch.load('samp_trajs_val_TE_tau18k5_timestep500.pt')

tau = 18
k = 5
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

batch = 144

ts_num = 2.5
tot_num = 500

samp_ts = np.linspace(0, ts_num, num=tot_num)
samp_ts = torch.from_numpy(samp_ts).float().to(device)

orig_trajs_TE = np.load('orig_trajs_TE_tau18k6_timestep500.npy')
orig_trajs_TE = orig_trajs_TE.reshape(72, 1310, 6*(k+1))
samp_trajs_TE_test = orig_trajs_TE[:, :tot_num, :]
samp_trajs_TE_test = torch.from_numpy(samp_trajs_TE_test).float().to(device).reshape(72, tot_num, 6*(k+1))

#Load to Dataloader
train_loader = DataLoader(dataset = samp_trajs_TE, batch_size = batch, shuffle = True, drop_last = True)
val_loader = DataLoader(dataset = samp_trajs_TE, batch_size = 72, shuffle = True, drop_last = True)

In [4]:
class LatentODEfunc(nn.Module):

    def __init__(self, latent_dim=8, nhidden=50):
        super(LatentODEfunc, self).__init__()
        #self.tanh = nn.ELU(inplace= True)
        self.tanh = nn.Tanh()
        self.fc1 = nn.Linear(latent_dim, nhidden)
        self.fc2 = nn.Linear(nhidden, nhidden)
        self.fc3 = nn.Linear(nhidden, latent_dim)
        self.nfe = 0

    def forward(self, t, x):
        self.nfe += 1
        out = self.fc1(x)
        out = self.tanh(out)
        out = self.fc2(out)
        out = self.tanh(out)
        out = self.fc3(out)
        return out

class RecognitionRNN(nn.Module):

    def __init__(self, latent_dim=8, obs_dim=46, nhidden=50, nbatch=1):
        super(RecognitionRNN, self).__init__()
        self.nhidden = nhidden
        self.nbatch = nbatch
        #self.h1o = nn.Linear(obs_dim, latent_dim*4)
        #self.h3o = nn.Linear(latent_dim*4, latent_dim*2)
        #self.lstm = nn.LSTMCell(latent_dim*2, nhidden)
        self.h1o = nn.Linear(obs_dim, 12)
        self.h3o = nn.Linear(12, 6)
        self.lstm = nn.LSTMCell(6, nhidden)
        self.tanh = nn.Tanh()
        self.h2o = nn.Linear(nhidden, latent_dim*2)

    def forward(self, x, h, c):
        xo = self.h1o(x)
        xo = self.tanh(xo)
        xxo = self.h3o(xo)
        hn, cn = self.lstm(xxo, (h,c))
        hn = self.tanh(hn)
        out = self.h2o(hn)
        return out, hn, cn
    

    def initHidden(self):
        return torch.zeros(1, self.nbatch, self.nhidden)


class Decoder(nn.Module):

    def __init__(self, latent_dim=8, obs_dim=46, nhidden=50):
        super(Decoder, self).__init__()
        self.relu = nn.ReLU(inplace=True)
        self.tanh = nn.Tanh()
        self.fc1 = nn.Linear(latent_dim, nhidden)
        self.fc2 = nn.Linear(nhidden, nhidden*2)
        self.fc3 = nn.Linear(nhidden*2, obs_dim)

    def forward(self, z):
        out = self.fc1(z)
        out = self.tanh(out)
        out = self.fc2(out)
        out = self.tanh(out)
        out = self.fc3(out)
        return out


class RunningAverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, momentum=0.99):
        self.momentum = momentum
        self.reset()

    def reset(self):
        self.val = None
        self.avg = 0

    def update(self, val):
        if self.val is None:
            self.avg = val
        else:
            self.avg = self.avg * self.momentum + val * (1 - self.momentum)
        self.val = val


def log_normal_pdf(x, mean, logvar):
    const = torch.from_numpy(np.array([2. * np.pi])).float().to(x.device)
    const = torch.log(const)
    return -.5 * (const + logvar + (x - mean) ** 2. / torch.exp(logvar))

def mseloss(x, mean):
    loss = nn.MSELoss()
    return loss(x, mean)

def normal_kl(mu1, lv1, mu2, lv2):
    v1 = torch.exp(lv1)
    v2 = torch.exp(lv2)
    lstd1 = lv1 / 2.
    lstd2 = lv2 / 2.

    kl = lstd2 - lstd1 + ((v1 + (mu1 - mu2) ** 2.) / (2. * v2)) - .5
    return kl

def MSELoss(yhat, y):
    assert type(yhat) == torch.Tensor
    assert type(y) == torch.Tensor
    return torch.mean((yhat - y) ** 2)


def get_args():
    return {'latent_dim': latent_dim,
            'obs_dim': obs_dim,
            'nhidden': nhidden,
            'dec_nhidden' : dec_nhidden,
            'rnn_nhidden': rnn_nhidden,
            'device': device}

def get_state_dicts():
    return {'odefunc_state_dict': func.state_dict(),
            'encoder_state_dict': rec.state_dict(),
            'decoder_state_dict': dec.state_dict()}

def data_get_dict():
    return {
        'samp_trajs_TE': samp_trajs_TE,
        'samp_trajs_val_TE': samp_trajs_val_TE,
        'samp_ts': samp_ts,
    }

def get_losses():
    return {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'val_losses_k1': val_losses_k1,
        'val_losses_k2': val_losses_k2,
        'val_losses_k3': val_losses_k3,
        'val_losses_k4': val_losses_k4,
        'val_losses_k5': val_losses_k5
    }

def save_model(tau, k, latent_dim, itr):

    save_dict = {
        'model_args': get_args(),
        'optimizer_state_dict': optimizer.state_dict(),
        #'data': data_get_dict(),
        'train_loss': get_losses()
    }
    
    save_dict.update(get_state_dicts())
    
    torch.save(save_dict, 'model/ODE_TakenEmbedding_RLONG_rnn2_lstm{}_tau{}k{}_LSTM_lr0.008_latent{}_LSTMautoencoder_Dataloader_timestep{}_Trial2_epoch{}.pth'.format(rnn_nhidden, tau, k, latent_dim,tot_num, itr))

    
def data_for_plot_graph(gen_index):
    with torch.no_grad():
        # sample from trajectorys' approx. posterior

        ts_pos = np.linspace(0, ts_num*gen_index, num=tot_num*gen_index)
        ts_pos = torch.from_numpy(ts_pos).float().to(device)
    
        h = torch.zeros(1, samp_trajs_TE_test.shape[0], rnn_nhidden).to(device)
        c = torch.zeros(1, samp_trajs_TE_test.shape[0], rnn_nhidden).to(device)
    
        hn = h[0, :, :]
        cn = c[0, :, :]
    
        for t in reversed(range(samp_trajs_TE_test.size(1))):
            obs = samp_trajs_TE_test[:, t, :]
            out, hn, cn = rec.forward(obs, hn, cn)
        qz0_mean, qz0_logvar = out[:, :latent_dim], out[:, latent_dim:]
        epsilon = torch.randn(qz0_mean.size()).to(device)
        z0 = epsilon * torch.exp(.5 * qz0_logvar) + qz0_mean

        # forward in time and solve ode for reconstructions
        pred_z = odeint(func, z0, ts_pos).permute(1, 0, 2) #change time and batch with permute
        pred_x = dec(pred_z)
        
        return pred_x, pred_z
    
def plot_graph(gen_index, times_index, tot_index, dataset_value, deriv_index, pred_x_forgraph, orig_trajs, itr, path):
    with torch.no_grad():
        orig_trajs_forgraph = orig_trajs
        ts_pos_combined = np.linspace(0, ts_num*tot_index, num=tot_num*tot_index) 
        
        fig, axes = plt.subplots(nrows=6, ncols=1, figsize=(15, 9))
        axes = axes.flatten()
        
        for i, ax in enumerate(axes):
            ax.scatter(ts_pos_combined[times_index:times_index+tot_num*gen_index], orig_trajs_forgraph[dataset_value,times_index:tot_num*gen_index, i*(k+1)+deriv_index], s = 5)
            ax.plot(ts_pos_combined[times_index:times_index+tot_num*tot_index], pred_x_forgraph[dataset_value, times_index:times_index+tot_num*tot_index, i*(k+1)+deriv_index], 'r')
            ax.set_ylim(-3, 3)

        
        plot_name = 'lstm_datasetnum{}_latent{}_gen{}_deriv{}_epoch{}.png'.format(dataset_value, latent_dim, tot_index, deriv_index, itr)
        save_path = os.path.join(path, plot_name)
        plt.savefig(save_path, dpi=500)
        plt.close()
    

In [13]:
latent_dim = 12
nhidden = 64
dec_nhidden = 12
obs_dim = 6*(k+1)
rnn_nhidden = 256
nitrs = 8000
noise_std = 0.2

In [9]:
func = LatentODEfunc(latent_dim, nhidden).to(device)
rec = RecognitionRNN(latent_dim, obs_dim, rnn_nhidden, batch).to(device)
dec = Decoder(latent_dim, obs_dim, dec_nhidden).to(device)
params = (list(func.parameters()) + list(dec.parameters()) + list(rec.parameters()))
optimizer = optim.Adam(params, lr=0.008)
loss_meter = RunningAverageMeter()

train_losses = []
val_losses = []
val_losses_k1 = []
val_losses_k2 = []
val_losses_k3 = []
val_losses_k4 = []
val_losses_k5 = []
torch.cuda.empty_cache()

In [14]:
for itr in range(5001, nitrs+1):
    for data in train_loader:
        optimizer.zero_grad()
        h = rec.initHidden().to(device)
        c = rec.initHidden().to(device)
        hn = h[0, :, :]
        cn = c[0, :, :]
        for t in reversed(range(data.size(1))):
            obs = data[:, t, :]
            out, hn, cn = rec.forward(obs, hn, cn)
        qz0_mean, qz0_logvar = out[:, :latent_dim], out[:, latent_dim:]
        epsilon = torch.randn(qz0_mean.size()).to(device)
        z0 = epsilon * torch.exp(.5 * qz0_logvar) + qz0_mean   
        
        # forward in time and solve ode for reconstructions
        pred_z = odeint(func, z0, samp_ts).permute(1, 0, 2)
        pred_x = dec(pred_z)

        # compute loss
        loss = MSELoss(pred_x, data)
        loss.backward()
        optimizer.step()
        train_losses.append(loss.item())
        
    with torch.no_grad():
        for data_val in val_loader:
            h = torch.zeros(1, 72, rnn_nhidden).to(device)
            c = torch.zeros(1, 72, rnn_nhidden).to(device)
            hn = h[0, :, :]
            cn = c[0, :, :]
            
            for t in reversed(range(data_val.size(1))):
                obs = data_val[:, t, :]
                out, hn, cn = rec.forward(obs, hn, cn)
            qz0_mean, qz0_logvar = out[:, :latent_dim], out[:, latent_dim:]
            epsilon = torch.randn(qz0_mean.size()).to(device)
            z0 = epsilon * torch.exp(.5 * qz0_logvar) + qz0_mean

            #forward in time and solve ode for reconstructions
            pred_z = odeint(func, z0, samp_ts).permute(1, 0, 2)
            pred_x = dec(pred_z)
        
            #val_loss = MSELoss(pred_x, samp_trajs_val_TE) + torch.mean(-0.5 * torch.sum(1 + qz0_logvar - qz0_mean**2 - torch.exp(qz0_logvar), dim = -1)/(31*(k+1)+2))
            val_loss = MSELoss(pred_x[:,:,::(k+1)], data_val[:,:,::(k+1)])
            val_loss_k1 = MSELoss(pred_x[:,:,1::(k+1)], data_val[:,:,1::(k+1)])
            val_loss_k2 = MSELoss(pred_x[:,:,2::(k+1)], data_val[:,:,2::(k+1)])
            val_loss_k3 = MSELoss(pred_x[:,:,3::(k+1)], data_val[:,:,3::(k+1)])
            val_loss_k4 = MSELoss(pred_x[:,:,4::(k+1)], data_val[:,:,4::(k+1)])
            val_loss_k5 = MSELoss(pred_x[:,:,5::(k+1)], data_val[:,:,5::(k+1)])
            val_losses.append(val_loss)
            val_losses_k1.append(val_loss_k1)
            val_losses_k2.append(val_loss_k2)
            val_losses_k3.append(val_loss_k3)
            val_losses_k4.append(val_loss_k4)
            val_losses_k5.append(val_loss_k5)
            V = [val_loss, val_loss_k1, val_loss_k2, val_loss_k3, val_loss_k4, val_loss_k5]
            lowest_val_loss = torch.asarray(V).min(0)[0]
            deriv_index = torch.asarray(V).min(0)[1]

    if ((itr > 1000) and (itr % 15 == 0)):
        save_model(tau, k, latent_dim, itr)
        gen_index = 2
        times_index = 0
        tot_index = 10
        deriv_index = deriv_index.numpy()
        
        orig_trajs = orig_trajs_TE[:, 0:1300, :]

        pred_x, pred_z = data_for_plot_graph(tot_index)
        pred_x = pred_x.reshape(72, tot_num*tot_index, 6*(k+1))
        pred_z = pred_z.reshape(72, tot_num*tot_index, latent_dim)
        pred_x_forgraph = pred_x.detach().cpu().numpy()
        pred_z_forgraph = pred_z.detach().cpu().numpy()

        path = "Results_pic/tau{}k{}/longtimeseries/latent{}/Trial2_RLONG_data_loader_rnn2layer_lstm{}_lr0.008_timestep{}/epoch{}".format(tau, k, latent_dim, rnn_nhidden, tot_num, itr)

        if not os.path.exists(path):
           os.makedirs(path)

        
        tot_index = 2
        
        for i in range(72):
            plot_graph(2, times_index, tot_index, i, deriv_index, pred_x_forgraph, orig_trajs, itr, path)
            
        tot_index = 10
        
        for i in range(72):
            plot_graph(2, times_index, tot_index, i, deriv_index, pred_x_forgraph, orig_trajs, itr, path)


        
    print('Iter: {}, running avg mse: {:.4f} lowest val mse: {:.4f} at k {}'.format(itr, loss, lowest_val_loss, deriv_index))

Iter: 5001, running avg mse: 0.0880 lowest val mse: 0.0717 at k 3
Iter: 5002, running avg mse: 0.1071 lowest val mse: 0.0934 at k 3
Iter: 5003, running avg mse: 0.0893 lowest val mse: 0.0945 at k 2
Iter: 5004, running avg mse: 0.0943 lowest val mse: 0.1078 at k 3
Iter: 5005, running avg mse: 0.0883 lowest val mse: 0.0947 at k 3
Iter: 5006, running avg mse: 0.0998 lowest val mse: 0.0896 at k 3
Iter: 5007, running avg mse: 0.1154 lowest val mse: 0.0943 at k 4
Iter: 5008, running avg mse: 0.1108 lowest val mse: 0.1476 at k 3
Iter: 5009, running avg mse: 0.1265 lowest val mse: 0.0994 at k 3
Iter: 5010, running avg mse: 0.2003 lowest val mse: 0.3094 at k 3
Iter: 5011, running avg mse: 0.3873 lowest val mse: 0.8446 at k 4
Iter: 5012, running avg mse: 0.7296 lowest val mse: 0.7800 at k 1
Iter: 5013, running avg mse: 0.6595 lowest val mse: 0.6276 at k 3
Iter: 5014, running avg mse: 0.4279 lowest val mse: 0.3853 at k 3
Iter: 5015, running avg mse: 0.3488 lowest val mse: 0.3905 at k 0
Iter: 5016

Iter: 5126, running avg mse: 0.0833 lowest val mse: 0.0749 at k 3
Iter: 5127, running avg mse: 0.0839 lowest val mse: 0.0733 at k 3
Iter: 5128, running avg mse: 0.0862 lowest val mse: 0.0725 at k 3
Iter: 5129, running avg mse: 0.0844 lowest val mse: 0.0753 at k 3
Iter: 5130, running avg mse: 0.0817 lowest val mse: 0.0751 at k 3
Iter: 5131, running avg mse: 0.0855 lowest val mse: 0.0716 at k 3
Iter: 5132, running avg mse: 0.0845 lowest val mse: 0.0738 at k 3
Iter: 5133, running avg mse: 0.0806 lowest val mse: 0.0714 at k 3
Iter: 5134, running avg mse: 0.0820 lowest val mse: 0.0785 at k 3
Iter: 5135, running avg mse: 0.0860 lowest val mse: 0.0748 at k 3
Iter: 5136, running avg mse: 0.0837 lowest val mse: 0.0799 at k 3
Iter: 5137, running avg mse: 0.0793 lowest val mse: 0.0753 at k 3
Iter: 5138, running avg mse: 0.0784 lowest val mse: 0.0675 at k 3
Iter: 5139, running avg mse: 0.0830 lowest val mse: 0.0678 at k 3
Iter: 5140, running avg mse: 0.0814 lowest val mse: 0.0763 at k 3
Iter: 5141

Iter: 5251, running avg mse: 0.0847 lowest val mse: 0.0808 at k 3
Iter: 5252, running avg mse: 0.0857 lowest val mse: 0.0796 at k 3
Iter: 5253, running avg mse: 0.0901 lowest val mse: 0.0769 at k 3
Iter: 5254, running avg mse: 0.0826 lowest val mse: 0.0851 at k 3
Iter: 5255, running avg mse: 0.0839 lowest val mse: 0.0771 at k 3
Iter: 5256, running avg mse: 0.0919 lowest val mse: 0.0743 at k 3
Iter: 5257, running avg mse: 0.0937 lowest val mse: 0.0850 at k 3
Iter: 5258, running avg mse: 0.0904 lowest val mse: 0.0759 at k 3
Iter: 5259, running avg mse: 0.0856 lowest val mse: 0.0751 at k 3
Iter: 5260, running avg mse: 0.1044 lowest val mse: 0.0789 at k 3
Iter: 5261, running avg mse: 0.0900 lowest val mse: 0.0827 at k 3
Iter: 5262, running avg mse: 0.0894 lowest val mse: 0.0773 at k 2
Iter: 5263, running avg mse: 0.0834 lowest val mse: 0.0761 at k 3
Iter: 5264, running avg mse: 0.0896 lowest val mse: 0.0762 at k 3
Iter: 5265, running avg mse: 0.0886 lowest val mse: 0.0781 at k 3
Iter: 5266

Iter: 5376, running avg mse: 0.0900 lowest val mse: 0.0817 at k 3
Iter: 5377, running avg mse: 0.0981 lowest val mse: 0.0853 at k 3
Iter: 5378, running avg mse: 0.0925 lowest val mse: 0.0763 at k 3
Iter: 5379, running avg mse: 0.0902 lowest val mse: 0.0749 at k 3
Iter: 5380, running avg mse: 0.0918 lowest val mse: 0.0844 at k 3
Iter: 5381, running avg mse: 0.0910 lowest val mse: 0.0823 at k 3
Iter: 5382, running avg mse: 0.0917 lowest val mse: 0.0804 at k 3
Iter: 5383, running avg mse: 0.0965 lowest val mse: 0.0877 at k 4
Iter: 5384, running avg mse: 0.0860 lowest val mse: 0.0750 at k 3
Iter: 5385, running avg mse: 0.0856 lowest val mse: 0.0832 at k 3
Iter: 5386, running avg mse: 0.0876 lowest val mse: 0.0741 at k 3
Iter: 5387, running avg mse: 0.0896 lowest val mse: 0.0879 at k 3
Iter: 5388, running avg mse: 0.0838 lowest val mse: 0.0796 at k 3
Iter: 5389, running avg mse: 0.0817 lowest val mse: 0.0795 at k 3
Iter: 5390, running avg mse: 0.0865 lowest val mse: 0.0818 at k 3
Iter: 5391

Iter: 5501, running avg mse: 0.0869 lowest val mse: 0.0851 at k 3
Iter: 5502, running avg mse: 0.0829 lowest val mse: 0.0816 at k 3
Iter: 5503, running avg mse: 0.0800 lowest val mse: 0.0821 at k 3
Iter: 5504, running avg mse: 0.0876 lowest val mse: 0.0768 at k 2
Iter: 5505, running avg mse: 0.0901 lowest val mse: 0.0802 at k 3
Iter: 5506, running avg mse: 0.0939 lowest val mse: 0.0836 at k 3
Iter: 5507, running avg mse: 0.0853 lowest val mse: 0.0716 at k 3
Iter: 5508, running avg mse: 0.0883 lowest val mse: 0.0797 at k 3
Iter: 5509, running avg mse: 0.0862 lowest val mse: 0.0806 at k 3
Iter: 5510, running avg mse: 0.0810 lowest val mse: 0.0827 at k 3
Iter: 5511, running avg mse: 0.0868 lowest val mse: 0.0762 at k 3
Iter: 5512, running avg mse: 0.0910 lowest val mse: 0.0765 at k 3
Iter: 5513, running avg mse: 0.0871 lowest val mse: 0.0850 at k 3
Iter: 5514, running avg mse: 0.0865 lowest val mse: 0.0733 at k 3
Iter: 5515, running avg mse: 0.0911 lowest val mse: 0.0815 at k 3
Iter: 5516

Iter: 5626, running avg mse: 0.0802 lowest val mse: 0.0736 at k 3
Iter: 5627, running avg mse: 0.0892 lowest val mse: 0.0799 at k 3
Iter: 5628, running avg mse: 0.0876 lowest val mse: 0.0793 at k 3
Iter: 5629, running avg mse: 0.0827 lowest val mse: 0.0782 at k 3
Iter: 5630, running avg mse: 0.0841 lowest val mse: 0.0747 at k 3
Iter: 5631, running avg mse: 0.0842 lowest val mse: 0.0741 at k 2
Iter: 5632, running avg mse: 0.0825 lowest val mse: 0.0779 at k 3
Iter: 5633, running avg mse: 0.0820 lowest val mse: 0.0725 at k 3
Iter: 5634, running avg mse: 0.0845 lowest val mse: 0.0810 at k 3
Iter: 5635, running avg mse: 0.0866 lowest val mse: 0.0766 at k 3
Iter: 5636, running avg mse: 0.0847 lowest val mse: 0.0753 at k 3
Iter: 5637, running avg mse: 0.0847 lowest val mse: 0.0841 at k 3
Iter: 5638, running avg mse: 0.0857 lowest val mse: 0.0791 at k 3
Iter: 5639, running avg mse: 0.1052 lowest val mse: 0.0983 at k 3
Iter: 5640, running avg mse: 0.0968 lowest val mse: 0.0969 at k 3
Iter: 5641

Iter: 5751, running avg mse: 0.0842 lowest val mse: 0.0743 at k 3
Iter: 5752, running avg mse: 0.0849 lowest val mse: 0.0706 at k 3
Iter: 5753, running avg mse: 0.0798 lowest val mse: 0.0731 at k 3
Iter: 5754, running avg mse: 0.0844 lowest val mse: 0.0750 at k 3
Iter: 5755, running avg mse: 0.0809 lowest val mse: 0.0748 at k 3
Iter: 5756, running avg mse: 0.0804 lowest val mse: 0.0737 at k 3
Iter: 5757, running avg mse: 0.0834 lowest val mse: 0.0727 at k 2
Iter: 5758, running avg mse: 0.0842 lowest val mse: 0.0775 at k 3
Iter: 5759, running avg mse: 0.0849 lowest val mse: 0.0724 at k 3
Iter: 5760, running avg mse: 0.0822 lowest val mse: 0.0739 at k 3
Iter: 5761, running avg mse: 0.0806 lowest val mse: 0.0691 at k 3
Iter: 5762, running avg mse: 0.0843 lowest val mse: 0.0746 at k 3
Iter: 5763, running avg mse: 0.0828 lowest val mse: 0.0751 at k 3
Iter: 5764, running avg mse: 0.0803 lowest val mse: 0.0763 at k 3
Iter: 5765, running avg mse: 0.0811 lowest val mse: 0.0767 at k 3
Iter: 5766

Iter: 5876, running avg mse: 0.0899 lowest val mse: 0.0720 at k 3
Iter: 5877, running avg mse: 0.0871 lowest val mse: 0.0801 at k 3
Iter: 5878, running avg mse: 0.0879 lowest val mse: 0.0784 at k 2
Iter: 5879, running avg mse: 0.0849 lowest val mse: 0.0790 at k 3
Iter: 5880, running avg mse: 0.0911 lowest val mse: 0.0796 at k 3
Iter: 5881, running avg mse: 0.0843 lowest val mse: 0.0716 at k 3
Iter: 5882, running avg mse: 0.0859 lowest val mse: 0.0763 at k 3
Iter: 5883, running avg mse: 0.0874 lowest val mse: 0.0889 at k 3
Iter: 5884, running avg mse: 0.0893 lowest val mse: 0.0906 at k 2
Iter: 5885, running avg mse: 0.0825 lowest val mse: 0.0753 at k 2
Iter: 5886, running avg mse: 0.0912 lowest val mse: 0.0804 at k 3
Iter: 5887, running avg mse: 0.0942 lowest val mse: 0.0857 at k 3
Iter: 5888, running avg mse: 0.0979 lowest val mse: 0.0816 at k 3
Iter: 5889, running avg mse: 0.1012 lowest val mse: 0.0955 at k 4
Iter: 5890, running avg mse: 0.0887 lowest val mse: 0.0961 at k 3
Iter: 5891

Iter: 6001, running avg mse: 0.1421 lowest val mse: 0.1239 at k 4
Iter: 6002, running avg mse: 0.1250 lowest val mse: 0.0985 at k 3
Iter: 6003, running avg mse: 0.1037 lowest val mse: 0.0968 at k 2
Iter: 6004, running avg mse: 0.1116 lowest val mse: 0.1067 at k 4
Iter: 6005, running avg mse: 0.1101 lowest val mse: 0.1237 at k 3
Iter: 6006, running avg mse: 0.1281 lowest val mse: 0.1116 at k 4
Iter: 6007, running avg mse: 0.1182 lowest val mse: 0.1111 at k 3
Iter: 6008, running avg mse: 0.1107 lowest val mse: 0.1051 at k 3
Iter: 6009, running avg mse: 0.1026 lowest val mse: 0.0870 at k 3
Iter: 6010, running avg mse: 0.1031 lowest val mse: 0.0884 at k 3
Iter: 6011, running avg mse: 0.0993 lowest val mse: 0.0881 at k 3
Iter: 6012, running avg mse: 0.0926 lowest val mse: 0.0831 at k 3
Iter: 6013, running avg mse: 0.0991 lowest val mse: 0.0868 at k 3
Iter: 6014, running avg mse: 0.0920 lowest val mse: 0.0814 at k 3
Iter: 6015, running avg mse: 0.0865 lowest val mse: 0.0847 at k 3
Iter: 6016

Iter: 6126, running avg mse: 0.0812 lowest val mse: 0.0769 at k 3
Iter: 6127, running avg mse: 0.0813 lowest val mse: 0.0737 at k 3
Iter: 6128, running avg mse: 0.0786 lowest val mse: 0.0744 at k 3
Iter: 6129, running avg mse: 0.0817 lowest val mse: 0.0746 at k 3
Iter: 6130, running avg mse: 0.0794 lowest val mse: 0.0728 at k 3
Iter: 6131, running avg mse: 0.0814 lowest val mse: 0.0786 at k 3
Iter: 6132, running avg mse: 0.0799 lowest val mse: 0.0748 at k 3
Iter: 6133, running avg mse: 0.0795 lowest val mse: 0.0725 at k 3
Iter: 6134, running avg mse: 0.0821 lowest val mse: 0.0775 at k 3
Iter: 6135, running avg mse: 0.0824 lowest val mse: 0.0735 at k 3
Iter: 6136, running avg mse: 0.0831 lowest val mse: 0.0785 at k 3
Iter: 6137, running avg mse: 0.0792 lowest val mse: 0.0709 at k 3
Iter: 6138, running avg mse: 0.0814 lowest val mse: 0.0676 at k 3
Iter: 6139, running avg mse: 0.0798 lowest val mse: 0.0712 at k 3
Iter: 6140, running avg mse: 0.0781 lowest val mse: 0.0723 at k 3
Iter: 6141

Iter: 6251, running avg mse: 0.1002 lowest val mse: 0.0891 at k 3
Iter: 6252, running avg mse: 0.1041 lowest val mse: 0.0898 at k 2
Iter: 6253, running avg mse: 0.0995 lowest val mse: 0.0941 at k 4
Iter: 6254, running avg mse: 0.0937 lowest val mse: 0.0922 at k 3
Iter: 6255, running avg mse: 0.1011 lowest val mse: 0.0971 at k 4
Iter: 6256, running avg mse: 0.1221 lowest val mse: 0.0956 at k 3
Iter: 6257, running avg mse: 0.1089 lowest val mse: 0.0966 at k 3
Iter: 6258, running avg mse: 0.0981 lowest val mse: 0.0842 at k 3
Iter: 6259, running avg mse: 0.0908 lowest val mse: 0.0838 at k 3
Iter: 6260, running avg mse: 0.0915 lowest val mse: 0.0829 at k 4
Iter: 6261, running avg mse: 0.0931 lowest val mse: 0.0962 at k 3
Iter: 6262, running avg mse: 0.0915 lowest val mse: 0.1059 at k 3
Iter: 6263, running avg mse: 0.0940 lowest val mse: 0.0942 at k 3
Iter: 6264, running avg mse: 0.1112 lowest val mse: 0.1133 at k 3
Iter: 6265, running avg mse: 0.1067 lowest val mse: 0.0917 at k 3
Iter: 6266

Iter: 6376, running avg mse: 0.0880 lowest val mse: 0.0845 at k 3
Iter: 6377, running avg mse: 0.0933 lowest val mse: 0.0779 at k 3
Iter: 6378, running avg mse: 0.0803 lowest val mse: 0.0822 at k 3
Iter: 6379, running avg mse: 0.0870 lowest val mse: 0.0829 at k 3
Iter: 6380, running avg mse: 0.0845 lowest val mse: 0.0776 at k 4
Iter: 6381, running avg mse: 0.0832 lowest val mse: 0.0743 at k 2
Iter: 6382, running avg mse: 0.0841 lowest val mse: 0.0761 at k 3
Iter: 6383, running avg mse: 0.0871 lowest val mse: 0.0801 at k 3
Iter: 6384, running avg mse: 0.0845 lowest val mse: 0.0802 at k 3
Iter: 6385, running avg mse: 0.0852 lowest val mse: 0.0784 at k 3
Iter: 6386, running avg mse: 0.0835 lowest val mse: 0.0843 at k 3
Iter: 6387, running avg mse: 0.0945 lowest val mse: 0.0796 at k 4
Iter: 6388, running avg mse: 0.0848 lowest val mse: 0.0723 at k 3
Iter: 6389, running avg mse: 0.0814 lowest val mse: 0.0765 at k 4
Iter: 6390, running avg mse: 0.0846 lowest val mse: 0.0766 at k 3
Iter: 6391

Iter: 6501, running avg mse: 0.0751 lowest val mse: 0.0731 at k 3
Iter: 6502, running avg mse: 0.0828 lowest val mse: 0.0772 at k 3
Iter: 6503, running avg mse: 0.0791 lowest val mse: 0.0782 at k 3
Iter: 6504, running avg mse: 0.0771 lowest val mse: 0.0781 at k 3
Iter: 6505, running avg mse: 0.0762 lowest val mse: 0.0725 at k 3
Iter: 6506, running avg mse: 0.0767 lowest val mse: 0.0745 at k 3
Iter: 6507, running avg mse: 0.0791 lowest val mse: 0.0683 at k 3
Iter: 6508, running avg mse: 0.0814 lowest val mse: 0.0682 at k 3
Iter: 6509, running avg mse: 0.0810 lowest val mse: 0.0756 at k 3
Iter: 6510, running avg mse: 0.0812 lowest val mse: 0.0743 at k 3
Iter: 6511, running avg mse: 0.0824 lowest val mse: 0.0694 at k 3
Iter: 6512, running avg mse: 0.0801 lowest val mse: 0.0749 at k 3
Iter: 6513, running avg mse: 0.0787 lowest val mse: 0.0738 at k 3
Iter: 6514, running avg mse: 0.0805 lowest val mse: 0.0729 at k 3
Iter: 6515, running avg mse: 0.0811 lowest val mse: 0.0733 at k 3
Iter: 6516

Iter: 6626, running avg mse: 0.1503 lowest val mse: 0.1538 at k 4
Iter: 6627, running avg mse: 0.1752 lowest val mse: 0.1508 at k 3
Iter: 6628, running avg mse: 0.1552 lowest val mse: 0.1342 at k 3
Iter: 6629, running avg mse: 0.1418 lowest val mse: 0.1543 at k 4
Iter: 6630, running avg mse: 0.1594 lowest val mse: 0.1206 at k 3
Iter: 6631, running avg mse: 0.1420 lowest val mse: 0.1420 at k 2
Iter: 6632, running avg mse: 0.1732 lowest val mse: 0.1439 at k 4
Iter: 6633, running avg mse: 0.1618 lowest val mse: 0.1274 at k 3
Iter: 6634, running avg mse: 0.1383 lowest val mse: 0.1401 at k 3
Iter: 6635, running avg mse: 0.1534 lowest val mse: 0.1278 at k 3
Iter: 6636, running avg mse: 0.1484 lowest val mse: 0.1447 at k 3
Iter: 6637, running avg mse: 0.1420 lowest val mse: 0.1381 at k 3
Iter: 6638, running avg mse: 0.1858 lowest val mse: 0.1344 at k 4
Iter: 6639, running avg mse: 0.1373 lowest val mse: 0.1548 at k 4
Iter: 6640, running avg mse: 0.1515 lowest val mse: 0.1350 at k 3
Iter: 6641

Iter: 6751, running avg mse: 0.0928 lowest val mse: 0.0841 at k 3
Iter: 6752, running avg mse: 0.0883 lowest val mse: 0.0826 at k 3
Iter: 6753, running avg mse: 0.0905 lowest val mse: 0.0901 at k 3
Iter: 6754, running avg mse: 0.0886 lowest val mse: 0.0776 at k 2
Iter: 6755, running avg mse: 0.0874 lowest val mse: 0.0793 at k 3
Iter: 6756, running avg mse: 0.0871 lowest val mse: 0.0798 at k 2
Iter: 6757, running avg mse: 0.0875 lowest val mse: 0.0782 at k 3
Iter: 6758, running avg mse: 0.0901 lowest val mse: 0.0819 at k 3
Iter: 6759, running avg mse: 0.0902 lowest val mse: 0.0864 at k 3
Iter: 6760, running avg mse: 0.0873 lowest val mse: 0.0787 at k 3
Iter: 6761, running avg mse: 0.0916 lowest val mse: 0.0875 at k 3
Iter: 6762, running avg mse: 0.0841 lowest val mse: 0.0759 at k 3
Iter: 6763, running avg mse: 0.0865 lowest val mse: 0.0832 at k 3
Iter: 6764, running avg mse: 0.0897 lowest val mse: 0.0798 at k 3
Iter: 6765, running avg mse: 0.0898 lowest val mse: 0.0837 at k 3
Iter: 6766

Iter: 6876, running avg mse: 0.0908 lowest val mse: 0.0814 at k 3
Iter: 6877, running avg mse: 0.0874 lowest val mse: 0.0802 at k 3
Iter: 6878, running avg mse: 0.0841 lowest val mse: 0.0788 at k 2
Iter: 6879, running avg mse: 0.0882 lowest val mse: 0.0793 at k 3
Iter: 6880, running avg mse: 0.0915 lowest val mse: 0.0773 at k 3
Iter: 6881, running avg mse: 0.0906 lowest val mse: 0.0778 at k 2
Iter: 6882, running avg mse: 0.0841 lowest val mse: 0.0753 at k 3
Iter: 6883, running avg mse: 0.0877 lowest val mse: 0.0777 at k 3
Iter: 6884, running avg mse: 0.0838 lowest val mse: 0.0792 at k 3
Iter: 6885, running avg mse: 0.0853 lowest val mse: 0.0779 at k 3
Iter: 6886, running avg mse: 0.0798 lowest val mse: 0.0752 at k 3
Iter: 6887, running avg mse: 0.0846 lowest val mse: 0.0769 at k 3
Iter: 6888, running avg mse: 0.0852 lowest val mse: 0.0714 at k 4
Iter: 6889, running avg mse: 0.0849 lowest val mse: 0.0800 at k 3
Iter: 6890, running avg mse: 0.0846 lowest val mse: 0.0778 at k 3
Iter: 6891

Iter: 7001, running avg mse: 0.1726 lowest val mse: 0.1647 at k 4
Iter: 7002, running avg mse: 0.1525 lowest val mse: 0.1323 at k 3
Iter: 7003, running avg mse: 0.1393 lowest val mse: 0.1187 at k 2
Iter: 7004, running avg mse: 0.1320 lowest val mse: 0.1254 at k 4
Iter: 7005, running avg mse: 0.1314 lowest val mse: 0.1146 at k 3
Iter: 7006, running avg mse: 0.1339 lowest val mse: 0.1095 at k 3
Iter: 7007, running avg mse: 0.1170 lowest val mse: 0.1266 at k 3
Iter: 7008, running avg mse: 0.1207 lowest val mse: 0.1130 at k 3
Iter: 7009, running avg mse: 0.1274 lowest val mse: 0.1321 at k 3
Iter: 7010, running avg mse: 0.1281 lowest val mse: 0.1275 at k 2
Iter: 7011, running avg mse: 0.1228 lowest val mse: 0.1147 at k 4
Iter: 7012, running avg mse: 0.1204 lowest val mse: 0.1345 at k 3
Iter: 7013, running avg mse: 0.1274 lowest val mse: 0.1176 at k 3
Iter: 7014, running avg mse: 0.1302 lowest val mse: 0.1166 at k 4
Iter: 7015, running avg mse: 0.1270 lowest val mse: 0.1382 at k 3
Iter: 7016

Iter: 7126, running avg mse: 0.0907 lowest val mse: 0.0808 at k 2
Iter: 7127, running avg mse: 0.0873 lowest val mse: 0.0832 at k 3
Iter: 7128, running avg mse: 0.0881 lowest val mse: 0.0855 at k 3
Iter: 7129, running avg mse: 0.0913 lowest val mse: 0.0782 at k 3
Iter: 7130, running avg mse: 0.0934 lowest val mse: 0.0818 at k 2
Iter: 7131, running avg mse: 0.0867 lowest val mse: 0.0763 at k 3
Iter: 7132, running avg mse: 0.0938 lowest val mse: 0.0835 at k 2
Iter: 7133, running avg mse: 0.0899 lowest val mse: 0.0879 at k 4
Iter: 7134, running avg mse: 0.0939 lowest val mse: 0.0845 at k 3
Iter: 7135, running avg mse: 0.0903 lowest val mse: 0.0822 at k 3
Iter: 7136, running avg mse: 0.0911 lowest val mse: 0.0789 at k 3
Iter: 7137, running avg mse: 0.0878 lowest val mse: 0.0816 at k 3
Iter: 7138, running avg mse: 0.0908 lowest val mse: 0.0833 at k 3
Iter: 7139, running avg mse: 0.0912 lowest val mse: 0.0844 at k 3
Iter: 7140, running avg mse: 0.0910 lowest val mse: 0.0830 at k 3
Iter: 7141

KeyboardInterrupt: 

In [25]:
save_model(tau, k, latent_dim, itr)

In [6]:
checkpoint = torch.load('model/ODE_TakenEmbedding_RLONG_rnn2_lstm256_tau18k5_LSTM_lr0.008_latent12_LSTMautoencoder_Dataloader_timestep500_epoch1410.pth')
rec.load_state_dict(checkpoint['encoder_state_dict'])
func.load_state_dict(checkpoint['odefunc_state_dict'])
dec.load_state_dict(checkpoint['decoder_state_dict'])

<All keys matched successfully>

In [36]:
        save_model(tau, k, latent_dim, itr)
        gen_index = 3
        times_index = 0
        tot_index = 15
        deriv_index = deriv_index.numpy()
        
        orig_trajs = orig_trajs_TE[:, 0:0+150*6, :]

        pred_x, pred_z = data_for_plot_graph(tot_index)
        pred_x = pred_x.reshape(72, 300*tot_index, 6*(k+1))
        pred_z = pred_z.reshape(72, 300*tot_index, latent_dim)
        pred_x_forgraph = pred_x.detach().cpu().numpy()
        pred_z_forgraph = pred_z.detach().cpu().numpy()

        path = "Results_pic/tau{}k{}/longtimeseries/latent{}/RLONG_data_loader_rnn2layer_lstm{}_lr0.008/epoch{}".format(tau, k, latent_dim, rnn_nhidden, itr)

        if not os.path.exists(path):
           os.makedirs(path)

        
        plotgraph_index = [0, 1, 2, 3, 4, 5, 18, 19, 20, 21, 22, 23, 48, 49, 50, 51, 52, 53,  60, 61, 62, 63, 64, 65]
        
        tot_index = 3
        
        for i in range(72):
            plot_graph(3, times_index, tot_index, i, deriv_index, pred_x_forgraph, orig_trajs, itr, path)
            
        tot_index = 15
        
        for i in range(72):
            plot_graph(3, times_index, tot_index, i, deriv_index, pred_x_forgraph, orig_trajs, itr, path)


In [None]:
        gen_index = 50
        times_index = 0
        deriv_index = deriv_index.numpy()
        
        orig_trajs = orig_trajs_TE[:, 0:0+50*gen_index, :]

        pred_x, pred_z = data_for_plot_graph(gen_index)
        pred_x = pred_x.reshape(203, 50*gen_index, 31*(k+1)+2)
        pred_z = pred_z.reshape(203, 50*gen_index, latent_dim)
        pred_x_forgraph = pred_x.detach().cpu().numpy()
        pred_z_forgraph = pred_z.detach().cpu().numpy()

        path = "Results_pic/tau{}k{}/longtimeseries/latent{}/data_loader_rnn2layer_lstm{}_lr0.008/epoch{}".format(tau, k, latent_dim, rnn_nhidden, itr)

        if not os.path.exists(path):
           os.makedirs(path)

        gen_index = 15
        plotgraph_index = [0, 4, 20, 24, 28, 42, 43, 44, 45, 46, 47, 65]
        
        for i in range(len(plotgraph_index)):
            plot_graph(gen_index, times_index, plotgraph_index[i], deriv_index, pred_x_forgraph, orig_trajs, itr, path)
        
        gen_index = 30
        
        for i in range(len(plotgraph_index)):
            plot_graph(gen_index, times_index, plotgraph_index[i], deriv_index, pred_x_forgraph, orig_trajs, itr, path)
            
        gen_index = 50
        
        for i in range(len(plotgraph_index)):
            plot_graph(gen_index, times_index, plotgraph_index[i], deriv_index, pred_x_forgraph, orig_trajs, itr, path)

In [32]:
deriv_index

0

In [28]:
V

[array(0.35570502, dtype=float32),
 array(0.34477258, dtype=float32),
 array(0.33471537, dtype=float32),
 array(0.3282059, dtype=float32)]

In [9]:
train_loss = np.array(train_losses)

plt.plot(train_loss, 'r')
plt.savefig('C:/Users/shiny/Documents/NeuralODE_RatTreadMill/Results_pic/TrainingLossGraph/tau6k3/trainloss_0.005_nonoise_latent8_lstm256.png')
np.save('C:/Users/shiny/Documents/NeuralODE_RatTreadMill/Results_pic/TrainingLossGraph/tau6k3/trainloss_0.005_nonoise_latent8_lstm256.npy', train_loss)

In [117]:
folder = 'runs/model_'

def save_model():
    folder = 'runs/model'
    folder = os.path.join(folder, 'ckpt')

    ckpt_path = os.path.join(folder, f'ODE_normalized_4_128_2tanh.pth')

    save_dict = {
        'model_args': get_args(),
        'optimizer_state_dict': optimizer.state_dict(),
        'data': data_get_dict(),
        'train_loss': get_losses()
    }
    
    save_dict.update(get_state_dicts())
    
    torch.save(save_dict, 'C:/Users/shiny/Documents/NeuralODE_RatTreadMill/model/All_rodent_ODE_TakenEmbedding_tau6k3_LSTM_lr0.008_latent8_LSTMautoencoder_epoch500.pth')

save_model()

In [70]:
checkpoint = torch.load('model/All_rodent_ODE_TakenEmbedding_tau6k3_LSTM_lr0.008_latent8_LSTMautoencoder_epoch380.pth')
rec.load_state_dict(checkpoint['encoder_state_dict'])
func.load_state_dict(checkpoint['odefunc_state_dict'])
dec.load_state_dict(checkpoint['decoder_state_dict'])

<All keys matched successfully>

## Long time series generation

In [71]:
gen_index = 20
times_index = 0
deriv_index = 0
itr= 380
orig_trajs_TE = np.load('orig_trajs_TE_tau6k3.npy')
orig_trajs_TE = orig_trajs_TE.reshape(203, 200*34, 31*(k+1)+2)
samp_trajs_TE_test = orig_trajs_TE[:, :50, :]

samp_trajs_TE_test = torch.from_numpy(samp_trajs_TE_test).float().to(device).reshape(203, 50, 31*(k+1)+2)
orig_trajs = orig_trajs_TE[:, 0:0+50*gen_index, :]

pred_x = data_for_plot_graph(gen_index)

path = "Results_pic/tau6k3/longtimeseries/epoch{}".format(itr)

if not os.path.exists(path):
   os.makedirs(path)

plot_graph(gen_index, times_index, 0, deriv_index, pred_x, orig_trajs, itr, path)
plot_graph(gen_index, times_index, 4, deriv_index, pred_x, orig_trajs, itr, path)
plot_graph(gen_index, times_index, 20, deriv_index, pred_x, orig_trajs, itr, path)

In [35]:
with torch.no_grad():
    # sample from trajectorys' approx. posterior

    ts_pos = np.linspace(0, 0.25*gen_index, num=50*gen_index)
    ts_pos = torch.from_numpy(ts_pos).float().to(device)
    
    h = torch.zeros(1, samp_trajs_TE.shape[0], rnn_nhidden).to(device)
    c = torch.zeros(1, samp_trajs_TE.shape[0], rnn_nhidden).to(device)
    
    hn = h[0, :, :]
    cn = c[0, :, :]
    
    for t in reversed(range(samp_trajs_TE.size(1))):
        obs = samp_trajs_TE[:, t, :]
        out, hn, cn = rec.forward(obs, hn, cn)
    qz0_mean, qz0_logvar = out[:, :latent_dim], out[:, latent_dim:]
    epsilon = torch.randn(qz0_mean.size()).to(device)
    z0 = epsilon * torch.exp(.5 * qz0_logvar) + qz0_mean

    # forward in time and solve ode for reconstructions
    pred_z = odeint(func, z0, ts_pos).permute(1, 0, 2) #change time and batch with permute
    pred_x = dec(pred_z)

In [148]:
with torch.no_grad():
    samp_ts_forgraph = samp_ts.detach().cpu().numpy()
    pred_x = pred_x.reshape(203, 50*gen_index, 28*(k+1)+2)
    pred_x_forgraph = pred_x.detach().cpu().numpy()
    orig_trajs_forgraph = orig_trajs
    ts_pos_forgraph = ts_pos.detach().cpu().numpy()
    ts_pos_combined = np.linspace(0, 0.25*gen_index, num=50*gen_index) 
    
    times_index = 0
    dataset_value = 0
    deriv_index = 0
    fig, axes = plt.subplots(nrows=6, ncols=5, figsize=(15, 9))
    axes = axes.flatten()
    for i, ax in enumerate(axes):
        ax.scatter(ts_pos_combined[times_index:50*gen_index], orig_trajs_forgraph[dataset_value,0:50*gen_index, i], label='sampled data', s = 5)
        ax.plot(ts_pos_combined[times_index:+50*gen_index], pred_x_forgraph[dataset_value, times_index:+50*gen_index, i*(k+1)+deriv_index], 'r',
                 label='learned trajectory (t>0)')

    plt.legend()
    plt.savefig('./Results_pic/tau6k3/longtimeseries/lstm_Tied_latent8_gen10_deriv0_50.png', dpi=500)
    #plt.savefig('./minibatchfps200_take300_predict900_positionalvalue3.png', dpi=500)
    print('Saved visualization figure at {}'.format('./test.png'))

Saved visualization figure at ./test.png


In [149]:
with torch.no_grad():
    samp_ts_forgraph = samp_ts.detach().cpu().numpy()
    pred_x = pred_x.reshape(203, 50*gen_index, 28*(k+1)+2)
    pred_x_forgraph = pred_x.detach().cpu().numpy()
    orig_trajs_forgraph = orig_trajs
    ts_pos_forgraph = ts_pos.detach().cpu().numpy()
    ts_pos_combined = np.linspace(0, 0.25*gen_index, num=50*gen_index) 
    
    times_index = 0
    dataset_value = 4
    deriv_index = 0
    fig, axes = plt.subplots(nrows=6, ncols=5, figsize=(15, 9))
    axes = axes.flatten()
    for i, ax in enumerate(axes):
        ax.scatter(ts_pos_combined[times_index:50*gen_index], orig_trajs_forgraph[dataset_value,0:50*gen_index, i], label='sampled data', s = 5)
        ax.plot(ts_pos_combined[times_index:+50*gen_index], pred_x_forgraph[dataset_value, times_index:+50*gen_index, i*(k+1)+deriv_index], 'r',
                 label='learned trajectory (t>0)')

    plt.legend()
    plt.savefig('./Results_pic/tau6k3/longtimeseries/lstm_split_latent8_gen10_deriv0_50.png', dpi=500)
    #plt.savefig('./minibatchfps200_take300_predict900_positionalvalue3.png', dpi=500)
    print('Saved visualization figure at {}'.format('./test.png'))

Saved visualization figure at ./test.png


In [150]:
with torch.no_grad():
    samp_ts_forgraph = samp_ts.detach().cpu().numpy()
    pred_x = pred_x.reshape(203, 50*gen_index, 28*(k+1)+2)
    pred_x_forgraph = pred_x.detach().cpu().numpy()
    orig_trajs_forgraph = orig_trajs
    ts_pos_forgraph = ts_pos.detach().cpu().numpy()
    ts_pos_combined = np.linspace(0, 0.25*gen_index, num=50*gen_index) 
    
    times_index = 0
    dataset_value = 20
    deriv_index = 0
    fig, axes = plt.subplots(nrows=6, ncols=5, figsize=(15, 9))
    axes = axes.flatten()
    for i, ax in enumerate(axes):
        ax.scatter(ts_pos_combined[times_index:50*gen_index], orig_trajs_forgraph[dataset_value,0:50*gen_index, i], label='sampled data', s = 5)
        ax.plot(ts_pos_combined[times_index:+50*gen_index], pred_x_forgraph[dataset_value, times_index:+50*gen_index, i*(k+1)+deriv_index], 'r',
                 label='learned trajectory (t>0)')

    plt.legend()
    plt.savefig('./Results_pic/tau6k3/longtimeseries/lstm_Tied_latent8_gen10_deriv0_50_wash.png', dpi=500)
    #plt.savefig('./minibatchfps200_take300_predict900_positionalvalue3.png', dpi=500)
    print('Saved visualization figure at {}'.format('./test.png'))

Saved visualization figure at ./test.png


In [123]:
with torch.no_grad():
    samp_ts_forgraph = samp_ts.detach().cpu().numpy()
    pred_x = pred_x.reshape(203, 50*gen_index, 31*(k+1)+2)
    pred_x_forgraph = pred_x.detach().cpu().numpy()
    orig_trajs_forgraph = orig_trajs
    ts_pos_forgraph = ts_pos.detach().cpu().numpy()
    ts_pos_combined = np.linspace(0, 0.25*gen_index, num=2500) 
    
    times_index = 0
    positional_value = 0
    fig, axes = plt.subplots(nrows=4, ncols=1, figsize=(15, 9))
    axes = axes.flatten()
    for i, ax in enumerate(axes):
        ax.scatter(ts_pos_combined[0:50*gen_index], orig_trajs_forgraph[i,0:50*gen_index, positional_value], label='sampled data', s = 5)
        ax.plot(ts_pos_combined[times_index:+50*gen_index], pred_x_forgraph[i, times_index:+50*gen_index, positional_value*(k+1)], 'r',
                 label='learned trajectory (t>0)')

    plt.legend()
    plt.savefig('./Results_pic/tau6k3/longtimeseries/Allrodent_Tau6k3_takenembedding_longepochgeneration_position0.png', dpi=500)
    #plt.savefig('./minibatchfps200_take300_predict900_positionalvalue3.png', dpi=500)
    print('Saved visualization figure at {}'.format('./test.png'))

  fig, axes = plt.subplots(nrows=4, ncols=1, figsize=(15, 9))


Saved visualization figure at ./test.png


## Predicting longer timescales with combining minibatches

In [None]:
with torch.no_grad():
    # sample from trajectorys' approx. posterior

    ts_pos = np.linspace(0, np.pi*2, num=50)
    ts_pos = torch.from_numpy(ts_pos).float().to(device)
    #ts_neg = np.linspace(-np.pi*20, 0., num=400)[::-1].copy()
    #ts_neg = torch.from_numpy(ts_neg).float().to(device)
    
    h = rec.initHidden().to(device)
    for t in reversed(range(samp_trajs.size(1))):
        obs = samp_trajs[:, t, :]
        out, h = rec.forward(obs, h)
    qz0_mean, qz0_logvar = out[:, :latent_dim], out[:, latent_dim:]
    epsilon = torch.randn(qz0_mean.size()).to(device)
    z0 = epsilon * torch.exp(.5 * qz0_logvar) + qz0_mean

    # forward in time and solve ode for reconstructions
    pred_z = odeint(func, z0, ts_pos).permute(1, 0, 2)
    pred_x = dec(pred_z)

In [53]:
with torch.no_grad():
    samp_ts_forgraph = samp_ts.detach().cpu().numpy()
    pred_x = pred_x.reshape(64, 1500, 46)
    pred_x_forgraph = pred_x.detach().cpu().numpy()
    samp_trajs = samp_trajs.reshape(64, 300, 46)
    samp_trajs_forgraph = samp_trajs.detach().cpu().numpy()
    ts_pos_forgraph = ts_pos.detach().cpu().numpy()
    samp_ts_combined = np.linspace(0, 299, num=300)
    ts_pos_combined = np.linspace(0, 1499, num=1500) 
    
    times_index = 0
    positional_value = 3
    fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(15, 9))
    axes = axes.flatten()
    for i, ax in enumerate(axes):
        ax.scatter(samp_ts_combined[times_index:+300], samp_trajs_forgraph[i,times_index:+300, positional_value], label='sampled data', s = 5)
        ax.plot(ts_pos_combined[times_index:+1200], pred_x_forgraph[i, times_index:+1200, positional_value], 'r',
                 label='learned trajectory (t>0)')

    plt.legend()
    plt.savefig('./1200.png', dpi=500)
    #plt.savefig('./minibatchfps200_take300_predict900_positionalvalue3.png', dpi=500)
    print('Saved visualization figure at {}'.format('./test.png'))

Saved visualization figure at ./test.png


## Predicting longer timescales: looking at each one

In [67]:
with torch.no_grad():
    samp_ts_forgraph = samp_ts.detach().cpu().numpy()
    pred_x_forgraph = pred_x.detach().cpu().numpy()
    samp_trajs_forgraph = samp_trajs.detach().cpu().numpy()
    ts_pos_forgraph = ts_pos.detach().cpu().numpy()
    
    times_index = 0
    positional_value = 0
    fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(15, 9))
    axes = axes.flatten()
    for i, ax in enumerate(axes):
        ax.scatter(samp_ts_combined[times_index:+50], samp_trajs_forgraph[i,times_index:+50, positional_value], label='sampled data', s = 5)
        ax.plot(ts_pos_combined[times_index:+250], pred_x_forgraph[i, times_index:+250, positional_value], 'r',
                 label='learned trajectory (t>0)')

    plt.legend()
    plt.savefig('./test.png', dpi=500)
    print('Saved visualization figure at {}'.format('./test.png'))

Saved visualization figure at ./test.png


In [42]:
z0.shape

torch.Size([203, 4])

## PCA for z0

In [24]:
from sklearn.decomposition import PCA
import pandas as pd

In [18]:
with torch.no_grad():
    ts_pos = np.linspace(0, np.pi*2*5, num=250)
    ts_pos = torch.from_numpy(ts_pos).float().to(device)
    
    h = torch.zeros(samp_trajs.shape[0], rnn_nhidden).to(device)
    
    for t in reversed(range(samp_trajs.size(1))):
        obs = samp_trajs[:, t, :]
        out, h = rec.forward(obs, h)
    qz0_mean, qz0_logvar = out[:, :latent_dim], out[:, latent_dim:]
    epsilon = torch.randn(qz0_mean.size()).to(device)
    z0 = epsilon * torch.exp(.5 * qz0_logvar) + qz0_mean
    z0 = z0.cpu()
    
    pca = PCA(n_components=2)
    pca.fit(z0.cpu())
    print("Explained variance:", pca.explained_variance_ratio_)
    z0_red = pca.fit_transform(z0)
    
    print(z0_red[:, 0].shape)

Explained variance: [0.42682021 0.34515899]
(14616,)


In [32]:
z0 = z0.cpu()

In [38]:
    pca = PCA(n_components=2)
    pca.fit(z0.cpu())
    print("Explained variance:", pca.explained_variance_ratio_)

Explained variance: [0.42036639 0.36758968]


In [34]:
pca_z = PCA(n_components=2)
pca.fit(z0)

z0_red = pca.fit_transform(z0)

d = {'PC1': z0_red[:, 0], 'PC2': z0_red[:, 1]}
df = pd.DataFrame(d)

plt.figure()
plt.plot(z0_red[:, 0], z0_red[:, 1], 'o', label='z0 samples in 2D', linewidth=2, zorder=1)
plt.legend()
plt.savefig('./PCAgraph.png', dpi=250)

In [None]:
from sklearn.manifold import TSNE

In [None]:
time_start = time.time()
tsne = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=300)
tsne_results = tsne.fit_transform(data_subset)