diff --git a/README.md b/README.md
index a2c938ca1..957dda499 100644
--- a/README.md
+++ b/README.md
@@ -114,7 +114,7 @@ Use `pgx.available_envs() -> Tuple[EnvId]` to see the list of currently availabl
|MinAtar/SpaceInvaders
`"minatar-space_invaders"` || `v0` | *Alien shooter game, dodge bullets.* |
|Othello
`"othello"` || `v0` | *Flip and conquer opponent's pieces.* |
|Shogi
`"shogi"` | | `v0` | *Japanese chess with captured pieces.* |
-|Sparrow Mahjong
`"sparrow_mahjong"` || `v0` | *A simplified, children-friendly Mahjong.* |
+|Sparrow Mahjong
`"sparrow_mahjong"` || `v1` | *A simplified, children-friendly Mahjong.* |
|Tic-tac-toe
`"tic_tac_toe"` || `v0` | *Three in a row wins.* |
- Mahjong environments are under development 🚧 If you have any requests for new environments, please let us know by [opening an issue](https://github.com/sotetsuk/pgx/issues/new)
diff --git a/docs/sparrow_mahjong.md b/docs/sparrow_mahjong.md
index 5ac6276ee..f426d774f 100644
--- a/docs/sparrow_mahjong.md
+++ b/docs/sparrow_mahjong.md
@@ -65,10 +65,10 @@ Pgx implementation is simplified as follows:
| Name | Value |
|:---|:----:|
-| Version | `v0` |
+| Version | `v1` |
| Number of players | `3` |
| Number of actions | `11` |
-| Observation shape | `(15, 11)` |
+| Observation shape | `(11, 15)` |
| Observation type | `bool` |
| Rewards | `[-1, 1]` |
@@ -97,4 +97,4 @@ Terminates when either player wins or the wall becomes empty.
## Version History
-- `v0` : Initial release (v1.0.0)
\ No newline at end of file
+- `v1` : Initial release (v1.0.0)
\ No newline at end of file
diff --git a/pgx/sparrow_mahjong.py b/pgx/sparrow_mahjong.py
index 7bf655ce1..d96b8b7da 100644
--- a/pgx/sparrow_mahjong.py
+++ b/pgx/sparrow_mahjong.py
@@ -162,7 +162,7 @@ def id(self) -> v1.EnvId:
@property
def version(self) -> str:
- return "v0"
+ return "v1"
@property
def num_players(self) -> int:
@@ -498,7 +498,7 @@ def _observe(state: State, player_id: jnp.ndarray):
),
lambda: obs,
)
- return obs
+ return jnp.transpose(obs)
def _tile_type_to_str(tile_type) -> str:
diff --git a/tests/test_sparrow_mahjong.py b/tests/test_sparrow_mahjong.py
index 4e6699a43..48fc9204c 100644
--- a/tests/test_sparrow_mahjong.py
+++ b/tests/test_sparrow_mahjong.py
@@ -500,17 +500,17 @@ def test_observe():
state = step(state, jnp.int32(1))
print(_to_str(state))
obs = observe(state, player_id=jnp.int8(2))
- assert obs.shape[0] == 15
- assert obs.shape[1] == 11
- assert jnp.all(obs[0] == jnp.bool_([0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0]))
- assert jnp.all(obs[1] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
- assert jnp.all(obs[2] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
- assert jnp.all(obs[3] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
- assert jnp.all(obs[4] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0]))
- assert jnp.all(obs[5] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]))
- assert jnp.all(obs[6] == jnp.bool_([0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
- assert jnp.all(obs[7] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
- assert jnp.all(obs[8] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
+ assert obs.shape[0] == 11
+ assert obs.shape[1] == 15
+ assert jnp.all(obs[:, 0] == jnp.bool_([0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0]))
+ assert jnp.all(obs[:, 1] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
+ assert jnp.all(obs[:, 2] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
+ assert jnp.all(obs[:, 3] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
+ assert jnp.all(obs[:, 4] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0]))
+ assert jnp.all(obs[:, 5] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]))
+ assert jnp.all(obs[:, 6] == jnp.bool_([0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
+ assert jnp.all(obs[:, 7] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
+ assert jnp.all(obs[:, 8] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
seed = 5
key = jax.random.PRNGKey(seed)
@@ -529,59 +529,59 @@ def test_observe():
[1] 1 2*3 4 5 : r*_ _ _ _ _ _ _ _ _
"""
obs = observe(state, player_id=jnp.int8(0))
- assert obs.shape[0] == 15
- assert obs.shape[1] == 11
- assert jnp.all(obs[0] == jnp.bool_([1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0]))
- assert jnp.all(obs[1] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
- assert jnp.all(obs[2] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
- assert jnp.all(obs[3] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
- assert jnp.all(obs[4] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
- assert jnp.all(obs[5] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
- assert jnp.all(obs[6] == jnp.bool_([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1]))
- assert jnp.all(obs[7] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
- assert jnp.all(obs[8] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]))
- assert jnp.all(obs[9] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
- assert jnp.all(obs[10] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
- assert jnp.all(obs[11] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
- assert jnp.all(obs[12] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]))
- assert jnp.all(obs[13] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
- assert jnp.all(obs[14] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
+ assert obs.shape[0] == 11
+ assert obs.shape[1] == 15
+ assert jnp.all(obs[:, 0] == jnp.bool_([1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0]))
+ assert jnp.all(obs[:, 1] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
+ assert jnp.all(obs[:, 2] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
+ assert jnp.all(obs[:, 3] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
+ assert jnp.all(obs[:, 4] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
+ assert jnp.all(obs[:, 5] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
+ assert jnp.all(obs[:, 6] == jnp.bool_([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1]))
+ assert jnp.all(obs[:, 7] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
+ assert jnp.all(obs[:, 8] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]))
+ assert jnp.all(obs[:, 9] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
+ assert jnp.all(obs[:, 10] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
+ assert jnp.all(obs[:, 11] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
+ assert jnp.all(obs[:, 12] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]))
+ assert jnp.all(obs[:, 13] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
+ assert jnp.all(obs[:, 14] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
obs = observe(state, player_id=jnp.int8(1))
- assert obs.shape[0] == 15
- assert obs.shape[1] == 11
- assert jnp.all(obs[0] == jnp.bool_([1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]))
- assert jnp.all(obs[1] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
- assert jnp.all(obs[2] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
- assert jnp.all(obs[3] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
- assert jnp.all(obs[4] == jnp.bool_([0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
- assert jnp.all(obs[5] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
- assert jnp.all(obs[6] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]))
- assert jnp.all(obs[7] == jnp.bool_([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1]))
- assert jnp.all(obs[8] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
- assert jnp.all(obs[9] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]))
- assert jnp.all(obs[10] == jnp.bool_([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]))
- assert jnp.all(obs[11] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
- assert jnp.all(obs[12] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
- assert jnp.all(obs[13] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
- assert jnp.all(obs[14] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
+ assert obs.shape[0] == 11
+ assert obs.shape[1] == 15
+ assert jnp.all(obs[:, 0] == jnp.bool_([1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]))
+ assert jnp.all(obs[:, 1] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
+ assert jnp.all(obs[:, 2] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
+ assert jnp.all(obs[:, 3] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
+ assert jnp.all(obs[:, 4] == jnp.bool_([0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
+ assert jnp.all(obs[:, 5] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
+ assert jnp.all(obs[:, 6] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]))
+ assert jnp.all(obs[:, 7] == jnp.bool_([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1]))
+ assert jnp.all(obs[:, 8] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
+ assert jnp.all(obs[:, 9] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]))
+ assert jnp.all(obs[:, 10] == jnp.bool_([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]))
+ assert jnp.all(obs[:, 11] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
+ assert jnp.all(obs[:, 12] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
+ assert jnp.all(obs[:, 13] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
+ assert jnp.all(obs[:, 14] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
obs = observe(state, player_id=jnp.int8(2))
- assert obs.shape[0] == 15
- assert obs.shape[1] == 11
- assert jnp.all(obs[0] == jnp.bool_([0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0]))
- assert jnp.all(obs[1] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]))
- assert jnp.all(obs[2] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
- assert jnp.all(obs[3] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
- assert jnp.all(obs[4] == jnp.bool_([0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0]))
- assert jnp.all(obs[5] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
- assert jnp.all(obs[6] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
- assert jnp.all(obs[7] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]))
- assert jnp.all(obs[8] == jnp.bool_([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1]))
- assert jnp.all(obs[9] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]))
- assert jnp.all(obs[10] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
- assert jnp.all(obs[11] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
- assert jnp.all(obs[12] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]))
- assert jnp.all(obs[13] == jnp.bool_([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]))
- assert jnp.all(obs[14] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
+ assert obs.shape[0] == 11
+ assert obs.shape[1] == 15
+ assert jnp.all(obs[:, 0] == jnp.bool_([0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0]))
+ assert jnp.all(obs[:, 1] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]))
+ assert jnp.all(obs[:, 2] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
+ assert jnp.all(obs[:, 3] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
+ assert jnp.all(obs[:, 4] == jnp.bool_([0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0]))
+ assert jnp.all(obs[:, 5] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
+ assert jnp.all(obs[:, 6] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
+ assert jnp.all(obs[:, 7] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]))
+ assert jnp.all(obs[:, 8] == jnp.bool_([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1]))
+ assert jnp.all(obs[:, 9] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]))
+ assert jnp.all(obs[:, 10] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
+ assert jnp.all(obs[:, 11] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
+ assert jnp.all(obs[:, 12] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]))
+ assert jnp.all(obs[:, 13] == jnp.bool_([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]))
+ assert jnp.all(obs[:, 14] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
def test_api():