<a href="https://colab.research.google.com/github/sotetsuk/pgx/blob/sotetsuk%2Fcolab%2Fupdate-check-chess/colab/check_chess.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install open_spiel pgx

Collecting open_spiel
  Downloading open_spiel-1.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.4/5.4 MB[0m [31m9.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting pgx
  Downloading pgx-1.4.0-py3-none-any.whl (413 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m413.9/413.9 kB[0m [31m35.4 MB/s[0m eta [36m0:00:00[0m
Collecting svgwrite (from pgx)
  Downloading svgwrite-1.4.3-py3-none-any.whl (67 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.1/67.1 kB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: svgwrite, open_spiel, pgx
Successfully installed open_spiel-1.3 pgx-1.4.0 svgwrite-1.4.3


In [2]:
import random
from tqdm import tqdm
import numpy as np
import jax
import jax.numpy as jnp
import pyspiel
import pgx
from pgx.chess import State

pgx.__version__



'1.4.0'

In [3]:
game = pyspiel.load_game('chess')
env = pgx.make("chess")
init = jax.jit(env.init)
step = jax.jit(env.step)

def check(seed):
    np.random.seed(seed)
    spiel_state = game.new_initial_state()
    pgx_state = init(jax.random.PRNGKey(0))  # seed is not related
    action_seq = []
    for _ in range(512):  # pgx chess terminates after 512 steps (following AZ paper)
        fen_before = spiel_state.debug_string().splitlines()[0][5:]

        expected_legal_actions = [m for m in spiel_state.legal_actions()]
        if not expected_legal_actions:
            break

        # check legal actions
        ok = pgx_state.legal_action_mask.sum() == len(expected_legal_actions)
        for a in expected_legal_actions:
            ok = ok and pgx_state.legal_action_mask[a]

        assert ok, f"\n{fen_before}\n{pgx_state.legal_action_mask.sum()} != {len(expected_legal_actions)}\nactual:{jnp.nonzero(pgx_state.legal_action_mask)[0]}\nexpected:{expected_legal_actions}\naction sequence: {action_seq}"

        # step by OpenSpiel
        action = np.random.choice(expected_legal_actions)
        action_seq.append(action)
        spiel_state.apply_action(action)
        fen_after = spiel_state.debug_string().splitlines()[0][5:]

        # step by Pgx
        pgx_state = step(pgx_state, jnp.int32(action))

        # check state transition
        assert pgx_state._to_fen() == fen_after, f"\n{fen_before}\nactual:{pgx_state._to_fen()}\nexpected: {fen_after}\naction sequence: {action_seq}"

In [5]:
for i in tqdm(range(1000)):
    check(i)

100%|██████████| 1000/1000 [3:05:04<00:00, 11.10s/it]
