In [4]:
%pip install git+https://github.com/sotetsuk/pgx.git
%pip install flax==0.10.6

Collecting flax==0.10.6
  Downloading flax-0.10.6-py3-none-any.whl.metadata (11 kB)
Collecting msgpack (from flax==0.10.6)
  Downloading msgpack-1.1.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.4 kB)
Collecting optax (from flax==0.10.6)
  Downloading optax-0.2.5-py3-none-any.whl.metadata (7.5 kB)
Collecting orbax-checkpoint (from flax==0.10.6)
  Downloading orbax_checkpoint-0.11.24-py3-none-any.whl.metadata (2.3 kB)
Collecting tensorstore (from flax==0.10.6)
  Downloading tensorstore-0.1.76-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (21 kB)
Collecting rich>=11.1 (from flax==0.10.6)
  Downloading rich-14.1.0-py3-none-any.whl.metadata (18 kB)
Collecting treescope>=0.1.7 (from flax==0.10.6)
  Downloading treescope-0.1.10-py3-none-any.whl.metadata (6.6 kB)
Collecting markdown-it-py>=2.2.0 (from rich>=11.1->flax==0.10.6)
  Downloading markdown_it_py-4.0.0-py3-none-any.whl.metadata (7.3 kB)
Collecting mdurl~=0.1 (from markdown-it-py>=2.2.0

In [7]:
import jax
import jax.numpy as jnp
from pgx.core import Env
from flax.struct import dataclass as struct_dataclass

# --- Constants & Precomputation ---

NUM_GOATS = 15
NUM_TIGERS = 3
TIGER_WIN_THRESHOLD = 10
BOARD_POSITIONS = 23
MAX_TURNS = 200
PLACEMENT_ACTIONS = BOARD_POSITIONS

# Adjacency matrices (1-indexed, 0-padded for JAX compatibility)
ADJ = jnp.array([
    [0, 0, 0, 0], [3, 4, 5, 6], [3, 8, 0, 0], [1, 2, 4, 9], [1, 3, 5, 10],
    [1, 4, 6, 11], [1, 5, 7, 12], [6, 13, 0, 0], [2, 9, 14, 0], [3, 8, 10, 15],
    [4, 9, 11, 16], [5, 10, 12, 17], [6, 11, 13, 18], [7, 12, 14, 0],
    [8, 15, 0, 0], [9, 14, 16, 20], [10, 15, 17, 21], [11, 16, 18, 22],
    [12, 17, 19, 23], [13, 18, 0, 0], [15, 21, 0, 0], [16, 20, 22, 0],
    [17, 21, 23, 0], [18, 22, 0, 0]
], dtype=jnp.int32)

JUMP_ADJ = jnp.array([
    [0, 0, 0, 0], [9, 10, 11, 12], [4, 14, 0, 0], [5, 15, 0, 0], [2, 6, 16, 0],
    [3, 7, 17, 0], [4, 18, 0, 0], [5, 19, 0, 0], [10, 0, 0, 0], [1, 11, 20, 0],
    [1, 8, 12, 21], [1, 9, 13, 22], [1, 10, 23, 0], [11, 0, 0, 0], [2, 16, 0, 0],
    [3, 17, 0, 0], [4, 14, 18, 0], [5, 15, 19, 0], [6, 16, 0, 0], [7, 17, 0, 0],
    [9, 22, 0, 0], [10, 23, 0, 0], [11, 20, 0, 0], [12, 21, 0, 0]
], dtype=jnp.int32)

def _create_move_info():
    """Precomputes detailed information for every possible move."""
    move_info = []
    # Adjacent moves (is_jump=0, mid_pos=0)
    for start_pos in range(1, BOARD_POSITIONS + 1):
        for end_pos in ADJ[start_pos]:
            if end_pos != 0:
                move_info.append([int(start_pos), int(end_pos), 0, 0])

    # Jump moves (is_jump=1, mid_pos=calculated)
    processed_jumps = set()
    for start_pos in range(1, BOARD_POSITIONS + 1):
        for end_pos in JUMP_ADJ[start_pos]:
            end_pos_int = int(end_pos)
            if end_pos_int != 0 and (start_pos, end_pos_int) not in processed_jumps:
                # Find midpoint by checking common neighbors
                mid_pos = 0
                for neighbor1 in ADJ[start_pos]:
                    if neighbor1 != 0:
                        for neighbor2 in ADJ[end_pos_int]:
                            if neighbor2 != 0 and neighbor1 == neighbor2:
                                mid_pos = int(neighbor1)
                                break
                    if mid_pos != 0:
                        break
                move_info.append([start_pos, end_pos_int, 1, mid_pos])
                processed_jumps.add((start_pos, end_pos_int))

    return jnp.array(move_info, dtype=jnp.int32)

MOVE_INFO = _create_move_info()
MOVE_ACTIONS_COUNT = MOVE_INFO.shape[0]
TOTAL_ACTIONS = PLACEMENT_ACTIONS + MOVE_ACTIONS_COUNT

@struct_dataclass
class State:
    """State dataclass for the Aadu Puli Aattam environment."""
    current_player: jnp.ndarray
    observation: jnp.ndarray
    rewards: jnp.ndarray
    terminated: jnp.ndarray
    truncated: jnp.ndarray
    legal_action_mask: jnp.ndarray
    _step_count: jnp.ndarray
    board: jnp.ndarray
    goats_to_place: jnp.ndarray
    goats_captured: jnp.ndarray
    turn_count: jnp.ndarray


class AaduPuliAattam(Env):
    """Aadu Puli Aattam (Goats and Tigers) game environment."""
    version = "v0"
    num_players = 2

    def _init(self, key: jax.random.PRNGKey) -> State:
        """Initializes the game state."""
        board = jnp.zeros(BOARD_POSITIONS, dtype=jnp.int32)
        board = board.at[jnp.array([0, 3, 4])].set(2)  # Initial tiger positions

        state = State(
            current_player=jnp.int32(0),
            board=board,
            goats_to_place=jnp.int32(NUM_GOATS),
            goats_captured=jnp.int32(0),
            turn_count=jnp.int32(0),
            terminated=jnp.bool_(False),
            truncated=jnp.bool_(False),
            legal_action_mask=jnp.zeros(TOTAL_ACTIONS, dtype=jnp.bool_),
            _step_count= jnp.int32(0),
            rewards=jnp.zeros(2, dtype=jnp.float32),
            observation=jnp.zeros(BOARD_POSITIONS + 3, dtype=jnp.int32)
        )
        return state.replace(legal_action_mask=self._legal_action_mask(state))

    def _step(self, state: State, action: jnp.ndarray, key: jax.random.PRNGKey) -> State:
        """Takes a step in the environment."""

        def _handle_placement(state, action):
            new_board = state.board.at[action].set(1)
            return state.replace(
                board=new_board,
                goats_to_place=state.goats_to_place - 1,
            )

        def _handle_movement(state, action):
            move_idx = action - PLACEMENT_ACTIONS
            from_pos, to_pos, is_jump, mid_pos = MOVE_INFO[move_idx]
            from_idx, to_idx = from_pos - 1, to_pos - 1

            piece = state.board[from_idx]
            new_board = state.board.at[from_idx].set(0).at[to_idx].set(piece)

            # Handle goat capture during a tiger jump
            goats_captured = jax.lax.cond(
                (piece == 2) & (is_jump == 1),
                lambda: state.goats_captured + 1,
                lambda: state.goats_captured
            )
            new_board = jax.lax.cond(
                (piece == 2) & (is_jump == 1),
                lambda: new_board.at[mid_pos - 1].set(0),
                lambda: new_board
            )
            return state.replace(board=new_board, goats_captured=goats_captured)

        # Update board based on action type
        # This produces an intermediate state with the board updated but player/turn not yet switched
        intermediate_state = jax.lax.cond(
            action < PLACEMENT_ACTIONS,
            _handle_placement,
            _handle_movement,
            state, action
        )
        jax.debug.print("intermediate_state after action: {}", intermediate_state)
        # Determine win/loss conditions based on the intermediate state board and goats captured
        t_win = intermediate_state.goats_captured >= TIGER_WIN_THRESHOLD
        draw = intermediate_state.turn_count >= MAX_TURNS # Check against turn count *before* increment

        # Calculate the *next* player and next turn count
        next_player = 1 - intermediate_state.current_player
        next_turn_count = intermediate_state.turn_count + 1

        # Create a temporary state to get the legal actions for the NEXT player on the NEW board
        temp_state_for_legal_mask = intermediate_state.replace(
            current_player=next_player
        )
        next_legal_mask = self._legal_action_mask(temp_state_for_legal_mask)

        # Check for Goat Win (Tigers Blocked) - This happens if it's the tiger's turn next AND they have no legal moves
        is_next_player_tiger = next_player == 1
        next_player_has_no_legal_moves = ~jnp.any(next_legal_mask)
        g_win = is_next_player_tiger & next_player_has_no_legal_moves


        # Determine final terminated status and rewards based on the determined win/loss conditions
        terminated = t_win | draw | g_win
        jax.debug.print("t_win: {}", t_win)
        jax.debug.print("draw: {}", draw)
        jax.debug.print("g_win: {}", g_win)
        # Calculate reward based on termination condition
        reward = jax.lax.cond(
        terminated,
        lambda: jax.lax.cond(
            t_win,
            lambda: jnp.array([-1.0, 1.0]), # Tiger win (FIXED: Goat gets -1, Tiger gets +1)
            lambda: jax.lax.cond(
                g_win,
                lambda: jnp.array([1.0, -1.0]), # Goat win
                lambda: jnp.zeros(2, dtype=jnp.float32) # Draw (shouldn't happen if terminated is true and not t_win or g_win, but for completeness)
            )
        ),
        lambda: jnp.zeros(2, dtype=jnp.float32) # No reward if not terminated
        )
        jax.debug.print("reward: {}", reward)
        jax.debug.print("terminated: {}", terminated)
        jax.debug.print("intermediate_state: {}", intermediate_state)

        final_state = intermediate_state.replace(
            current_player=next_player,
            turn_count=next_turn_count,
            terminated=terminated,
            rewards=reward,
            legal_action_mask=next_legal_mask
        )

        jax.debug.print("final_state: {}", final_state)
        return final_state

    def _observe(self, state: State, player_id: jnp.ndarray) -> jnp.ndarray:
        """Returns the observation for the specified player."""
        return jnp.concatenate([
            state.board,
            jnp.array([state.current_player], dtype=jnp.int32),
            jnp.array([state.goats_to_place], dtype=jnp.int32),
            jnp.array([state.goats_captured], dtype=jnp.int32)
        ])
    def _legal_action_mask(self, state: State) -> jnp.ndarray:
        """Computes a boolean mask of legal actions."""

        # Placement Phase Legal Actions
        is_placement_phase = (state.current_player == 0) & (state.goats_to_place > 0)
        can_place = state.board == 0
        placement_mask = is_placement_phase & can_place

        # Movement Phase Legal Actions
        def is_move_legal(move_info):
            from_pos, to_pos, is_jump, mid_pos = move_info
            from_idx, to_idx = from_pos - 1, to_pos - 1

            is_dest_empty = state.board[to_idx] == 0

            # Goat move logic
            is_goat_move = (state.board[from_idx] == 1) & (is_jump == 0)

            # Tiger move logic
            is_tiger_adj_move = (state.board[from_idx] == 2) & (is_jump == 0)
            is_tiger_jump_move = (state.board[from_idx] == 2) & (is_jump == 1) & (state.board[mid_pos - 1] == 1)
            is_tiger_move = is_tiger_adj_move | is_tiger_jump_move

            is_goat_turn = state.current_player == 0
            is_tiger_turn = state.current_player == 1
            is_legal = (is_goat_turn & is_goat_move) | (is_tiger_turn & is_tiger_move)

            return is_legal & is_dest_empty

        is_move_phase = ~is_placement_phase
        move_mask = is_move_phase & jax.vmap(is_move_legal)(MOVE_INFO)

        return jnp.concatenate([placement_mask, move_mask])

    # Removed _are_tigers_blocked as a separate function.
    # The logic is now integrated into _step.
    @staticmethod
    def _are_tigers_blocked(board: jnp.ndarray) -> jnp.ndarray:
        """This function is no longer used in the step logic.
           The tiger blocked check is now done using the legal action mask in _step.
        """
        # Keep a placeholder or remove if confident it's not called.
        # If needed elsewhere, reimplement carefully.
        # For now, returning a dummy value or raising an error might be best if it shouldn't be called.
        # print("WARNING: _are_tigers_blocked was called, but its logic is now integrated into _step.")
        # raise NotImplementedError("_are_tigers_blocked is deprecated. Use legal action mask in _step instead.")
        return jnp.bool_(False) # Dummy return to avoid errors if still referenced


    @property
    def id(self) -> str:
        return "aadu_puli_aattam"

    @property
    def num_actions(self) -> int:
        return TOTAL_ACTIONS

# --- Testing Suite ---
if __name__ == '__main__':
    print("Running tests for AaduPuliAattam...")
    env = AaduPuliAattam()
    key = jax.random.PRNGKey(0)

    # Test 1: API Compliance
    print("\n--- Test 1: pgx API Compliance ---")
    try:
        # To test a local file, you would typically register it first.
        # For this test, we'll just check if the class works as expected.
        print("ℹ️  Skipping v1.make test for local file. Testing class directly.")
        # env_pgx = v1.make("aadu_puli_aattam")
        print("✅ SUCCESS: Environment class is structured for pgx.")
    except Exception as e:
        print(f"❌ FAILURE: Could not load environment. Error: {e}")

    # Test 2: Initialization
    print("\n--- Test 2: Initial State ---")
    state = env.init(key)
    assert state.current_player == 0, "Initial player should be 0 (Goat)"
    assert state.goats_to_place == NUM_GOATS, f"Should start with {NUM_GOATS} goats to place"
    assert jnp.sum(state.board == 2) == NUM_TIGERS, f"Should start with {NUM_TIGERS} tigers"
    print("✅ SUCCESS: Initial state is correct.")

    # Test 3: Goat Placement
    print("\n--- Test 3: Goat Placement ---")
    action = jnp.int32(1) # Place goat at position 2 (index 1)
    state = env.step(state, action, key)
    assert state.board[1] == 1, "Goat not placed correctly"
    assert state.goats_to_place == NUM_GOATS - 1, "Goats to place not decremented"
    assert state.current_player == 1, "Player turn did not switch to Tiger"
    print("✅ SUCCESS: Goat placement is correct.")

    # Test 4: Tiger Movement (Adjacent)
    print("\n--- Test 4: Tiger Movement (Adjacent) ---")
    # Find a legal move for a tiger. Tiger at pos 1 (idx 0) can move to pos 3 (idx 2)
    # This corresponds to MOVE_INFO[0], so action is PLACEMENT_ACTIONS + 0
    # Note: The exact index depends on the _create_move_info generation order.
    # We must find it dynamically.
    move_idx = jnp.where((MOVE_INFO[:, 0] == 1) & (MOVE_INFO[:, 1] == 3) & (MOVE_INFO[:,2] == 0))[0][0]
    action = jnp.int32(PLACEMENT_ACTIONS + move_idx)

    # State is after one goat placement, so it's tiger's turn.
    # Tiger is at [0, 3, 4]. Tiger at 0 (pos 1) can move to pos 3.
    # We need to set the current player to 1 for the tiger to move.
    initial_state = env.init(key)
    # Let's place a goat somewhere else so pos 3 is free for the tiger
    state_after_placement = env.step(initial_state, jnp.int32(10), key)

    state = env.step(state_after_placement, action, key)
    assert state.board[0] == 0, "Tiger did not move from original position"
    assert state.board[2] == 2, "Tiger did not move to new position"
    assert state.current_player == 0, "Player turn did not switch back to Goat"
    print("✅ SUCCESS: Tiger adjacent move is correct.")

    # Test 5: Tiger Jump (Capture)
    print("\n--- Test 5: Tiger Jump (Capture) ---")
    # Setup: Goat at pos 3, Tiger at pos 1. Tiger jumps to pos 9.
    setup_state = env.init(key)
    setup_board = setup_state.board.at[2].set(1) # Goat at pos 3 (idx 2)
    state = setup_state.replace(board=setup_board, current_player=jnp.int32(1))
    # Tiger at pos 1 (idx 0) jumps over pos 3 (idx 2) to pos 9 (idx 8)
    # Find this move in MOVE_INFO: [1, 9, 1, 3].
    jump_move_idx = jnp.where((MOVE_INFO[:, 0] == 1) & (MOVE_INFO[:, 1] == 9) & (MOVE_INFO[:,2] == 1) & (MOVE_INFO[:,3] == 3))[0][0]
    action = jnp.int32(PLACEMENT_ACTIONS + jump_move_idx)
    state = env.step(state, action, key)
    assert state.goats_captured == 1, "Goat capture count did not increment"
    assert state.board[0] == 0, "Tiger did not move from pos 1"
    assert state.board[8] == 2, "Tiger did not land on pos 9"
    assert state.board[2] == 0, "Goat at pos 3 was not captured"
    print("✅ SUCCESS: Tiger jump and capture is correct.")

    # Test 6: Goat Win Condition (Tigers Blocked)
    print("\n--- Test 6: Goat Win Condition ---")
    # Setup: Block all three tigers
    blocked_board = jnp.ones(BOARD_POSITIONS, dtype=jnp.int32) # Fill with goats
    blocked_board = blocked_board.at[jnp.array([0,1,2])].set(2) # Place tigers
    blocked_board = blocked_board.at[3].set(0) # Remove goat at pos 3
    # We need to simulate the state *after* a goat player's move,
    # where the board is blocked and it becomes the tiger's turn.
    blocked_state = env.init(key).replace(board=blocked_board, current_player=jnp.int32(1))

    # Step with an illegal action to trigger the win condition check
    # Any action is fine here as we are testing the win condition check itself
    state = env.step(blocked_state, jnp.int32(3), key) # Action 0 is placing a goat at pos 1, which is blocked by a tiger
    jax.debug.print("State after step in Test 6: {}", state) # Add debug print here
    assert state.terminated, "State should be terminated due to tigers being blocked"
    assert state.rewards[0] == 1.0, "Player 0 (Goat) should have a reward of 1"
    assert state.rewards[1] == -1.0, "Player 1 (Tiger) should have a reward of -1"
    print("✅ SUCCESS: Goat win condition (tigers blocked) is detected correctly.")

    # Test 7: Tiger Win Condition (Goats Captured)
    print("\n--- Test 7: Tiger Win Condition ---")
    state = env.init(key).replace(goats_captured=jnp.int32(TIGER_WIN_THRESHOLD - 1))
    # Fabricate a board state to make a capture that wins the game
    # Tiger at pos 1, goat at pos 3. Tiger jumps to pos 9.
    setup_board = state.board.at[2].set(1)
    jax.debug.print("Setup board: {}", setup_board)
    state = state.replace(board=setup_board, current_player=jnp.int32(1))
    jump_move_idx = jnp.where((MOVE_INFO[:, 0] == 1) & (MOVE_INFO[:, 1] == 9) & (MOVE_INFO[:,2] == 1) & (MOVE_INFO[:,3] == 3))[0][0]
    jax.debug.print("to_pos: {}",MOVE_INFO[jump_move_idx])
    action = jnp.int32(PLACEMENT_ACTIONS + jump_move_idx)

    state = env.step(state, action, key)
    print(f"State after tiger win step: terminated={state.terminated}, rewards={state.rewards}, goats_captured={state.goats_captured}")
    jax.debug.print("State after step in Test 7: {}", state) # Add debug print here
    assert state.goats_captured == TIGER_WIN_THRESHOLD
    assert state.terminated, f"State should be terminated after {TIGER_WIN_THRESHOLD} captures"
    assert state.rewards[0] == -1.0, "Player 0 (Goat) should have a reward of -1"
    assert state.rewards[1] == 1.0, "Player 1 (Tiger) should have a reward of 1"
    print("✅ SUCCESS: Tiger win condition (goats captured) is detected correctly.")

    # Test 8: Legal Action Mask
    print("\n--- Test 8: Legal Action Mask ---")
    state = env.init(key)
    mask = state.legal_action_mask
    # Initially, player 0 can only place goats on empty squares.
    expected_placement_mask = (state.board == 0)
    assert jnp.all(mask[:PLACEMENT_ACTIONS] == expected_placement_mask), "Placement mask incorrect"
    assert jnp.all(mask[PLACEMENT_ACTIONS:] == False), "Move actions should be illegal during placement"
    print("✅ SUCCESS: Initial legal action mask is correct.")

    print("\nAll tests passed!")

Running tests for AaduPuliAattam...

--- Test 1: pgx API Compliance ---
ℹ️  Skipping v1.make test for local file. Testing class directly.
✅ SUCCESS: Environment class is structured for pgx.

--- Test 2: Initial State ---
✅ SUCCESS: Initial state is correct.

--- Test 3: Goat Placement ---
draw: False
intermediate_state after action: State(current_player=Array(0, dtype=int32), observation=Array([ 2,  0,  0,  2,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0, 15,  0], dtype=int32), rewards=Array([0., 0.], dtype=float32), terminated=Array(False, dtype=bool), truncated=Array(False, dtype=bool), legal_action_mask=Array([False,  True,  True, False, False,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, Fa

AssertionError: Player 0 (Goat) should have a reward of -1