In [61]:
import numpy as np
import torch
import torch.nn as nn
import pandas as pd
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import time
import pydde as d

In [62]:
#Data Processing
target = np.array([0, 2, 0.9])
target = torch.tensor(target).float()

In [63]:
# Load model and Sim

# Generate simulation
sim = d.PySimSeq('test2.sim', 60)
yseq = sim.compute(sim.p)
f = sim.f(sim.y, sim.ydot, sim.yddot, sim.p)
df = sim.df_dp(sim.y, sim.ydot, sim.yddot, sim.p)

class Model(nn.Module):

    def __init__(self, n_in, out_sz, hlayers):
        super().__init__()

        self.bnorm = nn.BatchNorm1d(n_in)

        layerlist = []

        for i in hlayers:
            layerlist.append(nn.Linear(n_in,i))
            layerlist.append(nn.ReLU(inplace=True))
            #layerlist.append(nn.BatchNorm1d(i))
            n_in = i
        layerlist.append(nn.Linear(hlayers[-1],out_sz))

        self.layers = nn.Sequential(*layerlist)

    def forward(self, x):

        #x = self.bnorm(x)
        x = self.layers(x)
        return x

#Load model
model_eval = Model(3, 180, [200,100])
model_eval.load_state_dict(torch.load('Trained_Model5000z.pt'));
model_eval.eval() 

Model(
  (bnorm): BatchNorm1d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layers): Sequential(
    (0): Linear(in_features=3, out_features=200, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=200, out_features=100, bias=True)
    (3): ReLU(inplace=True)
    (4): Linear(in_features=100, out_features=180, bias=True)
  )
)

In [64]:
with torch.no_grad():
    p = model_eval(target)

In [65]:
# Store p traj
sim.save_ptraj('p_traj.p', p)

In [76]:
ysim = sim.compute(p)

print(f'\ny_0: {ysim[0:3]}') 
print(f'\ny_end: {ysim[59*3:]}') 
print(f'\nDiff: {(abs(ysim[59*3:]-np.array(target)))}') 
print(f'\nError: {sum(abs(ysim[59*3:]-np.array(target)))**2}') 


y_0: [-8.34951326e-05  2.02299677e+00  1.69111837e-03]

y_end: [-0.00816879  2.02603096  0.89845623]

Diff: [0.00816879 0.02603096 0.00154374]

Error: 0.0012775972791816523
