Restarted pianist (Python 3.10.12)

In [1]:
import os
os.environ['MUJOCO_GL'] = 'egl'

from mjx_env import PipelineEnv, EnvState
import mjcf
from pathlib import Path
import mujoco
import jax
import sim_math as utils
from jax import numpy as jp
from mujoco import mjx

_HERE = Path(__file__).resolve().parent
_SHADOW_HAND_DIR = _HERE / "third_party" / "shadow_hand"


################## Constants ##################
# Timestep of the physics simulation, in seconds.
_PHYSICS_TIMESTEP = 0.005
# Interval between agent actions, in seconds.
_CONTROL_TIMESTEP = 0.05  # 20 Hz. # TODO: whether to increase control timesteps
# Distance thresholds for the shaping reward.
_FINGER_CLOSE_ENOUGH_TO_KEY = 0.01
_KEY_CLOSE_ENOUGH_TO_PRESSED = 0.05
# Energy penalty coefficient.
_ENERGY_PENALTY_COEF = 5e-3
# Transparency of fingertip geoms.
_FINGERTIP_ALPHA = 1.0
# Bounds for the uniform distribution from which initial hand offset is sampled.
_POSITION_OFFSET = 0.05
###############################################


class RoboPianist(PipelineEnv):

  def __init__(
      self,

      **kwargs,
  ):
    path = _HERE/ 'third_party' / 'shadow_hand' / 'robopianist.xml'

    mj_model = mujoco.MjModel.from_xml_path(path.as_posix())
    mj_model.opt.timestep = _PHYSICS_TIMESTEP

    sys = mjcf.load_model(mj_model)

    physics_steps_per_control_step = utils.compute_n_steps(_CONTROL_TIMESTEP, _PHYSICS_TIMESTEP)
    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step)
    super().__init__(sys, **kwargs)


  def reset(self, rng: jp.ndarray) -> EnvState:
    """Resets the environment to an initial state."""
    # TODO: decide the correct initial state
    rng, rng1, rng2 = jax.random.split(rng, 3)

    low, hi = -0.01, 0.01
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )
    qvel = jax.random.uniform(
        rng2, (self.sys.nv,), minval=low, maxval=hi
    )

    data = self.pipeline_init(qpos, qvel)
    obs = self._get_obs(data, jp.zeros(self.sys.nu))
    return EnvState(data, obs, jp.zeros(1), jp.zeros(1), {})

  def step(self, state: EnvState, action: jp.ndarray) -> EnvState:
    """Runs one timestep of the environment's dynamics."""
    data0 = state.pipeline_state
    data = self.pipeline_step(data0, action)
    obs = self._get_obs(data, action)
    return state.replace(
        pipeline_state=data, obs=obs, reward=jp.zeros(1), done=jp.zeros(1)
    )
  
  def _get_obs(
      self, data: mjx.Data, action: jp.ndarray
  ) -> jp.ndarray:
    """Observes humanoid body position, velocities, and angles."""
    position = data.qpos
    return jp.concatenate([position, action])



# envs.register_environment('humanoid', Humanoid)

In [2]:
# venv
from mjx_env import VmapWrapper, EpisodeWrapper
env = RoboPianist()
env = EpisodeWrapper(env, episode_length=1000, action_repeat=2)
env = VmapWrapper(env)

env_key = jax.random.PRNGKey(42)
num_envs = 2046
local_devices_to_use = jax.local_device_count()
env_keys = jax.random.split(env_key, num_envs // jax.process_count())
#   env_keys = jnp.reshape(env_keys,
#                          (local_devices_to_use, -1) + env_keys.shape[1:])
#   env_state = jax.pmap(env.reset)(env_keys)
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
env_state = jit_reset(env_keys)
rollout = [env_state.pipeline_state]
for i in range(100):
  env_key, action_key = jax.random.split(env_key, 2)

  action = jax.random.uniform(action_key, (*env_keys.shape[:1], env.action_size))
  env_state = jit_step(env_state, action)
  rollout.append(env_state.pipeline_state)
  print(env_state.reward.shape)
  # if jnp.all(env_state.done):
  #   env_state = jax.pmap(env.reset)(env_keys) # TODO: check how to reset only the done envs
# import ipdb; ipdb.set_trace()
video = env.render(rollout, camera='side')

2024-03-14 15:35:06.060342: W external/tsl/tsl/framework/bfc_allocator.cc:296] Allocator (GPU_0_bfc) ran out of memory trying to allocate 5.18GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.


XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 5566756816 bytes.

In [3]:
# venv
from mjx_env import VmapWrapper, EpisodeWrapper
env = RoboPianist()
env = EpisodeWrapper(env, episode_length=1000, action_repeat=2)
env = VmapWrapper(env)

env_key = jax.random.PRNGKey(42)
num_envs = 1024
local_devices_to_use = jax.local_device_count()
env_keys = jax.random.split(env_key, num_envs // jax.process_count())
#   env_keys = jnp.reshape(env_keys,
#                          (local_devices_to_use, -1) + env_keys.shape[1:])
#   env_state = jax.pmap(env.reset)(env_keys)
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
env_state = jit_reset(env_keys)
rollout = [env_state.pipeline_state]
for i in range(100):
  env_key, action_key = jax.random.split(env_key, 2)

  action = jax.random.uniform(action_key, (*env_keys.shape[:1], env.action_size))
  env_state = jit_step(env_state, action)
  rollout.append(env_state.pipeline_state)
  print(env_state.reward.shape)
  # if jnp.all(env_state.done):
  #   env_state = jax.pmap(env.reset)(env_keys) # TODO: check how to reset only the done envs
# import ipdb; ipdb.set_trace()
video = env.render(rollout, camera='side')

2024-03-14 15:36:14.768907: W external/tsl/tsl/framework/bfc_allocator.cc:296] Allocator (GPU_0_bfc) ran out of memory trying to allocate 131.58GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.


XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 141286260736 bytes.

Restarted pianist (Python 3.10.12)

In [1]:
import os
os.environ['MUJOCO_GL'] = 'egl'

from mjx_env import PipelineEnv, EnvState
import mjcf
from pathlib import Path
import mujoco
import jax
import sim_math as utils
from jax import numpy as jp
from mujoco import mjx

_HERE = Path(__file__).resolve().parent
_SHADOW_HAND_DIR = _HERE / "third_party" / "shadow_hand"


################## Constants ##################
# Timestep of the physics simulation, in seconds.
_PHYSICS_TIMESTEP = 0.005
# Interval between agent actions, in seconds.
_CONTROL_TIMESTEP = 0.05  # 20 Hz. # TODO: whether to increase control timesteps
# Distance thresholds for the shaping reward.
_FINGER_CLOSE_ENOUGH_TO_KEY = 0.01
_KEY_CLOSE_ENOUGH_TO_PRESSED = 0.05
# Energy penalty coefficient.
_ENERGY_PENALTY_COEF = 5e-3
# Transparency of fingertip geoms.
_FINGERTIP_ALPHA = 1.0
# Bounds for the uniform distribution from which initial hand offset is sampled.
_POSITION_OFFSET = 0.05
###############################################


class RoboPianist(PipelineEnv):

  def __init__(
      self,

      **kwargs,
  ):
    path = _HERE/ 'third_party' / 'shadow_hand' / 'robopianist.xml'

    mj_model = mujoco.MjModel.from_xml_path(path.as_posix())
    mj_model.opt.timestep = _PHYSICS_TIMESTEP

    sys = mjcf.load_model(mj_model)

    physics_steps_per_control_step = utils.compute_n_steps(_CONTROL_TIMESTEP, _PHYSICS_TIMESTEP)
    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step)
    super().__init__(sys, **kwargs)


  def reset(self, rng: jp.ndarray) -> EnvState:
    """Resets the environment to an initial state."""
    # TODO: decide the correct initial state
    rng, rng1, rng2 = jax.random.split(rng, 3)

    low, hi = -0.01, 0.01
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )
    qvel = jax.random.uniform(
        rng2, (self.sys.nv,), minval=low, maxval=hi
    )

    data = self.pipeline_init(qpos, qvel)
    obs = self._get_obs(data, jp.zeros(self.sys.nu))
    return EnvState(data, obs, jp.zeros(1), jp.zeros(1), {})

  def step(self, state: EnvState, action: jp.ndarray) -> EnvState:
    """Runs one timestep of the environment's dynamics."""
    data0 = state.pipeline_state
    data = self.pipeline_step(data0, action)
    obs = self._get_obs(data, action)
    return state.replace(
        pipeline_state=data, obs=obs, reward=jp.zeros(1), done=jp.zeros(1)
    )
  
  def _get_obs(
      self, data: mjx.Data, action: jp.ndarray
  ) -> jp.ndarray:
    """Observes humanoid body position, velocities, and angles."""
    position = data.qpos
    return jp.concatenate([position, action])



# envs.register_environment('humanoid', Humanoid)

In [2]:
# venv
from mjx_env import VmapWrapper, EpisodeWrapper
env = RoboPianist()
env = EpisodeWrapper(env, episode_length=1000, action_repeat=2)
env = VmapWrapper(env)

env_key = jax.random.PRNGKey(42)
num_envs = 1024
local_devices_to_use = jax.local_device_count()
env_keys = jax.random.split(env_key, num_envs // jax.process_count())
#   env_keys = jnp.reshape(env_keys,
#                          (local_devices_to_use, -1) + env_keys.shape[1:])
#   env_state = jax.pmap(env.reset)(env_keys)
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
env_state = jit_reset(env_keys)
rollout = [env_state.pipeline_state]
for i in range(100):
  env_key, action_key = jax.random.split(env_key, 2)

  action = jax.random.uniform(action_key, (*env_keys.shape[:1], env.action_size))
  env_state = jit_step(env_state, action)
  rollout.append(env_state.pipeline_state)
  print(env_state.reward.shape)
  # if jnp.all(env_state.done):
  #   env_state = jax.pmap(env.reset)(env_keys) # TODO: check how to reset only the done envs
# import ipdb; ipdb.set_trace()
video = env.render(rollout, camera='side')

2024-03-14 15:39:28.411917: W external/tsl/tsl/framework/bfc_allocator.cc:296] Allocator (GPU_0_bfc) ran out of memory trying to allocate 131.58GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.


XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 141286260736 bytes.

Restarted pianist (Python 3.10.12)

In [1]:
# venv
from mjx_env import VmapWrapper, EpisodeWrapper
env = RoboPianist()
env = EpisodeWrapper(env, episode_length=1000, action_repeat=2)
env = VmapWrapper(env)

env_key = jax.random.PRNGKey(42)
num_envs = 256
local_devices_to_use = jax.local_device_count()
env_keys = jax.random.split(env_key, num_envs // jax.process_count())
#   env_keys = jnp.reshape(env_keys,
#                          (local_devices_to_use, -1) + env_keys.shape[1:])
#   env_state = jax.pmap(env.reset)(env_keys)
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
env_state = jit_reset(env_keys)
rollout = [env_state.pipeline_state]
for i in range(100):
  env_key, action_key = jax.random.split(env_key, 2)

  action = jax.random.uniform(action_key, (*env_keys.shape[:1], env.action_size))
  env_state = jit_step(env_state, action)
  rollout.append(env_state.pipeline_state)
  print(env_state.reward.shape)
  # if jnp.all(env_state.done):
  #   env_state = jax.pmap(env.reset)(env_keys) # TODO: check how to reset only the done envs
# import ipdb; ipdb.set_trace()
video = env.render(rollout, camera='side')

NameError: name 'RoboPianist' is not defined

In [2]:
import os
os.environ['MUJOCO_GL'] = 'egl'

from mjx_env import PipelineEnv, EnvState
import mjcf
from pathlib import Path
import mujoco
import jax
import sim_math as utils
from jax import numpy as jp
from mujoco import mjx

_HERE = Path(__file__).resolve().parent
_SHADOW_HAND_DIR = _HERE / "third_party" / "shadow_hand"


################## Constants ##################
# Timestep of the physics simulation, in seconds.
_PHYSICS_TIMESTEP = 0.005
# Interval between agent actions, in seconds.
_CONTROL_TIMESTEP = 0.05  # 20 Hz. # TODO: whether to increase control timesteps
# Distance thresholds for the shaping reward.
_FINGER_CLOSE_ENOUGH_TO_KEY = 0.01
_KEY_CLOSE_ENOUGH_TO_PRESSED = 0.05
# Energy penalty coefficient.
_ENERGY_PENALTY_COEF = 5e-3
# Transparency of fingertip geoms.
_FINGERTIP_ALPHA = 1.0
# Bounds for the uniform distribution from which initial hand offset is sampled.
_POSITION_OFFSET = 0.05
###############################################


class RoboPianist(PipelineEnv):

  def __init__(
      self,

      **kwargs,
  ):
    path = _HERE/ 'third_party' / 'shadow_hand' / 'robopianist.xml'

    mj_model = mujoco.MjModel.from_xml_path(path.as_posix())
    mj_model.opt.timestep = _PHYSICS_TIMESTEP

    sys = mjcf.load_model(mj_model)

    physics_steps_per_control_step = utils.compute_n_steps(_CONTROL_TIMESTEP, _PHYSICS_TIMESTEP)
    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step)
    super().__init__(sys, **kwargs)


  def reset(self, rng: jp.ndarray) -> EnvState:
    """Resets the environment to an initial state."""
    # TODO: decide the correct initial state
    rng, rng1, rng2 = jax.random.split(rng, 3)

    low, hi = -0.01, 0.01
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )
    qvel = jax.random.uniform(
        rng2, (self.sys.nv,), minval=low, maxval=hi
    )

    data = self.pipeline_init(qpos, qvel)
    obs = self._get_obs(data, jp.zeros(self.sys.nu))
    return EnvState(data, obs, jp.zeros(1), jp.zeros(1), {})

  def step(self, state: EnvState, action: jp.ndarray) -> EnvState:
    """Runs one timestep of the environment's dynamics."""
    data0 = state.pipeline_state
    data = self.pipeline_step(data0, action)
    obs = self._get_obs(data, action)
    return state.replace(
        pipeline_state=data, obs=obs, reward=jp.zeros(1), done=jp.zeros(1)
    )
  
  def _get_obs(
      self, data: mjx.Data, action: jp.ndarray
  ) -> jp.ndarray:
    """Observes humanoid body position, velocities, and angles."""
    position = data.qpos
    return jp.concatenate([position, action])



# envs.register_environment('humanoid', Humanoid)

In [3]:
# venv
from mjx_env import VmapWrapper, EpisodeWrapper
env = RoboPianist()
env = EpisodeWrapper(env, episode_length=1000, action_repeat=2)
env = VmapWrapper(env)

env_key = jax.random.PRNGKey(42)
num_envs = 256
local_devices_to_use = jax.local_device_count()
env_keys = jax.random.split(env_key, num_envs // jax.process_count())
#   env_keys = jnp.reshape(env_keys,
#                          (local_devices_to_use, -1) + env_keys.shape[1:])
#   env_state = jax.pmap(env.reset)(env_keys)
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
env_state = jit_reset(env_keys)
rollout = [env_state.pipeline_state]
for i in range(100):
  env_key, action_key = jax.random.split(env_key, 2)

  action = jax.random.uniform(action_key, (*env_keys.shape[:1], env.action_size))
  env_state = jit_step(env_state, action)
  rollout.append(env_state.pipeline_state)
  print(env_state.reward.shape)
  # if jnp.all(env_state.done):
  #   env_state = jax.pmap(env.reset)(env_keys) # TODO: check how to reset only the done envs
# import ipdb; ipdb.set_trace()
video = env.render(rollout, camera='side')

2024-03-14 15:41:31.663373: W external/tsl/tsl/framework/bfc_allocator.cc:296] Allocator (GPU_0_bfc) ran out of memory trying to allocate 32.91GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.


XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 35334148096 bytes.

Restarted pianist (Python 3.10.12)

In [1]:
import os
os.environ['MUJOCO_GL'] = 'egl'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '4.0'

from mjx_env import PipelineEnv, EnvState
import mjcf
from pathlib import Path
import mujoco
import jax
import sim_math as utils
from jax import numpy as jp
from mujoco import mjx

_HERE = Path(__file__).resolve().parent
_SHADOW_HAND_DIR = _HERE / "third_party" / "shadow_hand"


################## Constants ##################
# Timestep of the physics simulation, in seconds.
_PHYSICS_TIMESTEP = 0.005
# Interval between agent actions, in seconds.
_CONTROL_TIMESTEP = 0.05  # 20 Hz. # TODO: whether to increase control timesteps
# Distance thresholds for the shaping reward.
_FINGER_CLOSE_ENOUGH_TO_KEY = 0.01
_KEY_CLOSE_ENOUGH_TO_PRESSED = 0.05
# Energy penalty coefficient.
_ENERGY_PENALTY_COEF = 5e-3
# Transparency of fingertip geoms.
_FINGERTIP_ALPHA = 1.0
# Bounds for the uniform distribution from which initial hand offset is sampled.
_POSITION_OFFSET = 0.05
###############################################


class RoboPianist(PipelineEnv):

  def __init__(
      self,

      **kwargs,
  ):
    path = _HERE/ 'third_party' / 'shadow_hand' / 'robopianist.xml'

    mj_model = mujoco.MjModel.from_xml_path(path.as_posix())
    mj_model.opt.timestep = _PHYSICS_TIMESTEP

    sys = mjcf.load_model(mj_model)

    physics_steps_per_control_step = utils.compute_n_steps(_CONTROL_TIMESTEP, _PHYSICS_TIMESTEP)
    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step)
    super().__init__(sys, **kwargs)


  def reset(self, rng: jp.ndarray) -> EnvState:
    """Resets the environment to an initial state."""
    # TODO: decide the correct initial state
    rng, rng1, rng2 = jax.random.split(rng, 3)

    low, hi = -0.01, 0.01
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )
    qvel = jax.random.uniform(
        rng2, (self.sys.nv,), minval=low, maxval=hi
    )

    data = self.pipeline_init(qpos, qvel)
    obs = self._get_obs(data, jp.zeros(self.sys.nu))
    return EnvState(data, obs, jp.zeros(1), jp.zeros(1), {})

  def step(self, state: EnvState, action: jp.ndarray) -> EnvState:
    """Runs one timestep of the environment's dynamics."""
    data0 = state.pipeline_state
    data = self.pipeline_step(data0, action)
    obs = self._get_obs(data, action)
    return state.replace(
        pipeline_state=data, obs=obs, reward=jp.zeros(1), done=jp.zeros(1)
    )
  
  def _get_obs(
      self, data: mjx.Data, action: jp.ndarray
  ) -> jp.ndarray:
    """Observes humanoid body position, velocities, and angles."""
    position = data.qpos
    return jp.concatenate([position, action])



# envs.register_environment('humanoid', Humanoid)

In [2]:
# venv
from mjx_env import VmapWrapper, EpisodeWrapper
env = RoboPianist()
env = EpisodeWrapper(env, episode_length=1000, action_repeat=2)
env = VmapWrapper(env)

env_key = jax.random.PRNGKey(42)
num_envs = 256
local_devices_to_use = jax.local_device_count()
env_keys = jax.random.split(env_key, num_envs // jax.process_count())
#   env_keys = jnp.reshape(env_keys,
#                          (local_devices_to_use, -1) + env_keys.shape[1:])
#   env_state = jax.pmap(env.reset)(env_keys)
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
env_state = jit_reset(env_keys)
rollout = [env_state.pipeline_state]
for i in range(100):
  env_key, action_key = jax.random.split(env_key, 2)

  action = jax.random.uniform(action_key, (*env_keys.shape[:1], env.action_size))
  env_state = jit_step(env_state, action)
  rollout.append(env_state.pipeline_state)
  print(env_state.reward.shape)
  # if jnp.all(env_state.done):
  #   env_state = jax.pmap(env.reset)(env_keys) # TODO: check how to reset only the done envs
# import ipdb; ipdb.set_trace()
video = env.render(rollout, camera='side')

2024-03-14 15:46:26.034261: W external/tsl/tsl/framework/bfc_allocator.cc:296] Allocator (GPU_0_bfc) ran out of memory trying to allocate 32.91GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.


XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 35334148096 bytes.

Restarted pianist (Python 3.10.12)

In [1]:
import os
os.environ['MUJOCO_GL'] = 'egl'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".5"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform"

from mjx_env import PipelineEnv, EnvState
import mjcf
from pathlib import Path
import mujoco
import jax
import sim_math as utils
from jax import numpy as jp
from mujoco import mjx

_HERE = Path(__file__).resolve().parent
_SHADOW_HAND_DIR = _HERE / "third_party" / "shadow_hand"


################## Constants ##################
# Timestep of the physics simulation, in seconds.
_PHYSICS_TIMESTEP = 0.005
# Interval between agent actions, in seconds.
_CONTROL_TIMESTEP = 0.05  # 20 Hz. # TODO: whether to increase control timesteps
# Distance thresholds for the shaping reward.
_FINGER_CLOSE_ENOUGH_TO_KEY = 0.01
_KEY_CLOSE_ENOUGH_TO_PRESSED = 0.05
# Energy penalty coefficient.
_ENERGY_PENALTY_COEF = 5e-3
# Transparency of fingertip geoms.
_FINGERTIP_ALPHA = 1.0
# Bounds for the uniform distribution from which initial hand offset is sampled.
_POSITION_OFFSET = 0.05
###############################################


class RoboPianist(PipelineEnv):

  def __init__(
      self,

      **kwargs,
  ):
    path = _HERE/ 'third_party' / 'shadow_hand' / 'robopianist.xml'

    mj_model = mujoco.MjModel.from_xml_path(path.as_posix())
    mj_model.opt.timestep = _PHYSICS_TIMESTEP

    sys = mjcf.load_model(mj_model)

    physics_steps_per_control_step = utils.compute_n_steps(_CONTROL_TIMESTEP, _PHYSICS_TIMESTEP)
    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step)
    super().__init__(sys, **kwargs)


  def reset(self, rng: jp.ndarray) -> EnvState:
    """Resets the environment to an initial state."""
    # TODO: decide the correct initial state
    rng, rng1, rng2 = jax.random.split(rng, 3)

    low, hi = -0.01, 0.01
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )
    qvel = jax.random.uniform(
        rng2, (self.sys.nv,), minval=low, maxval=hi
    )

    data = self.pipeline_init(qpos, qvel)
    obs = self._get_obs(data, jp.zeros(self.sys.nu))
    return EnvState(data, obs, jp.zeros(1), jp.zeros(1), {})

  def step(self, state: EnvState, action: jp.ndarray) -> EnvState:
    """Runs one timestep of the environment's dynamics."""
    data0 = state.pipeline_state
    data = self.pipeline_step(data0, action)
    obs = self._get_obs(data, action)
    return state.replace(
        pipeline_state=data, obs=obs, reward=jp.zeros(1), done=jp.zeros(1)
    )
  
  def _get_obs(
      self, data: mjx.Data, action: jp.ndarray
  ) -> jp.ndarray:
    """Observes humanoid body position, velocities, and angles."""
    position = data.qpos
    return jp.concatenate([position, action])



# envs.register_environment('humanoid', Humanoid)

In [2]:
# venv
from mjx_env import VmapWrapper, EpisodeWrapper
env = RoboPianist()
env = EpisodeWrapper(env, episode_length=1000, action_repeat=2)
env = VmapWrapper(env)

env_key = jax.random.PRNGKey(42)
num_envs = 256
local_devices_to_use = jax.local_device_count()
env_keys = jax.random.split(env_key, num_envs // jax.process_count())
#   env_keys = jnp.reshape(env_keys,
#                          (local_devices_to_use, -1) + env_keys.shape[1:])
#   env_state = jax.pmap(env.reset)(env_keys)
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
env_state = jit_reset(env_keys)
rollout = [env_state.pipeline_state]
for i in range(100):
  env_key, action_key = jax.random.split(env_key, 2)

  action = jax.random.uniform(action_key, (*env_keys.shape[:1], env.action_size))
  env_state = jit_step(env_state, action)
  rollout.append(env_state.pipeline_state)
  print(env_state.reward.shape)
  # if jnp.all(env_state.done):
  #   env_state = jax.pmap(env.reset)(env_keys) # TODO: check how to reset only the done envs
# import ipdb; ipdb.set_trace()
video = env.render(rollout, camera='side')

XlaRuntimeError: RESOURCE_EXHAUSTED: Failed to allocate request for 32.91GiB (35334148096B) on device ordinal 0

In [3]:
# venv
from mjx_env import VmapWrapper, EpisodeWrapper
env = RoboPianist()
env = EpisodeWrapper(env, episode_length=1000, action_repeat=2)
env = VmapWrapper(env)

env_key = jax.random.PRNGKey(42)
num_envs = 1
local_devices_to_use = jax.local_device_count()
env_keys = jax.random.split(env_key, num_envs // jax.process_count())
#   env_keys = jnp.reshape(env_keys,
#                          (local_devices_to_use, -1) + env_keys.shape[1:])
#   env_state = jax.pmap(env.reset)(env_keys)
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
env_state = jit_reset(env_keys)
rollout = [env_state.pipeline_state]
for i in range(100):
  env_key, action_key = jax.random.split(env_key, 2)

  action = jax.random.uniform(action_key, (*env_keys.shape[:1], env.action_size))
  env_state = jit_step(env_state, action)
  rollout.append(env_state.pipeline_state)
  print(env_state.reward.shape)
  # if jnp.all(env_state.done):
  #   env_state = jax.pmap(env.reset)(env_keys) # TODO: check how to reset only the done envs
# import ipdb; ipdb.set_trace()
video = env.render(rollout, camera='side')

(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)
(1, 1)


In [4]:
len(rollout)

101

In [5]:
import mediapy as media
media.show_video(video, fps=60)

0
This browser does not support the video tag.


In [6]:
import os
os.environ['MUJOCO_GL'] = 'egl'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="true"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".5"

from mjx_env import PipelineEnv, EnvState
import mjcf
from pathlib import Path
import mujoco
import jax
import sim_math as utils
from jax import numpy as jp
from mujoco import mjx

_HERE = Path(__file__).resolve().parent
_SHADOW_HAND_DIR = _HERE / "third_party" / "shadow_hand"


################## Constants ##################
# Timestep of the physics simulation, in seconds.
_PHYSICS_TIMESTEP = 0.005
# Interval between agent actions, in seconds.
_CONTROL_TIMESTEP = 0.05  # 20 Hz. # TODO: whether to increase control timesteps
# Distance thresholds for the shaping reward.
_FINGER_CLOSE_ENOUGH_TO_KEY = 0.01
_KEY_CLOSE_ENOUGH_TO_PRESSED = 0.05
# Energy penalty coefficient.
_ENERGY_PENALTY_COEF = 5e-3
# Transparency of fingertip geoms.
_FINGERTIP_ALPHA = 1.0
# Bounds for the uniform distribution from which initial hand offset is sampled.
_POSITION_OFFSET = 0.05
###############################################


class RoboPianist(PipelineEnv):

  def __init__(
      self,

      **kwargs,
  ):
    path = _HERE/ 'third_party' / 'shadow_hand' / 'robopianist.xml'

    mj_model = mujoco.MjModel.from_xml_path(path.as_posix())
    mj_model.opt.timestep = _PHYSICS_TIMESTEP

    sys = mjcf.load_model(mj_model)

    physics_steps_per_control_step = utils.compute_n_steps(_CONTROL_TIMESTEP, _PHYSICS_TIMESTEP)
    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step)
    super().__init__(sys, **kwargs)


  def reset(self, rng: jp.ndarray) -> EnvState:
    """Resets the environment to an initial state."""
    # TODO: decide the correct initial state
    rng, rng1, rng2 = jax.random.split(rng, 3)

    low, hi = -0.01, 0.01
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )
    qvel = jax.random.uniform(
        rng2, (self.sys.nv,), minval=low, maxval=hi
    )

    data = self.pipeline_init(qpos, qvel)
    obs = self._get_obs(data, jp.zeros(self.sys.nu))
    return EnvState(data, obs, jp.zeros(1), jp.zeros(1), {})

  def step(self, state: EnvState, action: jp.ndarray) -> EnvState:
    """Runs one timestep of the environment's dynamics."""
    data0 = state.pipeline_state
    data = self.pipeline_step(data0, action)
    obs = self._get_obs(data, action)
    return state.replace(
        pipeline_state=data, obs=obs, reward=jp.zeros(1), done=jp.zeros(1)
    )
  
  def _get_obs(
      self, data: mjx.Data, action: jp.ndarray
  ) -> jp.ndarray:
    """Observes humanoid body position, velocities, and angles."""
    position = data.qpos
    return jp.concatenate([position, action])



# envs.register_environment('humanoid', Humanoid)

Restarted pianist (Python 3.10.12)

In [1]:
import os
os.environ['MUJOCO_GL'] = 'egl'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="true"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".5"

from mjx_env import PipelineEnv, EnvState
import mjcf
from pathlib import Path
import mujoco
import jax
import sim_math as utils
from jax import numpy as jp
from mujoco import mjx

_HERE = Path(__file__).resolve().parent
_SHADOW_HAND_DIR = _HERE / "third_party" / "shadow_hand"


################## Constants ##################
# Timestep of the physics simulation, in seconds.
_PHYSICS_TIMESTEP = 0.005
# Interval between agent actions, in seconds.
_CONTROL_TIMESTEP = 0.05  # 20 Hz. # TODO: whether to increase control timesteps
# Distance thresholds for the shaping reward.
_FINGER_CLOSE_ENOUGH_TO_KEY = 0.01
_KEY_CLOSE_ENOUGH_TO_PRESSED = 0.05
# Energy penalty coefficient.
_ENERGY_PENALTY_COEF = 5e-3
# Transparency of fingertip geoms.
_FINGERTIP_ALPHA = 1.0
# Bounds for the uniform distribution from which initial hand offset is sampled.
_POSITION_OFFSET = 0.05
###############################################


class RoboPianist(PipelineEnv):

  def __init__(
      self,

      **kwargs,
  ):
    path = _HERE/ 'third_party' / 'shadow_hand' / 'robopianist.xml'

    mj_model = mujoco.MjModel.from_xml_path(path.as_posix())
    mj_model.opt.timestep = _PHYSICS_TIMESTEP

    sys = mjcf.load_model(mj_model)

    physics_steps_per_control_step = utils.compute_n_steps(_CONTROL_TIMESTEP, _PHYSICS_TIMESTEP)
    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step)
    super().__init__(sys, **kwargs)


  def reset(self, rng: jp.ndarray) -> EnvState:
    """Resets the environment to an initial state."""
    # TODO: decide the correct initial state
    rng, rng1, rng2 = jax.random.split(rng, 3)

    low, hi = -0.01, 0.01
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )
    qvel = jax.random.uniform(
        rng2, (self.sys.nv,), minval=low, maxval=hi
    )

    data = self.pipeline_init(qpos, qvel)
    obs = self._get_obs(data, jp.zeros(self.sys.nu))
    return EnvState(data, obs, jp.zeros(1), jp.zeros(1), {})

  def step(self, state: EnvState, action: jp.ndarray) -> EnvState:
    """Runs one timestep of the environment's dynamics."""
    data0 = state.pipeline_state
    data = self.pipeline_step(data0, action)
    obs = self._get_obs(data, action)
    return state.replace(
        pipeline_state=data, obs=obs, reward=jp.zeros(1), done=jp.zeros(1)
    )
  
  def _get_obs(
      self, data: mjx.Data, action: jp.ndarray
  ) -> jp.ndarray:
    """Observes humanoid body position, velocities, and angles."""
    position = data.qpos
    return jp.concatenate([position, action])



# envs.register_environment('humanoid', Humanoid)

In [2]:
# venv
from mjx_env import VmapWrapper, EpisodeWrapper
env = RoboPianist()
env = EpisodeWrapper(env, episode_length=1000, action_repeat=2)
env = VmapWrapper(env)

env_key = jax.random.PRNGKey(42)
num_envs = 10
local_devices_to_use = jax.local_device_count()
env_keys = jax.random.split(env_key, num_envs // jax.process_count())
#   env_keys = jnp.reshape(env_keys,
#                          (local_devices_to_use, -1) + env_keys.shape[1:])
#   env_state = jax.pmap(env.reset)(env_keys)
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
env_state = jit_reset(env_keys)
rollout = [env_state.pipeline_state]
for i in range(100):
  env_key, action_key = jax.random.split(env_key, 2)

  action = jax.random.uniform(action_key, (*env_keys.shape[:1], env.action_size))
  env_state = jit_step(env_state, action)
  rollout.append(env_state.pipeline_state)
  print(env_state.reward.shape)
  # if jnp.all(env_state.done):
  #   env_state = jax.pmap(env.reset)(env_keys) # TODO: check how to reset only the done envs
# import ipdb; ipdb.set_trace()
video = env.render(rollout, camera='side')

(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)
(10, 1)


ValueError: could not broadcast input array from shape (10,134) into shape (134,)

In [3]:
len(video)

NameError: name 'video' is not defined

In [4]:
env_state.pipeline_state

MjxState(solver_niter=Array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32), time=Array([10.000181, 10.000181, 10.000181, 10.000181, 10.000181, 10.000181,
       10.000181, 10.000181, 10.000181, 10.000181],      dtype=float32, weak_type=True), qpos=Array([[ 2.1983215e-01,  1.7256634e-02, -4.8518158e-03, ...,
        -2.4661754e-04, -1.0363663e-03, -2.4661754e-04],
       [ 1.7941296e-01,  1.4509780e-02,  1.0549298e-01, ...,
        -4.8228825e-04, -2.4422330e-03, -4.8228825e-04],
       [ 2.2563612e-01,  2.4252020e-02,  8.1276648e-02, ...,
        -1.5704670e-04, -1.3471110e-03, -1.5704670e-04],
       ...,
       [ 1.4686768e-01,  4.5323215e-02,  1.0895470e-01, ...,
        -4.3365324e-04, -2.1181267e-03, -4.3365324e-04],
       [ 1.7943002e-01,  1.5276319e-02,  2.0490275e-01, ...,
        -1.7357708e-04, -6.2435213e-04, -1.7357708e-04],
       [ 2.2216801e-01,  1.8880079e-02,  1.4731647e-01, ...,
        -1.7007545e-04, -6.1683502e-04, -1.7007545e-04]], dtype=float32), qvel=Array([[-6.8

In [5]:
env_state.pipeline_state.qvel.shape

(10, 134)

Restarted pianist (Python 3.10.12)

In [1]:
import os
os.environ['MUJOCO_GL'] = 'egl'
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".45"

from mjx_env import PipelineEnv, EnvState
import mjcf
from pathlib import Path
import mujoco
import jax
import sim_math as utils
from jax import numpy as jp
from mujoco import mjx

_HERE = Path(__file__).resolve().parent
_SHADOW_HAND_DIR = _HERE / "third_party" / "shadow_hand"


################## Constants ##################
# Timestep of the physics simulation, in seconds.
_PHYSICS_TIMESTEP = 0.005
# Interval between agent actions, in seconds.
_CONTROL_TIMESTEP = 0.05  # 20 Hz. # TODO: whether to increase control timesteps
# Distance thresholds for the shaping reward.
_FINGER_CLOSE_ENOUGH_TO_KEY = 0.01
_KEY_CLOSE_ENOUGH_TO_PRESSED = 0.05
# Energy penalty coefficient.
_ENERGY_PENALTY_COEF = 5e-3
# Transparency of fingertip geoms.
_FINGERTIP_ALPHA = 1.0
# Bounds for the uniform distribution from which initial hand offset is sampled.
_POSITION_OFFSET = 0.05
###############################################


class RoboPianist(PipelineEnv):

  def __init__(
      self,

      **kwargs,
  ):
    path = _HERE/ 'third_party' / 'shadow_hand' / 'robopianist.xml'

    mj_model = mujoco.MjModel.from_xml_path(path.as_posix())
    mj_model.opt.timestep = _PHYSICS_TIMESTEP

    sys = mjcf.load_model(mj_model)

    physics_steps_per_control_step = utils.compute_n_steps(_CONTROL_TIMESTEP, _PHYSICS_TIMESTEP)
    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step)
    super().__init__(sys, **kwargs)


  def reset(self, rng: jp.ndarray) -> EnvState:
    """Resets the environment to an initial state."""
    # TODO: decide the correct initial state
    rng, rng1, rng2 = jax.random.split(rng, 3)

    low, hi = -0.01, 0.01
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )
    qvel = jax.random.uniform(
        rng2, (self.sys.nv,), minval=low, maxval=hi
    )

    data = self.pipeline_init(qpos, qvel)
    obs = self._get_obs(data, jp.zeros(self.sys.nu))
    return EnvState(data, obs, jp.zeros(1), jp.zeros(1), {})

  def step(self, state: EnvState, action: jp.ndarray) -> EnvState:
    """Runs one timestep of the environment's dynamics."""
    data0 = state.pipeline_state
    data = self.pipeline_step(data0, action)
    obs = self._get_obs(data, action)
    return state.replace(
        pipeline_state=data, obs=obs, reward=jp.zeros(1), done=jp.zeros(1)
    )
  
  def _get_obs(
      self, data: mjx.Data, action: jp.ndarray
  ) -> jp.ndarray:
    """Observes humanoid body position, velocities, and angles."""
    position = data.qpos
    return jp.concatenate([position, action])



# envs.register_environment('humanoid', Humanoid)

In [2]:
# venv
from mjx_env import VmapWrapper, EpisodeWrapper
env = RoboPianist()
env = EpisodeWrapper(env, episode_length=1000, action_repeat=2)
env = VmapWrapper(env)

env_key = jax.random.PRNGKey(42)
num_envs = 100
local_devices_to_use = jax.local_device_count()
env_keys = jax.random.split(env_key, num_envs // jax.process_count())
#   env_keys = jnp.reshape(env_keys,
#                          (local_devices_to_use, -1) + env_keys.shape[1:])
#   env_state = jax.pmap(env.reset)(env_keys)
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
env_state = jit_reset(env_keys)
rollout = [env_state.pipeline_state]
for i in range(100):
  env_key, action_key = jax.random.split(env_key, 2)

  action = jax.random.uniform(action_key, (*env_keys.shape[:1], env.action_size))
  env_state = jit_step(env_state, action)
  rollout.append(env_state.pipeline_state)
  print(env_state.reward.shape)
  # if jnp.all(env_state.done):
  #   env_state = jax.pmap(env.reset)(env_keys) # TODO: check how to reset only the done envs
# import ipdb; ipdb.set_trace()
video = env.render(rollout, camera='side')

2024-03-14 16:15:34.410084: W external/tsl/tsl/framework/bfc_allocator.cc:296] Allocator (GPU_0_bfc) ran out of memory trying to allocate 12.86GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.


XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 13812625216 bytes.

In [3]:
# test the naive mujoco-mjx memory usage
path = _HERE/ 'third_party' / 'shadow_hand' / 'robopianist.xml'

mj_model = mujoco.MjModel.from_xml_path(path.as_posix())
mj_model.opt.timestep = _PHYSICS_TIMESTEP

mj_model = mujoco.MjModel.from_xml_path(path.as_posix())
mj_data = mujoco.MjData(mj_model)
renderer = mujoco.Renderer(mj_model, height=240, width=320)

mjx_model = mjx.put_model(mj_model)
mjx_data = mjx.put_data(mj_model, mj_data)

duration = 3.8
framerate = 60

rng = jax.random.PRNGKey(0)
rng = jax.random.split(rng, 100)
batch = jax.vmap(lambda rng: mjx_data.replace(qpos=jax.random.uniform(rng, (1,))))(rng)
jit_step = jax.vmap(mjx.step, in_axes=(None, 0))
batch = jit_step(mjx_model, batch)

IndexError: f argument "3" with type "q" has length "1" which does not match the in_types[3] expected length of "134".

In [4]:
batch

Data(solver_niter=Array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]], dtype=int32), time=Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],      dtype=float32, weak_type=True), qpos=Array([[0.83727205],
       [0.05812871],
       [0.06249678],
       [0.24408734],
       [0.57799315],
       [0.21110463],
       [0.2097305 ],
       [0.8750204 ],
       [0.00115514],
       [0.6094818 ],
       [0.9049611 ],
       [0.6554712 ],
       [0.12806177],


In [5]:
batch.qpos.shape

(100, 1)

In [6]:
mjx_model.qpos

AttributeError: 'Model' object has no attribute 'qpos'

In [7]:
mjx_data.qpos

Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],      dtype=float32)

In [8]:
mjx_data.qpos.shape

(134,)

In [9]:
# test the naive mujoco-mjx memory usage
path = _HERE/ 'third_party' / 'shadow_hand' / 'robopianist.xml'

mj_model = mujoco.MjModel.from_xml_path(path.as_posix())
mj_model.opt.timestep = _PHYSICS_TIMESTEP

mj_model = mujoco.MjModel.from_xml_path(path.as_posix())
mj_data = mujoco.MjData(mj_model)
renderer = mujoco.Renderer(mj_model, height=240, width=320)

mjx_model = mjx.put_model(mj_model)
mjx_data = mjx.put_data(mj_model, mj_data)

duration = 3.8
framerate = 60

rng = jax.random.PRNGKey(0)
rng = jax.random.split(rng, 100)
batch = jax.vmap(lambda rng: mjx_data.replace(qpos=jax.random.uniform(rng, (134,))))(rng)
jit_step = jax.vmap(mjx.step, in_axes=(None, 0))
batch = jit_step(mjx_model, batch)

2024-03-14 16:30:18.155145: W external/tsl/tsl/framework/bfc_allocator.cc:296] Allocator (GPU_0_bfc) ran out of memory trying to allocate 12.86GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.


XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 13812625216 bytes.

Restarted pianist (Python 3.10.12)

In [1]:
import os
os.environ['MUJOCO_GL'] = 'egl'
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".45"

from mjx_env import PipelineEnv, EnvState
import mjcf
from pathlib import Path
import mujoco
import jax
import sim_math as utils
from jax import numpy as jp
from mujoco import mjx

_HERE = Path(__file__).resolve().parent
_SHADOW_HAND_DIR = _HERE / "third_party" / "shadow_hand"


################## Constants ##################
# Timestep of the physics simulation, in seconds.
_PHYSICS_TIMESTEP = 0.005
# Interval between agent actions, in seconds.
_CONTROL_TIMESTEP = 0.05  # 20 Hz. # TODO: whether to increase control timesteps
# Distance thresholds for the shaping reward.
_FINGER_CLOSE_ENOUGH_TO_KEY = 0.01
_KEY_CLOSE_ENOUGH_TO_PRESSED = 0.05
# Energy penalty coefficient.
_ENERGY_PENALTY_COEF = 5e-3
# Transparency of fingertip geoms.
_FINGERTIP_ALPHA = 1.0
# Bounds for the uniform distribution from which initial hand offset is sampled.
_POSITION_OFFSET = 0.05
###############################################


class RoboPianist(PipelineEnv):

  def __init__(
      self,

      **kwargs,
  ):
    path = _HERE/ 'third_party' / 'shadow_hand' / 'robopianist.xml'

    mj_model = mujoco.MjModel.from_xml_path(path.as_posix())
    mj_model.opt.timestep = _PHYSICS_TIMESTEP

    sys = mjcf.load_model(mj_model)

    physics_steps_per_control_step = utils.compute_n_steps(_CONTROL_TIMESTEP, _PHYSICS_TIMESTEP)
    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step)
    super().__init__(sys, **kwargs)


  def reset(self, rng: jp.ndarray) -> EnvState:
    """Resets the environment to an initial state."""
    # TODO: decide the correct initial state
    rng, rng1, rng2 = jax.random.split(rng, 3)

    low, hi = -0.01, 0.01
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )
    qvel = jax.random.uniform(
        rng2, (self.sys.nv,), minval=low, maxval=hi
    )

    data = self.pipeline_init(qpos, qvel)
    obs = self._get_obs(data, jp.zeros(self.sys.nu))
    return EnvState(data, obs, jp.zeros(1), jp.zeros(1), {})

  def step(self, state: EnvState, action: jp.ndarray) -> EnvState:
    """Runs one timestep of the environment's dynamics."""
    data0 = state.pipeline_state
    data = self.pipeline_step(data0, action)
    obs = self._get_obs(data, action)
    return state.replace(
        pipeline_state=data, obs=obs, reward=jp.zeros(1), done=jp.zeros(1)
    )
  
  def _get_obs(
      self, data: mjx.Data, action: jp.ndarray
  ) -> jp.ndarray:
    """Observes humanoid body position, velocities, and angles."""
    position = data.qpos
    return jp.concatenate([position, action])



# envs.register_environment('humanoid', Humanoid)

In [2]:
# test the naive mujoco-mjx memory usage
path = _HERE/ 'third_party' / 'shadow_hand' / 'robopianist.xml'

mj_model = mujoco.MjModel.from_xml_path(path.as_posix())
mj_model.opt.timestep = _PHYSICS_TIMESTEP

mj_model = mujoco.MjModel.from_xml_path(path.as_posix())
mj_data = mujoco.MjData(mj_model)
renderer = mujoco.Renderer(mj_model, height=240, width=320)

mjx_model = mjx.put_model(mj_model)
mjx_data = mjx.put_data(mj_model, mj_data)

duration = 3.8
framerate = 60

rng = jax.random.PRNGKey(0)
rng = jax.random.split(rng, 100)
batch = jax.vmap(lambda rng: mjx_data.replace(qpos=jax.random.uniform(rng, (134,))))(rng)
jit_step = jax.vmap(mjx.step, in_axes=(None, 0))
batch = jit_step(mjx_model, batch)

IndexError: f argument "3" with type "q" has length "134" which does not match the in_types[3] expected length of "46".

In [3]:
# venv
from mjx_env import VmapWrapper, EpisodeWrapper
env = RoboPianist()
env = EpisodeWrapper(env, episode_length=1000, action_repeat=2)
env = VmapWrapper(env)

env_key = jax.random.PRNGKey(42)
num_envs = 100
local_devices_to_use = jax.local_device_count()
env_keys = jax.random.split(env_key, num_envs // jax.process_count())
#   env_keys = jnp.reshape(env_keys,
#                          (local_devices_to_use, -1) + env_keys.shape[1:])
#   env_state = jax.pmap(env.reset)(env_keys)
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
env_state = jit_reset(env_keys)
rollout = [env_state.pipeline_state]
for i in range(100):
  env_key, action_key = jax.random.split(env_key, 2)

  action = jax.random.uniform(action_key, (*env_keys.shape[:1], env.action_size))
  env_state = jit_step(env_state, action)
  rollout.append(env_state.pipeline_state)
  print(env_state.reward.shape)
  # if jnp.all(env_state.done):
  #   env_state = jax.pmap(env.reset)(env_keys) # TODO: check how to reset only the done envs
# import ipdb; ipdb.set_trace()
# video = env.render(rollout, camera='side')

2024-03-14 16:38:41.892713: W external/tsl/tsl/framework/bfc_allocator.cc:296] Allocator (GPU_0_bfc) ran out of memory trying to allocate 12.86GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.


XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 13812625216 bytes.

Restarted pianist (Python 3.10.12)

In [1]:
import os
os.environ['MUJOCO_GL'] = 'egl'
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".45"

from mjx_env import PipelineEnv, EnvState
import mjcf
from pathlib import Path
import mujoco
import jax
import sim_math as utils
from jax import numpy as jp
from mujoco import mjx

_HERE = Path(__file__).resolve().parent
_SHADOW_HAND_DIR = _HERE / "third_party" / "shadow_hand"


################## Constants ##################
# Timestep of the physics simulation, in seconds.
_PHYSICS_TIMESTEP = 0.005
# Interval between agent actions, in seconds.
_CONTROL_TIMESTEP = 0.05  # 20 Hz. # TODO: whether to increase control timesteps
# Distance thresholds for the shaping reward.
_FINGER_CLOSE_ENOUGH_TO_KEY = 0.01
_KEY_CLOSE_ENOUGH_TO_PRESSED = 0.05
# Energy penalty coefficient.
_ENERGY_PENALTY_COEF = 5e-3
# Transparency of fingertip geoms.
_FINGERTIP_ALPHA = 1.0
# Bounds for the uniform distribution from which initial hand offset is sampled.
_POSITION_OFFSET = 0.05
###############################################


class RoboPianist(PipelineEnv):

  def __init__(
      self,

      **kwargs,
  ):
    path = _HERE/ 'third_party' / 'shadow_hand' / 'robopianist.xml'

    mj_model = mujoco.MjModel.from_xml_path(path.as_posix())
    mj_model.opt.timestep = _PHYSICS_TIMESTEP

    sys = mjcf.load_model(mj_model)

    physics_steps_per_control_step = utils.compute_n_steps(_CONTROL_TIMESTEP, _PHYSICS_TIMESTEP)
    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step)
    super().__init__(sys, **kwargs)


  def reset(self, rng: jp.ndarray) -> EnvState:
    """Resets the environment to an initial state."""
    # TODO: decide the correct initial state
    rng, rng1, rng2 = jax.random.split(rng, 3)

    low, hi = -0.01, 0.01
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )
    qvel = jax.random.uniform(
        rng2, (self.sys.nv,), minval=low, maxval=hi
    )

    data = self.pipeline_init(qpos, qvel)
    obs = self._get_obs(data, jp.zeros(self.sys.nu))
    return EnvState(data, obs, jp.zeros(1), jp.zeros(1), {})

  def step(self, state: EnvState, action: jp.ndarray) -> EnvState:
    """Runs one timestep of the environment's dynamics."""
    data0 = state.pipeline_state
    data = self.pipeline_step(data0, action)
    obs = self._get_obs(data, action)
    return state.replace(
        pipeline_state=data, obs=obs, reward=jp.zeros(1), done=jp.zeros(1)
    )
  
  def _get_obs(
      self, data: mjx.Data, action: jp.ndarray
  ) -> jp.ndarray:
    """Observes humanoid body position, velocities, and angles."""
    position = data.qpos
    return jp.concatenate([position, action])



# envs.register_environment('humanoid', Humanoid)

In [2]:
# venv
from mjx_env import VmapWrapper, EpisodeWrapper
env = RoboPianist()
env = EpisodeWrapper(env, episode_length=1000, action_repeat=2)
env = VmapWrapper(env)

env_key = jax.random.PRNGKey(42)
num_envs = 100
local_devices_to_use = jax.local_device_count()
env_keys = jax.random.split(env_key, num_envs // jax.process_count())
#   env_keys = jnp.reshape(env_keys,
#                          (local_devices_to_use, -1) + env_keys.shape[1:])
#   env_state = jax.pmap(env.reset)(env_keys)
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
env_state = jit_reset(env_keys)
rollout = [env_state.pipeline_state]
for i in range(100):
  env_key, action_key = jax.random.split(env_key, 2)

  action = jax.random.uniform(action_key, (*env_keys.shape[:1], env.action_size))
  env_state = jit_step(env_state, action)
  rollout.append(env_state.pipeline_state)
  print(env_state.reward.shape)
  # if jnp.all(env_state.done):
  #   env_state = jax.pmap(env.reset)(env_keys) # TODO: check how to reset only the done envs
# import ipdb; ipdb.set_trace()
# video = env.render(rollout, camera='side')

2024-03-14 16:39:51.041817: W external/tsl/tsl/framework/bfc_allocator.cc:296] Allocator (GPU_0_bfc) ran out of memory trying to allocate 5.15GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.


XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 5535116416 bytes.

In [3]:
import os
os.environ['MUJOCO_GL'] = 'egl'
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".45"

from mjx_env import PipelineEnv, EnvState
import mjcf
from pathlib import Path
import mujoco
import jax
import sim_math as utils
from jax import numpy as jp
from mujoco import mjx

_HERE = Path(__file__).resolve().parent
_SHADOW_HAND_DIR = _HERE / "third_party" / "shadow_hand"


################## Constants ##################
# Timestep of the physics simulation, in seconds.
_PHYSICS_TIMESTEP = 0.005
# Interval between agent actions, in seconds.
_CONTROL_TIMESTEP = 0.05  # 20 Hz. # TODO: whether to increase control timesteps
# Distance thresholds for the shaping reward.
_FINGER_CLOSE_ENOUGH_TO_KEY = 0.01
_KEY_CLOSE_ENOUGH_TO_PRESSED = 0.05
# Energy penalty coefficient.
_ENERGY_PENALTY_COEF = 5e-3
# Transparency of fingertip geoms.
_FINGERTIP_ALPHA = 1.0
# Bounds for the uniform distribution from which initial hand offset is sampled.
_POSITION_OFFSET = 0.05
###############################################


class RoboPianist(PipelineEnv):

  def __init__(
      self,

      **kwargs,
  ):
    path = _HERE/ 'third_party' / 'shadow_hand' / 'left_hand.xml'

    mj_model = mujoco.MjModel.from_xml_path(path.as_posix())
    mj_model.opt.timestep = _PHYSICS_TIMESTEP

    sys = mjcf.load_model(mj_model)

    physics_steps_per_control_step = utils.compute_n_steps(_CONTROL_TIMESTEP, _PHYSICS_TIMESTEP)
    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step)
    super().__init__(sys, **kwargs)


  def reset(self, rng: jp.ndarray) -> EnvState:
    """Resets the environment to an initial state."""
    # TODO: decide the correct initial state
    rng, rng1, rng2 = jax.random.split(rng, 3)

    low, hi = -0.01, 0.01
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )
    qvel = jax.random.uniform(
        rng2, (self.sys.nv,), minval=low, maxval=hi
    )

    data = self.pipeline_init(qpos, qvel)
    obs = self._get_obs(data, jp.zeros(self.sys.nu))
    return EnvState(data, obs, jp.zeros(1), jp.zeros(1), {})

  def step(self, state: EnvState, action: jp.ndarray) -> EnvState:
    """Runs one timestep of the environment's dynamics."""
    data0 = state.pipeline_state
    data = self.pipeline_step(data0, action)
    obs = self._get_obs(data, action)
    return state.replace(
        pipeline_state=data, obs=obs, reward=jp.zeros(1), done=jp.zeros(1)
    )
  
  def _get_obs(
      self, data: mjx.Data, action: jp.ndarray
  ) -> jp.ndarray:
    """Observes humanoid body position, velocities, and angles."""
    position = data.qpos
    return jp.concatenate([position, action])



# envs.register_environment('humanoid', Humanoid)

In [4]:
# if __name__ == "__main__":

env = RoboPianist()
# #   import ipdb; ipdb.set_trace()
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

state = jit_reset(jax.random.PRNGKey(0))
rollout = [state.pipeline_state]

# grab a trajectory
for i in range(100):
  ctrl = -0.1 * jp.ones(env.sys.nu)
  state = jit_step(state, ctrl)
  rollout.append(state.pipeline_state)

video = env.render(rollout, camera='side')

In [5]:
import mediapy as media
media.show_video(video, fps=60)

0
This browser does not support the video tag.


In [6]:
# venv
from mjx_env import VmapWrapper, EpisodeWrapper
env = RoboPianist()
env = EpisodeWrapper(env, episode_length=1000, action_repeat=2)
env = VmapWrapper(env)

env_key = jax.random.PRNGKey(42)
num_envs = 100
local_devices_to_use = jax.local_device_count()
env_keys = jax.random.split(env_key, num_envs // jax.process_count())
#   env_keys = jnp.reshape(env_keys,
#                          (local_devices_to_use, -1) + env_keys.shape[1:])
#   env_state = jax.pmap(env.reset)(env_keys)
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
env_state = jit_reset(env_keys)
rollout = [env_state.pipeline_state]
for i in range(100):
  env_key, action_key = jax.random.split(env_key, 2)

  action = jax.random.uniform(action_key, (*env_keys.shape[:1], env.action_size))
  env_state = jit_step(env_state, action)
  rollout.append(env_state.pipeline_state)
  print(env_state.reward.shape)
  # if jnp.all(env_state.done):
  #   env_state = jax.pmap(env.reset)(env_keys) # TODO: check how to reset only the done envs
# import ipdb; ipdb.set_trace()
# video = env.render(rollout, camera='side')

2024-03-14 16:44:11.907385: W external/tsl/tsl/framework/bfc_allocator.cc:296] Allocator (GPU_0_bfc) ran out of memory trying to allocate 5.15GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.


XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 5535116416 bytes.

In [7]:
# test the naive mujoco-mjx memory usage
path = _HERE/ 'third_party' / 'shadow_hand' / 'left_hand.xml'

mj_model = mujoco.MjModel.from_xml_path(path.as_posix())
mj_model.opt.timestep = _PHYSICS_TIMESTEP

mj_model = mujoco.MjModel.from_xml_path(path.as_posix())
mj_data = mujoco.MjData(mj_model)
renderer = mujoco.Renderer(mj_model, height=240, width=320)

mjx_model = mjx.put_model(mj_model)
mjx_data = mjx.put_data(mj_model, mj_data)

duration = 3.8
framerate = 60

rng = jax.random.PRNGKey(0)
rng = jax.random.split(rng, 100)
batch = jax.vmap(lambda rng: mjx_data.replace(qpos=jax.random.uniform(rng, (134,))))(rng)
jit_step = jax.vmap(mjx.step, in_axes=(None, 0))
batch = jit_step(mjx_model, batch)

IndexError: f argument "3" with type "q" has length "134" which does not match the in_types[3] expected length of "24".

In [8]:
# test the naive mujoco-mjx memory usage
path = _HERE/ 'third_party' / 'shadow_hand' / 'left_hand.xml'

mj_model = mujoco.MjModel.from_xml_path(path.as_posix())
mj_model.opt.timestep = _PHYSICS_TIMESTEP

mj_model = mujoco.MjModel.from_xml_path(path.as_posix())
mj_data = mujoco.MjData(mj_model)
renderer = mujoco.Renderer(mj_model, height=240, width=320)

mjx_model = mjx.put_model(mj_model)
mjx_data = mjx.put_data(mj_model, mj_data)

duration = 3.8
framerate = 60

rng = jax.random.PRNGKey(0)
rng = jax.random.split(rng, 100)
batch = jax.vmap(lambda rng: mjx_data.replace(qpos=jax.random.uniform(rng, (24,))))(rng)
jit_step = jax.vmap(mjx.step, in_axes=(None, 0))
batch = jit_step(mjx_model, batch)

2024-03-14 16:47:24.681874: W external/tsl/tsl/framework/bfc_allocator.cc:296] Allocator (GPU_0_bfc) ran out of memory trying to allocate 5.15GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.


XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 5535116416 bytes.

Restarted pianist (Python 3.10.12)

In [1]:
# test the naive mujoco-mjx memory usage
path = _HERE/ 'third_party' / 'shadow_hand' / 'left_hand_mjx.xml'

mj_model = mujoco.MjModel.from_xml_path(path.as_posix())
mj_model.opt.timestep = _PHYSICS_TIMESTEP

mj_model = mujoco.MjModel.from_xml_path(path.as_posix())
mj_data = mujoco.MjData(mj_model)
renderer = mujoco.Renderer(mj_model, height=240, width=320)

mjx_model = mjx.put_model(mj_model)
mjx_data = mjx.put_data(mj_model, mj_data)

duration = 3.8
framerate = 60

rng = jax.random.PRNGKey(0)
rng = jax.random.split(rng, 100)
batch = jax.vmap(lambda rng: mjx_data.replace(qpos=jax.random.uniform(rng, (24,))))(rng)
jit_step = jax.vmap(mjx.step, in_axes=(None, 0))
batch = jit_step(mjx_model, batch)

NameError: name '_HERE' is not defined

In [2]:
import os
os.environ['MUJOCO_GL'] = 'egl'
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".45"

from mjx_env import PipelineEnv, EnvState
import mjcf
from pathlib import Path
import mujoco
import jax
import sim_math as utils
from jax import numpy as jp
from mujoco import mjx

_HERE = Path(__file__).resolve().parent
_SHADOW_HAND_DIR = _HERE / "third_party" / "shadow_hand"


################## Constants ##################
# Timestep of the physics simulation, in seconds.
_PHYSICS_TIMESTEP = 0.005
# Interval between agent actions, in seconds.
_CONTROL_TIMESTEP = 0.05  # 20 Hz. # TODO: whether to increase control timesteps
# Distance thresholds for the shaping reward.
_FINGER_CLOSE_ENOUGH_TO_KEY = 0.01
_KEY_CLOSE_ENOUGH_TO_PRESSED = 0.05
# Energy penalty coefficient.
_ENERGY_PENALTY_COEF = 5e-3
# Transparency of fingertip geoms.
_FINGERTIP_ALPHA = 1.0
# Bounds for the uniform distribution from which initial hand offset is sampled.
_POSITION_OFFSET = 0.05
###############################################


class RoboPianist(PipelineEnv):

  def __init__(
      self,

      **kwargs,
  ):
    path = _HERE/ 'third_party' / 'shadow_hand' / 'left_hand.xml'

    mj_model = mujoco.MjModel.from_xml_path(path.as_posix())
    mj_model.opt.timestep = _PHYSICS_TIMESTEP

    sys = mjcf.load_model(mj_model)

    physics_steps_per_control_step = utils.compute_n_steps(_CONTROL_TIMESTEP, _PHYSICS_TIMESTEP)
    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step)
    super().__init__(sys, **kwargs)


  def reset(self, rng: jp.ndarray) -> EnvState:
    """Resets the environment to an initial state."""
    # TODO: decide the correct initial state
    rng, rng1, rng2 = jax.random.split(rng, 3)

    low, hi = -0.01, 0.01
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )
    qvel = jax.random.uniform(
        rng2, (self.sys.nv,), minval=low, maxval=hi
    )

    data = self.pipeline_init(qpos, qvel)
    obs = self._get_obs(data, jp.zeros(self.sys.nu))
    return EnvState(data, obs, jp.zeros(1), jp.zeros(1), {})

  def step(self, state: EnvState, action: jp.ndarray) -> EnvState:
    """Runs one timestep of the environment's dynamics."""
    data0 = state.pipeline_state
    data = self.pipeline_step(data0, action)
    obs = self._get_obs(data, action)
    return state.replace(
        pipeline_state=data, obs=obs, reward=jp.zeros(1), done=jp.zeros(1)
    )
  
  def _get_obs(
      self, data: mjx.Data, action: jp.ndarray
  ) -> jp.ndarray:
    """Observes humanoid body position, velocities, and angles."""
    position = data.qpos
    return jp.concatenate([position, action])



# envs.register_environment('humanoid', Humanoid)

Restarted pianist (Python 3.10.12)

In [1]:
import os
os.environ['MUJOCO_GL'] = 'egl'
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".45"

from mjx_env import PipelineEnv, EnvState
import mjcf
from pathlib import Path
import mujoco
import jax
import sim_math as utils
from jax import numpy as jp
from mujoco import mjx

_HERE = Path(__file__).resolve().parent
_SHADOW_HAND_DIR = _HERE / "third_party" / "shadow_hand"


################## Constants ##################
# Timestep of the physics simulation, in seconds.
_PHYSICS_TIMESTEP = 0.005
# Interval between agent actions, in seconds.
_CONTROL_TIMESTEP = 0.05  # 20 Hz. # TODO: whether to increase control timesteps
# Distance thresholds for the shaping reward.
_FINGER_CLOSE_ENOUGH_TO_KEY = 0.01
_KEY_CLOSE_ENOUGH_TO_PRESSED = 0.05
# Energy penalty coefficient.
_ENERGY_PENALTY_COEF = 5e-3
# Transparency of fingertip geoms.
_FINGERTIP_ALPHA = 1.0
# Bounds for the uniform distribution from which initial hand offset is sampled.
_POSITION_OFFSET = 0.05
###############################################

In [2]:
# test the naive mujoco-mjx memory usage
path = _HERE/ 'third_party' / 'shadow_hand' / 'left_hand_mjx.xml'

mj_model = mujoco.MjModel.from_xml_path(path.as_posix())
mj_model.opt.timestep = _PHYSICS_TIMESTEP

mj_model = mujoco.MjModel.from_xml_path(path.as_posix())
mj_data = mujoco.MjData(mj_model)
renderer = mujoco.Renderer(mj_model, height=240, width=320)

mjx_model = mjx.put_model(mj_model)
mjx_data = mjx.put_data(mj_model, mj_data)

duration = 3.8
framerate = 60

rng = jax.random.PRNGKey(0)
rng = jax.random.split(rng, 100)
batch = jax.vmap(lambda rng: mjx_data.replace(qpos=jax.random.uniform(rng, (24,))))(rng)
jit_step = jax.vmap(mjx.step, in_axes=(None, 0))
batch = jit_step(mjx_model, batch)

ValueError: mjParseXML: resource not found via provider or OS filesystem: '/ssd/rl/robopianist/robopianist-mjx/dev/third_party/shadow_hand/left_hand_mjx.xml'

In [3]:
# test the naive mujoco-mjx memory usage
path = _HERE/ 'third_party' / 'shadow_hand' / 'right_hand_mjx.xml'

mj_model = mujoco.MjModel.from_xml_path(path.as_posix())
mj_model.opt.timestep = _PHYSICS_TIMESTEP

mj_model = mujoco.MjModel.from_xml_path(path.as_posix())
mj_data = mujoco.MjData(mj_model)
renderer = mujoco.Renderer(mj_model, height=240, width=320)

mjx_model = mjx.put_model(mj_model)
mjx_data = mjx.put_data(mj_model, mj_data)

duration = 3.8
framerate = 60

rng = jax.random.PRNGKey(0)
rng = jax.random.split(rng, 100)
batch = jax.vmap(lambda rng: mjx_data.replace(qpos=jax.random.uniform(rng, (24,))))(rng)
jit_step = jax.vmap(mjx.step, in_axes=(None, 0))
batch = jit_step(mjx_model, batch)

2024-03-14 16:55:51.314587: W external/tsl/tsl/framework/bfc_allocator.cc:296] Allocator (GPU_0_bfc) ran out of memory trying to allocate 5.15GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.


XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 5535116416 bytes.

Restarted pianist (Python 3.10.12)

In [1]:
import os
os.environ['MUJOCO_GL'] = 'egl'
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".45"

from mjx_env import PipelineEnv, EnvState
import mjcf
from pathlib import Path
import mujoco
import jax
import sim_math as utils
from jax import numpy as jp
from mujoco import mjx

_HERE = Path(__file__).resolve().parent
_SHADOW_HAND_DIR = _HERE / "third_party" / "shadow_hand"


################## Constants ##################
# Timestep of the physics simulation, in seconds.
_PHYSICS_TIMESTEP = 0.005
# Interval between agent actions, in seconds.
_CONTROL_TIMESTEP = 0.05  # 20 Hz. # TODO: whether to increase control timesteps
# Distance thresholds for the shaping reward.
_FINGER_CLOSE_ENOUGH_TO_KEY = 0.01
_KEY_CLOSE_ENOUGH_TO_PRESSED = 0.05
# Energy penalty coefficient.
_ENERGY_PENALTY_COEF = 5e-3
# Transparency of fingertip geoms.
_FINGERTIP_ALPHA = 1.0
# Bounds for the uniform distribution from which initial hand offset is sampled.
_POSITION_OFFSET = 0.05
###############################################

In [2]:
# test the naive mujoco-mjx memory usage
path = _HERE/ 'third_party' / 'shadow_hand' / 'scene_right_mjx.xml'

mj_model = mujoco.MjModel.from_xml_path(path.as_posix())
mj_model.opt.timestep = _PHYSICS_TIMESTEP

mj_model = mujoco.MjModel.from_xml_path(path.as_posix())
mj_data = mujoco.MjData(mj_model)
renderer = mujoco.Renderer(mj_model, height=240, width=320)

mjx_model = mjx.put_model(mj_model)
mjx_data = mjx.put_data(mj_model, mj_data)

duration = 3.8
framerate = 60

rng = jax.random.PRNGKey(0)
rng = jax.random.split(rng, 100)
batch = jax.vmap(lambda rng: mjx_data.replace(qpos=jax.random.uniform(rng, (24,))))(rng)
jit_step = jax.vmap(mjx.step, in_axes=(None, 0))
batch = jit_step(mjx_model, batch)

ValueError: Error: material 'gray' not found in geom 1
Object name = , id = 1, line = 0

In [3]:
# test the naive mujoco-mjx memory usage
path = _HERE/ 'third_party' / 'shadow_hand' / 'scene_right_mjx.xml'

mj_model = mujoco.MjModel.from_xml_path(path.as_posix())
mj_model.opt.timestep = _PHYSICS_TIMESTEP

mj_model = mujoco.MjModel.from_xml_path(path.as_posix())
mj_data = mujoco.MjData(mj_model)
renderer = mujoco.Renderer(mj_model, height=240, width=320)

mjx_model = mjx.put_model(mj_model)
mjx_data = mjx.put_data(mj_model, mj_data)

duration = 3.8
framerate = 60

rng = jax.random.PRNGKey(0)
rng = jax.random.split(rng, 100)
batch = jax.vmap(lambda rng: mjx_data.replace(qpos=jax.random.uniform(rng, (24,))))(rng)
jit_step = jax.vmap(mjx.step, in_axes=(None, 0))
batch = jit_step(mjx_model, batch)

IndexError: f argument "3" with type "q" has length "24" which does not match the in_types[3] expected length of "31".

Restarted pianist (Python 3.10.12)

In [1]:
import os
os.environ['MUJOCO_GL'] = 'egl'
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".45"

from mjx_env import PipelineEnv, EnvState
import mjcf
from pathlib import Path
import mujoco
import jax
import sim_math as utils
from jax import numpy as jp
from mujoco import mjx

_HERE = Path(__file__).resolve().parent
_SHADOW_HAND_DIR = _HERE / "third_party" / "shadow_hand"


################## Constants ##################
# Timestep of the physics simulation, in seconds.
_PHYSICS_TIMESTEP = 0.005
# Interval between agent actions, in seconds.
_CONTROL_TIMESTEP = 0.05  # 20 Hz. # TODO: whether to increase control timesteps
# Distance thresholds for the shaping reward.
_FINGER_CLOSE_ENOUGH_TO_KEY = 0.01
_KEY_CLOSE_ENOUGH_TO_PRESSED = 0.05
# Energy penalty coefficient.
_ENERGY_PENALTY_COEF = 5e-3
# Transparency of fingertip geoms.
_FINGERTIP_ALPHA = 1.0
# Bounds for the uniform distribution from which initial hand offset is sampled.
_POSITION_OFFSET = 0.05
###############################################

In [2]:
# test the naive mujoco-mjx memory usage
path = _HERE/ 'third_party' / 'shadow_hand' / 'scene_right_mjx.xml'

mj_model = mujoco.MjModel.from_xml_path(path.as_posix())
mj_model.opt.timestep = _PHYSICS_TIMESTEP

mj_model = mujoco.MjModel.from_xml_path(path.as_posix())
mj_data = mujoco.MjData(mj_model)
renderer = mujoco.Renderer(mj_model, height=240, width=320)

mjx_model = mjx.put_model(mj_model)
mjx_data = mjx.put_data(mj_model, mj_data)

duration = 3.8
framerate = 60

rng = jax.random.PRNGKey(0)
rng = jax.random.split(rng, 100)
batch = jax.vmap(lambda rng: mjx_data.replace(qpos=jax.random.uniform(rng, (31,))))(rng)
jit_step = jax.vmap(mjx.step, in_axes=(None, 0))
batch = jit_step(mjx_model, batch)

2024-03-14 17:14:26.838526: W external/tsl/tsl/framework/bfc_allocator.cc:296] Allocator (GPU_0_bfc) ran out of memory trying to allocate 5.15GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.


XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 5535116416 bytes.

Restarted pianist (Python 3.10.12)

In [1]:
import os
os.environ['MUJOCO_GL'] = 'egl'
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".45"

from mjx_env import PipelineEnv, EnvState
import mjcf
from pathlib import Path
import mujoco
import jax
import sim_math as utils
from jax import numpy as jp
from mujoco import mjx

_HERE = Path(__file__).resolve().parent
_SHADOW_HAND_DIR = _HERE / "third_party" / "shadow_hand"


################## Constants ##################
# Timestep of the physics simulation, in seconds.
_PHYSICS_TIMESTEP = 0.005
# Interval between agent actions, in seconds.
_CONTROL_TIMESTEP = 0.05  # 20 Hz. # TODO: whether to increase control timesteps
# Distance thresholds for the shaping reward.
_FINGER_CLOSE_ENOUGH_TO_KEY = 0.01
_KEY_CLOSE_ENOUGH_TO_PRESSED = 0.05
# Energy penalty coefficient.
_ENERGY_PENALTY_COEF = 5e-3
# Transparency of fingertip geoms.
_FINGERTIP_ALPHA = 1.0
# Bounds for the uniform distribution from which initial hand offset is sampled.
_POSITION_OFFSET = 0.05
###############################################

In [2]:
class RoboPianist(PipelineEnv):

  def __init__(
      self,

      **kwargs,
  ):
    path = _HERE/ 'third_party' / 'shadow_hand' / 'left_hand.xml'

    mj_model = mujoco.MjModel.from_xml_path(path.as_posix())
    mj_model.opt.timestep = _PHYSICS_TIMESTEP

    sys = mjcf.load_model(mj_model)

    physics_steps_per_control_step = utils.compute_n_steps(_CONTROL_TIMESTEP, _PHYSICS_TIMESTEP)
    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step)
    super().__init__(sys, **kwargs)


  def reset(self, rng: jp.ndarray) -> EnvState:
    """Resets the environment to an initial state."""
    # TODO: decide the correct initial state
    rng, rng1, rng2 = jax.random.split(rng, 3)

    low, hi = -0.01, 0.01
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )
    qvel = jax.random.uniform(
        rng2, (self.sys.nv,), minval=low, maxval=hi
    )

    data = self.pipeline_init(qpos, qvel)
    obs = self._get_obs(data, jp.zeros(self.sys.nu))
    return EnvState(data, obs, jp.zeros(1), jp.zeros(1), {})

  def step(self, state: EnvState, action: jp.ndarray) -> EnvState:
    """Runs one timestep of the environment's dynamics."""
    data0 = state.pipeline_state
    data = self.pipeline_step(data0, action)
    obs = self._get_obs(data, action)
    return state.replace(
        pipeline_state=data, obs=obs, reward=jp.zeros(1), done=jp.zeros(1)
    )
  
  def _get_obs(
      self, data: mjx.Data, action: jp.ndarray
  ) -> jp.ndarray:
    """Observes humanoid body position, velocities, and angles."""
    position = data.qpos
    return jp.concatenate([position, action])



# envs.register_environment('humanoid', Humanoid)

In [3]:
# venv
from mjx_env import VmapWrapper, EpisodeWrapper
env = RoboPianist()
env = EpisodeWrapper(env, episode_length=1000, action_repeat=2)
env = VmapWrapper(env)

env_key = jax.random.PRNGKey(42)
num_envs = 100
local_devices_to_use = jax.local_device_count()
env_keys = jax.random.split(env_key, num_envs // jax.process_count())
#   env_keys = jnp.reshape(env_keys,
#                          (local_devices_to_use, -1) + env_keys.shape[1:])
#   env_state = jax.pmap(env.reset)(env_keys)
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
env_state = jit_reset(env_keys)
rollout = [env_state.pipeline_state]
for i in range(100):
  env_key, action_key = jax.random.split(env_key, 2)

  action = jax.random.uniform(action_key, (*env_keys.shape[:1], env.action_size))
  env_state = jit_step(env_state, action)
  rollout.append(env_state.pipeline_state)
  print(env_state.reward.shape)
  # if jnp.all(env_state.done):
  #   env_state = jax.pmap(env.reset)(env_keys) # TODO: check how to reset only the done envs
# import ipdb; ipdb.set_trace()
# video = env.render(rollout, camera='side')

2024-03-14 18:39:07.527811: W external/tsl/tsl/framework/bfc_allocator.cc:296] Allocator (GPU_0_bfc) ran out of memory trying to allocate 5.15GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.


XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 5535116416 bytes.

In [4]:
class RoboPianist(PipelineEnv):

  def __init__(
      self,

      **kwargs,
  ):
    path = _HERE/ 'third_party' / 'shadow_hand' / 'robopianist.xml'

    mj_model = mujoco.MjModel.from_xml_path(path.as_posix())
    mj_model.opt.timestep = _PHYSICS_TIMESTEP

    sys = mjcf.load_model(mj_model)

    physics_steps_per_control_step = utils.compute_n_steps(_CONTROL_TIMESTEP, _PHYSICS_TIMESTEP)
    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step)
    super().__init__(sys, **kwargs)


  def reset(self, rng: jp.ndarray) -> EnvState:
    """Resets the environment to an initial state."""
    # TODO: decide the correct initial state
    rng, rng1, rng2 = jax.random.split(rng, 3)

    low, hi = -0.01, 0.01
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )
    qvel = jax.random.uniform(
        rng2, (self.sys.nv,), minval=low, maxval=hi
    )

    data = self.pipeline_init(qpos, qvel)
    obs = self._get_obs(data, jp.zeros(self.sys.nu))
    return EnvState(data, obs, jp.zeros(1), jp.zeros(1), {})

  def step(self, state: EnvState, action: jp.ndarray) -> EnvState:
    """Runs one timestep of the environment's dynamics."""
    data0 = state.pipeline_state
    data = self.pipeline_step(data0, action)
    obs = self._get_obs(data, action)
    return state.replace(
        pipeline_state=data, obs=obs, reward=jp.zeros(1), done=jp.zeros(1)
    )
  
  def _get_obs(
      self, data: mjx.Data, action: jp.ndarray
  ) -> jp.ndarray:
    """Observes humanoid body position, velocities, and angles."""
    position = data.qpos
    return jp.concatenate([position, action])



# envs.register_environment('humanoid', Humanoid)

In [5]:
# venv
from mjx_env import VmapWrapper, EpisodeWrapper
env = RoboPianist()
env = EpisodeWrapper(env, episode_length=1000, action_repeat=2)
env = VmapWrapper(env)

env_key = jax.random.PRNGKey(42)
num_envs = 100
local_devices_to_use = jax.local_device_count()
env_keys = jax.random.split(env_key, num_envs // jax.process_count())
#   env_keys = jnp.reshape(env_keys,
#                          (local_devices_to_use, -1) + env_keys.shape[1:])
#   env_state = jax.pmap(env.reset)(env_keys)
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
env_state = jit_reset(env_keys)
rollout = [env_state.pipeline_state]
for i in range(100):
  env_key, action_key = jax.random.split(env_key, 2)

  action = jax.random.uniform(action_key, (*env_keys.shape[:1], env.action_size))
  env_state = jit_step(env_state, action)
  rollout.append(env_state.pipeline_state)
  print(env_state.reward.shape)
  # if jnp.all(env_state.done):
  #   env_state = jax.pmap(env.reset)(env_keys) # TODO: check how to reset only the done envs
# import ipdb; ipdb.set_trace()
# video = env.render(rollout, camera='side')

(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
(100, 1)
