diff --git a/pgx/_mahjong/_mahjong.py b/pgx/_mahjong/_mahjong.py deleted file mode 100644 index b1f1d077c..000000000 --- a/pgx/_mahjong/_mahjong.py +++ /dev/null @@ -1,359 +0,0 @@ -# Copyright 2023 The Pgx Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import jax -import jax.numpy as jnp - -import pgx.v1 as v1 -from pgx._mahjong._action import Action -from pgx._mahjong._hand import Hand -from pgx._mahjong._meld import Meld -from pgx._src.struct import dataclass - -FALSE = jnp.bool_(False) -TRUE = jnp.bool_(True) -NUM_ACTION = 78 - - -@dataclass -class State(v1.State): - current_player: jnp.ndarray = jnp.int8(0) # actionを行うplayer - observation: jnp.ndarray = jnp.int8(0) - rewards: jnp.ndarray = jnp.float32([0.0, 0.0, 0.0, 0.0]) - terminated: jnp.ndarray = FALSE - truncated: jnp.ndarray = FALSE - legal_action_mask: jnp.ndarray = jnp.zeros(NUM_ACTION, dtype=jnp.bool_) - _rng_key: jax.random.KeyArray = jax.random.PRNGKey(0) - _step_count: jnp.ndarray = jnp.int32(0) - # --- Mahjong specific --- - deck: jnp.ndarray = jnp.zeros(136, dtype=jnp.int8) - - # 次に引く牌のindex - next_deck_ix: jnp.ndarray = jnp.int8(135) - - # 各プレイヤーの手牌. 長さ34で、数字は持っている牌の数 - hand: jnp.ndarray = jnp.zeros((4, 34), dtype=jnp.int8) - - doras: jnp.ndarray = jnp.zeros(4, dtype=jnp.int8) - - num_kan: jnp.ndarray = jnp.int8(0) - - # 直前に捨てられてron,pon,chi,kanの対象になっている牌. 存在しなければ-1 - target: jnp.ndarray = jnp.int8(0) - - # 手牌が3n+2枚のplayerが直前に引いた牌. 存在しなければ-1 - last_draw: jnp.ndarray = jnp.int8(0) - - # 最後にプレイしたプレイヤ(ron,pon,chi,kanの対象). - last_player: jnp.ndarray = jnp.int8(0) - - # state.current_player がリーチ宣言してから, その直後の打牌が通るまでTrue - riichi_declared: jnp.ndarray = FALSE - - # 各playerのリーチが成立しているかどうか - riichi: jnp.ndarray = jnp.zeros(4, dtype=jnp.bool_) - - # 各playerの副露回数 - n_meld: jnp.ndarray = jnp.zeros(4, dtype=jnp.int8) - - # melds[i][j]: player i のj回目の副露(j=1,2,3,4). 存在しなければ0 - melds: jnp.ndarray = jnp.zeros((4, 4), dtype=jnp.int32) - - is_menzen: jnp.ndarray = jnp.zeros(4, dtype=jnp.bool_) - - # pon[i][j]: player i がjをポンを所有している場合, src << 2 | index. or 0 - pon: jnp.ndarray = jnp.zeros((4, 34), dtype=jnp.int32) - - @property - def env_id(self) -> v1.EnvId: - # TODO add envid - return "mahjong" # type:ignore - - -class Mahjong(v1.Env): - def __init__(self): - super().__init__() - - def _init(self, key: jax.random.KeyArray) -> State: - return _init(key) - - def _step(self, state: v1.State, action: jnp.ndarray) -> State: - assert isinstance(state, State) - return _step(state, action, player=state.current_player) - - def _observe(self, state: v1.State, player_id: jnp.ndarray) -> jnp.ndarray: - assert isinstance(state, State) - return _observe(state, player_id) - - @property - def id(self) -> v1.EnvId: - # TODO add envid - return "mahjong" # type:ignore - - @property - def version(self) -> str: - return "alpha" - - @property - def num_players(self) -> int: - return 4 - - -def _init(rng: jax.random.KeyArray) -> State: - rng, subkey = jax.random.split(rng) - current_player = jnp.int8(jax.random.bernoulli(subkey)) - last_player = jnp.int8(-1) - deck = jax.random.permutation(rng, jnp.arange(136) // 4) - init_hand = Hand.make_init_hand(deck) - state = State( # type:ignore - current_player=current_player, - last_player=last_player, - deck=deck, - hand=init_hand, - next_deck_ix=135 - 13 * 4, - ) - return _draw(state) - - -def _step(state: State, action: jnp.ndarray, player: jnp.ndarray) -> State: - # TODO - # - Actionの処理 - # - meld - # - riichi - # - ron, tsumo - # - 勝利条件確認 - # - playerどうするか - # - lax.switch使った方が良いんだろうけど、簡単なうちはcondで分岐させる - - state = state.replace(current_player=player) # type:ignore - - discard = (action < 34) | (action == 68) - self_kan = ~discard & (action < 68) - state = jax.lax.cond( - discard, - lambda: _discard(state, action), - lambda: state, - ) - state = jax.lax.cond( - self_kan, - lambda: _selfkan(state, action), - lambda: state, - ) - state = jax.lax.cond( - (~discard) & (~self_kan), - lambda: jax.lax.switch( - action - 68, - [ - lambda: _discard(state, action), - lambda: _riichi(state), - lambda: _tsumo(state), - lambda: _ron(state), - lambda: _pon(state), - lambda: _minkan(state), - lambda: _chi(state, player, Action.CHI_L), - lambda: _chi(state, player, Action.CHI_M), - lambda: _chi(state, player, Action.CHI_R), - lambda: _draw(state), - ], - ), - lambda: state, - ) - - return state - - -def _draw(state: State): - new_tile = state.deck[state.next_deck_ix] - next_deck_ix = state.next_deck_ix - 1 - hand = state.hand.at[state.current_player].set( - Hand.add(state.hand[state.current_player], new_tile) - ) - - legal_action_mask = jnp.zeros(NUM_ACTION, dtype=jnp.bool_) - legal_action_mask = legal_action_mask.at[:34].set( - hand[state.current_player] > 0 - ) - legal_action_mask = legal_action_mask.at[new_tile].set(FALSE) - legal_action_mask = legal_action_mask.at[Action.TSUMOGIRI].set(TRUE) - - return state.replace( # type:ignore - current_player=state.current_player, - next_deck_ix=next_deck_ix, - hand=hand, - last_draw=new_tile, - legal_action_mask=legal_action_mask, - ) - - -def _discard(state: State, tile: jnp.ndarray): - tile = jax.lax.select(tile == 68, state.last_draw, tile) - hand = state.hand.at[state.current_player].set( - Hand.sub(state.hand[state.current_player], tile) - ) - state = state.replace( # type:ignore - target=jnp.int8(tile), last_draw=-1, hand=hand - ) - - return _draw(state) - - -def _append_meld(state: State, meld, player): - melds = state.melds.at[(player, state.n_meld[player])].set(meld) - n_meld = state.n_meld.at[player].add(1) - return state.replace(melds=melds, n_meld=n_meld) # type:ignore - - -def _selfkan(state: State, action): - target = action - 34 - pon = state.pon[(state.current_player, target)] - state = jax.lax.cond( - pon == 0, - lambda: _ankan(state, target), - lambda: _kakan(state, target, pon >> 2, pon & 0b11), - ) - - # 嶺上牌 - rinshan_tile = state.deck[state.next_deck_ix] - next_deck_ix = state.next_deck_ix - 1 - hand = state.hand.at[state.current_player].set( - Hand.add(state.hand[state.current_player], rinshan_tile) - ) - return state.replace( # type:ignore - next_deck_ix=next_deck_ix, last_draw=rinshan_tile, hand=hand - ) - - -def _ankan(state: State, target): - curr_player = state.current_player - meld = Meld.init(target + 34, target, src=0) - state = _append_meld(state, meld, curr_player) - hand = state.hand.at[curr_player].set( - Hand.ankan(state.hand[curr_player], target) - ) - # TODO: 国士無双ロンの受付 - - return state.replace(hand=hand) # type:ignore - - -def _kakan(state: State, target, pon_src, pon_idx): - melds = state.melds.at[(state.current_player, pon_idx)].set( - Meld.init(target + 34, target, pon_src) - ) - hand = state.hand.at[state.current_player].set( - Hand.kakan(state.hand[state.current_player], target) - ) - pon = state.pon.at[(state.current_player, target)].set(0) - # TODO: 槍槓の受付 - - return state.replace(melds=melds, hand=hand, pon=pon) # type:ignore - - -def _accept_riichi(state: State) -> State: - riichi = state.riichi.at[state.current_player].set( - state.riichi[state.current_player] | state.riichi_declared - ) - return state.replace(riichi=riichi, riichi_declared=FALSE) # type:ignore - - -def _minkan(state: State): - state = _accept_riichi(state) - src = (state.last_player - state.current_player) % 4 - meld = Meld.init(Action.MINKAN, state.target, src) - state = _append_meld(state, meld, state.current_player) - hand = state.hand.at[state.current_player].set( - Hand.minkan(state.hand[state.current_player], state.target) - ) - is_menzen = state.is_menzen.at[state.current_player].set(FALSE) - - rinshan_tile = state.deck[state.next_deck_ix] - next_deck_ix = state.next_deck_ix - 1 - hand = hand.at[state.current_player].set( - Hand.add(state.hand[state.current_player], rinshan_tile) - ) - return state.replace( # type:ignore - target=jnp.int8(-1), - is_menzen=is_menzen, - next_deck_ix=next_deck_ix, - last_draw=rinshan_tile, - hand=hand, - ) - - -def _pon(state: State): - state = _accept_riichi(state) - src = (state.last_player - state.current_player) % 4 - meld = Meld.init(Action.PON, state.target, src) - state = _append_meld(state, meld, state.current_player) - hand = state.hand.at[state.current_player].set( - Hand.pon(state.hand[state.current_player], state.target) - ) - is_menzen = state.is_menzen.at[state.current_player].set(FALSE) - pon = state.pon.at[(state.current_player, state.target)].set( - src << 2 | state.n_meld[state.current_player] - 1 - ) - - return state.replace( # type:ignore - target=jnp.int8(-1), - is_menzen=is_menzen, - pon=pon, - hand=hand, - ) - - -def _chi(state: State, player, action: int): - state = _accept_riichi(state) - meld = Meld.init(action, state.target, src=3) - state = _append_meld(state, meld, player) - hand = state.hand.at[player].set( - Hand.chi(state.hand[player], state.target, action) - ) - is_menzen = state.is_menzen.at[player].set(False) - legal_action_mask = jnp.zeros(NUM_ACTION, dtype=jnp.bool_) - legal_action_mask = legal_action_mask.at[:34].set(hand[player] > 0) - - return state.replace( # type:ignore - current_player=player, - target=jnp.int8(-1), - is_menzen=is_menzen, - hand=hand, - legal_action_mask=legal_action_mask, - ) - - -def _pass(state: State): - # pon -> chi - - # ponでpassした場合 - - # chiでpassした場合 - - # kanでpassした場合 - - ... - - -def _riichi(state: State): - ... - - -def _tsumo(state: State): - ... - - -def _ron(state: State): - ... - - -def _observe(state: State, player_id: jnp.ndarray) -> jnp.ndarray: - ... diff --git a/pgx/_mahjong/cache/__init__.py b/pgx/_mahjong/cache/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/pgx/_src/dwg/mahjong.py b/pgx/_src/dwg/mahjong.py index 855954625..af14c6842 100644 --- a/pgx/_src/dwg/mahjong.py +++ b/pgx/_src/dwg/mahjong.py @@ -1,7 +1,7 @@ -from pgx._mahjong._action import Action -from pgx._mahjong._mahjong2 import State as MahjongState -from pgx._mahjong._meld import Meld from pgx._src.dwg.mahjong_tile import TilePath +from pgx.mahjong._action import Action +from pgx.mahjong._env import State as MahjongState +from pgx.mahjong._meld import Meld path_list = TilePath.str_list tile_w = 30 @@ -51,8 +51,8 @@ def _make_mahjong_dwg(dwg, state: MahjongState, config): kanji = ["", "一", "二", "三", "四", "五", "六", "七", "八", "九", "十"] ro = state._round round = f"{wind[ro//4]}{kanji[ro%4+1]}局" - if state.honba > 0: - round += f"{kanji[state.honba]}本場" + if state._honba > 0: + round += f"{kanji[state._honba]}本場" fontsize = 20 y = -25 @@ -73,7 +73,7 @@ def _make_mahjong_dwg(dwg, state: MahjongState, config): dora_scale = 0.6 x = (BOARD_WIDTH * GRID_SIZE) / 2 - tile_w * dora_scale * 2.5 y = (BOARD_WIDTH * GRID_SIZE) / 2 - 15 - for _x, dora in enumerate(state.doras): + for _x, dora in enumerate(state._doras): if dora == -1: dora = 34 p = dwg.path(d=path_list[dora]) @@ -92,7 +92,7 @@ def _make_mahjong_dwg(dwg, state: MahjongState, config): board_g.add(p) board_g.add( dwg.text( - text=f"x {state.next_deck_ix-14+1}", + text=f"x {state._next_deck_ix-14+1}", insert=(x + tile_w * yama_scale + 5, y + tile_h * yama_scale - 5), fill=color_set.text_color, font_size=f"{fontsize}px", @@ -115,7 +115,13 @@ def _make_mahjong_dwg(dwg, state: MahjongState, config): def _make_players_dwg( - dwg, state, i, color_set, BOARD_WIDTH, BOARD_HEIGHT, GRID_SIZE + dwg, + state: MahjongState, + i, + color_set, + BOARD_WIDTH, + BOARD_HEIGHT, + GRID_SIZE, ): players_g = dwg.g( style="stroke:#000000;stroke-width:0.01mm;fill:#000000", @@ -128,7 +134,7 @@ def _make_players_dwg( fontsize = 22 players_g.add( dwg.text( - text=wind[(i - state.oya) % 4], + text=wind[(i - state._oya) % 4], insert=(x, y), fill=color_set.text_color, font_size=f"{fontsize}px", @@ -138,7 +144,7 @@ def _make_players_dwg( # score fontsize = 20 - score = str(int(state.score[i]) * 100) + score = str(int(state._score[i]) * 100) y = 70 players_g.add( dwg.text( @@ -157,7 +163,7 @@ def _make_players_dwg( width = 100 height = 10 y = 75 - if state.riichi[i]: + if state._riichi[i]: players_g.add( dwg.rect( ( @@ -185,7 +191,7 @@ def _make_players_dwg( # hand offset = 0 - hand = state.hand[i] + hand = state._hand[i] for tile, num in enumerate(hand): for _ in range(num): p = dwg.path(d=path_list[tile]) @@ -196,7 +202,7 @@ def _make_players_dwg( offset += tile_w # meld - for meld in state.melds[i]: + for meld in state._melds[i]: if meld == 0: continue if Meld.action(meld) == Action.PON: @@ -218,7 +224,7 @@ def _make_players_dwg( x = BOARD_WIDTH * GRID_SIZE / 2 - 3 * tile_w y = 450 - river = state.river[i] + river = state._river[i] for river_ix, tile in enumerate(river): fill = "black" if (tile >> 7) & 0b1: diff --git a/pgx/mahjong/__init__.py b/pgx/mahjong/__init__.py new file mode 100644 index 000000000..a2fcf891a --- /dev/null +++ b/pgx/mahjong/__init__.py @@ -0,0 +1,3 @@ +from ._env import Mahjong + +__all__ = ["Mahjong"] diff --git a/pgx/_mahjong/_action.py b/pgx/mahjong/_action.py similarity index 100% rename from pgx/_mahjong/_action.py rename to pgx/mahjong/_action.py diff --git a/pgx/_mahjong/_mahjong2.py b/pgx/mahjong/_env.py similarity index 63% rename from pgx/_mahjong/_mahjong2.py rename to pgx/mahjong/_env.py index 51995cdcf..4ada626bd 100644 --- a/pgx/_mahjong/_mahjong2.py +++ b/pgx/mahjong/_env.py @@ -16,11 +16,11 @@ import jax.numpy as jnp import pgx.v1 as v1 -from pgx._mahjong._action import Action -from pgx._mahjong._hand import Hand -from pgx._mahjong._meld import Meld -from pgx._mahjong._yaku import Yaku from pgx._src.struct import dataclass +from pgx.mahjong._action import Action +from pgx.mahjong._hand import Hand +from pgx.mahjong._meld import Meld +from pgx.mahjong._yaku import Yaku FALSE = jnp.bool_(False) TRUE = jnp.bool_(True) @@ -39,71 +39,71 @@ class State(v1.State): _step_count: jnp.ndarray = jnp.int32(0) # --- Mahjong specific --- _round: jnp.ndarray = jnp.int8(0) - honba: jnp.ndarray = jnp.int8(0) + _honba: jnp.ndarray = jnp.int8(0) # 東1局の親 - # 各局の親は (state.oya+round)%4 - oya: jnp.ndarray = jnp.int8(0) + # 各局の親は (state._oya+round)%4 + _oya: jnp.ndarray = jnp.int8(0) # 点数(百の位から) - score: jnp.ndarray = jnp.full(4, 250, dtype=jnp.float32) + _score: jnp.ndarray = jnp.full(4, 250, dtype=jnp.float32) # 嶺上 ドラ カンドラ # ... 13 11 9 7 5 3 1 # ... 12 10 8 6 4 2 0 - deck: jnp.ndarray = jnp.zeros(136, dtype=jnp.int8) + _deck: jnp.ndarray = jnp.zeros(136, dtype=jnp.int8) # 次に引く牌のindex - next_deck_ix: jnp.ndarray = jnp.int8(135 - 13 * 4) + _next_deck_ix: jnp.ndarray = jnp.int8(135 - 13 * 4) # 各playerの手牌. 長さ34で、数字は持っている牌の数 - hand: jnp.ndarray = jnp.zeros((4, 34), dtype=jnp.int8) + _hand: jnp.ndarray = jnp.zeros((4, 34), dtype=jnp.int8) # 河 # int8 # 0b 0 0 0 0 0 0 0 0 # 灰色|リーチ| 牌(0-33) - river: jnp.ndarray = 34 * jnp.ones((4, 4 * 6), dtype=jnp.uint8) + _river: jnp.ndarray = 34 * jnp.ones((4, 4 * 6), dtype=jnp.uint8) # 各playerの河の数 - n_river: jnp.ndarray = jnp.zeros(4, dtype=jnp.int8) + _n_river: jnp.ndarray = jnp.zeros(4, dtype=jnp.int8) # ドラ - doras: jnp.ndarray = jnp.zeros(5, dtype=jnp.int8) + _doras: jnp.ndarray = jnp.zeros(5, dtype=jnp.int8) # カンの回数=追加ドラ枚数 - n_kan: jnp.ndarray = jnp.int8(0) + _n_kan: jnp.ndarray = jnp.int8(0) # 直前に捨てられてron,pon,chiの対象になっている牌. 存在しなければ-1 - target: jnp.ndarray = jnp.int8(-1) + _target: jnp.ndarray = jnp.int8(-1) # 手牌が3n+2枚のplayerが直前に引いた牌. 存在しなければ-1 - last_draw: jnp.ndarray = jnp.int8(0) + _last_draw: jnp.ndarray = jnp.int8(0) # 最後のプレイヤー. ron,pon,chiの対象 - last_player: jnp.ndarray = jnp.int8(0) + _last_player: jnp.ndarray = jnp.int8(0) # 打牌の後に競合する副露が生じた場合用 # pon player, kan player, chi player # b0 00 00 00 - furo_check_num: jnp.ndarray = jnp.uint8(0) + _furo_check_num: jnp.ndarray = jnp.uint8(0) # state.current_player がリーチ宣言してから, その直後の打牌が通るまでTrue - riichi_declared: jnp.ndarray = FALSE + _riichi_declared: jnp.ndarray = FALSE # 各playerのリーチが成立しているかどうか - riichi: jnp.ndarray = jnp.zeros(4, dtype=jnp.bool_) + _riichi: jnp.ndarray = jnp.zeros(4, dtype=jnp.bool_) # 各playerの副露回数 - n_meld: jnp.ndarray = jnp.zeros(4, dtype=jnp.int8) + _n_meld: jnp.ndarray = jnp.zeros(4, dtype=jnp.int8) # melds[i][j]: player i のj回目の副露(j=1,2,3,4). 存在しなければ0 - melds: jnp.ndarray = jnp.zeros((4, 4), dtype=jnp.int32) + _melds: jnp.ndarray = jnp.zeros((4, 4), dtype=jnp.int32) - is_menzen: jnp.ndarray = jnp.zeros(4, dtype=jnp.bool_) + _is_menzen: jnp.ndarray = jnp.zeros(4, dtype=jnp.bool_) # pon[i][j]: player i がjをポンを所有している場合, src << 2 | index. or 0 - pon: jnp.ndarray = jnp.zeros((4, 34), dtype=jnp.int32) + _pon: jnp.ndarray = jnp.zeros((4, 34), dtype=jnp.int32) @property def env_id(self) -> v1.EnvId: @@ -145,26 +145,26 @@ def decode_state(data: dict): _rng_key=jnp.array(data["_rng_key"]), _step_count=jnp.int32(data["_step_count"]), _round=jnp.int8(data["_round"]), - honba=jnp.int8(data["honba"]), - oya=jnp.int8(data["oya"]), - score=jnp.array(data["score"], dtype=jnp.float32), - deck=jnp.array(data["deck"], dtype=jnp.int8), - next_deck_ix=jnp.int8(data["next_deck_ix"]), - hand=jnp.array(data["hand"], dtype=jnp.int8), - river=jnp.array(data["river"], dtype=jnp.uint8), - n_river=jnp.array(data["n_river"], dtype=jnp.int8), - doras=jnp.array(data["doras"], dtype=jnp.int8), - n_kan=jnp.int8(data["n_kan"]), - target=jnp.int8(data["target"]), - last_draw=jnp.int8(data["last_draw"]), - last_player=jnp.int8(data["last_player"]), - furo_check_num=jnp.uint8(data["furo_check_num"]), - riichi_declared=jnp.bool_(data["riichi_declared"]), - riichi=jnp.array(data["riichi"], dtype=jnp.bool_), - n_meld=jnp.array(data["n_meld"], dtype=jnp.int8), - melds=jnp.array(data["melds"], dtype=jnp.int32), - is_menzen=jnp.array(data["is_menzen"], dtype=jnp.bool_), - pon=jnp.array(data["pon"], dtype=jnp.int32), + _honba=jnp.int8(data["_honba"]), + _oya=jnp.int8(data["_oya"]), + _score=jnp.array(data["_score"], dtype=jnp.float32), + _deck=jnp.array(data["_deck"], dtype=jnp.int8), + _next_deck_ix=jnp.int8(data["_next_deck_ix"]), + _hand=jnp.array(data["_hand"], dtype=jnp.int8), + _river=jnp.array(data["_river"], dtype=jnp.uint8), + _n_river=jnp.array(data["_n_river"], dtype=jnp.int8), + _doras=jnp.array(data["_doras"], dtype=jnp.int8), + _n_kan=jnp.int8(data["_n_kan"]), + _target=jnp.int8(data["_target"]), + _last_draw=jnp.int8(data["_last_draw"]), + _last_player=jnp.int8(data["_last_player"]), + _furo_check_num=jnp.uint8(data["_furo_check_num"]), + _riichi_declared=jnp.bool_(data["_riichi_declared"]), + _riichi=jnp.array(data["_riichi"], dtype=jnp.bool_), + _n_meld=jnp.array(data["_n_meld"], dtype=jnp.int8), + _melds=jnp.array(data["_melds"], dtype=jnp.int32), + _is_menzen=jnp.array(data["_is_menzen"], dtype=jnp.bool_), + _pon=jnp.array(data["_pon"], dtype=jnp.int32), ) with open(path, "r") as f: @@ -184,25 +184,25 @@ def __eq__(self, b): and (a._rng_key == b._rng_key).all() and a._step_count == b._step_count and a._round == b._round - and a.honba == b.honba - and a.oya == b.oya - and (a.score == b.score).all() - and (a.deck == b.deck).all() - and a.next_deck_ix == b.next_deck_ix - and (a.hand == b.hand).all() - and (a.river == b.river).all() - and (a.doras == b.doras).all() - and a.n_kan == b.n_kan - and a.target == b.target - and a.last_draw == b.last_draw - and a.last_player == b.last_player - and a.furo_check_num == b.furo_check_num - and a.riichi_declared == b.riichi_declared - and (a.riichi == b.riichi).all() - and (a.n_meld == b.n_meld).all() - and (a.melds == b.melds).all() - and (a.is_menzen == b.is_menzen).all() - and (a.pon == b.pon).all() + and a._honba == b._honba + and a._oya == b._oya + and (a._score == b._score).all() + and (a._deck == b._deck).all() + and a._next_deck_ix == b._next_deck_ix + and (a._hand == b._hand).all() + and (a._river == b._river).all() + and (a._doras == b._doras).all() + and a._n_kan == b._n_kan + and a._target == b._target + and a._last_draw == b._last_draw + and a._last_player == b._last_player + and a._furo_check_num == b._furo_check_num + and a._riichi_declared == b._riichi_declared + and (a._riichi == b._riichi).all() + and (a._n_meld == b._n_meld).all() + and (a._melds == b._melds).all() + and (a._is_menzen == b._is_menzen).all() + and (a._pon == b._pon).all() ) @@ -228,7 +228,7 @@ def id(self) -> v1.EnvId: @property def version(self) -> str: - return "alpha" + return "beta" @property def num_players(self) -> int: @@ -244,11 +244,11 @@ def _init(rng: jax.random.KeyArray) -> State: doras = jnp.array([deck[9], -1, -1, -1, -1], dtype=jnp.int8) state = State( # type:ignore current_player=current_player, - oya=current_player, - last_player=last_player, - deck=deck, - doras=doras, - hand=init_hand, + _oya=current_player, + _last_player=last_player, + _deck=deck, + _doras=doras, + _hand=init_hand, _rng_key=subkey, ) return _draw(state) @@ -292,22 +292,22 @@ def _step(state: State, action) -> State: def _draw(state: State): state = _accept_riichi(state) c_p = state.current_player - new_tile = state.deck[state.next_deck_ix] - next_deck_ix = state.next_deck_ix - 1 - hand = state.hand.at[c_p].set(Hand.add(state.hand[c_p], new_tile)) + new_tile = state._deck[state._next_deck_ix] + next_deck_ix = state._next_deck_ix - 1 + hand = state._hand.at[c_p].set(Hand.add(state._hand[c_p], new_tile)) legal_action_mask = jax.lax.select( - state.riichi[c_p], + state._riichi[c_p], _make_legal_action_mask_w_riichi(state, hand, c_p, new_tile), _make_legal_action_mask(state, hand, c_p, new_tile), ) return state.replace( # type:ignore - target=jnp.int8(-1), - next_deck_ix=next_deck_ix, - hand=hand, - last_draw=new_tile, - last_player=c_p, + _target=jnp.int8(-1), + _next_deck_ix=next_deck_ix, + _hand=hand, + _last_draw=new_tile, + _last_player=c_p, legal_action_mask=legal_action_mask, ) @@ -320,15 +320,15 @@ def _make_legal_action_mask(state: State, hand, c_p, new_tile): legal_action_mask = legal_action_mask.at[new_tile + 34].set( Hand.can_ankan(hand[c_p], new_tile) | ( - Hand.can_kakan(hand[c_p], new_tile) & state.pon[(c_p, new_tile)] + Hand.can_kakan(hand[c_p], new_tile) & state._pon[(c_p, new_tile)] > 0 ) ) legal_action_mask = legal_action_mask.at[Action.RIICHI].set( jax.lax.cond( - state.riichi[c_p] - | state.is_menzen[c_p] - | (state.next_deck_ix < 13 + 4), + state._riichi[c_p] + | state._is_menzen[c_p] + | (state._next_deck_ix < 13 + 4), lambda: FALSE, lambda: Hand.can_riichi(hand[c_p]), ) @@ -336,13 +336,13 @@ def _make_legal_action_mask(state: State, hand, c_p, new_tile): legal_action_mask = legal_action_mask.at[Action.TSUMO].set( Hand.can_tsumo(hand[c_p]) & Yaku.judge( - state.hand[c_p], - state.melds[c_p], - state.n_meld[c_p], - state.last_draw, - state.riichi[c_p], + state._hand[c_p], + state._melds[c_p], + state._n_meld[c_p], + state._last_draw, + state._riichi[c_p], FALSE, - _dora_array(state, state.riichi[c_p]), + _dora_array(state, state._riichi[c_p]), )[0].any() ) return legal_action_mask @@ -354,13 +354,13 @@ def _make_legal_action_mask_w_riichi(state, hand, c_p, new_tile): legal_action_mask = legal_action_mask.at[Action.TSUMO].set( Hand.can_tsumo(hand[c_p]) & Yaku.judge( - state.hand[c_p], - state.melds[c_p], - state.n_meld[c_p], - state.last_draw, - state.riichi[c_p], + state._hand[c_p], + state._melds[c_p], + state._n_meld[c_p], + state._last_draw, + state._riichi[c_p], FALSE, - _dora_array(state, state.riichi[c_p]), + _dora_array(state, state._riichi[c_p]), )[0].any() ) return legal_action_mask @@ -368,18 +368,18 @@ def _make_legal_action_mask_w_riichi(state, hand, c_p, new_tile): def _discard(state: State, tile: jnp.ndarray): c_p = state.current_player - tile = jnp.where(tile == 68, state.last_draw, tile) + tile = jnp.where(tile == 68, state._last_draw, tile) _tile = jnp.where( - state.riichi_declared, tile | jnp.uint8(0b01000000), tile + state._riichi_declared, tile | jnp.uint8(0b01000000), tile ) - river = state.river.at[c_p, state.n_river[c_p]].set(jnp.uint8(_tile)) - n_river = state.n_river.at[c_p].add(1) - hand = state.hand.at[c_p].set(Hand.sub(state.hand[c_p], tile)) + river = state._river.at[c_p, state._n_river[c_p]].set(jnp.uint8(_tile)) + n_river = state._n_river.at[c_p].add(1) + hand = state._hand.at[c_p].set(Hand.sub(state._hand[c_p], tile)) state = state.replace( # type:ignore - last_draw=jnp.int8(-1), - hand=hand, - river=river, - n_river=n_river, + _last_draw=jnp.int8(-1), + _hand=hand, + _river=river, + _n_river=n_river, ) # ポンとかチーとかがあるか @@ -387,9 +387,9 @@ def _discard(state: State, tile: jnp.ndarray): pon_player = kan_player = ron_player = c_p chi_player = (c_p + 1) % 4 can_chi = ( - Hand.can_chi(state.hand[chi_player], tile, Action.CHI_L) - | Hand.can_chi(state.hand[chi_player], tile, Action.CHI_M) - | Hand.can_chi(state.hand[chi_player], tile, Action.CHI_R) + Hand.can_chi(state._hand[chi_player], tile, Action.CHI_L) + | Hand.can_chi(state._hand[chi_player], tile, Action.CHI_M) + | Hand.can_chi(state._hand[chi_player], tile, Action.CHI_R) ) meld_type = jax.lax.cond( can_chi, @@ -402,7 +402,7 @@ def search(i, tpl): meld_type, pon_player, kan_player, ron_player = tpl player = (c_p + 1 + i) % 4 # 絶対位置 pon_player, meld_type = jax.lax.cond( - Hand.can_pon(state.hand[player], tile), + Hand.can_pon(state._hand[player], tile), lambda: (i, jnp.max(jnp.array([2, meld_type]))), lambda: (pon_player, meld_type), ) @@ -412,15 +412,15 @@ def search(i, tpl): lambda: (kan_player, meld_type), ) ron_player, meld_type = jax.lax.cond( - Hand.can_ron(state.hand[player], tile) + Hand.can_ron(state._hand[player], tile) & Yaku.judge( - state.hand[player], - state.melds[player], - state.n_meld[player], - state.last_draw, - state.riichi[player], + state._hand[player], + state._melds[player], + state._n_meld[player], + state._last_draw, + state._riichi[player], FALSE, - _dora_array(state, state.riichi[player]), + _dora_array(state, state._riichi[player]), )[0].any(), lambda: (i, jnp.max(jnp.array([4, meld_type]))), lambda: (ron_player, meld_type), @@ -442,7 +442,9 @@ def search(i, tpl): (c_p + 1 + ron_player) % 4, ) - rewards = jnp.float32([Hand.is_tenpai(hand) * 100 for hand in state.hand]) + rewards = jnp.float32( + [Hand.is_tenpai(_hand) * 100 for _hand in state._hand] + ) no_meld_state = jax.lax.cond( _is_ryukyoku(state), lambda: state.replace( # type:ignore @@ -452,7 +454,7 @@ def search(i, tpl): lambda: _draw( state.replace( # type:ignore current_player=(c_p + 1) % 4, - target=jnp.int8(-1), + _target=jnp.int8(-1), ) ), ) @@ -463,24 +465,24 @@ def search(i, tpl): lambda: no_meld_state, lambda: state.replace( # type:ignore current_player=chi_player, - last_player=c_p, - target=jnp.int8(tile), - furo_check_num=furo_num & 0b11111100, + _last_player=c_p, + _target=jnp.int8(tile), + _furo_check_num=furo_num & 0b11111100, legal_action_mask=jnp.zeros(NUM_ACTION, dtype=jnp.bool_) .at[Action.CHI_L] - .set(Hand.can_chi(state.hand[chi_player], tile, Action.CHI_L)) + .set(Hand.can_chi(state._hand[chi_player], tile, Action.CHI_L)) .at[Action.CHI_M] - .set(Hand.can_chi(state.hand[chi_player], tile, Action.CHI_M)) + .set(Hand.can_chi(state._hand[chi_player], tile, Action.CHI_M)) .at[Action.CHI_R] - .set(Hand.can_chi(state.hand[chi_player], tile, Action.CHI_R)) + .set(Hand.can_chi(state._hand[chi_player], tile, Action.CHI_R)) .at[Action.PASS] .set(TRUE), ), lambda: state.replace( # type:ignore current_player=pon_player, - last_player=c_p, - target=jnp.int8(tile), - furo_check_num=furo_num & 0b11110011, + _last_player=c_p, + _target=jnp.int8(tile), + _furo_check_num=furo_num & 0b11110011, legal_action_mask=jnp.zeros(NUM_ACTION, dtype=jnp.bool_) .at[Action.PON] .set(TRUE) @@ -489,9 +491,9 @@ def search(i, tpl): ), lambda: state.replace( # type:ignore current_player=kan_player, - last_player=c_p, - target=jnp.int8(tile), - furo_check_num=furo_num & 0b11001111, + _last_player=c_p, + _target=jnp.int8(tile), + _furo_check_num=furo_num & 0b11001111, legal_action_mask=jnp.zeros(NUM_ACTION, dtype=jnp.bool_) .at[Action.MINKAN] .set(Hand.can_minkan(hand[kan_player], tile)) @@ -500,9 +502,9 @@ def search(i, tpl): ), lambda: state.replace( # type:ignore current_player=ron_player, - last_player=c_p, - target=jnp.int8(tile), - furo_check_num=furo_num, + _last_player=c_p, + _target=jnp.int8(tile), + _furo_check_num=furo_num, legal_action_mask=jnp.zeros(NUM_ACTION, dtype=jnp.bool_) .at[Action.RON] .set(TRUE) @@ -514,14 +516,14 @@ def search(i, tpl): def _append_meld(state: State, meld, player): - melds = state.melds.at[(player, state.n_meld[player])].set(meld) - n_meld = state.n_meld.at[player].add(1) - return state.replace(melds=melds, n_meld=n_meld) # type:ignore + melds = state._melds.at[(player, state._n_meld[player])].set(meld) + n_meld = state._n_meld.at[player].add(1) + return state.replace(_melds=melds, _n_meld=n_meld) # type:ignore def _selfkan(state: State, action): target = action - 34 - pon = state.pon[(state.current_player, target)] + pon = state._pon[(state.current_player, target)] state = jax.lax.cond( pon == 0, lambda: _ankan(state, target), @@ -529,10 +531,10 @@ def _selfkan(state: State, action): ) # 嶺上牌 - rinshan_tile = state.deck[state.next_deck_ix] - next_deck_ix = state.next_deck_ix - 1 - hand = state.hand.at[state.current_player].set( - Hand.add(state.hand[state.current_player], rinshan_tile) + rinshan_tile = state._deck[state._next_deck_ix] + next_deck_ix = state._next_deck_ix - 1 + hand = state._hand.at[state.current_player].set( + Hand.add(state._hand[state.current_player], rinshan_tile) ) legal_action_mask = jnp.zeros(NUM_ACTION, dtype=jnp.bool_) legal_action_mask = legal_action_mask.at[0:34].set( @@ -542,13 +544,13 @@ def _selfkan(state: State, action): legal_action_mask = legal_action_mask.at[Action.TSUMOGIRI].set(TRUE) return state.replace( # type:ignore - next_deck_ix=next_deck_ix, - last_draw=rinshan_tile, - hand=hand, + _next_deck_ix=next_deck_ix, + _last_draw=rinshan_tile, + _hand=hand, legal_action_mask=legal_action_mask, - n_kan=state.n_kan + 1, - doras=state.doras.at[state.n_kan + 1].set( - state.deck[9 - 2 * (state.n_kan + 1)] + _n_kan=state._n_kan + 1, + _doras=state._doras.at[state._n_kan + 1].set( + state._deck[9 - 2 * (state._n_kan + 1)] ), ) @@ -557,56 +559,58 @@ def _ankan(state: State, target): curr_player = state.current_player meld = Meld.init(target + 34, target, src=0) state = _append_meld(state, meld, curr_player) - hand = state.hand.at[curr_player].set( - Hand.ankan(state.hand[curr_player], target) + hand = state._hand.at[curr_player].set( + Hand.ankan(state._hand[curr_player], target) ) # TODO: 国士無双ロンの受付 return state.replace( # type:ignore - hand=hand, + _hand=hand, ) def _kakan(state: State, target, pon_src, pon_idx): - melds = state.melds.at[(state.current_player, pon_idx)].set( + melds = state._melds.at[(state.current_player, pon_idx)].set( Meld.init(target + 34, target, pon_src) ) - hand = state.hand.at[state.current_player].set( - Hand.kakan(state.hand[state.current_player], target) + hand = state._hand.at[state.current_player].set( + Hand.kakan(state._hand[state.current_player], target) ) - pon = state.pon.at[(state.current_player, target)].set(0) + pon = state._pon.at[(state.current_player, target)].set(0) # TODO: 槍槓の受付 - return state.replace(melds=melds, hand=hand, pon=pon) # type:ignore + return state.replace(_melds=melds, _hand=hand, _pon=pon) # type:ignore def _accept_riichi(state: State) -> State: - l_p = state.last_player - score = state.score.at[l_p].add( - jnp.where(~state.riichi[l_p] & state.riichi_declared, -10, 0) + l_p = state._last_player + _score = state._score.at[l_p].add( + jnp.where(~state._riichi[l_p] & state._riichi_declared, -10, 0) ) - riichi = state.riichi.at[l_p].set( - state.riichi[l_p] | state.riichi_declared + riichi = state._riichi.at[l_p].set( + state._riichi[l_p] | state._riichi_declared ) return state.replace( # type:ignore - riichi=riichi, riichi_declared=FALSE, score=score + _riichi=riichi, _riichi_declared=FALSE, _score=_score ) def _minkan(state: State): c_p = state.current_player - l_p = state.last_player + l_p = state._last_player state = _accept_riichi(state) src = (l_p - c_p) % 4 - meld = Meld.init(Action.MINKAN, state.target, src) + meld = Meld.init(Action.MINKAN, state._target, src) state = _append_meld(state, meld, c_p) - hand = state.hand.at[c_p].set(Hand.minkan(state.hand[c_p], state.target)) - state = state.replace(hand=hand) # type:ignore - is_menzen = state.is_menzen.at[c_p].set(FALSE) + hand = state._hand.at[c_p].set( + Hand.minkan(state._hand[c_p], state._target) + ) + state = state.replace(_hand=hand) # type:ignore + is_menzen = state._is_menzen.at[c_p].set(FALSE) - rinshan_tile = state.deck[state.next_deck_ix] - next_deck_ix = state.next_deck_ix - 1 - hand = state.hand.at[c_p].set(Hand.add(state.hand[c_p], rinshan_tile)) + rinshan_tile = state._deck[state._next_deck_ix] + _next_deck_ix = state._next_deck_ix - 1 + hand = state._hand.at[c_p].set(Hand.add(state._hand[c_p], rinshan_tile)) legal_action_mask = jnp.zeros(NUM_ACTION, dtype=jnp.bool_) legal_action_mask = legal_action_mask.at[0:34].set(hand[c_p] > 0) @@ -614,95 +618,95 @@ def _minkan(state: State): legal_action_mask = legal_action_mask.at[Action.TSUMOGIRI].set(TRUE) # 半透明処理 - river = state.river.at[l_p, state.n_river[l_p] - 1].set( - state.river[l_p, state.n_river[l_p] - 1] | jnp.uint8(0b10000000) + river = state._river.at[l_p, state._n_river[l_p] - 1].set( + state._river[l_p, state._n_river[l_p] - 1] | jnp.uint8(0b10000000) ) return state.replace( # type:ignore - target=jnp.int8(-1), - is_menzen=is_menzen, - next_deck_ix=next_deck_ix, - last_draw=rinshan_tile, - hand=hand, + _target=jnp.int8(-1), + _is_menzen=is_menzen, + _next_deck_ix=_next_deck_ix, + _last_draw=rinshan_tile, + _hand=hand, legal_action_mask=legal_action_mask, - river=river, - n_kan=state.n_kan + 1, - doras=state.doras.at[state.n_kan + 1].set( - state.deck[9 - 2 * (state.n_kan + 1)] + _river=river, + _n_kan=state._n_kan + 1, + _doras=state._doras.at[state._n_kan + 1].set( + state._deck[9 - 2 * (state._n_kan + 1)] ), ) def _pon(state: State): c_p = state.current_player - l_p = state.last_player + l_p = state._last_player state = _accept_riichi(state) src = (l_p - c_p) % 4 - meld = Meld.init(Action.PON, state.target, src) + meld = Meld.init(Action.PON, state._target, src) state = _append_meld(state, meld, c_p) - hand = state.hand.at[c_p].set(Hand.pon(state.hand[c_p], state.target)) - is_menzen = state.is_menzen.at[c_p].set(FALSE) - pon = state.pon.at[(c_p, state.target)].set( - src << 2 | state.n_meld[c_p] - 1 + hand = state._hand.at[c_p].set(Hand.pon(state._hand[c_p], state._target)) + is_menzen = state._is_menzen.at[c_p].set(FALSE) + pon = state._pon.at[(c_p, state._target)].set( + src << 2 | state._n_meld[c_p] - 1 ) # 半透明処理 - river = state.river.at[l_p, state.n_river[l_p] - 1].set( - state.river[l_p, state.n_river[l_p] - 1] | jnp.uint8(0b10000000) + river = state._river.at[l_p, state._n_river[l_p] - 1].set( + state._river[l_p, state._n_river[l_p] - 1] | jnp.uint8(0b10000000) ) legal_action_mask = jnp.zeros(NUM_ACTION, dtype=jnp.bool_) legal_action_mask = legal_action_mask.at[:34].set(hand[c_p] > 0) return state.replace( # type:ignore - target=jnp.int8(-1), - is_menzen=is_menzen, - pon=pon, - hand=hand, + _target=jnp.int8(-1), + _is_menzen=is_menzen, + _pon=pon, + _hand=hand, legal_action_mask=legal_action_mask, - river=river, + _river=river, ) def _chi(state: State, action): c_p = state.current_player tar_p = (c_p + 3) % 4 - tar = state.target + tar = state._target state = _accept_riichi(state) meld = Meld.init(action, tar, src=jnp.int32(3)) state = _append_meld(state, meld, c_p) - hand = state.hand.at[c_p].set(Hand.chi(state.hand[c_p], tar, action)) - is_menzen = state.is_menzen.at[c_p].set(FALSE) + hand = state._hand.at[c_p].set(Hand.chi(state._hand[c_p], tar, action)) + is_menzen = state._is_menzen.at[c_p].set(FALSE) legal_action_mask = jnp.zeros(NUM_ACTION, dtype=jnp.bool_) legal_action_mask = legal_action_mask.at[:34].set(hand[c_p] > 0) # 半透明処理 - river = state.river.at[tar_p, state.n_river[tar_p] - 1].set( - state.river[tar_p, state.n_river[tar_p] - 1] | jnp.uint8(0b10000000) + river = state._river.at[tar_p, state._n_river[tar_p] - 1].set( + state._river[tar_p, state._n_river[tar_p] - 1] | jnp.uint8(0b10000000) ) return state.replace( # type:ignore - target=jnp.int8(-1), - is_menzen=is_menzen, - hand=hand, + _target=jnp.int8(-1), + _is_menzen=is_menzen, + _hand=hand, legal_action_mask=legal_action_mask, - river=river, + _river=river, ) def _pass(state: State): - last_player = (state.furo_check_num & 0b11000000) >> 6 - kan_player = (state.furo_check_num & 0b00110000) >> 4 - pon_player = (state.furo_check_num & 0b00001100) >> 2 - chi_player = state.furo_check_num & 0b00000011 + last_player = (state._furo_check_num & 0b11000000) >> 6 + kan_player = (state._furo_check_num & 0b00110000) >> 4 + pon_player = (state._furo_check_num & 0b00001100) >> 2 + chi_player = state._furo_check_num & 0b00000011 return jax.lax.cond( kan_player > 0, lambda: state.replace( # type:ignore current_player=jnp.int8(last_player + 1 + kan_player) % 4, - furo_check_num=jnp.uint8(state.furo_check_num & 0b11001111), + _furo_check_num=jnp.uint8(state._furo_check_num & 0b11001111), legal_action_mask=jnp.zeros(NUM_ACTION, dtype=jnp.bool_) .at[Action.MINKAN] - .set(Hand.can_minkan(state.hand[kan_player], state.target)) + .set(Hand.can_minkan(state._hand[kan_player], state._target)) .at[Action.PASS] .set(TRUE), ), @@ -710,10 +714,10 @@ def _pass(state: State): pon_player > 0, lambda: state.replace( # type:ignore current_player=jnp.int8(last_player + 1 + pon_player) % 4, - furo_check_num=jnp.uint8(state.furo_check_num & 0b11110011), + _furo_check_num=jnp.uint8(state._furo_check_num & 0b11110011), legal_action_mask=jnp.zeros(NUM_ACTION, dtype=jnp.bool_) .at[Action.PON] - .set(Hand.can_pon(state.hand[pon_player], state.target)) + .set(Hand.can_pon(state._hand[pon_player], state._target)) .at[Action.PASS] .set(TRUE), ), @@ -721,31 +725,31 @@ def _pass(state: State): chi_player > 0, lambda: state.replace( # type:ignore current_player=jnp.int8(last_player + 1 + chi_player) % 4, - furo_check_num=jnp.uint8( - state.furo_check_num & 0b11111100 + _furo_check_num=jnp.uint8( + state._furo_check_num & 0b11111100 ), legal_action_mask=jnp.zeros(NUM_ACTION, dtype=jnp.bool_) .at[Action.CHI_L] .set( Hand.can_chi( - state.hand[(last_player + 1 + chi_player) % 4], - state.target, + state._hand[(last_player + 1 + chi_player) % 4], + state._target, Action.CHI_L, ) ) .at[Action.CHI_M] .set( Hand.can_chi( - state.hand[(last_player + 1 + chi_player) % 4], - state.target, + state._hand[(last_player + 1 + chi_player) % 4], + state._target, Action.CHI_M, ) ) .at[Action.CHI_R] .set( Hand.can_chi( - state.hand[(last_player + 1 + chi_player) % 4], - state.target, + state._hand[(last_player + 1 + chi_player) % 4], + state._target, Action.CHI_R, ) ) @@ -771,19 +775,19 @@ def _riichi(state: State): 34, lambda i, arr: arr.at[i].set( jax.lax.cond( - state.hand[c_p][i] > (i == state.last_draw), - lambda: Hand.is_tenpai(Hand.sub(state.hand[c_p], i)), + state._hand[c_p][i] > (i == state._last_draw), + lambda: Hand.is_tenpai(Hand.sub(state._hand[c_p], i)), lambda: FALSE, ) ), jnp.zeros(NUM_ACTION, dtype=jnp.bool_), ) .at[Action.TSUMOGIRI] - .set(Hand.is_tenpai(Hand.sub(state.hand[c_p], state.last_draw))) + .set(Hand.is_tenpai(Hand.sub(state._hand[c_p], state._last_draw))) ) return state.replace( # type:ignore - riichi_declared=TRUE, legal_action_mask=legal_action_mask + _riichi_declared=TRUE, legal_action_mask=legal_action_mask ) @@ -791,31 +795,31 @@ def _tsumo(state: State): c_p = state.current_player score = Yaku.score( - state.hand[c_p], - state.melds[c_p], - state.n_meld[c_p], - state.target, - state.riichi[c_p], + state._hand[c_p], + state._melds[c_p], + state._n_meld[c_p], + state._target, + state._riichi[c_p], is_ron=FALSE, - dora=_dora_array(state, state.riichi[c_p]), + dora=_dora_array(state, state._riichi[c_p]), ) s1 = score + (-score) % 100 s2 = (score * 2) + (-(score * 2)) % 100 - oya = (state.oya + state._round) % 4 + _oya = (state._oya + state._round) % 4 reward = jax.lax.cond( - oya == c_p, + _oya == c_p, lambda: jnp.full(4, -s2, dtype=jnp.int32).at[c_p].set(s2 * 3), lambda: jnp.full(4, -s1, dtype=jnp.int32) - .at[oya] + .at[_oya] .set(-s2) .at[c_p] .set(s1 * 2 + s2), ) # 供託 - reward -= 1000 * state.riichi - reward = reward.at[c_p].set(reward[c_p] + 1000 * jnp.sum(state.riichi)) + reward -= 1000 * state._riichi + reward = reward.at[c_p].set(reward[c_p] + 1000 * jnp.sum(state._riichi)) return state.replace( # type:ignore terminated=TRUE, rewards=jnp.float32(reward) ) @@ -824,16 +828,16 @@ def _tsumo(state: State): def _ron(state: State): c_p = state.current_player score = Yaku.score( - state.hand[c_p], - state.melds[c_p], - state.n_meld[c_p], - state.target, - state.riichi[c_p], + state._hand[c_p], + state._melds[c_p], + state._n_meld[c_p], + state._target, + state._riichi[c_p], is_ron=TRUE, - dora=_dora_array(state, state.riichi[c_p]), + dora=_dora_array(state, state._riichi[c_p]), ) score = jax.lax.cond( - (state.oya + state._round) % 4 == c_p, + (state._oya + state._round) % 4 == c_p, lambda: score * 6, lambda: score * 4, ) @@ -842,20 +846,20 @@ def _ron(state: State): jnp.zeros(4, dtype=jnp.int32) .at[c_p] .set(score) - .at[state.last_player] + .at[state._last_player] .set(-score) ) # 供託 - reward -= 1000 * state.riichi - reward = reward.at[c_p].set(reward[c_p] + 1000 * jnp.sum(state.riichi)) + reward -= 1000 * state._riichi + reward = reward.at[c_p].set(reward[c_p] + 1000 * jnp.sum(state._riichi)) return state.replace( # type:ignore terminated=TRUE, rewards=jnp.float32(reward) ) def _is_ryukyoku(state: State): - return state.next_deck_ix == 13 + return state._next_deck_ix == 13 def _next_game(state: State): @@ -884,17 +888,17 @@ def next(tile): riichi, lambda: jax.lax.fori_loop( 0, - state.n_kan + 1, - lambda i, arr: arr.at[next(state.deck[5 + 2 * i])] + state._n_kan + 1, + lambda i, arr: arr.at[next(state._deck[5 + 2 * i])] .set(TRUE) - .at[next(state.doras[4 + 2 * i])] + .at[next(state._doras[4 + 2 * i])] .set(TRUE), dora, ), lambda: jax.lax.fori_loop( 0, - state.n_kan + 1, - lambda i, arr: arr.at[next(state.doras[5 + 2 * i])].set(TRUE), + state._n_kan + 1, + lambda i, arr: arr.at[next(state._doras[5 + 2 * i])].set(TRUE), dora, ), ) diff --git a/pgx/_mahjong/_hand.py b/pgx/mahjong/_hand.py similarity index 99% rename from pgx/_mahjong/_hand.py rename to pgx/mahjong/_hand.py index bf58e4a5a..7a3aa89cb 100644 --- a/pgx/_mahjong/_hand.py +++ b/pgx/mahjong/_hand.py @@ -4,7 +4,7 @@ import jax import jax.numpy as jnp -from pgx._mahjong._action import Action # type: ignore +from pgx.mahjong._action import Action # type: ignore DIR = os.path.join(os.path.dirname(__file__), "cache") diff --git a/pgx/_mahjong/_meld.py b/pgx/mahjong/_meld.py similarity index 99% rename from pgx/_mahjong/_meld.py rename to pgx/mahjong/_meld.py index 729eacdf5..5d0cc6eeb 100644 --- a/pgx/_mahjong/_meld.py +++ b/pgx/mahjong/_meld.py @@ -1,7 +1,7 @@ import jax import jax.numpy as jnp -from pgx._mahjong._action import Action +from pgx.mahjong._action import Action class Meld: diff --git a/pgx/_mahjong/_shanten.py b/pgx/mahjong/_shanten.py similarity index 100% rename from pgx/_mahjong/_shanten.py rename to pgx/mahjong/_shanten.py diff --git a/pgx/_mahjong/_yaku.py b/pgx/mahjong/_yaku.py similarity index 99% rename from pgx/_mahjong/_yaku.py rename to pgx/mahjong/_yaku.py index 13569037b..f38484abe 100644 --- a/pgx/_mahjong/_yaku.py +++ b/pgx/mahjong/_yaku.py @@ -4,9 +4,9 @@ import jax import jax.numpy as jnp -from pgx._mahjong._action import Action -from pgx._mahjong._hand import Hand -from pgx._mahjong._meld import Meld +from pgx.mahjong._action import Action +from pgx.mahjong._hand import Hand +from pgx.mahjong._meld import Meld DIR = os.path.join(os.path.dirname(__file__), "cache") diff --git a/pgx/_mahjong/__init__.py b/pgx/mahjong/cache/__init__.py similarity index 100% rename from pgx/_mahjong/__init__.py rename to pgx/mahjong/cache/__init__.py diff --git a/pgx/_mahjong/cache/hand_cache.json b/pgx/mahjong/cache/hand_cache.json similarity index 100% rename from pgx/_mahjong/cache/hand_cache.json rename to pgx/mahjong/cache/hand_cache.json diff --git a/pgx/_mahjong/cache/shanten_cache.json b/pgx/mahjong/cache/shanten_cache.json similarity index 100% rename from pgx/_mahjong/cache/shanten_cache.json rename to pgx/mahjong/cache/shanten_cache.json diff --git a/pgx/_mahjong/cache/yaku_cache.json b/pgx/mahjong/cache/yaku_cache.json similarity index 100% rename from pgx/_mahjong/cache/yaku_cache.json rename to pgx/mahjong/cache/yaku_cache.json diff --git a/pgx/v1.py b/pgx/v1.py index 005eb9d64..18a62d7c6 100644 --- a/pgx/v1.py +++ b/pgx/v1.py @@ -50,7 +50,7 @@ "hex", "kuhn_poker", "leduc_holdem", - # "mahjong", + "mahjong", "minatar-asterix", "minatar-breakout", "minatar-freeway", @@ -387,6 +387,10 @@ def make(env_id: EnvId): # noqa: C901 from pgx.leduc_holdem import LeducHoldem return LeducHoldem() + elif env_id == "mahjong": + from pgx.mahjong import Mahjong + + return Mahjong() elif env_id == "minatar-asterix": try: from pgx_minatar.asterix import MinAtarAsterix # type: ignore diff --git a/tests/assets/mahjong/riichi_test.json b/tests/assets/mahjong/riichi_test.json index 487f021c0..d01b9e805 100644 --- a/tests/assets/mahjong/riichi_test.json +++ b/tests/assets/mahjong/riichi_test.json @@ -96,15 +96,15 @@ ], "_step_count": 0, "_round": 0, - "honba": 0, - "oya": 0, - "score": [ + "_honba": 0, + "_oya": 0, + "_score": [ 250.0, 250.0, 250.0, 250.0 ], - "deck": [ + "_deck": [ 7, 13, 7, @@ -242,8 +242,8 @@ 30, 10 ], - "next_deck_ix": 82, - "hand": [ + "_next_deck_ix": 82, + "_hand": [ [ 0, 0, @@ -389,7 +389,7 @@ 0 ] ], - "river": [ + "_river": [ [ 34, 34, @@ -495,38 +495,38 @@ 34 ] ], - "n_river": [ + "_n_river": [ 0, 0, 0, 0 ], - "doras": [ + "_doras": [ 7, -1, -1, -1, -1 ], - "n_kan": 0, - "target": -1, - "last_draw": 28, - "last_player": 0, - "furo_check_num": 0, - "riichi_declared": false, - "riichi": [ + "_n_kan": 0, + "_target": -1, + "_last_draw": 28, + "_last_player": 0, + "_furo_check_num": 0, + "_riichi_declared": false, + "_riichi": [ false, false, false, false ], - "n_meld": [ + "_n_meld": [ 0, 0, 0, 0 ], - "melds": [ + "_melds": [ [ 0, 0, @@ -552,13 +552,13 @@ 0 ] ], - "is_menzen": [ + "_is_menzen": [ false, false, false, false ], - "pon": [ + "_pon": [ [ 0, 0, diff --git a/tests/assets/mahjong/ron_test.json b/tests/assets/mahjong/ron_test.json index dd6092a5b..7b4b6ce7e 100644 --- a/tests/assets/mahjong/ron_test.json +++ b/tests/assets/mahjong/ron_test.json @@ -96,15 +96,15 @@ ], "_step_count": 0, "_round": 0, - "honba": 0, - "oya": 0, - "score": [ + "_honba": 0, + "_oya": 0, + "_score": [ 250.0, 250.0, 250.0, 250.0 ], - "deck": [ + "_deck": [ 7, 13, 7, @@ -242,8 +242,8 @@ 30, 10 ], - "next_deck_ix": 82, - "hand": [ + "_next_deck_ix": 82, + "_hand": [ [ 0, 0, @@ -389,7 +389,7 @@ 0 ] ], - "river": [ + "_river": [ [ 34, 34, @@ -495,38 +495,38 @@ 34 ] ], - "n_river": [ + "_n_river": [ 0, 0, 0, 0 ], - "doras": [ + "_doras": [ 7, -1, -1, -1, -1 ], - "n_kan": 0, - "target": -1, - "last_draw": 28, - "last_player": 0, - "furo_check_num": 0, - "riichi_declared": false, - "riichi": [ + "_n_kan": 0, + "_target": -1, + "_last_draw": 28, + "_last_player": 0, + "_furo_check_num": 0, + "_riichi_declared": false, + "_riichi": [ false, false, false, false ], - "n_meld": [ + "_n_meld": [ 0, 0, 0, 0 ], - "melds": [ + "_melds": [ [ 0, 0, @@ -552,13 +552,13 @@ 0 ] ], - "is_menzen": [ + "_is_menzen": [ false, false, false, false ], - "pon": [ + "_pon": [ [ 0, 0, diff --git a/tests/assets/mahjong/tsumo_test.json b/tests/assets/mahjong/tsumo_test.json index 9a436e32c..01c7b24a7 100644 --- a/tests/assets/mahjong/tsumo_test.json +++ b/tests/assets/mahjong/tsumo_test.json @@ -96,15 +96,15 @@ ], "_step_count": 0, "_round": 0, - "honba": 0, - "oya": 0, - "score": [ + "_honba": 0, + "_oya": 0, + "_score": [ 250.0, 250.0, 250.0, 250.0 ], - "deck": [ + "_deck": [ 7, 13, 7, @@ -242,8 +242,8 @@ 30, 10 ], - "next_deck_ix": 82, - "hand": [ + "_next_deck_ix": 82, + "_hand": [ [ 0, 0, @@ -389,7 +389,7 @@ 0 ] ], - "river": [ + "_river": [ [ 34, 34, @@ -495,38 +495,38 @@ 34 ] ], - "n_river": [ + "_n_river": [ 0, 0, 0, 0 ], - "doras": [ + "_doras": [ 7, -1, -1, -1, -1 ], - "n_kan": 0, - "target": -1, - "last_draw": 28, - "last_player": 0, - "furo_check_num": 0, - "riichi_declared": false, - "riichi": [ + "_n_kan": 0, + "_target": -1, + "_last_draw": 28, + "_last_player": 0, + "_furo_check_num": 0, + "_riichi_declared": false, + "_riichi": [ false, false, false, false ], - "n_meld": [ + "_n_meld": [ 0, 0, 0, 0 ], - "melds": [ + "_melds": [ [ 0, 0, @@ -552,13 +552,13 @@ 0 ] ], - "is_menzen": [ + "_is_menzen": [ false, false, false, false ], - "pon": [ + "_pon": [ [ 0, 0, diff --git a/tests/test_mahjong.py b/tests/test_mahjong.py index 025220983..312faea67 100644 --- a/tests/test_mahjong.py +++ b/tests/test_mahjong.py @@ -1,8 +1,9 @@ -from pgx._mahjong._hand import Hand -from pgx._mahjong._yaku import Yaku -from pgx._mahjong._action import Action -from pgx._mahjong._shanten import Shanten -from pgx._mahjong._mahjong2 import Mahjong +from pgx.mahjong import Mahjong +from pgx.mahjong._env import State +from pgx.mahjong._hand import Hand +from pgx.mahjong._yaku import Yaku +from pgx.mahjong._shanten import Shanten +from pgx.mahjong._action import Action import jax.numpy as jnp from jax import jit import jax @@ -92,7 +93,7 @@ def test_hand(): assert jit(Hand.can_riichi)(hand) - from pgx._mahjong._action import Action + from pgx.mahjong._action import Action # fmt:off hand = jnp.int8([ @@ -202,29 +203,29 @@ def test_shanten(): def test_discard(): key = jax.random.PRNGKey(0) - state = init(key=key) + state: State = init(key=key) assert state.current_player == jnp.int8(0) - assert state.target == jnp.int8(-1) - assert state.deck[state.next_deck_ix] == jnp.int8(8) - assert state.hand[0, 8] == jnp.int8(1) + assert state._target == jnp.int8(-1) + assert state._deck[state._next_deck_ix] == jnp.int8(8) + assert state._hand[0, 8] == jnp.int8(1) - state = step(state, 8) - assert state.hand[0, 8] == jnp.int8(0) + state: State = step(state, 8) + assert state._hand[0, 8] == jnp.int8(0) assert state.current_player == jnp.int8(1) - assert state.target == jnp.int8(-1) - assert state.deck[state.next_deck_ix] == jnp.int8(31) + assert state._target == jnp.int8(-1) + assert state._deck[state._next_deck_ix] == jnp.int8(31) - assert state.hand[1, 8] == jnp.int8(2) + assert state._hand[1, 8] == jnp.int8(2) - state = step(state, Action.TSUMOGIRI) - assert state.hand[1, 8] == jnp.int8(1) + state: State = step(state, Action.TSUMOGIRI) + assert state._hand[1, 8] == jnp.int8(1) assert state.current_player == jnp.int8(2) - assert state.target == jnp.int8(-1) + assert state._target == jnp.int8(-1) def test_chi(): key = jax.random.PRNGKey(0) - state = init(key=key) + state: State = init(key=key) """ current_player 0 [[0 0 0 0 1 0 1 0 1 1 1 0 1 0 0 0 0 2 1 1 0 0 0 0 0 1 0 1 1 0 1 0 0 0] @@ -233,23 +234,23 @@ def test_chi(): [1 0 2 0 0 0 0 0 0 0 0 1 0 0 1 1 1 0 0 1 0 1 0 1 1 0 0 0 0 0 0 2 0 0]] """ assert state.legal_action_mask[6] - state = step(state, 6) + state: State = step(state, 6) assert state.current_player == jnp.int8(1) - assert state.target == jnp.int8(6) + assert state._target == jnp.int8(6) assert state.legal_action_mask[Action.CHI_R] state1 = step(state, Action.CHI_R) assert state1.current_player == jnp.int8(1) - assert state1.melds[1, 0] == jnp.int32(25420) + assert state1._melds[1, 0] == jnp.int32(25420) state2 = step(state, Action.PASS) assert state2.current_player == jnp.int8(1) - assert state2.melds[1, 0] == jnp.int8(0) + assert state2._melds[1, 0] == jnp.int8(0) def test_ankan(): key = jax.random.PRNGKey(352) - state = init(key=key) + state: State = init(key=key) assert state.current_player == jnp.int8(0) """ [[1 2 0 0 0 0 1 0 0 0 0 1 0 1 0 0 0 0 0 0 0 2 0 0 0 0 1 0 0 1 0 4 0 0] @@ -258,49 +259,45 @@ def test_ankan(): [0 0 1 0 1 0 0 0 1 0 0 1 1 2 0 0 0 0 0 0 0 0 0 1 0 1 1 1 1 0 1 0 0 0]] """ assert state.legal_action_mask[65] - assert (state.doras == jnp.int32([28, -1, -1, -1, -1])).all() - assert state.n_kan == jnp.int8(0) + assert (state._doras == jnp.int32([28, -1, -1, -1, -1])).all() + assert state._n_kan == jnp.int8(0) - state = step(state, 65) - assert state.melds[0, 0] == jnp.int32(4033) - assert (state.doras == jnp.int32([28, 23, -1, -1, -1])).all() - assert state.n_kan == jnp.int8(1) + state: State = step(state, 65) + assert state._melds[0, 0] == jnp.int32(4033) + assert (state._doras == jnp.int32([28, 23, -1, -1, -1])).all() + assert state._n_kan == jnp.int8(1) def test_riichi(): - from pgx._mahjong._mahjong2 import State - rng = jax.random.PRNGKey(0) state = State.from_json("tests/assets/mahjong/riichi_test.json") visualize(state, "tests/assets/mahjong/before_riichi.svg") assert state.current_player == jnp.int8(0) - state = step(state, 9) + state: State = step(state, 9) assert state.legal_action_mask[Action.RIICHI] - state = step(state, Action.RIICHI) + state: State = step(state, Action.RIICHI) assert not state.terminated N = 10 for _ in range(N): rng, subkey = jax.random.split(rng) a = act_randomly(subkey, state) - state = step(state, a) + state: State = step(state, a) visualize(state, f"tests/assets/mahjong/after_riichi_{N}.svg") def test_ron(): - from pgx._mahjong._mahjong2 import State - state = State.from_json("tests/assets/mahjong/ron_test.json") visualize(state, "tests/assets/mahjong/before_ron.svg") assert state.current_player == jnp.int8(0) - state = step(state, 30) # 北 + state: State = step(state, 30) # 北 assert state.legal_action_mask[Action.RON] - state = step(state, Action.RON) + state: State = step(state, Action.RON) assert state.terminated assert ( @@ -311,16 +308,14 @@ def test_ron(): def test_tsumo(): - from pgx._mahjong._mahjong2 import State - state = State.from_json("tests/assets/mahjong/tsumo_test.json") visualize(state, "tests/assets/mahjong/before_tsumo.svg") assert state.current_player == jnp.int8(0) - state = step(state, 30) + state: State = step(state, 30) assert state.legal_action_mask[Action.TSUMO] - state = step(state, Action.TSUMO) + state: State = step(state, Action.TSUMO) assert state.terminated assert ( @@ -336,13 +331,12 @@ def test_transparent(): for _ in range(65): rng, subkey = jax.random.split(rng) a = act_randomly(subkey, state) - state = step(state, a) + state: State = step(state, a) visualize(state, "tests/assets/mahjong/transparent.svg") def test_json(): - from pgx._mahjong._mahjong2 import State import os rng = jax.random.PRNGKey(0) @@ -350,7 +344,7 @@ def test_json(): for _ in range(50): rng, subkey = jax.random.split(rng) a = act_randomly(subkey, state) - state = step(state, a) + state: State = step(state, a) path = "temp.json" with open(path, mode="w") as f: @@ -369,11 +363,18 @@ def test_random_play(): for _ in range(70): rng, subkey = jax.random.split(rng) a = act_randomly(subkey, state) - state = step(state, a) + state: State = step(state, a) - assert state.hand[state.current_player].sum() + jnp.count_nonzero( - state.melds[state.current_player] + assert state._hand[state.current_player].sum() + jnp.count_nonzero( + state._melds[state.current_player] ) * 3 in [13, 14] - assert (0 <= state.hand).all() - assert (state.hand <= 4).all() - assert (0 <= state.melds).all() + assert (0 <= state._hand).all() + assert (state._hand <= 4).all() + assert (0 <= state._melds).all() + + +# def test_api(): +# import pgx +# +# env = pgx.make("mahjong") +# pgx.v1_api_test(env, 1)