In [1]:
%cd ../..

/home/eli/AnacondaProjects/combinators


In [2]:
import logging

import numpy as np
import probtorch
import torch

from examples.fep_control import fep_control
from combinators.model import active
from combinators.model import compose, foldable
from combinators.inference import importance, mcmc
from combinators import utils

In [3]:
logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)

In [4]:
target_actor = fep_control.BipedalWalkerActor(state_dim=12, batch_shape=(10,), trainable=False)
target_observer = fep_control.GenerativeObserver(observation_dim=24, batch_shape=(10,), trainable=False)
target = compose(target_observer, target_actor)

proposal_actor = fep_control.RecognitionActor(state_dim=12, action_dim=4, observation_dim=24, batch_shape=(10,),
                                              discrete_actions=False, name='BipedalWalkerActor', trainable=True)
proposal_encoder = fep_control.RecognitionEncoder(state_dim=12, observation_dim=24, batch_shape=(10,), trainable=True)
proposal = compose(proposal_encoder, proposal_actor)

agent = importance.propose(target, proposal)

In [5]:
theta, graph, log_weight = active.active_inference_test(agent, 'BipedalWalker-v2', online_inference=False, iterations=1000)

In [6]:
theta, graph, log_weight = active.active_inference(agent, 'BipedalWalker-v2', lr=3e-3, episodes=1000, episode_length=1000, dream=False)

06/03/2019 14:29:19 ELBO=-1.12157352e+11 at episode 0 of length 1000
06/03/2019 14:29:30 ELBO=-3.26939083e+11 at episode 1 of length 1000
06/03/2019 14:29:41 ELBO=-4.05376532e+11 at episode 2 of length 110
06/03/2019 14:29:52 ELBO=-2.23051416e+11 at episode 3 of length 1000
06/03/2019 14:30:06 ELBO=-1.70185933e+11 at episode 4 of length 1000
06/03/2019 14:30:19 ELBO=-3.91243858e+10 at episode 5 of length 1000
06/03/2019 14:30:32 ELBO=-6.94294200e+10 at episode 6 of length 1000
06/03/2019 14:30:46 ELBO=-4.12799140e+11 at episode 7 of length 79
06/03/2019 14:30:58 ELBO=-4.96400548e+10 at episode 8 of length 1000
06/03/2019 14:31:11 ELBO=-4.70753608e+10 at episode 9 of length 62
06/03/2019 14:31:24 ELBO=-2.50004452e+12 at episode 10 of length 81
06/03/2019 14:31:36 ELBO=-2.01920655e+10 at episode 11 of length 1000
06/03/2019 14:31:48 ELBO=-5.09553725e+12 at episode 12 of length 159
06/03/2019 14:32:00 ELBO=-1.27979553e+12 at episode 13 of length 55
06/03/2019 14:32:11 ELBO=-4.68054390e+12

In [7]:
log_weight

tensor([-8.1709e+10, -2.7672e+10, -2.3367e+09, -7.8521e+09, -2.5241e+09,
        -1.6180e+10, -4.4427e+09, -1.7998e+11, -6.3397e+09, -1.3984e+09],
       device='cuda:0', grad_fn=<AddBackward0>)

In [8]:
agent = mcmc.resample_move_smc(agent, moves=5)

In [9]:
theta, graph, log_weight = active.active_inference_test(agent, 'BipedalWalker-v2', online_inference=False, iterations=1000)

In [10]:
theta

(tensor([[-0.3046, -0.6031,  0.0839, -0.0646, -0.0953,  0.1359,  0.1024,  0.2257,
           0.2330,  0.3183,  0.4355,  0.0071],
         [ 0.2150, -0.5239, -0.6912, -0.7210, -0.9141, -0.3802, -0.6230,  0.6133,
          -0.0998, -0.5958, -0.1339,  0.8194],
         [-0.2766, -0.4952,  0.1299, -0.0589, -0.0343,  0.2224,  0.1721,  0.1120,
           0.1918,  0.2618,  0.3664, -0.0276],
         [-0.2585, -0.4366, -0.0673, -0.0950, -0.1417,  0.0734,  0.0766,  0.1289,
           0.2573,  0.0878,  0.3365,  0.2109],
         [-0.2766, -0.4952,  0.1299, -0.0589, -0.0343,  0.2224,  0.1721,  0.1120,
           0.1918,  0.2618,  0.3664, -0.0276],
         [-0.2585, -0.4366, -0.0673, -0.0950, -0.1417,  0.0734,  0.0766,  0.1289,
           0.2573,  0.0878,  0.3365,  0.2109],
         [-0.2766, -0.4952,  0.1299, -0.0589, -0.0343,  0.2224,  0.1721,  0.1120,
           0.1918,  0.2618,  0.3664, -0.0276],
         [-0.2766, -0.4952,  0.1299, -0.0589, -0.0343,  0.2224,  0.1721,  0.1120,
           0.19

In [11]:
log_weight

tensor([-43528.6484, -41808.1992, -41117.1953, -40729.6953, -43735.5000,
        -38172.0664, -40325.3203, -39722.0391, -24383.3672, -41158.8320],
       device='cuda:0', grad_fn=<AddBackward0>)

In [12]:
torch.save(agent, 'examples/fep_control/fep_bipedal_walker_agent.dat')