diff --git a/alfredo/agents/aant/aant.py b/alfredo/agents/aant/aant.py index b8dd332..157ac5d 100644 --- a/alfredo/agents/aant/aant.py +++ b/alfredo/agents/aant/aant.py @@ -14,6 +14,7 @@ from alfredo.rewards import rTorques from alfredo.rewards import rTracking_lin_vel from alfredo.rewards import rTracking_yaw_vel +from alfredo.rewards import rUpright class AAnt(PipelineEnv): """ """ @@ -105,9 +106,11 @@ def reset(self, rng: jax.Array) -> State: 'reward_ctrl': zero, 'reward_alive': zero, 'reward_torque': zero, - 'reward_lin_vel': zero, + 'reward_lin_vel': zero, + 'reward_yaw_vel': zero, + 'reward_upright': zero, } - + return State(pipeline_state, obs, reward, done, metrics, state_info) def step(self, state: State, action: jax.Array) -> State: @@ -122,9 +125,14 @@ def step(self, state: State, action: jax.Array) -> State: jp.array([0, 0, 0]), #dummy values for current CoM self.dt, state.info['jcmd'], - weight=10.0, + weight=15.5, focus_idx_range=(0,0)) + yaw_vel_reward = rTracking_yaw_vel(self.sys, + state.pipeline_state, + state.info['jcmd'], + weight=10.8, + focus_idx_range=(0,0)) ctrl_cost = rControl_act_ss(self.sys, state.pipeline_state, @@ -135,6 +143,10 @@ def step(self, state: State, action: jax.Array) -> State: state.pipeline_state, action, weight=-0.0003) + + upright_reward = rUpright(self.sys, + state.pipeline_state, + weight=1.0) healthy_reward = rHealthy_simple_z(self.sys, state.pipeline_state, @@ -146,7 +158,9 @@ def step(self, state: State, action: jax.Array) -> State: reward = healthy_reward[0] reward += ctrl_cost reward += torque_cost + reward += upright_reward reward += lin_vel_reward + reward += yaw_vel_reward obs = self._get_obs(pipeline_state, state.info) done = 1.0 - healthy_reward[1] if self._terminate_when_unhealthy else 0.0 @@ -155,7 +169,9 @@ def step(self, state: State, action: jax.Array) -> State: reward_ctrl = ctrl_cost, reward_alive = healthy_reward[0], reward_torque = torque_cost, - reward_lin_vel = lin_vel_reward + reward_upright = upright_reward, + reward_lin_vel = lin_vel_reward, + reward_yaw_vel = yaw_vel_reward ) return state.replace( @@ -176,7 +192,7 @@ def _get_obs(self, pipeline_state, state_info) -> jax.Array: def _sample_command(self, rng: jax.Array) -> jax.Array: lin_vel_x_range = [-0.6, 1.5] #[m/s] lin_vel_y_range = [-0.6, 1.5] #[m/s] - yaw_vel_range = [-0.7, 0.7] #[rad/s] + yaw_vel_range = [-0.0, 0.0] #[rad/s] _, key1, key2, key3 = jax.random.split(rng, 4) diff --git a/alfredo/agents/aant/aant.xml b/alfredo/agents/aant/aant.xml index b35b3f2..e82b6d2 100644 --- a/alfredo/agents/aant/aant.xml +++ b/alfredo/agents/aant/aant.xml @@ -1,7 +1,7 @@ - + diff --git a/alfredo/rewards/__init__.py b/alfredo/rewards/__init__.py index 1e0d929..761a96e 100644 --- a/alfredo/rewards/__init__.py +++ b/alfredo/rewards/__init__.py @@ -3,3 +3,4 @@ from .rHealthy import * from .rControl import * from .rEnergy import * +from .rOrientation import * diff --git a/alfredo/rewards/rOrientation.py b/alfredo/rewards/rOrientation.py new file mode 100644 index 0000000..2c67881 --- /dev/null +++ b/alfredo/rewards/rOrientation.py @@ -0,0 +1,18 @@ +from typing import Tuple + +import jax +from brax import actuator, base, math +from brax.envs import PipelineEnv, State +from brax.io import mjcf +from etils import epath +from jax import numpy as jp + +def rUpright(sys: base.System, + pipeline_state: base.State, + weight = 1.0, + focus_idx_range = (0,0)) -> jp.ndarray: + + up = jp.array([0.0, 0.0, 1.0]) + rot_up = math.rotate(up, pipeline_state.x.rot[0]) + + return weight*jp.dot(up, rot_up) diff --git a/experiments/AAnt-locomotion/one_physics_step.py b/experiments/AAnt-locomotion/one_physics_step.py new file mode 100644 index 0000000..61c78c0 --- /dev/null +++ b/experiments/AAnt-locomotion/one_physics_step.py @@ -0,0 +1,59 @@ +import functools +import os +import re +import sys +from datetime import datetime + +import brax +import jax +import matplotlib.pyplot as plt +from brax import envs, math +from brax.envs.wrappers import training +from brax.io import html, json, model +from brax.training.acme import running_statistics +from brax.training.agents.ppo import networks as ppo_networks +from jax import numpy as jp + +from alfredo.agents.aant import AAnt + +backend = "positional" + +# Load desired model xml and trained param set +# get filepaths from commandline args +cwd = os.getcwd() + +# get the filepath to the env and agent xmls +import alfredo.scenes as scenes + +import alfredo.agents as agents +agents_fp = os.path.dirname(agents.__file__) +agent_xml_path = f"{agents_fp}/aant/aant.xml" + +scenes_fp = os.path.dirname(scenes.__file__) + +env_xml_paths = [f"{scenes_fp}/flatworld/flatworld_A1_env.xml"] + +# create an env and initial state +env = AAnt(backend=backend, + env_xml_path=env_xml_paths[0], + agent_xml_path=agent_xml_path) + +state = env.reset(rng=jax.random.PRNGKey(seed=0)) +#state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=0)) + +x_vel = 0.8 # m/s +y_vel = 0.0 # m/s +yaw_vel = 0.0 # rad/s +jcmd = jp.array([x_vel, y_vel, yaw_vel]) +state.info['jcmd'] = jcmd + +up = jp.array([0.0, 0.0, 1]) +# rot_up = math.rotate(up, jp.array([1, 0, 0, 0])) +rot_up = math.rotate(up, state.pipeline_state.x.rot[6]) +rew = jp.dot(up, rot_up) +print(f"x.rot = {state.pipeline_state.x.rot}") +print(f"up: {up}, rot_up: {rot_up}") +print(rew) + +print(f"\n-----------------------------------------------------------------\n") +state = env.step(state, jp.zeros(env.action_size)) diff --git a/experiments/AAnt-locomotion/training.py b/experiments/AAnt-locomotion/training.py index 0927655..3bbaca5 100644 --- a/experiments/AAnt-locomotion/training.py +++ b/experiments/AAnt-locomotion/training.py @@ -41,8 +41,10 @@ def progress(num_steps, metrics): "step": num_steps, "Total Reward": metrics["eval/episode_reward"], "Lin Vel Reward": metrics["eval/episode_reward_lin_vel"], + "Yaw Vel Reward": metrics["eval/episode_reward_yaw_vel"], "Alive Reward": metrics["eval/episode_reward_alive"], "Ctrl Reward": metrics["eval/episode_reward_ctrl"], + "Upright Reward": metrics["eval/episode_reward_upright"], "Torque Reward": metrics["eval/episode_reward_torque"], } ) @@ -91,7 +93,7 @@ def progress(num_steps, metrics): # ============================ # Training & Saving Params # ============================ -i = 0 +i = 3 for p in env_xml_paths: @@ -106,7 +108,7 @@ def progress(num_steps, metrics): d_and_t = datetime.now() print(f"[{d_and_t}] jitting start for model: {i}") - state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=0)) + state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=1)) d_and_t = datetime.now() print(f"[{d_and_t}] jitting end for model: {i}") @@ -114,7 +116,7 @@ def progress(num_steps, metrics): train_fn = functools.partial( ppo.train, num_timesteps=wandb.config.len_training, - num_evals=100, + num_evals=300, reward_scaling=0.1, episode_length=1000, normalize_observations=True, diff --git a/experiments/AAnt-locomotion/vis_traj.py b/experiments/AAnt-locomotion/vis_traj.py index c4b17a9..b081cfb 100644 --- a/experiments/AAnt-locomotion/vis_traj.py +++ b/experiments/AAnt-locomotion/vis_traj.py @@ -75,7 +75,7 @@ jit_inference_fn = jax.jit(inference_fn) x_vel = 0.0 # m/s -y_vel = 1.0 # m/s +y_vel = -1.5 # m/s yaw_vel = 0.0 # rad/s jcmd = jp.array([x_vel, y_vel, yaw_vel])