In [1]:
%cd ../..

/home/work/AnacondaProjects/combinators


In [2]:
import logging

import numpy as np
import probtorch
import torch

from examples.fep_control import fep_control
from combinators.model import active
from combinators.model import compose, foldable
from combinators.inference import importance, mcmc
from combinators import utils

In [3]:
logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)

In [4]:
target_actor = fep_control.BipedalWalkerActor(state_dim=12, batch_shape=(10,), trainable=False)
target_observer = fep_control.GenerativeObserver(observation_dim=24, batch_shape=(10,), trainable=False)
target = compose(target_observer, target_actor)

proposal_actor = fep_control.RecognitionActor(state_dim=12, action_dim=4, observation_dim=24, batch_shape=(10,),
                                              discrete_actions=False, name='BipedalWalkerActor', trainable=True)
proposal_encoder = fep_control.RecognitionEncoder(state_dim=12, observation_dim=24, batch_shape=(10,), trainable=True)
proposal = compose(proposal_encoder, proposal_actor)

agent = importance.propose(target, proposal)

In [5]:
theta, graph, log_weight = active.active_inference_test(agent, 'BipedalWalker-v2', online_inference=False, iterations=1000)

In [6]:
theta, graph, log_weight = active.active_inference(agent, 'BipedalWalker-v2', lr=5e-3, episodes=100, episode_length=1000, dream=False)

06/06/2019 16:57:23 ELBO=-6.83765033e+12 at episode 0 of length 105
06/06/2019 16:57:38 ELBO=-2.94854250e+12 at episode 1 of length 47
06/06/2019 16:57:53 ELBO=-1.65828782e+12 at episode 2 of length 85
06/06/2019 16:58:09 ELBO=-5.33672821e+11 at episode 3 of length 741
06/06/2019 16:58:24 ELBO=-3.39013193e+12 at episode 4 of length 74
06/06/2019 16:58:40 ELBO=-4.99447267e+11 at episode 5 of length 49
06/06/2019 16:58:56 ELBO=-8.84089867e+12 at episode 6 of length 67
06/06/2019 16:59:12 ELBO=-4.50738415e+12 at episode 7 of length 98
06/06/2019 16:59:28 ELBO=-3.00886209e+12 at episode 8 of length 54
06/06/2019 16:59:43 ELBO=-1.75727693e+12 at episode 9 of length 48
06/06/2019 16:59:59 ELBO=-4.28541051e+11 at episode 10 of length 74
06/06/2019 17:00:15 ELBO=-9.38139402e+10 at episode 11 of length 1000
06/06/2019 17:00:31 ELBO=-1.72031043e+12 at episode 12 of length 43
06/06/2019 17:00:47 ELBO=-2.67388416e+10 at episode 13 of length 1000
06/06/2019 17:01:02 ELBO=-5.33913108e+11 at episode 

In [7]:
log_weight

tensor([-6.3714e+10, -2.9943e+10, -5.5353e+11, -8.5832e+09, -8.3724e+10,
        -1.1613e+12, -7.1121e+09, -2.0174e+10, -5.3515e+10, -7.6192e+11],
       device='cuda:0', grad_fn=<AddBackward0>)

In [8]:
agent = mcmc.resample_move_smc(agent, moves=5)

In [13]:
theta, graph, log_weight = active.active_inference_test(agent, 'BipedalWalker-v2', online_inference=False, iterations=1000)

In [10]:
theta

(tensor([[-0.2444, -0.1347,  0.1068, -0.1160,  0.0289,  0.0774,  0.5279, -0.4575,
          -0.0143,  0.1037, -0.2604, -0.4259],
         [ 0.5905,  0.1118,  0.1545, -0.0925, -0.2623,  0.0334,  0.1364, -0.5172,
           0.2354, -0.0140,  0.2086, -0.6469],
         [ 0.5218,  0.0843,  0.2776, -0.0996, -0.1614,  0.0172,  0.4568, -0.8193,
          -0.1028, -0.7541,  0.4037, -0.2714],
         [ 0.5905,  0.1118,  0.1545, -0.0925, -0.2623,  0.0334,  0.1364, -0.5172,
           0.2354, -0.0140,  0.2086, -0.6469],
         [-0.2444, -0.1347,  0.1068, -0.1160,  0.0289,  0.0774,  0.5279, -0.4575,
          -0.0143,  0.1037, -0.2604, -0.4259],
         [ 0.5905,  0.1118,  0.1545, -0.0925, -0.2623,  0.0334,  0.1364, -0.5172,
           0.2354, -0.0140,  0.2086, -0.6469],
         [-0.2444, -0.1347,  0.1068, -0.1160,  0.0289,  0.0774,  0.5279, -0.4575,
          -0.0143,  0.1037, -0.2604, -0.4259],
         [ 0.5905,  0.1118,  0.1545, -0.0925, -0.2623,  0.0334,  0.1364, -0.5172,
           0.23

In [11]:
log_weight

tensor([-3682839.5000, -3068946.2500, -5431378.0000, -3591961.5000,
        -5023071.0000, -6094434.0000, -4884586.0000, -4503638.0000,
        -2264719.5000, -4564189.5000], device='cuda:0', grad_fn=<AddBackward0>)

In [12]:
torch.save(agent, 'examples/fep_control/fep_bipedal_walker_agent.dat')