In [1]:
%cd ../..

/home/work/AnacondaProjects/combinators


In [2]:
import logging

import numpy as np
import probtorch
import torch

from examples.fep_control import fep_control
from combinators.model import active
from combinators.model import compose, foldable
from combinators.inference import importance
from combinators import utils

In [3]:
logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)

In [4]:
target_actor = fep_control.BipedalWalkerActor(state_dim=12, batch_shape=(10,), trainable=True)
target_observer = fep_control.GenerativeObserver(observation_dim=24, batch_shape=(10,), trainable=False)


proposal_actor = fep_control.RecognitionActor(state_dim=12, action_dim=4, observation_dim=24, batch_shape=(10,),
                                              discrete_actions=False, name='BipedalWalkerActor', trainable=True)
proposal_encoder = fep_control.RecognitionEncoder(state_dim=12, observation_dim=24, batch_shape=(10,), trainable=True)

In [5]:
actor = importance.propose(target_actor, proposal_actor)
observer = importance.propose(target_observer, proposal_encoder)

agent = compose(observer, actor)

In [6]:
theta, graph, log_weight = active.active_inference_test(agent, 'BipedalWalker-v2', online_inference=False, iterations=1000)

In [7]:
theta, graph, log_weight = active.active_inference(agent, 'BipedalWalker-v2', lr=1e-2, episodes=100, episode_length=1000, dream=False)

05/31/2019 14:27:48 ELBO=-1.50721298e+11 at episode 0 of length 104
05/31/2019 14:28:01 ELBO=-1.39589981e+13 at episode 1 of length 1000
05/31/2019 14:28:16 ELBO=-5.26194353e+13 at episode 2 of length 92
05/31/2019 14:28:32 ELBO=-2.62495969e+14 at episode 3 of length 52
05/31/2019 14:28:46 ELBO=-2.53405776e+13 at episode 4 of length 97
05/31/2019 14:28:58 ELBO=-7.50161085e+13 at episode 5 of length 73
05/31/2019 14:29:10 ELBO=-6.74843675e+13 at episode 6 of length 92
05/31/2019 14:29:23 ELBO=-1.01206891e+15 at episode 7 of length 115
05/31/2019 14:29:36 ELBO=-2.27674186e+14 at episode 8 of length 126
05/31/2019 14:29:49 ELBO=-5.19564610e+14 at episode 9 of length 59
05/31/2019 14:30:01 ELBO=-5.51150462e+13 at episode 10 of length 76
05/31/2019 14:30:15 ELBO=-3.21401429e+14 at episode 11 of length 1000
05/31/2019 14:30:28 ELBO=-3.02405732e+13 at episode 12 of length 1000
05/31/2019 14:30:47 ELBO=-1.28327306e+14 at episode 13 of length 1000
05/31/2019 14:31:12 ELBO=-6.86764641e+13 at epi

In [8]:
log_weight

tensor([-7.5268e+11, -1.2031e+11, -1.9584e+15, -2.1537e+11, -2.7224e+14,
        -2.5100e+11, -5.3390e+14, -1.2227e+12, -1.3600e+11, -2.8096e+13],
       grad_fn=<AddBackward0>)

In [9]:
agent = importance.smc(agent)

In [None]:
theta, graph, log_weight = active.active_inference_test(agent, 'BipedalWalker-v2', online_inference=False, iterations=1000)

In [11]:
theta

(tensor([[-10.7290, -12.5822,   7.8599, -21.5004, -12.0851,   7.5835,  14.0497,
           -6.8767,   4.8493,  16.4500,   0.1189,  16.8472],
         [-10.7290, -12.5822,   7.8599, -21.5004, -12.0851,   7.5835,  14.0497,
           -6.8767,   4.8493,  16.4500,   0.1189,  16.8472],
         [-10.7290, -12.5822,   7.8599, -21.5004, -12.0851,   7.5835,  14.0497,
           -6.8767,   4.8493,  16.4500,   0.1189,  16.8472],
         [-10.7290, -12.5822,   7.8599, -21.5004, -12.0851,   7.5835,  14.0497,
           -6.8767,   4.8493,  16.4500,   0.1189,  16.8472],
         [-10.7290, -12.5822,   7.8599, -21.5004, -12.0851,   7.5835,  14.0497,
           -6.8767,   4.8493,  16.4500,   0.1189,  16.8472],
         [-10.7290, -12.5822,   7.8599, -21.5004, -12.0851,   7.5835,  14.0497,
           -6.8767,   4.8493,  16.4500,   0.1189,  16.8472],
         [-10.7290, -12.5822,   7.8599, -21.5004, -12.0851,   7.5835,  14.0497,
           -6.8767,   4.8493,  16.4500,   0.1189,  16.8472],
         [-10

In [12]:
log_weight

tensor([-172047.4844, -172047.4844, -172047.4844, -172047.4844, -172047.4844,
        -172047.4844, -172047.4844, -172047.4844, -172047.4844, -172047.4844],
       grad_fn=<AddBackward0>)

In [13]:
torch.save(agent, 'examples/fep_control/fep_bipedal_walker_agent.dat')