Skip to content

Commit

Permalink
debugging why waypoint reward doesn't work with humanoid model now - …
Browse files Browse the repository at this point in the history
…similar to aant, it overfits?
  • Loading branch information
mginoya committed Apr 14, 2024
1 parent fbe9679 commit c0b0b90
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 79 deletions.
125 changes: 60 additions & 65 deletions alfredo/agents/A1/alfredo_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from alfredo.rewards import rTorques
from alfredo.rewards import rTracking_lin_vel
from alfredo.rewards import rTracking_yaw_vel
from alfredo.rewards import rTracking_Waypoint

class Alfredo(PipelineEnv):
# pyformat: disable
Expand Down Expand Up @@ -112,121 +113,91 @@ def reset(self, rng: jp.ndarray) -> State:

low, hi = -self._reset_noise_scale, self._reset_noise_scale

jcmd = self._sample_command(rng3)
wcmd = self._sample_waypoint(rng3)

qpos = self.sys.init_q + jax.random.uniform(
rng1, (self.sys.q_size(),), minval=low, maxval=hi
)
qvel = jax.random.uniform(rng2, (self.sys.qd_size(),), minval=low, maxval=hi)

pipeline_state = self.pipeline_init(qpos, qvel)

obs = self._get_obs(pipeline_state, jp.zeros(self.sys.act_size()))
com, *_ = self._com(pipeline_state)

state_info = {
'jcmd':jcmd,
'wcmd':wcmd,
'CoM': com,
}
}

obs = self._get_obs(pipeline_state, jp.zeros(self.sys.act_size()), state_info)

reward, done, zero = jp.zeros(3)
metrics = {
"reward_ctrl": zero,
"reward_alive": zero,
#"reward_velocity": zero,
"reward_torque":zero,
"agent_x_position": zero,
"agent_y_position": zero,
#"agent_x_velocity": zero,
#"agent_y_velocity": zero,
"lin_vel_reward":zero,
#"yaw_vel_reward":zero,
"reward_torque": zero,
"reward_waypoint": zero,
}

return State(pipeline_state, obs, reward, done, metrics, state_info)

def step(self, state: State, action: jp.ndarray) -> State:
"""Runs one timestep of the environment's dynamics."""
prev_pipeline_state = state.pipeline_state
pipeline_state = self.pipeline_step(prev_pipeline_state, action)
obs = self._get_obs(pipeline_state, action)

com_before, *_ = self._com(prev_pipeline_state)
com_after, *_ = self._com(pipeline_state)

lin_vel_reward = rTracking_lin_vel(self.sys,
state.pipeline_state,
com_before,
com_after,
self.dt,
state.info['jcmd'],
weight=10.0,
focus_idx_range=(0,0))

yaw_vel_reward = rTracking_yaw_vel(self.sys,
state.pipeline_state,
state.info['jcmd'],
weight=0.8,
focus_idx_range=(0,0))
pipeline_state0 = state.pipeline_state
pipeline_state = self.pipeline_step(pipeline_state0, action)

x_speed_reward = rSpeed_X(self.sys,
state.pipeline_state,
CoM_prev=com_before,
CoM_now=com_after,
dt=self.dt,
weight=self._forward_reward_weight)
waypoint_cost = rTracking_Waypoint(self.sys,
state.pipeline_state,
state.info['wcmd'],
weight=1.0,
focus_idx_range=0)

ctrl_cost = rControl_act_ss(self.sys,
state.pipeline_state,
action,
weight=-self._ctrl_cost_weight)

torque_cost = rTorques(self.sys,
state.pipeline_state,
action,
weight=-0.0003)
weight=-0.0003)

healthy_reward = rHealthy_simple_z(self.sys,
state.pipeline_state,
self._healthy_z_range,
early_terminate=self._terminate_when_unhealthy,
weight=0.2,
weight=1.0,
focus_idx_range=(0, 2))
reward = 0.0
reward = healthy_reward[0]
reward += ctrl_cost
#reward += x_speed_reward[0]


reward = healthy_reward[0]
reward += ctrl_cost
reward += torque_cost
reward += lin_vel_reward
#reward += yaw_vel_reward

state.info['CoM'] = com_after
reward += waypoint_cost


obs = self._get_obs(pipeline_state, action, state.info)
done = 1.0 - healthy_reward[1] if self._terminate_when_unhealthy else 0.0

state.metrics.update(
reward_ctrl=ctrl_cost,
reward_alive=healthy_reward[0],
#reward_velocity=x_speed_reward[0],
reward_torque=torque_cost,
agent_x_position=com_after[0],
agent_y_position=com_after[1],
#agent_x_velocity=x_speed_reward[1],
#agent_y_velocity=x_speed_reward[2],
lin_vel_reward=lin_vel_reward,
#yaw_vel_reward=yaw_vel_reward,
reward_ctrl = ctrl_cost,
reward_alive = healthy_reward[0],
reward_torque = torque_cost,
reward_waypoint = waypoint_cost,
)

return state.replace(
pipeline_state=pipeline_state, obs=obs, reward=reward, done=done
)

def _get_obs(self, pipeline_state: base.State, action: jp.ndarray) -> jp.ndarray:
def _get_obs(self, pipeline_state: base.State, action: jp.ndarray, state_info) -> jp.ndarray:
"""Observes Alfredo's body position, velocities, and angles."""

a_positions = pipeline_state.q
a_velocities = pipeline_state.qd

wcmd = state_info['wcmd']

qfrc_actuator = actuator.to_tau(
self.sys, action, pipeline_state.q, pipeline_state.qd
)
Expand All @@ -237,6 +208,7 @@ def _get_obs(self, pipeline_state: base.State, action: jp.ndarray) -> jp.ndarray
a_positions,
a_velocities,
qfrc_actuator,
wcmd
]
)

Expand Down Expand Up @@ -265,8 +237,31 @@ def _com(self, pipeline_state: base.State) -> jp.ndarray:
mass_sum,
x_i,
) # pytype: disable=bad-return-type # jax-ndarray

def _sample_waypoint(self, rng:jax.Array) -> jax.Array:
x_range = [-10, 10]
y_range = [-10, 10]
z_range = [0, 2]

_, key1, key2, key3 = jax.random.split(rng, 4)

x = jax.random.uniform(
key1, (1,), minval=x_range[0], maxval=x_range[1]
)

y = jax.random.uniform(
key2, (1,), minval=y_range[0], maxval=y_range[1]
)

z = jax.random.uniform(
key3, (1,), minval=z_range[0], maxval=z_range[1]
)

wcmd = jp.array([x[0], y[0], z[0]])

return wcmd

def _sample_command(self, rng: jax.Array) -> jax.Array:
def _sample_jcommand(self, rng: jax.Array) -> jax.Array:
lin_vel_x_range = [-0.6, 1.5] #[m/s]
lin_vel_y_range = [-0.6, 1.5] #[m/s]
yaw_vel_range = [-0.7, 0.7] #[rad/s]
Expand Down
4 changes: 2 additions & 2 deletions alfredo/agents/aant/aant.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,11 @@ def _sample_waypoint(self, rng: jax.Array) -> jax.Array:
)

y = jax.random.uniform(
key1, (1,), minval=y_range[0], maxval=y_range[1]
key2, (1,), minval=y_range[0], maxval=y_range[1]
)

z = jax.random.uniform(
key1, (1,), minval=z_range[0], maxval=z_range[1]
key3, (1,), minval=z_range[0], maxval=z_range[1]
)

wcmd = jp.array([x[0], y[0], z[0]])
Expand Down
9 changes: 2 additions & 7 deletions experiments/Alfredo-simple-walk/seq_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,13 @@ def progress(num_steps, metrics):
{
"step": num_steps,
"Total Reward": metrics["eval/episode_reward"],
#"Vel Reward": metrics["eval/episode_reward_velocity"],
"Alive Reward": metrics["eval/episode_reward_alive"],
"Ctrl Reward": metrics["eval/episode_reward_ctrl"],
"Torque Reward": metrics["eval/episode_reward_torque"],
#"a_vel_x": metrics["eval/episode_agent_x_velocity"],
#"a_vel_y": metrics["eval/episode_agent_y_velocity"],
"Linear Vel Reward": metrics["eval/episode_lin_vel_reward"],
#"Yaw Vel Reward": metrics["eval/episode_yaw_vel_reward"]
"Waypoint Reward": metrics["eval/episode_reward_waypoint"],
}
)


# ==============================
# General Variable Defs
# ==============================
Expand Down Expand Up @@ -136,7 +131,7 @@ def progress(num_steps, metrics):
train_fn = functools.partial(
ppo.train,
num_timesteps=wandb.config.len_training,
num_evals=500,
num_evals=20,
reward_scaling=0.1,
episode_length=1000,
normalize_observations=True,
Expand Down
13 changes: 8 additions & 5 deletions experiments/Alfredo-simple-walk/vis_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
agent_xml_path = f"{agents_fp}/A1/a1.xml"

scenes_fp = os.path.dirname(scenes.__file__)
print(scenes_fp)

env_xml_path = f"{scenes_fp}/{sys.argv[-2]}"
tpf_path = f"{cwd}/{sys.argv[-1]}"
Expand Down Expand Up @@ -75,17 +76,19 @@

jit_inference_fn = jax.jit(inference_fn)

x_vel = -1.0 # m/s
y_vel = 0.0 # m/s
yaw_vel = 0.0 # rad/s
jcmd = jp.array([x_vel, y_vel, yaw_vel])
# x_vel = -1.0 # m/s
# y_vel = 0.0 # m/s
# yaw_vel = 0.0 # rad/s
# jcmd = jp.array([x_vel, y_vel, yaw_vel])

wcmd = jp.array([0.0, 10.0, 0.0])

# generate policy rollout
for _ in range(episode_length):
rollout.append(state.pipeline_state)
act_rng, rng = jax.random.split(rng)

state.info['jcmd'] = jcmd
state.info['wcmd'] = wcmd
act, _ = jit_inference_fn(state.obs, act_rng)
state = jit_env_step(state, act)
print(state.info)
Expand Down

0 comments on commit c0b0b90

Please sign in to comment.