In [2]:
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

In [3]:
!pip3 install pyro-ppl



In [8]:
import os
path = os.path.abspath(os.path.join(os.getcwd(),".."))
import sys
sys.path.append(os.path.dirname(os.getcwd()))
from dynamics_predict.defaults import DYNAMICS_PARAMS, HYPER_PARAMS

env_name = 'inverteddoublependulum'
data_path = path+'/data/dynamics_data/'+env_name+'/dynamics.npy'
param_dim = len(DYNAMICS_PARAMS[env_name+'dynamics'])
print('parameter dimension: ', param_dim)

train_data = np.load(data_path, allow_pickle=True)
print('number of samples in dest data: ', len(train_data))
# split data
data_s, data_a, data_param, data_s_ = [], [], [], []
for d in train_data:
    [s,a,param], s_ = d
    data_s.append(s)
    data_a.append(a)
    data_param.append(param)
    data_s_.append(s_)

data_s = np.array(data_s)
data_a = np.array(data_a)
data_param = np.array(data_param)
data_s_ = np.array(data_s_)

print(data_s.shape, data_a.shape, data_param.shape, data_s_.shape)

parameter dimension:  5
number of samples in dest data:  3549
(3549, 11) (3549, 1) (3549, 5) (3549, 11)


In [49]:
x = np.concatenate((data_s,data_a, data_s_), axis=-1)
y = data_param
print(x.shape, y.shape)

(3549, 23) (3549, 5)


In [50]:
import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule, PyroSample
import torch.nn as nn
from pyro.infer.autoguide import AutoDiagonalNormal
from pyro.infer import SVI, Trace_ELBO, Predictive
from tqdm.auto import trange, tqdm

In [136]:
x_dim = x.shape[1]
y_dim = y.shape[1]
print(x_dim, y_dim)

class Model(PyroModule):
    def __init__(self, h1=20, h2=20):
        super().__init__()
        self.fc1 = PyroModule[nn.Linear](x_dim, h1)
        self.fc1.weight = PyroSample(dist.Normal(0., 1.).expand([h1, x_dim]).to_event(2))
        self.fc1.bias = PyroSample(dist.Normal(0., 1.).expand([h1]).to_event(1))
        self.fc2 = PyroModule[nn.Linear](h1, h2)
        self.fc2.weight = PyroSample(dist.Normal(0., 1.).expand([h2, h1]).to_event(2))
        self.fc2.bias = PyroSample(dist.Normal(0., 1.).expand([h2]).to_event(1))
        self.fc3 = PyroModule[nn.Linear](h2, y_dim)
        self.fc3.weight = PyroSample(dist.Normal(0., 1.).expand([y_dim, h2]).to_event(2))
        self.fc3.bias = PyroSample(dist.Normal(0., 1.).expand([y_dim]).to_event(1))
        self.relu = nn.ReLU()

    def forward(self, x, y=None):
        batch_size = x.shape[0]
        # x = x.reshape(-1, 1)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        mu = self.fc3(x).squeeze()
        sigma = pyro.sample("sigma", dist.Uniform(0., 1.).expand([y_dim]).to_event(1))  # the to_event(1) is necessary, you’ll need to call .to_event(1) to use scalar distributions like Normal as a joint diagonal distributions over multiple variables: see: https://forum.pyro.ai/t/simple-gmm-in-pyro/3047/3
        # print(mu.shape, sigma.shape, y.shape)

        with pyro.plate("data", batch_size):
            # sigma = pyro.sample("sigma", dist.Uniform(0., 1.).expand([y_dim]).to_event(1))  
            obs = pyro.sample("obs", dist.Normal(mu, sigma).to_event(1), obs=y) # the to_event(1) is necessary
        return mu

23 5


In [137]:
model = Model()
guide = AutoDiagonalNormal(model)
adam = pyro.optim.Adam({"lr": 1e-3})
svi = SVI(model, guide, adam, loss=Trace_ELBO())

pyro.clear_param_store()
bar = trange(20)
x_train = torch.from_numpy(x).float()
y_train = torch.from_numpy(y).float()

for epoch in bar:
    loss = svi.step(x_train, y_train)
    # print(loss)
    bar.set_postfix(loss=f'{loss / x.shape[0]:.3f}')

100%|██████████| 20/20 [00:00<00:00, 180.74it/s, loss=95.852]


In [138]:
predictive = Predictive(model, guide=guide, num_samples=500)
x_test = x_train[:10]
print(x_test.shape)
preds = predictive(x_test)

y_pred = preds['obs'].T.detach().numpy().mean(axis=1)
y_std = preds['obs'].T.detach().numpy().std(axis=1)

print(y_pred, y_std)

# fig, ax = plt.subplots(figsize=(10, 5))
# ax.plot(x, y, 'o', markersize=1)
# ax.plot(x_test, y_pred)
# ax.fill_between(x_test, y_pred - y_std, y_pred + y_std,
#                 alpha=0.5, color='#ffcd3c')

torch.Size([10, 23])
[[-1.4620235  -0.9292353  -2.3858504  ... -1.5019815  -1.8051643
  -2.8081913 ]
 [ 1.2608469   1.4969504   2.1020446  ...  0.26011953  1.8149033
   1.5758591 ]
 [ 0.79152566  0.92237675  1.3267851  ... -0.35175693  1.1348053
   1.0248308 ]
 [-1.4267539  -0.08124174 -0.70871776 ... -0.19416465 -1.2698132
  -1.2513993 ]
 [-0.23448806 -1.2992532  -0.70612466 ... -2.1950686  -1.2392956
   0.31301618]] [[1.4301293  1.3278846  1.8229095  ... 1.5404499  1.7940197  2.0547945 ]
 [1.4298494  1.1652166  1.8705648  ... 1.3980244  1.4551692  0.9440567 ]
 [1.0242914  0.63166404 1.1168156  ... 0.56255925 0.6149589  0.73265785]
 [0.57261705 0.900117   0.68755174 ... 0.5758168  0.8907142  0.9437713 ]
 [0.91220987 1.0359749  1.8830494  ... 1.6415218  0.8795121  0.97397715]]
