In [1]:
%cd ../..

/home/eli/AnacondaProjects/combinators


In [2]:
import logging

import numpy as np
import probtorch
import torch

from examples.fep_control import fep_control
from combinators.model import active
from combinators.model import compose, foldable
from combinators.inference import importance, mcmc
from combinators import utils

In [3]:
logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)

In [4]:
target_actor = fep_control.BipedalWalkerActor(state_dim=12, batch_shape=(10,), trainable=True)
target_observer = fep_control.GenerativeObserver(observation_dim=24, batch_shape=(10,), trainable=False)
target = compose(target_observer, target_actor)

proposal_actor = fep_control.RecognitionActor(state_dim=12, action_dim=4, observation_dim=24, batch_shape=(10,),
                                              discrete_actions=False, name='BipedalWalkerActor', trainable=True)
proposal_encoder = fep_control.RecognitionEncoder(state_dim=12, observation_dim=24, batch_shape=(10,), trainable=True)
proposal = compose(proposal_encoder, proposal_actor)

agent = importance.propose(target, proposal)

In [5]:
theta, graph, log_weight = active.active_inference_test(agent, 'BipedalWalker-v2', online_inference=False, iterations=1000)

In [6]:
theta, graph, log_weight = active.active_inference(agent, 'BipedalWalker-v2', lr=1e-2, episodes=100, episode_length=1000, dream=False)

06/02/2019 18:18:00 ELBO=-3.49276242e+10 at episode 0 of length 1000
06/02/2019 18:18:12 ELBO=-1.41101902e+13 at episode 1 of length 68
06/02/2019 18:18:24 ELBO=-3.44590528e+14 at episode 2 of length 67
06/02/2019 18:18:37 ELBO=-3.31420716e+12 at episode 3 of length 62
06/02/2019 18:18:49 ELBO=-6.48390509e+11 at episode 4 of length 116
06/02/2019 18:19:00 ELBO=-3.98686585e+11 at episode 5 of length 77
06/02/2019 18:19:13 ELBO=-9.41829652e+11 at episode 6 of length 81
06/02/2019 18:19:25 ELBO=-1.37365566e+12 at episode 7 of length 60
06/02/2019 18:19:37 ELBO=-3.80222702e+11 at episode 8 of length 54
06/02/2019 18:19:48 ELBO=-5.61514611e+11 at episode 9 of length 139
06/02/2019 18:20:00 ELBO=-2.42427167e+11 at episode 10 of length 54
06/02/2019 18:20:13 ELBO=-1.03951919e+12 at episode 11 of length 97
06/02/2019 18:20:25 ELBO=-1.17920144e+12 at episode 12 of length 112
06/02/2019 18:20:36 ELBO=-2.02498140e+12 at episode 13 of length 134
06/02/2019 18:20:48 ELBO=-4.06397452e+11 at episode 

In [7]:
log_weight

tensor([-1.5191e+11, -1.9372e+11, -1.1713e+11, -1.8498e+13, -7.8495e+11,
        -8.2063e+11, -7.7951e+12, -6.8708e+16, -2.8796e+12, -3.2363e+11],
       device='cuda:0', grad_fn=<AddBackward0>)

In [8]:
agent = mcmc.resample_move_smc(agent, moves=5)

In [13]:
theta, graph, log_weight = active.active_inference_test(agent, 'BipedalWalker-v2', online_inference=False, iterations=1000)

In [10]:
theta

(tensor([[-0.1081,  0.1715,  0.0919, -0.3152,  0.0275,  0.0248,  0.4536, -0.1008,
           0.3715, -0.0021,  0.2669,  0.1505],
         [-0.0420,  0.1028, -0.0468, -0.2645, -0.0529,  0.0652,  0.2222,  0.0269,
           0.1933, -0.1356,  0.2426,  0.2001],
         [-0.0420,  0.1028, -0.0468, -0.2645, -0.0529,  0.0652,  0.2222,  0.0269,
           0.1933, -0.1356,  0.2426,  0.2001],
         [-0.0420,  0.1028, -0.0468, -0.2645, -0.0529,  0.0652,  0.2222,  0.0269,
           0.1933, -0.1356,  0.2426,  0.2001],
         [-0.0420,  0.1028, -0.0468, -0.2645, -0.0529,  0.0652,  0.2222,  0.0269,
           0.1933, -0.1356,  0.2426,  0.2001],
         [-0.1081,  0.1715,  0.0919, -0.3152,  0.0275,  0.0248,  0.4536, -0.1008,
           0.3715, -0.0021,  0.2669,  0.1505],
         [-0.0420,  0.1028, -0.0468, -0.2645, -0.0529,  0.0652,  0.2222,  0.0269,
           0.1933, -0.1356,  0.2426,  0.2001],
         [-0.0847,  0.1888,  0.4589, -1.0179, -0.0095,  0.1721,  0.5825, -0.0686,
           0.62

In [11]:
log_weight

tensor([-3017.0237, -3661.3679, -2564.0474, -3260.7517, -2857.3252, -1928.5608,
        -2940.8086, -2261.2805, -3813.4050, -2186.5847], device='cuda:0',
       grad_fn=<AddBackward0>)

In [12]:
torch.save(agent, 'examples/fep_control/fep_bipedal_walker_agent.dat')