Skip to content

Commit

Permalink
Add xml_file argument for custom mujoco envs
Browse files Browse the repository at this point in the history
  • Loading branch information
hartikainen committed Feb 4, 2019
1 parent cb7caa2 commit a646fb4
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 6 deletions.
3 changes: 2 additions & 1 deletion softlearning/environments/gym/mujoco/ant_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def __init__(self,
xml_file='ant.xml',
ctrl_cost_weight=0.5,
contact_cost_weight=5e-4,
healthy_reward=1.0,
Expand All @@ -34,7 +35,7 @@ def __init__(self,
self._exclude_current_positions_from_observation = (
exclude_current_positions_from_observation)

mujoco_env.MujocoEnv.__init__(self, 'ant.xml', 5)
mujoco_env.MujocoEnv.__init__(self, xml_file, 5)

@property
def healthy_reward(self):
Expand Down
3 changes: 2 additions & 1 deletion softlearning/environments/gym/mujoco/half_cheetah_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def __init__(self,
xml_file='half_cheetah.xml',
forward_reward_weight=1.0,
ctrl_cost_weight=0.1,
reset_noise_scale=0.1,
Expand All @@ -25,7 +26,7 @@ def __init__(self,
self._exclude_current_positions_from_observation = (
exclude_current_positions_from_observation)

mujoco_env.MujocoEnv.__init__(self, 'half_cheetah.xml', 5)
mujoco_env.MujocoEnv.__init__(self, xml_file, 5)

def control_cost(self, action):
control_cost = self._ctrl_cost_weight * np.sum(np.square(action))
Expand Down
3 changes: 2 additions & 1 deletion softlearning/environments/gym/mujoco/hopper_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def __init__(self,
xml_file='hopper.xml',
forward_reward_weight=1.0,
ctrl_cost_weight=1e-3,
healthy_reward=1.0,
Expand Down Expand Up @@ -40,7 +41,7 @@ def __init__(self,
self._exclude_current_positions_from_observation = (
exclude_current_positions_from_observation)

mujoco_env.MujocoEnv.__init__(self, 'hopper.xml', 4)
mujoco_env.MujocoEnv.__init__(self, xml_file, 4)

@property
def healthy_reward(self):
Expand Down
3 changes: 2 additions & 1 deletion softlearning/environments/gym/mujoco/humanoid_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def mass_center(model, sim):

class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def __init__(self,
xml_file='humanoid.xml',
forward_reward_weight=0.25,
ctrl_cost_weight=0.1,
contact_cost_weight=5e-7,
Expand All @@ -43,7 +44,7 @@ def __init__(self,
self._exclude_current_positions_from_observation = (
exclude_current_positions_from_observation)

mujoco_env.MujocoEnv.__init__(self, 'humanoid.xml', 5)
mujoco_env.MujocoEnv.__init__(self, xml_file, 5)

@property
def healthy_reward(self):
Expand Down
3 changes: 2 additions & 1 deletion softlearning/environments/gym/mujoco/swimmer_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def __init__(self,
xml_file='swimmer.xml',
forward_reward_weight=1.0,
ctrl_cost_weight=1e-4,
reset_noise_scale=0.1,
Expand All @@ -19,7 +20,7 @@ def __init__(self,
self._exclude_current_positions_from_observation = (
exclude_current_positions_from_observation)

mujoco_env.MujocoEnv.__init__(self, 'swimmer.xml', 4)
mujoco_env.MujocoEnv.__init__(self, xml_file, 4)

def control_cost(self, action):
control_cost = self._ctrl_cost_weight * np.sum(np.square(action))
Expand Down
3 changes: 2 additions & 1 deletion softlearning/environments/gym/mujoco/walker2d_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def __init__(self,
xml_file='walker2d.xml',
forward_reward_weight=1.0,
ctrl_cost_weight=1e-3,
healthy_reward=1.0,
Expand All @@ -37,7 +38,7 @@ def __init__(self,
self._exclude_current_positions_from_observation = (
exclude_current_positions_from_observation)

mujoco_env.MujocoEnv.__init__(self, "walker2d.xml", 4)
mujoco_env.MujocoEnv.__init__(self, xml_file, 4)

@property
def healthy_reward(self):
Expand Down

0 comments on commit a646fb4

Please sign in to comment.