From 737c40ed0709ae44af2cf2ac10c3dfc17f877871 Mon Sep 17 00:00:00 2001 From: Sotetsu KOYAMADA Date: Mon, 21 Aug 2023 17:18:27 +0900 Subject: [PATCH] v1 --- README.md | 2 +- docs/sparrow_mahjong.md | 6 +- pgx/sparrow_mahjong.py | 4 +- tests/test_sparrow_mahjong.py | 124 +++++++++++++++++----------------- 4 files changed, 68 insertions(+), 68 deletions(-) diff --git a/README.md b/README.md index a2c938ca1..957dda499 100644 --- a/README.md +++ b/README.md @@ -114,7 +114,7 @@ Use `pgx.available_envs() -> Tuple[EnvId]` to see the list of currently availabl |MinAtar/SpaceInvaders
`"minatar-space_invaders"` || `v0` | *Alien shooter game, dodge bullets.* | |Othello
`"othello"` || `v0` | *Flip and conquer opponent's pieces.* | |Shogi
`"shogi"` | | `v0` | *Japanese chess with captured pieces.* | -|Sparrow Mahjong
`"sparrow_mahjong"` || `v0` | *A simplified, children-friendly Mahjong.* | +|Sparrow Mahjong
`"sparrow_mahjong"` || `v1` | *A simplified, children-friendly Mahjong.* | |Tic-tac-toe
`"tic_tac_toe"` || `v0` | *Three in a row wins.* | - Mahjong environments are under development 🚧 If you have any requests for new environments, please let us know by [opening an issue](https://github.com/sotetsuk/pgx/issues/new) diff --git a/docs/sparrow_mahjong.md b/docs/sparrow_mahjong.md index 5ac6276ee..f426d774f 100644 --- a/docs/sparrow_mahjong.md +++ b/docs/sparrow_mahjong.md @@ -65,10 +65,10 @@ Pgx implementation is simplified as follows: | Name | Value | |:---|:----:| -| Version | `v0` | +| Version | `v1` | | Number of players | `3` | | Number of actions | `11` | -| Observation shape | `(15, 11)` | +| Observation shape | `(11, 15)` | | Observation type | `bool` | | Rewards | `[-1, 1]` | @@ -97,4 +97,4 @@ Terminates when either player wins or the wall becomes empty. ## Version History -- `v0` : Initial release (v1.0.0) \ No newline at end of file +- `v1` : Initial release (v1.0.0) \ No newline at end of file diff --git a/pgx/sparrow_mahjong.py b/pgx/sparrow_mahjong.py index 7bf655ce1..d96b8b7da 100644 --- a/pgx/sparrow_mahjong.py +++ b/pgx/sparrow_mahjong.py @@ -162,7 +162,7 @@ def id(self) -> v1.EnvId: @property def version(self) -> str: - return "v0" + return "v1" @property def num_players(self) -> int: @@ -498,7 +498,7 @@ def _observe(state: State, player_id: jnp.ndarray): ), lambda: obs, ) - return obs + return jnp.transpose(obs) def _tile_type_to_str(tile_type) -> str: diff --git a/tests/test_sparrow_mahjong.py b/tests/test_sparrow_mahjong.py index 4e6699a43..48fc9204c 100644 --- a/tests/test_sparrow_mahjong.py +++ b/tests/test_sparrow_mahjong.py @@ -500,17 +500,17 @@ def test_observe(): state = step(state, jnp.int32(1)) print(_to_str(state)) obs = observe(state, player_id=jnp.int8(2)) - assert obs.shape[0] == 15 - assert obs.shape[1] == 11 - assert jnp.all(obs[0] == jnp.bool_([0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0])) - assert jnp.all(obs[1] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0])) - assert jnp.all(obs[2] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])) - assert jnp.all(obs[3] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])) - assert jnp.all(obs[4] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0])) - assert jnp.all(obs[5] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1])) - assert jnp.all(obs[6] == jnp.bool_([0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0])) - assert jnp.all(obs[7] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])) - assert jnp.all(obs[8] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])) + assert obs.shape[0] == 11 + assert obs.shape[1] == 15 + assert jnp.all(obs[:, 0] == jnp.bool_([0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0])) + assert jnp.all(obs[:, 1] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0])) + assert jnp.all(obs[:, 2] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])) + assert jnp.all(obs[:, 3] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])) + assert jnp.all(obs[:, 4] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0])) + assert jnp.all(obs[:, 5] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1])) + assert jnp.all(obs[:, 6] == jnp.bool_([0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0])) + assert jnp.all(obs[:, 7] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])) + assert jnp.all(obs[:, 8] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])) seed = 5 key = jax.random.PRNGKey(seed) @@ -529,59 +529,59 @@ def test_observe(): [1] 1 2*3 4 5 : r*_ _ _ _ _ _ _ _ _ """ obs = observe(state, player_id=jnp.int8(0)) - assert obs.shape[0] == 15 - assert obs.shape[1] == 11 - assert jnp.all(obs[0] == jnp.bool_([1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0])) - assert jnp.all(obs[1] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])) - assert jnp.all(obs[2] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])) - assert jnp.all(obs[3] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])) - assert jnp.all(obs[4] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])) - assert jnp.all(obs[5] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0])) - assert jnp.all(obs[6] == jnp.bool_([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1])) - assert jnp.all(obs[7] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0])) - assert jnp.all(obs[8] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1])) - assert jnp.all(obs[9] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0])) - assert jnp.all(obs[10] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0])) - assert jnp.all(obs[11] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])) - assert jnp.all(obs[12] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1])) - assert jnp.all(obs[13] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])) - assert jnp.all(obs[14] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])) + assert obs.shape[0] == 11 + assert obs.shape[1] == 15 + assert jnp.all(obs[:, 0] == jnp.bool_([1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0])) + assert jnp.all(obs[:, 1] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])) + assert jnp.all(obs[:, 2] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])) + assert jnp.all(obs[:, 3] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])) + assert jnp.all(obs[:, 4] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])) + assert jnp.all(obs[:, 5] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0])) + assert jnp.all(obs[:, 6] == jnp.bool_([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1])) + assert jnp.all(obs[:, 7] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0])) + assert jnp.all(obs[:, 8] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1])) + assert jnp.all(obs[:, 9] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0])) + assert jnp.all(obs[:, 10] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0])) + assert jnp.all(obs[:, 11] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])) + assert jnp.all(obs[:, 12] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1])) + assert jnp.all(obs[:, 13] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])) + assert jnp.all(obs[:, 14] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])) obs = observe(state, player_id=jnp.int8(1)) - assert obs.shape[0] == 15 - assert obs.shape[1] == 11 - assert jnp.all(obs[0] == jnp.bool_([1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0])) - assert jnp.all(obs[1] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])) - assert jnp.all(obs[2] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])) - assert jnp.all(obs[3] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])) - assert jnp.all(obs[4] == jnp.bool_([0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0])) - assert jnp.all(obs[5] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0])) - assert jnp.all(obs[6] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1])) - assert jnp.all(obs[7] == jnp.bool_([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1])) - assert jnp.all(obs[8] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0])) - assert jnp.all(obs[9] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1])) - assert jnp.all(obs[10] == jnp.bool_([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0])) - assert jnp.all(obs[11] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])) - assert jnp.all(obs[12] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0])) - assert jnp.all(obs[13] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0])) - assert jnp.all(obs[14] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])) + assert obs.shape[0] == 11 + assert obs.shape[1] == 15 + assert jnp.all(obs[:, 0] == jnp.bool_([1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0])) + assert jnp.all(obs[:, 1] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])) + assert jnp.all(obs[:, 2] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])) + assert jnp.all(obs[:, 3] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])) + assert jnp.all(obs[:, 4] == jnp.bool_([0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0])) + assert jnp.all(obs[:, 5] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0])) + assert jnp.all(obs[:, 6] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1])) + assert jnp.all(obs[:, 7] == jnp.bool_([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1])) + assert jnp.all(obs[:, 8] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0])) + assert jnp.all(obs[:, 9] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1])) + assert jnp.all(obs[:, 10] == jnp.bool_([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0])) + assert jnp.all(obs[:, 11] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])) + assert jnp.all(obs[:, 12] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0])) + assert jnp.all(obs[:, 13] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0])) + assert jnp.all(obs[:, 14] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])) obs = observe(state, player_id=jnp.int8(2)) - assert obs.shape[0] == 15 - assert obs.shape[1] == 11 - assert jnp.all(obs[0] == jnp.bool_([0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0])) - assert jnp.all(obs[1] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0])) - assert jnp.all(obs[2] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])) - assert jnp.all(obs[3] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])) - assert jnp.all(obs[4] == jnp.bool_([0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0])) - assert jnp.all(obs[5] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0])) - assert jnp.all(obs[6] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0])) - assert jnp.all(obs[7] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1])) - assert jnp.all(obs[8] == jnp.bool_([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1])) - assert jnp.all(obs[9] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1])) - assert jnp.all(obs[10] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])) - assert jnp.all(obs[11] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])) - assert jnp.all(obs[12] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1])) - assert jnp.all(obs[13] == jnp.bool_([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0])) - assert jnp.all(obs[14] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])) + assert obs.shape[0] == 11 + assert obs.shape[1] == 15 + assert jnp.all(obs[:, 0] == jnp.bool_([0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0])) + assert jnp.all(obs[:, 1] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0])) + assert jnp.all(obs[:, 2] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])) + assert jnp.all(obs[:, 3] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])) + assert jnp.all(obs[:, 4] == jnp.bool_([0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0])) + assert jnp.all(obs[:, 5] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0])) + assert jnp.all(obs[:, 6] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0])) + assert jnp.all(obs[:, 7] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1])) + assert jnp.all(obs[:, 8] == jnp.bool_([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1])) + assert jnp.all(obs[:, 9] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1])) + assert jnp.all(obs[:, 10] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])) + assert jnp.all(obs[:, 11] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])) + assert jnp.all(obs[:, 12] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1])) + assert jnp.all(obs[:, 13] == jnp.bool_([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0])) + assert jnp.all(obs[:, 14] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])) def test_api():