Skip to content

Commit

Permalink
Merge 9496d0f into fc115c5
Browse files Browse the repository at this point in the history
  • Loading branch information
Stéphane Caron committed Mar 6, 2023
2 parents fc115c5 + 9496d0f commit 69cdd88
Show file tree
Hide file tree
Showing 14 changed files with 55 additions and 58 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,13 @@ All notable changes to this project will be documented in this file.

### Added

- Add ``upkie_locomotion.envs.register`` function
- PPO balancer: setting for total number of training timesteps

### Changed

- UpkieWheelsEnv: remove dependency on gin

## [0.2.0] - 2023/03/03

### Added
Expand Down
5 changes: 2 additions & 3 deletions agents/ppo_balancer/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
import gin
import mpacklog
from loop_rate_limiters import AsyncRateLimiter
from settings import Settings
from stable_baselines3 import PPO
from upkie_locomotion.envs import UpkieWheelsEnv

from settings import Settings
from upkie_locomotion.envs import UpkieWheelsEnv

keep_going = True

Expand Down Expand Up @@ -60,7 +60,6 @@ def load_policy(agent_dir: str, policy_name: str):

if __name__ == "__main__":
agent_dir = os.path.abspath(os.path.dirname(__file__))
gin.parse_config_file(UpkieWheelsEnv.gin_config())
gin.parse_config_file(f"{agent_dir}/settings.gin")

parser = argparse.ArgumentParser(description=__doc__)
Expand Down
1 change: 0 additions & 1 deletion agents/ppo_balancer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ def get_bullet_argv(agent_name: str, show: bool) -> List[str]:
args = parser.parse_args()

agent_dir = os.path.dirname(__file__)
gin.parse_config_file(UpkieWheelsEnv.gin_config())
gin.parse_config_file(f"{agent_dir}/settings.gin")

agent_name = generate_agent_name()
Expand Down
1 change: 0 additions & 1 deletion envs/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ py_library(
],
data = [
"spine.yaml",
"upkie_wheels_env.gin",
],
deps = [
"@upkie_locomotion//observers/base_pitch",
Expand Down
12 changes: 12 additions & 0 deletions envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import gym

from .upkie_wheels_env import UpkieWheelsEnv


def register():
gym.envs.register(
id=f"UpkieWheelsEnv-v{UpkieWheelsEnv.version}",
entry_point="upkie_locomotion.envs:UpkieWheelsEnv",
max_episode_steps=1_000_000_000,
)


__all__ = [
"UpkieWheelsEnv",
"register",
]
2 changes: 0 additions & 2 deletions envs/tests/upkie_wheels_env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import unittest

import gin
import numpy as np
import posix_ipc

Expand Down Expand Up @@ -61,7 +60,6 @@ def get_observation(self) -> dict:

class TestUpkieWheelsEnv(unittest.TestCase):
def setUp(self):
gin.parse_config_file(UpkieWheelsEnv.gin_config())
shm_name = "/vroum"
shared_memory = posix_ipc.SharedMemory(
shm_name, posix_ipc.O_RDWR | posix_ipc.O_CREAT, size=42
Expand Down
17 changes: 0 additions & 17 deletions envs/upkie_wheels_env.gin

This file was deleted.

30 changes: 5 additions & 25 deletions envs/upkie_wheels_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from os import path
from typing import Dict, Optional, Tuple, Union

import gin
import gym
import numpy as np
import yaml
Expand All @@ -36,7 +35,6 @@
MAX_IMU_ANGULAR_VELOCITY: float = 1000.0 # rad/s


@gin.configurable
class UpkieWheelsEnv(gym.Env):

"""!
Expand Down Expand Up @@ -109,22 +107,13 @@ class UpkieWheelsEnv(gym.Env):
for joint in ("hip", "knee")
]

def id(self) -> str:
"""!
Name and version of this environment for registration.
Returns:
Name and version of the environment.
"""
return f"UpkieWheelsEnv-v{self.version}"

def __init__(
self,
config: Optional[dict],
fall_pitch: float,
max_ground_velocity: float,
shm_name: str,
wheel_radius: float,
config: Optional[dict] = None,
fall_pitch: float = 1.0,
max_ground_velocity: float = 1.0,
shm_name: str = "/vulp",
wheel_radius: float = 0.06,
):
"""!
Initialize environment.
Expand Down Expand Up @@ -210,15 +199,6 @@ def _observe(self) -> dict:
self.last_observation = observation_dict
return observation_dict

@staticmethod
def gin_config():
"""!
Path to the Gin configuration for this environment.
"""
dirname = path.dirname(__file__)
basename = path.basename(__file__).replace(".py", ".gin")
return f"{dirname}/{basename}"

def vectorize_observation(self, observation_dict: dict) -> np.ndarray:
"""!
Extract observation vector from a full observation dictionary.
Expand Down
12 changes: 5 additions & 7 deletions envs/upkie_wheels_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@

from typing import Tuple

import gin
import numpy as np


@gin.configurable
class UpkieWheelsReward:

"""!
Expand Down Expand Up @@ -42,11 +40,11 @@ def get_range() -> Tuple[float, float]:

def __init__(
self,
lookahead_duration: float,
max_pitch: float,
max_position: float,
pitch_weight: float,
position_weight: float,
lookahead_duration: float = 0.1,
max_pitch: float = 1.5707963267948966,
max_position: float = 0.5,
pitch_weight: float = 1.0,
position_weight: float = 1.0,
):
"""!
Initialize reward.
Expand Down
24 changes: 24 additions & 0 deletions examples/upkie_wheels_env_simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright 2023 Inria

import gym
import numpy as np

import upkie_locomotion.envs

if __name__ == "__main__":
upkie_locomotion.envs.register()
env = gym.make("UpkieWheelsEnv-v1")
observation = env.reset(seed=42)

action = np.zeros(env.action_space.shape)
for step in range(1_000_000):
observation, reward, done, _ = env.step(action)
if done:
observation = env.reset()
pitch = observation[0]
action[0] = 10.0 * pitch

env.close()
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion pypi/tmpflit.sh → tools/pypi/tmpflit.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ fi

BASEDIR=$(dirname 0)
COMMAND=$@
SRCDIR=${BASEDIR}/..
SRCDIR=${BASEDIR}/../..
TMPDIR=$(mktemp -d)

echo "[debug] COMMAND=${COMMAND}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@

"""Real-time motion control for Python."""

__version__ = "0.3.0rc0"
__version__ = "0.3.0rc2"

0 comments on commit 69cdd88

Please sign in to comment.