Skip to content

Commit

Permalink
saving progress while I migrate systems on computer with 3070
Browse files Browse the repository at this point in the history
  • Loading branch information
mginoya committed Feb 1, 2024
1 parent 95e6467 commit ff8c00e
Show file tree
Hide file tree
Showing 12 changed files with 565 additions and 48 deletions.
108 changes: 73 additions & 35 deletions alfredo/agents/A1/alfredo_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from alfredo.rewards import rSpeed_X
from alfredo.rewards import rControl_act_ss
from alfredo.rewards import rTorques
from alfredo.rewards import rTracking_lin_vel
from alfredo.rewards import rTracking_yaw_vel

class Alfredo(PipelineEnv):
# pyformat: disable
Expand Down Expand Up @@ -88,6 +90,9 @@ def __init__(
kwargs["n_frames"] = kwargs.get("n_frames", n_frames)

super().__init__(sys=sys, backend=backend, **kwargs)

#torso_idx = self.sys.link_names.index('alfredo')
#print(self.sys.link_names)

self._forward_reward_weight = forward_reward_weight
self._ctrl_cost_weight = ctrl_cost_weight
Expand All @@ -103,9 +108,11 @@ def __init__(

def reset(self, rng: jp.ndarray) -> State:
"""Resets the environment to an initial state."""
rng, rng1, rng2 = jax.random.split(rng, 3)

rng, rng1, rng2, rng3 = jax.random.split(rng, 4)
low, hi = -self._reset_noise_scale, self._reset_noise_scale

jcmd = self._sample_command(rng3)

qpos = self.sys.init_q + jax.random.uniform(
rng1, (self.sys.q_size(),), minval=low, maxval=hi
Expand All @@ -115,20 +122,28 @@ def reset(self, rng: jp.ndarray) -> State:
pipeline_state = self.pipeline_init(qpos, qvel)

obs = self._get_obs(pipeline_state, jp.zeros(self.sys.act_size()))
com, *_ = self._com(pipeline_state)

state_info = {
'jcmd':jcmd,
'CoM': com,
}

reward, done, zero = jp.zeros(3)
metrics = {
"reward_ctrl": zero,
"reward_alive": zero,
"reward_velocity": zero,
#"reward_velocity": zero,
"reward_torque":zero,
"agent_x_position": zero,
"agent_y_position": zero,
"agent_x_velocity": zero,
"agent_y_velocity": zero,
#"agent_x_velocity": zero,
#"agent_y_velocity": zero,
"lin_vel_reward":zero,
#"yaw_vel_reward":zero,
}

return State(pipeline_state, obs, reward, done, metrics)
return State(pipeline_state, obs, reward, done, metrics, state_info)

def step(self, state: State, action: jp.ndarray) -> State:
"""Runs one timestep of the environment's dynamics."""
Expand All @@ -138,6 +153,21 @@ def step(self, state: State, action: jp.ndarray) -> State:

com_before, *_ = self._com(prev_pipeline_state)
com_after, *_ = self._com(pipeline_state)

lin_vel_reward = rTracking_lin_vel(self.sys,
state.pipeline_state,
com_before,
com_after,
self.dt,
state.info['jcmd'],
weight=10.0,
focus_idx_range=(0,0))

yaw_vel_reward = rTracking_yaw_vel(self.sys,
state.pipeline_state,
state.info['jcmd'],
weight=0.8,
focus_idx_range=(0,0))

x_speed_reward = rSpeed_X(self.sys,
state.pipeline_state,
Expand All @@ -160,22 +190,31 @@ def step(self, state: State, action: jp.ndarray) -> State:
state.pipeline_state,
self._healthy_z_range,
early_terminate=self._terminate_when_unhealthy,
weight=self._healthy_reward,
weight=0.2,
focus_idx_range=(0, 2))

reward = healthy_reward[0] + ctrl_cost + x_speed_reward[0] + torque_cost
reward = 0.0
reward = healthy_reward[0]
reward += ctrl_cost
#reward += x_speed_reward[0]
reward += torque_cost
reward += lin_vel_reward
#reward += yaw_vel_reward

state.info['CoM'] = com_after

done = 1.0 - healthy_reward[1] if self._terminate_when_unhealthy else 0.0

state.metrics.update(
reward_ctrl=ctrl_cost,
reward_alive=healthy_reward[0],
reward_velocity=x_speed_reward[0],
#reward_velocity=x_speed_reward[0],
reward_torque=torque_cost,
agent_x_position=com_after[0],
agent_y_position=com_after[1],
agent_x_velocity=x_speed_reward[1],
agent_y_velocity=x_speed_reward[2],
#agent_x_velocity=x_speed_reward[1],
#agent_y_velocity=x_speed_reward[2],
lin_vel_reward=lin_vel_reward,
#yaw_vel_reward=yaw_vel_reward,
)

return state.replace(
Expand All @@ -187,27 +226,6 @@ def _get_obs(self, pipeline_state: base.State, action: jp.ndarray) -> jp.ndarray

a_positions = pipeline_state.q
a_velocities = pipeline_state.qd
#print(f"a_positions = {a_positions}")
#print(f"a_velocities = {a_velocities}")

if self._exclude_current_positions_from_observation:
a_positions = a_positions[2:]

com, inertia, mass_sum, x_i = self._com(pipeline_state)
cinr = x_i.replace(pos=x_i.pos - com).vmap().do(inertia)
com_inertia = jp.hstack(
[cinr.i.reshape((cinr.i.shape[0], -1)), inertia.mass[:, None]]
)

xd_i = (
base.Transform.create(pos=x_i.pos - pipeline_state.x.pos)
.vmap()
.do(pipeline_state.xd)
)

com_vel = inertia.mass[:, None] * xd_i.vel / mass_sum
com_ang = xd_i.ang
com_velocity = jp.hstack([com_vel, com_ang])

qfrc_actuator = actuator.to_tau(
self.sys, action, pipeline_state.q, pipeline_state.qd
Expand All @@ -218,8 +236,6 @@ def _get_obs(self, pipeline_state: base.State, action: jp.ndarray) -> jp.ndarray
[
a_positions,
a_velocities,
com_inertia.ravel(),
com_velocity.ravel(),
qfrc_actuator,
]
)
Expand Down Expand Up @@ -249,4 +265,26 @@ def _com(self, pipeline_state: base.State) -> jp.ndarray:
mass_sum,
x_i,
) # pytype: disable=bad-return-type # jax-ndarray

def _sample_command(self, rng: jax.Array) -> jax.Array:
lin_vel_x_range = [-0.6, 1.5] #[m/s]
lin_vel_y_range = [-0.6, 1.5] #[m/s]
yaw_vel_range = [-0.7, 0.7] #[rad/s]

_, key1, key2, key3 = jax.random.split(rng, 4)

lin_vel_x = jax.random.uniform(
key1, (1,), minval=lin_vel_x_range[0], maxval=lin_vel_x_range[1]
)

lin_vel_y = jax.random.uniform(
key2, (1,), minval=lin_vel_y_range[0], maxval=lin_vel_y_range[1]
)

yaw_vel = jax.random.uniform(
key3, (1,), minval=yaw_vel_range[0], maxval=yaw_vel_range[1]
)

jcmd = jp.array([lin_vel_x[0], lin_vel_y[0], yaw_vel[0]])

return jcmd
1 change: 1 addition & 0 deletions alfredo/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from . import A1
from . import aant
1 change: 1 addition & 0 deletions alfredo/agents/aant/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .aant import *
155 changes: 155 additions & 0 deletions alfredo/agents/aant/aant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from brax import base
from brax import math
from brax.envs.base import PipelineEnv, State
from brax.io import mjcf
from etils import epath
import jax
from jax import numpy as jp

from alfredo.tools import compose_scene

class AAnt(PipelineEnv):
""" """

def __init__(self,
ctrl_cost_weight=0.5,
use_contact_forces=False,
contact_cost_weight=5e-4,
healthy_reward=1.0,
terminate_when_unhealthy=True,
healthy_z_range=(0.2, 1.0),
contact_force_range=(-1.0, 1.0),
reset_noise_scale=0.1,
exclude_current_positions_from_observation=True,
backend='generalized',
**kwargs,):

# forcing this model to need an input scene_xml_path or
# the combination of env_xml_path and agent_xml_path
# if none of these options are present, an error will be thrown
path=""

if "env_xml_path" and "agent_xml_path" in kwargs:
env_xp = kwargs["env_xml_path"]
agent_xp = kwargs["agent_xml_path"]
xml_scene = compose_scene(env_xp, agent_xp)
del kwargs["env_xml_path"]
del kwargs["agent_xml_path"]

sys = mjcf.loads(xml_scene)


n_frames = 5

if backend in ['spring', 'positional']:
sys = sys.replace(dt=0.005)
n_frames = 10

if backend == 'positional':
# TODO: does the same actuator strength work as in spring
sys = sys.replace(
actuator=sys.actuator.replace(
gear=200 * jp.ones_like(sys.actuator.gear)
)
)

kwargs['n_frames'] = kwargs.get('n_frames', n_frames)

super().__init__(sys=sys, backend=backend, **kwargs)

print(sys)

self._ctrl_cost_weight = ctrl_cost_weight
self._use_contact_forces = use_contact_forces
self._contact_cost_weight = contact_cost_weight
self._healthy_reward = healthy_reward
self._terminate_when_unhealthy = terminate_when_unhealthy
self._healthy_z_range = healthy_z_range
self._contact_force_range = contact_force_range
self._reset_noise_scale = reset_noise_scale
self._exclude_current_positions_from_observation = (
exclude_current_positions_from_observation
)

if self._use_contact_forces:
raise NotImplementedError('use_contact_forces not implemented.')


def reset(self, rng: jax.Array) -> State:
rng, rng1, rng2 = jax.random.split(rng, 3)

low, hi = -self._reset_noise_scale, self._reset_noise_scale

q = self.sys.init_q + jax.random.uniform(
rng1, (self.sys.q_size(),), minval=low, maxval=hi
)

qd = hi * jax.random.normal(rng2, (self.sys.qd_size(),))

pipeline_state = self.pipeline_init(q, qd)
obs = self._get_obs(pipeline_state)

reward, done, zero = jp.zeros(3)
metrics = {
'reward_forward': zero,
'reward_survive': zero,
'reward_ctrl': zero,
'reward_contact': zero,
'x_position': zero,
'y_position': zero,
'distance_from_origin': zero,
'x_velocity': zero,
'y_velocity': zero,
}

return State(pipeline_state, obs, reward, done, metrics)

def step(self, state: State, action: jax.Array) -> State:
"""Run one timestep of the environment's dynamics."""
pipeline_state0 = state.pipeline_state
pipeline_state = self.pipeline_step(pipeline_state0, action)

velocity = (pipeline_state.x.pos[0] - pipeline_state0.x.pos[0]) / self.dt
forward_reward = velocity[0]

min_z, max_z = self._healthy_z_range
is_healthy = jp.where(pipeline_state.x.pos[0, 2] < min_z, 0.0, 1.0)
is_healthy = jp.where(pipeline_state.x.pos[0, 2] > max_z, 0.0, is_healthy)

if self._terminate_when_unhealthy:
healthy_reward = self._healthy_reward
else:
healthy_reward = self._healthy_reward * is_healthy

ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))

contact_cost = 0.0

obs = self._get_obs(pipeline_state)
reward = forward_reward + healthy_reward - ctrl_cost - contact_cost
done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0
state.metrics.update(
reward_forward=forward_reward,
reward_survive=healthy_reward,
reward_ctrl=-ctrl_cost,
reward_contact=-contact_cost,
x_position=pipeline_state.x.pos[0, 0],
y_position=pipeline_state.x.pos[0, 1],
distance_from_origin=math.safe_norm(pipeline_state.x.pos[0]),
x_velocity=velocity[0],
y_velocity=velocity[1],
)

return state.replace(
pipeline_state=pipeline_state, obs=obs, reward=reward, done=done
)

def _get_obs(self, pipeline_state: base.State) -> jax.Array:
"""Observe ant body position and velocities."""
qpos = pipeline_state.q
qvel = pipeline_state.qd

if self._exclude_current_positions_from_observation:
qpos = pipeline_state.q[2:]

return jp.concatenate([qpos] + [qvel])
Loading

0 comments on commit ff8c00e

Please sign in to comment.