In [1]:
import numpy as np
import torch
import torch.nn as nn
import pandas as pd
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import time
import pydde as d

In [2]:
#Parameters
samplenum = 1
input_size = 3
output_size = 3
time_length = 60; #seconds

In [3]:
# Generate simulation
dyn = d.PyDyn('test2.sim', time_length)
state_init = dyn.compute(dyn.p_init)
f = dyn.f(state_init, dyn.p_init)
df = dyn.df_dp(state_init, dyn.p_init)
dy = dyn.dy_dp(state_init, dyn.p_init)
print(state_init.y)
print(dy.shape)
#Sample targets only variables in z direction
y_target = np.zeros((samplenum, 3))
y_target[:,2] = np.random.rand(samplenum)
#x[:,0] = np.random.rand(samplenum)
y_target[:,1] = 2
print(y_target)
p = dyn.get_p(y_target.transpose(), dyn.p_init)
y_target= torch.tensor(y_target, requires_grad= True)
print(p.dtype)

[ 2.15365312e-05  1.99969848e+00  0.00000000e+00  2.15757288e-05
  1.99969848e+00  0.00000000e+00  2.15562034e-05  1.99969848e+00
  0.00000000e+00  2.14781676e-05  1.99969848e+00  0.00000000e+00
  2.13419924e-05  1.99969848e+00  0.00000000e+00  2.11482057e-05
  1.99969848e+00  0.00000000e+00  2.08974905e-05  1.99969848e+00
  0.00000000e+00  2.05906827e-05  1.99969848e+00  0.00000000e+00
  2.02287680e-05  1.99969848e+00  0.00000000e+00  1.98128796e-05
  1.99969848e+00  0.00000000e+00  1.93442938e-05  1.99969848e+00
  0.00000000e+00  1.88244269e-05  1.99969848e+00  0.00000000e+00
  1.82548306e-05  1.99969848e+00  0.00000000e+00  1.76371875e-05
  1.99969848e+00  0.00000000e+00  1.69733060e-05  1.99969848e+00
  0.00000000e+00  1.62651155e-05  1.99969848e+00  0.00000000e+00
  1.55146602e-05  1.99969848e+00  0.00000000e+00  1.47240939e-05
  1.99969848e+00  0.00000000e+00  1.38956733e-05  1.99969848e+00
  0.00000000e+00  1.30317519e-05  1.99969848e+00  0.00000000e+00
  1.21347732e-05  1.99969

## Building the custon Simulation Activation Function

In [4]:
class Simulate(torch.autograd.Function):
    
    @staticmethod
    def forward(ctx, input):
        #print(f'input: {input.shape}')
        p = input.clone().numpy().transpose()
        state = dyn.compute(p)
        y_pred = torch.tensor(state.y[-3:], requires_grad = True)
        #print(f'y_pred: {y_pred.dtype}')
        
        ctx.save_for_backward(input)
        
        return y_pred
    
    @staticmethod
    def backward(ctx, grad_output):
        #print(grad_output.shape)
        input, = ctx.saved_tensors
        p = input.clone().numpy().transpose()
        state= dyn.compute(p)
        dy_dp = dyn.dy_dp(state, p)
        dy_dp = dy_dp[-3:, :]
        #print(f'shape of dy/dp: {dy_dp.shape}')
        #print(f'shape of grad_output: {grad_output.shape}')
        grad_output = grad_output.unsqueeze(0).t()
        #print(f'shape of grad_output unsqueezed: {grad_output.shape}')
        
        grad_input = torch.tensor(dy_dp, requires_grad = True).t().mm(grad_output).t()
        #grad_input = torch.tensor(dy_dp, requires_grad = True)
        #print(f'shape of grad_input: {grad_input.dtype}')

        return grad_input, None

Simulate = Simulate.apply
class ActiveLearn(nn.Module):

    def __init__(self, n_in, out_sz):
        super(ActiveLearn, self).__init__()

        self.L_in = nn.Linear(n_in, 3*time_length).double()
        self.Relu = nn.ReLU(inplace=True).double()
        self.P = nn.Linear(3*time_length, 3*time_length).double()
        #self.L_out = nn.Linear(3, 3)
    
    def forward(self, input):
        x = self.L_in(input)
        x = self.Relu(x)
        x = self.P(x)
        x = self.Relu(x)
        x, p = Simulate(x)
        #x = self.L_out(x)
        return x, p
    
model = ActiveLearn(input_size, output_size)



In [5]:
from torch.autograd import gradcheck

# gradcheck takes a tuple of tensors as input, check if your gradient
# evaluated with these tensors are close enough to numerical
# approximations and returns True if they all verify this condition.
#p = dyn.p_init
p = torch.tensor(p, requires_grad = True)
input = (p.double())
#print(input)
test = gradcheck(Simulate, (input,), eps=1e-6, atol=1e-7, raise_exception = True)
print(test)

True
