# Convert AaduPulliEnv (Gymnasium) to PGX Environment
This notebook demonstrates how to convert the custom Gymnasium environment `AaduPulliEnv` into a PGX-compatible environment.

In [1]:
# Install required packages (if not already installed)
!pip install gymnasium pgx numpy


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.1.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3 -m pip install --upgrade pip[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.1.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3 -m pip install --upgrade pip[0m



## Original AaduPulliEnv (Gymnasium) Implementation
Below is the original custom environment code for reference.

In [2]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np

class AaduPulliEnv(gym.Env):
    metadata = {'render.modes': ['human']}
    def __init__(self):
        super(AaduPulliEnv, self).__init__()
        self.NUM_GOATS = 15
        self.NUM_TIGERS = 3
        self.TIGER_WIN_THRESHOLD = 10
        self.BOARD_POSITIONS = 23
        self.MAX_TURNS = 200
        self.adj = self._get_adjacency()
        self.jump_adj = self._get_jump_adjacency()
        self.placement_actions = self.BOARD_POSITIONS
        self._move_action_map, self._move_action_lookup = self._create_move_maps()
        self.move_actions_count = len(self._move_action_map)
        total_actions = self.placement_actions + self.move_actions_count
        self.action_space = spaces.Discrete(total_actions)
        self.observation_space = spaces.Dict({
            'board': spaces.Box(low=0, high=2, shape=(self.BOARD_POSITIONS,), dtype=np.int32),
            'player_turn': spaces.Discrete(2),
            'goats_to_place': spaces.Box(low=0, high=self.NUM_GOATS, shape=(1,), dtype=np.int32),
            'goats_captured': spaces.Box(low=0, high=self.TIGER_WIN_THRESHOLD, shape=(1,), dtype=np.int32),
        })
        self.board_points = self._get_board_coordinates()
        self.reset()
    def _get_adjacency(self):
        return {
            1: [3, 4, 5, 6], 2: [3, 8], 3: [1, 4, 9, 2], 4: [1, 5, 10, 3], 5: [1, 6, 11, 4], 6: [1, 7, 12, 5], 7: [6, 13],
            8: [2, 9, 14], 9: [3, 10, 15, 8], 10: [4, 11, 16, 9], 11: [5, 12, 17, 10], 12: [6, 13, 18, 11], 13: [7, 14, 12],
            14: [8, 15], 15: [9, 16, 20, 14], 16: [10, 17, 21, 15], 17: [11, 18, 22, 16], 18: [12, 19, 23, 17], 19: [13, 18],
            20: [15, 21], 21: [16, 20, 22], 22: [17, 21, 23], 23: [18, 22]
        }
    def _get_jump_adjacency(self):
        return {
            1: [9, 10, 11, 12], 2: [4, 14], 3: [5, 15], 4: [2, 6, 16], 5: [3, 7, 17], 6: [4, 18], 7: [5, 19],
            8: [10], 9: [1, 11, 20], 10: [1, 8, 12, 21], 11: [1, 9, 13, 22], 12: [1, 10, 23], 13: [11],
            14: [2, 16], 15: [3, 17], 16: [4, 14, 18], 17: [5, 15, 19], 18: [6, 16], 19: [7, 17],
            20: [9, 22], 21: [10, 23], 22: [11, 20], 23: [12, 21]
        }
    def _create_move_maps(self):
        action_map, action_lookup, index = {}, {}, 0
        for start_pos in range(1, self.BOARD_POSITIONS + 1):
            for end_pos in self.adj.get(start_pos, []):
                move = (start_pos, end_pos); action_map[index] = move; action_lookup[move] = index; index += 1
            for end_pos in self.jump_adj.get(start_pos, []):
                move = (start_pos, end_pos)
                if move not in action_lookup:
                    action_map[index] = move; action_lookup[move] = index; index += 1
        return action_map, action_lookup
    def is_action_valid(self, action):
        if not (0 <= action < self.action_space.n): return False, {'error': 'Action out of bounds.'}
        if action < self.placement_actions:
            to_idx = action
            if self.player_turn != 0 or self.goats_placed_count >= self.NUM_GOATS: return False, {'error': 'Cannot place piece now.'}
            if self.board[to_idx] != 0: return False, {'error': 'Destination square is not empty.'}
            return True, {'type': 'place', 'to_idx': to_idx}
        move_idx = action - self.placement_actions; from_pos, to_pos = self._move_action_map[move_idx]; from_idx, to_idx = from_pos - 1, to_pos - 1
        if self.board[to_idx] != 0: return False, {'error': 'Destination square is not empty.'}
        if self.player_turn == 0:
            if self.goats_placed_count < self.NUM_GOATS: return False, {'error': 'Goat is still in placement phase.'}
            if self.board[from_idx] != 1: return False, {'error': 'Player must move a goat.'}
            if to_pos not in self.adj.get(from_pos, []): return False, {'error': 'Goat can only move to adjacent squares.'}
            return True, {'type': 'move', 'from_idx': from_idx, 'to_idx': to_idx}
        else:
            if self.board[from_idx] != 2: return False, {'error': 'Player must move a tiger.'}
            if to_pos in self.adj.get(from_pos, []): return True, {'type': 'move', 'from_idx': from_idx, 'to_idx': to_idx, 'is_jump': False}
            if to_pos in self.jump_adj.get(from_pos, []):
                from_neighbors = set(self.adj.get(from_pos, [])); to_neighbors = set(self.adj.get(to_pos, [])); mid_pos_set = from_neighbors.intersection(to_neighbors)
                if mid_pos_set:
                    mid_pos = mid_pos_set.pop()
                    if self.board[mid_pos - 1] == 1: return True, {'type': 'move', 'from_idx': from_idx, 'to_idx': to_idx, 'is_jump': True, 'mid_idx': mid_pos - 1}
            return False, {'error': 'Invalid tiger move.'}
    def _are_tigers_blocked(self):
        for t_idx in np.where(self.board == 2)[0]:
            t_pos = t_idx + 1
            for dest_pos in self.adj.get(t_pos, []):
                if self.board[dest_pos - 1] == 0: return False
            for dest_pos in self.jump_adj.get(t_pos, []):
                if self.board[dest_pos - 1] == 0:
                    from_neighbors = set(self.adj.get(t_pos, [])); to_neighbors = set(self.adj.get(dest_pos, [])); mid_pos_set = from_neighbors.intersection(to_neighbors)
                    if mid_pos_set and self.board[mid_pos_set.pop() - 1] == 1: return False
        return True
    def _get_current_observation(self):
        return {"board":self.board.copy(),"player_turn":self.player_turn,"goats_to_place":np.array([self.NUM_GOATS-self.goats_placed_count],dtype=np.int32),"goats_captured":np.array([self.goats_captured_count],dtype=np.int32)}
    def reset(self):
        self.board=np.zeros(self.BOARD_POSITIONS,dtype=np.int32); self.board[0]=2; self.board[3]=2; self.board[4]=2;self.player_turn=0; self.goats_placed_count=0; self.goats_captured_count=0; self.turn_count=0
        return self._get_current_observation()
    def step(self, action):
        is_valid, details = self.is_action_valid(action);reward, done, info = 0, False, details
        if not is_valid: reward = -1
        else:
            if details['type']=='place': self.board[details['to_idx']]=1; self.goats_placed_count+=1
            elif details['type']=='move':
                p=self.board[details['from_idx']]; self.board[details['from_idx']]=0; self.board[details['to_idx']]=p
                if p==2 and details.get('is_jump'): self.board[details['mid_idx']]=0; self.goats_captured_count+=1; reward=5
            g_win=self._are_tigers_blocked(); t_win=self.goats_captured_count>=self.TIGER_WIN_THRESHOLD
            if g_win: done=True; reward=100 if self.player_turn == 0 else -100; info['winner']=0
            elif t_win: done=True; reward=100 if self.player_turn == 1 else -100; info['winner']=1
            self.player_turn=1-self.player_turn; self.turn_count+=1
        if not done and self.turn_count>=self.MAX_TURNS: done=True; info['winner']=-1
        return self._get_current_observation(), reward, done, info
    def _get_board_coordinates(self):
        return {1:(12,20), 2:(1,16),3:(9.2,16),4:(10.7,16),5:(13.3,16),6:(14.8,16),7:(23,16), 8:(1,12),9:(6.5,12),10:(9.5,12),11:(14.5,12),12:(17.5,12),13:(23,12), 14:(1,8),15:(3.8,8),16:(8.3,8),17:(15.7,8),18:(20.3,8),19:(23,8), 20:(1,4),21:(7,4),22:(17,4),23:(23,4)}
    def copy(self):
        new_env = AaduPulliEnv(); new_env.board = self.board.copy(); new_env.player_turn = self.player_turn; new_env.goats_placed_count = self.goats_placed_count; new_env.goats_captured_count = self.goats_captured_count; new_env.turn_count = self.turn_count
        return new_env

## PGX Environment Overview
PGX is a fast RL environment library for board games. To convert a Gymnasium environment to PGX, we need to implement the PGX API: `step`, `reset`, `legal_actions`, `current_player`, and `state_to_tensor`.

In [3]:
import pgx
import numpy as np
class AaduPulliPGXEnv(pgx.Env):
    def id(self):
        return "aadupulli"
    def version(self):
        return "v0"
    def num_players(self):
        return 2
    def _init(self):
        self.NUM_GOATS = 15
        self.NUM_TIGERS = 3
        self.TIGER_WIN_THRESHOLD = 10
        self.BOARD_POSITIONS = 23
        self.MAX_TURNS = 200
        self.adj = self._get_adjacency()
        self.jump_adj = self._get_jump_adjacency()
        self.placement_actions = self.BOARD_POSITIONS
        self._move_action_map, self._move_action_lookup = self._create_move_maps()
        self.move_actions_count = len(self._move_action_map)
        self.total_actions = self.placement_actions + self.move_actions_count
        self.reset()
        return self._get_state()
    def _observe(self, state, player):
        return state['observation']
    def _step(self, state, action):
        self.board = state['observation']['board']
        self.player_turn = state['observation']['player_turn']
        self.goats_placed_count = self.NUM_GOATS - state['observation']['goats_to_place']
        self.goats_captured_count = state['observation']['goats_captured']
        self.turn_count = state.get('turn_count', 0)
        next_state, reward, done, info = self._step_game(action)
        return next_state, reward, done, info
    def _step_game(self, action):
        is_valid, details = self.is_action_valid(action)
        reward, done, info = 0, False, details
        if not is_valid:
            reward = -1
        else:
            if details['type'] == 'place':
                self.board[details['to_idx']] = 1; self.goats_placed_count += 1
            elif details['type'] == 'move':
                p = self.board[details['from_idx']]; self.board[details['from_idx']] = 0; self.board[details['to_idx']] = p
                if p == 2 and details.get('is_jump'):
                    self.board[details['mid_idx']] = 0; self.goats_captured_count += 1; reward = 5
            g_win = self._are_tigers_blocked(); t_win = self.goats_captured_count >= self.TIGER_WIN_THRESHOLD
            if g_win: done = True; reward = 100 if self.player_turn == 0 else -100; info['winner'] = 0
            elif t_win: done = True; reward = 100 if self.player_turn == 1 else -100; info['winner'] = 1
            self.player_turn = 1 - self.player_turn; self.turn_count += 1
        if not done and self.turn_count >= self.MAX_TURNS: done = True; info['winner'] = -1
        return self._get_state(), reward, done, info
    def _get_adjacency(self):
        # ...same as Gym env...
        return {
            1: [3, 4, 5, 6], 2: [3, 8], 3: [1, 4, 9, 2], 4: [1, 5, 10, 3], 5: [1, 6, 11, 4], 6: [1, 7, 12, 5], 7: [6, 13],
            8: [2, 9, 14], 9: [3, 10, 15, 8], 10: [4, 11, 16, 9], 11: [5, 12, 17, 10], 12: [6, 13, 18, 11], 13: [7, 14, 12],
            14: [8, 15], 15: [9, 16, 20, 14], 16: [10, 17, 21, 15], 17: [11, 18, 22, 16], 18: [12, 19, 23, 17], 19: [13, 18],
            20: [15, 21], 21: [16, 20, 22], 22: [17, 21, 23], 23: [18, 22]
        }
    def _get_jump_adjacency(self):
        # ...same as Gym env...
        return {
            1: [9, 10, 11, 12], 2: [4, 14], 3: [5, 15], 4: [2, 6, 16], 5: [3, 7, 17], 6: [4, 18], 7: [5, 19],
            8: [10], 9: [1, 11, 20], 10: [1, 8, 12, 21], 11: [1, 9, 13, 22], 12: [1, 10, 23], 13: [11],
            14: [2, 16], 15: [3, 17], 16: [4, 14, 18], 17: [5, 15, 19], 18: [6, 16], 19: [7, 17],
            20: [9, 22], 21: [10, 23], 22: [11, 20], 23: [12, 21]
        }
    def _create_move_maps(self):
        action_map, action_lookup, index = {}, {}, 0
        for start_pos in range(1, self.BOARD_POSITIONS + 1):
            for end_pos in self.adj.get(start_pos, []):
                move = (start_pos, end_pos); action_map[index] = move; action_lookup[move] = index; index += 1
            for end_pos in self.jump_adj.get(start_pos, []):
                move = (start_pos, end_pos)
                if move not in action_lookup:
                    action_map[index] = move; action_lookup[move] = index; index += 1
        return action_map, action_lookup
    def reset(self):
        self.board = np.zeros(self.BOARD_POSITIONS, dtype=np.int32)
        self.board[0] = 2; self.board[3] = 2; self.board[4] = 2
        self.player_turn = 0
        self.goats_placed_count = 0
        self.goats_captured_count = 0
        self.turn_count = 0
        return self._get_state()
    def is_action_valid(self, action):
        # ...same as Gym env...
        if not (0 <= action < self.total_actions): return False, {'error': 'Action out of bounds.'}
        if action < self.placement_actions:
            to_idx = action
            if self.player_turn != 0 or self.goats_placed_count >= self.NUM_GOATS: return False, {'error': 'Cannot place piece now.'}
            if self.board[to_idx] != 0: return False, {'error': 'Destination square is not empty.'}
            return True, {'type': 'place', 'to_idx': to_idx}
        move_idx = action - self.placement_actions; from_pos, to_pos = self._move_action_map[move_idx]; from_idx, to_idx = from_pos - 1, to_pos - 1
        if self.board[to_idx] != 0: return False, {'error': 'Destination square is not empty.'}
        if self.player_turn == 0:
            if self.goats_placed_count < self.NUM_GOATS: return False, {'error': 'Goat is still in placement phase.'}
            if self.board[from_idx] != 1: return False, {'error': 'Player must move a goat.'}
            if to_pos not in self.adj.get(from_pos, []): return False, {'error': 'Goat can only move to adjacent squares.'}
            return True, {'type': 'move', 'from_idx': from_idx, 'to_idx': to_idx}
        else:
            if self.board[from_idx] != 2: return False, {'error': 'Player must move a tiger.'}
            if to_pos in self.adj.get(from_pos, []): return True, {'type': 'move', 'from_idx': from_idx, 'to_idx': to_idx, 'is_jump': False}
            if to_pos in self.jump_adj.get(from_pos, []):
                from_neighbors = set(self.adj.get(from_pos, [])); to_neighbors = set(self.adj.get(to_pos, [])); mid_pos_set = from_neighbors.intersection(to_neighbors)
                if mid_pos_set:
                    mid_pos = mid_pos_set.pop()
                    if self.board[mid_pos - 1] == 1: return True, {'type': 'move', 'from_idx': from_idx, 'to_idx': to_idx, 'is_jump': True, 'mid_idx': mid_pos - 1}
            return False, {'error': 'Invalid tiger move.'}
    def _are_tigers_blocked(self):
        # ...same as Gym env...
        for t_idx in np.where(self.board == 2)[0]:
            t_pos = t_idx + 1
            for dest_pos in self.adj.get(t_pos, []):
                if self.board[dest_pos - 1] == 0: return False
            for dest_pos in self.jump_adj.get(t_pos, []):
                if self.board[dest_pos - 1] == 0:
                    from_neighbors = set(self.adj.get(t_pos, [])); to_neighbors = set(self.adj.get(dest_pos, [])); mid_pos_set = from_neighbors.intersection(to_neighbors)
                    if mid_pos_set and self.board[mid_pos_set.pop() - 1] == 1: return False
        return True
    def _get_state(self):
        # PGX expects a dict with 'observation', 'current_player', 'legal_actions', etc.
        obs = {
            'board': self.board.copy(),
            'player_turn': self.player_turn,
            'goats_to_place': self.NUM_GOATS - self.goats_placed_count,
            'goats_captured': self.goats_captured_count
        }
        return {
            'observation': obs,
            'current_player': self.player_turn,
            'legal_actions': self.legal_actions(),
        }
    def legal_actions(self):
        actions = []
        for a in range(self.total_actions):
            valid, _ = self.is_action_valid(a)
            if valid:
                actions.append(a)
        return actions
    def current_player(self):
        return self.player_turn
    def state_to_tensor(self):
        # Example: flatten board and add player_turn, goats_to_place, goats_captured
        return np.concatenate([self.board, [self.player_turn, self.NUM_GOATS - self.goats_placed_count, self.goats_captured_count]]).astype(np.float32)

## Usage Example
Instantiate and interact with the PGX environment.

In [4]:
import pgx
import numpy as np



env = AaduPulliPGXEnv()
state = env._init()
print('Initial state:', state)
actions = env.legal_actions()
print('Legal actions:', actions)
next_state, reward, done, info = env._step(state, actions[0])
print('Next state:', next_state)
print('Reward:', reward, 'Done:', done, 'Info:', info)

Initial state: {'observation': {'board': array([2, 0, 0, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0], dtype=int32), 'player_turn': 0, 'goats_to_place': 15, 'goats_captured': 0}, 'current_player': 0, 'legal_actions': [1, 2, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22]}
Legal actions: [1, 2, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22]
Next state: {'observation': {'board': array([2, 1, 0, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0], dtype=int32), 'player_turn': 1, 'goats_to_place': 14, 'goats_captured': 0}, 'current_player': 1, 'legal_actions': [23, 26, 43, 44, 49, 50]}
Reward: 0 Done: False Info: {'type': 'place', 'to_idx': 1}


In [5]:
# Example: Training AlphaZero on AaduPulliPGXEnv using PGX
# This assumes you have PGX's AlphaZero implementation installed and accessible.
# If PGX's alphazero.py script is available, you can use it like this:

from pgx.alphazero import train_alphazero

# Make sure AaduPulliPGXEnv is imported or defined in the same script/notebook

env = AaduPulliPGXEnv()

# Set AlphaZero hyperparameters (adjust as needed for your game)
config = {
    'env': env,
    'num_simulations': 50,
    'num_training_steps': 1000,
    'batch_size': 32,
    'learning_rate': 1e-3,
    'model': 'mlp',  # or 'resnet' if supported
    # Add more config options as needed
}

# Start training (this may take time and resources)
train_alphazero(**config)

# For more advanced usage, refer to PGX's AlphaZero documentation and scripts.

ModuleNotFoundError: No module named 'pgx.alphazero'