Skip to content

Commit

Permalink
fix random policy
Browse files Browse the repository at this point in the history
  • Loading branch information
dementrock committed Oct 19, 2016
1 parent 3e1d630 commit 55b7df1
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
26 changes: 26 additions & 0 deletions examples/nop_cartpole.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from rllab.algos.nop import NOP
from rllab.baselines.zero_baseline import ZeroBaseline
from rllab.envs.box2d.cartpole_env import CartpoleEnv
from rllab.envs.normalized_env import normalize
from rllab.policies.uniform_control_policy import UniformControlPolicy

env = normalize(CartpoleEnv())

policy = UniformControlPolicy(
env_spec=env.spec,
# The neural network policy should have two hidden layers, each with 32 hidden units.
)

baseline = ZeroBaseline(env_spec=env.spec)

algo = NOP(
env=env,
policy=policy,
baseline=baseline,
batch_size=4000,
max_path_length=100,
n_itr=40,
discount=0.99,
step_size=0.01,
)
algo.train()
15 changes: 15 additions & 0 deletions rllab/policies/uniform_control_policy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from rllab.core.parameterized import Parameterized
from rllab.core.serializable import Serializable
from rllab.distributions.delta import Delta
from rllab.policies.base import Policy
from rllab.misc.overrides import overrides

Expand All @@ -19,3 +20,17 @@ def get_action(self, observation):
def get_params_internal(self, **tags):
return []

def get_actions(self, observations):
return self.action_space.sample_n(len(observations)), dict()

@property
def vectorized(self):
return True

def reset(self, dones=None):
pass

@property
def distribution(self):
# Just a placeholder
return Delta()

0 comments on commit 55b7df1

Please sign in to comment.