Skip to content

Commit

Permalink
making some more progress on joystick control - some weird kinks need…
Browse files Browse the repository at this point in the history
… to be worked out
  • Loading branch information
mginoya committed Feb 11, 2024
1 parent c501fa8 commit 2fb5b5a
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 10 deletions.
26 changes: 21 additions & 5 deletions alfredo/agents/aant/aant.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from alfredo.rewards import rTorques
from alfredo.rewards import rTracking_lin_vel
from alfredo.rewards import rTracking_yaw_vel
from alfredo.rewards import rUpright

class AAnt(PipelineEnv):
""" """
Expand Down Expand Up @@ -105,9 +106,11 @@ def reset(self, rng: jax.Array) -> State:
'reward_ctrl': zero,
'reward_alive': zero,
'reward_torque': zero,
'reward_lin_vel': zero,
'reward_lin_vel': zero,
'reward_yaw_vel': zero,
'reward_upright': zero,
}

return State(pipeline_state, obs, reward, done, metrics, state_info)

def step(self, state: State, action: jax.Array) -> State:
Expand All @@ -122,9 +125,14 @@ def step(self, state: State, action: jax.Array) -> State:
jp.array([0, 0, 0]), #dummy values for current CoM
self.dt,
state.info['jcmd'],
weight=10.0,
weight=15.5,
focus_idx_range=(0,0))

yaw_vel_reward = rTracking_yaw_vel(self.sys,
state.pipeline_state,
state.info['jcmd'],
weight=10.8,
focus_idx_range=(0,0))

ctrl_cost = rControl_act_ss(self.sys,
state.pipeline_state,
Expand All @@ -135,6 +143,10 @@ def step(self, state: State, action: jax.Array) -> State:
state.pipeline_state,
action,
weight=-0.0003)

upright_reward = rUpright(self.sys,
state.pipeline_state,
weight=1.0)

healthy_reward = rHealthy_simple_z(self.sys,
state.pipeline_state,
Expand All @@ -146,7 +158,9 @@ def step(self, state: State, action: jax.Array) -> State:
reward = healthy_reward[0]
reward += ctrl_cost
reward += torque_cost
reward += upright_reward
reward += lin_vel_reward
reward += yaw_vel_reward

obs = self._get_obs(pipeline_state, state.info)
done = 1.0 - healthy_reward[1] if self._terminate_when_unhealthy else 0.0
Expand All @@ -155,7 +169,9 @@ def step(self, state: State, action: jax.Array) -> State:
reward_ctrl = ctrl_cost,
reward_alive = healthy_reward[0],
reward_torque = torque_cost,
reward_lin_vel = lin_vel_reward
reward_upright = upright_reward,
reward_lin_vel = lin_vel_reward,
reward_yaw_vel = yaw_vel_reward
)

return state.replace(
Expand All @@ -176,7 +192,7 @@ def _get_obs(self, pipeline_state, state_info) -> jax.Array:
def _sample_command(self, rng: jax.Array) -> jax.Array:
lin_vel_x_range = [-0.6, 1.5] #[m/s]
lin_vel_y_range = [-0.6, 1.5] #[m/s]
yaw_vel_range = [-0.7, 0.7] #[rad/s]
yaw_vel_range = [-0.0, 0.0] #[rad/s]

_, key1, key2, key3 = jax.random.split(rng, 4)

Expand Down
2 changes: 1 addition & 1 deletion alfredo/agents/aant/aant.xml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<agent>

<custom>
<numeric data="0.0 0.0 0.55 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -1.0 0.0 -1.0 0.0 1.0" name="init_qpos"/>
<numeric data="0.0 0.0 0.75 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -1.0 0.0 -1.0 0.0 1.0" name="init_qpos"/>
<numeric data="1000" name="constraint_limit_stiffness"/>
<numeric data="4000" name="constraint_stiffness"/>
<numeric data="10" name="constraint_ang_damping"/>
Expand Down
1 change: 1 addition & 0 deletions alfredo/rewards/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .rHealthy import *
from .rControl import *
from .rEnergy import *
from .rOrientation import *
18 changes: 18 additions & 0 deletions alfredo/rewards/rOrientation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import Tuple

import jax
from brax import actuator, base, math
from brax.envs import PipelineEnv, State
from brax.io import mjcf
from etils import epath
from jax import numpy as jp

def rUpright(sys: base.System,
pipeline_state: base.State,
weight = 1.0,
focus_idx_range = (0,0)) -> jp.ndarray:

up = jp.array([0.0, 0.0, 1.0])
rot_up = math.rotate(up, pipeline_state.x.rot[0])

return weight*jp.dot(up, rot_up)
59 changes: 59 additions & 0 deletions experiments/AAnt-locomotion/one_physics_step.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import functools
import os
import re
import sys
from datetime import datetime

import brax
import jax
import matplotlib.pyplot as plt
from brax import envs, math
from brax.envs.wrappers import training
from brax.io import html, json, model
from brax.training.acme import running_statistics
from brax.training.agents.ppo import networks as ppo_networks
from jax import numpy as jp

from alfredo.agents.aant import AAnt

backend = "positional"

# Load desired model xml and trained param set
# get filepaths from commandline args
cwd = os.getcwd()

# get the filepath to the env and agent xmls
import alfredo.scenes as scenes

import alfredo.agents as agents
agents_fp = os.path.dirname(agents.__file__)
agent_xml_path = f"{agents_fp}/aant/aant.xml"

scenes_fp = os.path.dirname(scenes.__file__)

env_xml_paths = [f"{scenes_fp}/flatworld/flatworld_A1_env.xml"]

# create an env and initial state
env = AAnt(backend=backend,
env_xml_path=env_xml_paths[0],
agent_xml_path=agent_xml_path)

state = env.reset(rng=jax.random.PRNGKey(seed=0))
#state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=0))

x_vel = 0.8 # m/s
y_vel = 0.0 # m/s
yaw_vel = 0.0 # rad/s
jcmd = jp.array([x_vel, y_vel, yaw_vel])
state.info['jcmd'] = jcmd

up = jp.array([0.0, 0.0, 1])
# rot_up = math.rotate(up, jp.array([1, 0, 0, 0]))
rot_up = math.rotate(up, state.pipeline_state.x.rot[6])
rew = jp.dot(up, rot_up)
print(f"x.rot = {state.pipeline_state.x.rot}")
print(f"up: {up}, rot_up: {rot_up}")
print(rew)

print(f"\n-----------------------------------------------------------------\n")
state = env.step(state, jp.zeros(env.action_size))
8 changes: 5 additions & 3 deletions experiments/AAnt-locomotion/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,10 @@ def progress(num_steps, metrics):
"step": num_steps,
"Total Reward": metrics["eval/episode_reward"],
"Lin Vel Reward": metrics["eval/episode_reward_lin_vel"],
"Yaw Vel Reward": metrics["eval/episode_reward_yaw_vel"],
"Alive Reward": metrics["eval/episode_reward_alive"],
"Ctrl Reward": metrics["eval/episode_reward_ctrl"],
"Upright Reward": metrics["eval/episode_reward_upright"],
"Torque Reward": metrics["eval/episode_reward_torque"],
}
)
Expand Down Expand Up @@ -91,7 +93,7 @@ def progress(num_steps, metrics):
# ============================
# Training & Saving Params
# ============================
i = 0
i = 3

for p in env_xml_paths:

Expand All @@ -106,15 +108,15 @@ def progress(num_steps, metrics):

d_and_t = datetime.now()
print(f"[{d_and_t}] jitting start for model: {i}")
state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=0))
state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=1))
d_and_t = datetime.now()
print(f"[{d_and_t}] jitting end for model: {i}")

# define new training function
train_fn = functools.partial(
ppo.train,
num_timesteps=wandb.config.len_training,
num_evals=100,
num_evals=300,
reward_scaling=0.1,
episode_length=1000,
normalize_observations=True,
Expand Down
2 changes: 1 addition & 1 deletion experiments/AAnt-locomotion/vis_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
jit_inference_fn = jax.jit(inference_fn)

x_vel = 0.0 # m/s
y_vel = 1.0 # m/s
y_vel = -1.5 # m/s
yaw_vel = 0.0 # rad/s
jcmd = jp.array([x_vel, y_vel, yaw_vel])

Expand Down

0 comments on commit 2fb5b5a

Please sign in to comment.