-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
saving progress while I migrate systems on computer with 3070
- Loading branch information
Showing
12 changed files
with
565 additions
and
48 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from . import A1 | ||
from . import aant |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .aant import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
from brax import base | ||
from brax import math | ||
from brax.envs.base import PipelineEnv, State | ||
from brax.io import mjcf | ||
from etils import epath | ||
import jax | ||
from jax import numpy as jp | ||
|
||
from alfredo.tools import compose_scene | ||
|
||
class AAnt(PipelineEnv): | ||
""" """ | ||
|
||
def __init__(self, | ||
ctrl_cost_weight=0.5, | ||
use_contact_forces=False, | ||
contact_cost_weight=5e-4, | ||
healthy_reward=1.0, | ||
terminate_when_unhealthy=True, | ||
healthy_z_range=(0.2, 1.0), | ||
contact_force_range=(-1.0, 1.0), | ||
reset_noise_scale=0.1, | ||
exclude_current_positions_from_observation=True, | ||
backend='generalized', | ||
**kwargs,): | ||
|
||
# forcing this model to need an input scene_xml_path or | ||
# the combination of env_xml_path and agent_xml_path | ||
# if none of these options are present, an error will be thrown | ||
path="" | ||
|
||
if "env_xml_path" and "agent_xml_path" in kwargs: | ||
env_xp = kwargs["env_xml_path"] | ||
agent_xp = kwargs["agent_xml_path"] | ||
xml_scene = compose_scene(env_xp, agent_xp) | ||
del kwargs["env_xml_path"] | ||
del kwargs["agent_xml_path"] | ||
|
||
sys = mjcf.loads(xml_scene) | ||
|
||
|
||
n_frames = 5 | ||
|
||
if backend in ['spring', 'positional']: | ||
sys = sys.replace(dt=0.005) | ||
n_frames = 10 | ||
|
||
if backend == 'positional': | ||
# TODO: does the same actuator strength work as in spring | ||
sys = sys.replace( | ||
actuator=sys.actuator.replace( | ||
gear=200 * jp.ones_like(sys.actuator.gear) | ||
) | ||
) | ||
|
||
kwargs['n_frames'] = kwargs.get('n_frames', n_frames) | ||
|
||
super().__init__(sys=sys, backend=backend, **kwargs) | ||
|
||
print(sys) | ||
|
||
self._ctrl_cost_weight = ctrl_cost_weight | ||
self._use_contact_forces = use_contact_forces | ||
self._contact_cost_weight = contact_cost_weight | ||
self._healthy_reward = healthy_reward | ||
self._terminate_when_unhealthy = terminate_when_unhealthy | ||
self._healthy_z_range = healthy_z_range | ||
self._contact_force_range = contact_force_range | ||
self._reset_noise_scale = reset_noise_scale | ||
self._exclude_current_positions_from_observation = ( | ||
exclude_current_positions_from_observation | ||
) | ||
|
||
if self._use_contact_forces: | ||
raise NotImplementedError('use_contact_forces not implemented.') | ||
|
||
|
||
def reset(self, rng: jax.Array) -> State: | ||
rng, rng1, rng2 = jax.random.split(rng, 3) | ||
|
||
low, hi = -self._reset_noise_scale, self._reset_noise_scale | ||
|
||
q = self.sys.init_q + jax.random.uniform( | ||
rng1, (self.sys.q_size(),), minval=low, maxval=hi | ||
) | ||
|
||
qd = hi * jax.random.normal(rng2, (self.sys.qd_size(),)) | ||
|
||
pipeline_state = self.pipeline_init(q, qd) | ||
obs = self._get_obs(pipeline_state) | ||
|
||
reward, done, zero = jp.zeros(3) | ||
metrics = { | ||
'reward_forward': zero, | ||
'reward_survive': zero, | ||
'reward_ctrl': zero, | ||
'reward_contact': zero, | ||
'x_position': zero, | ||
'y_position': zero, | ||
'distance_from_origin': zero, | ||
'x_velocity': zero, | ||
'y_velocity': zero, | ||
} | ||
|
||
return State(pipeline_state, obs, reward, done, metrics) | ||
|
||
def step(self, state: State, action: jax.Array) -> State: | ||
"""Run one timestep of the environment's dynamics.""" | ||
pipeline_state0 = state.pipeline_state | ||
pipeline_state = self.pipeline_step(pipeline_state0, action) | ||
|
||
velocity = (pipeline_state.x.pos[0] - pipeline_state0.x.pos[0]) / self.dt | ||
forward_reward = velocity[0] | ||
|
||
min_z, max_z = self._healthy_z_range | ||
is_healthy = jp.where(pipeline_state.x.pos[0, 2] < min_z, 0.0, 1.0) | ||
is_healthy = jp.where(pipeline_state.x.pos[0, 2] > max_z, 0.0, is_healthy) | ||
|
||
if self._terminate_when_unhealthy: | ||
healthy_reward = self._healthy_reward | ||
else: | ||
healthy_reward = self._healthy_reward * is_healthy | ||
|
||
ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action)) | ||
|
||
contact_cost = 0.0 | ||
|
||
obs = self._get_obs(pipeline_state) | ||
reward = forward_reward + healthy_reward - ctrl_cost - contact_cost | ||
done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0 | ||
state.metrics.update( | ||
reward_forward=forward_reward, | ||
reward_survive=healthy_reward, | ||
reward_ctrl=-ctrl_cost, | ||
reward_contact=-contact_cost, | ||
x_position=pipeline_state.x.pos[0, 0], | ||
y_position=pipeline_state.x.pos[0, 1], | ||
distance_from_origin=math.safe_norm(pipeline_state.x.pos[0]), | ||
x_velocity=velocity[0], | ||
y_velocity=velocity[1], | ||
) | ||
|
||
return state.replace( | ||
pipeline_state=pipeline_state, obs=obs, reward=reward, done=done | ||
) | ||
|
||
def _get_obs(self, pipeline_state: base.State) -> jax.Array: | ||
"""Observe ant body position and velocities.""" | ||
qpos = pipeline_state.q | ||
qvel = pipeline_state.qd | ||
|
||
if self._exclude_current_positions_from_observation: | ||
qpos = pipeline_state.q[2:] | ||
|
||
return jp.concatenate([qpos] + [qvel]) |
Oops, something went wrong.