In [None]:
import os

import chess
import chess.svg
from jax import random as jrandom
import numpy as np

In [None]:
from searchless_chess.src import tokenizer
from searchless_chess.src import training_utils
from searchless_chess.src import transformer
from searchless_chess.src import utils
from searchless_chess.src.engines import engine
from searchless_chess.src.engines import neural_engines

In [None]:
# @title Create the predictor.

policy = 'action_value'
num_return_buckets = 128

match policy:
  case 'action_value':
    output_size = num_return_buckets
  case 'behavioral_cloning':
    output_size = utils.NUM_ACTIONS
  case 'state_value':
    output_size = num_return_buckets
  case _:
    raise ValueError(f'Unknown policy: {policy}')

predictor_config = transformer.TransformerConfig(
    vocab_size=utils.NUM_ACTIONS,
    output_size=output_size,
    pos_encodings=transformer.PositionalEncodings.LEARNED,
    max_sequence_length=tokenizer.SEQUENCE_LENGTH + 2,
    num_heads=4,
    num_layers=4,
    embedding_dim=64,
    apply_post_ln=True,
    apply_qk_layernorm=False,
    use_causal_mask=False,
)

predictor = transformer.build_transformer_predictor(config=predictor_config)

In [None]:
# @title Load the predictor parameters

checkpoint_dir = os.path.join(
    os.getcwd(),
    '../checkpoints/local/action_value/',
)
dummy_params = predictor.initial_params(
    rng=jrandom.PRNGKey(0),
    targets=np.zeros((1, 1), dtype=np.uint32),
)
params = training_utils.load_parameters(
    checkpoint_dir=checkpoint_dir,
    params=dummy_params,
    use_ema_params=True,
    step=-1,
)
params = dummy_params

In [None]:
# @title Create the engine

predict_fn = neural_engines.wrap_predict_fn(predictor, params, batch_size=1)
_, return_buckets_values = utils.get_uniform_buckets_edges_values(
        num_return_buckets
)

neural_engine = neural_engines.ENGINE_FROM_POLICY[policy](
    return_buckets_values=return_buckets_values,
    predict_fn=predict_fn,
    temperature=0.005,
)

In [None]:
# @title Play a move with the agent

board = chess.Board()
best_move = neural_engine.play(board)
print(f'Best move: {best_move}')

In [None]:
# @title Compute the win percentages for all legal moves

board = chess.Board()
results = neural_engine.analyse(board)
buckets_log_probs = results['log_probs']

# Compute the expected return.
win_probs = np.inner(np.exp(buckets_log_probs), return_buckets_values)
sorted_legal_moves = engine.get_ordered_legal_moves(board)

print(board.fen())
print(f'Win percentages:')
for i in np.argsort(win_probs)[::-1]:
  print(f'  {sorted_legal_moves[i].uci()} -> {100*win_probs[i]:.1f}%')