In [None]:
import sys
import os
sys.path.append(os.path.abspath('..'))

import gymnasium as gym
import chess_gym

env = gym.make("Chess-v0")
env.reset()
board = env.action_space.board

display(board)

### Load Policy Network

In [None]:
import torch
import torch.nn as nn
import hydra
from omegaconf import DictConfig, OmegaConf
from MCTS.models.network import ChessNetwork
from chess_gym.chess_custom import FullyTrackedBoard

# Load config
# Load config directly without hydra decorator
checkpoint_path = "../checkpoints/model.pth"
cfg = OmegaConf.load("../config/train_mcts.yaml")
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(device)

# Initialize network with config parameters
model = ChessNetwork(
    input_channels=cfg.network.input_channels,
    dim_piece_type=cfg.network.dim_piece_type,
    board_size=cfg.network.board_size,
    num_residual_layers=cfg.network.num_residual_layers,
    num_filters=cfg.network.num_filters,
    conv_blocks_channel_lists=cfg.network.conv_blocks_channel_lists,
    action_space_size=cfg.network.action_space_size,
    num_pieces=cfg.network.num_pieces,
    value_head_hidden_size=cfg.network.value_head_hidden_size
)
# Load the model from checkpoint
model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')['model_state_dict'])
print('Model loaded!')
model.eval()  # Set model to evaluation mode
model.to(device)

def execute_policy(model: nn.Module, board: FullyTrackedBoard, device=device, deterministic=True):
    """
    Execute policy by selecting the highest probability action from model's output distribution.
    
    Args:
        model: ChessNetwork model instance
        board: Chess board instance
        device: Device to run model on ('cpu' or 'cuda')
        
    Returns:
        action: Selected chess move
    """
    model.eval()  # Set model to evaluation mode
    observation = torch.from_numpy(board.get_board_vector()).unsqueeze(0).to(device=device, dtype=torch.float32)

    # Get model predictions
    with torch.no_grad():
        policy, _ = model(observation)
        policy = policy.squeeze()  # Remove batch dimension
        
    # Get highest probability legal action
    # Convert policy logits to probabilities using softmax
    policy_probs = torch.softmax(policy, dim=0)

    if deterministic:
        # Sample action deterministically based on highest probability
        action_id = policy.argmax().item() + 1
    else:
        # Sample action stochastically based on probabilities
        action_id = torch.multinomial(policy_probs, num_samples=1).item() + 1

    return action_id, policy_probs


In [3]:
import numpy as np

def print_game_result(board, reward, terminated, truncated):
  # --- After the loop ---
  # Print game result
  if terminated or truncated: # Check why the loop ended
      if board.is_foul():
        print("Game ended in a foul!")
      if reward == 1:
        print("White wins!")
      elif reward == -1:
        print("Black wins!")
      elif board.is_stalemate():
          print("It's a stalemate!")
      elif board.is_insufficient_material():
          print("It's a draw due to insufficient material!")
      elif board.can_claim_draw():
          print("It's a draw by repetition or 50-move rule!")
      elif truncated:
          print("Game truncated.")
      else:
          print("Game ended in a draw (other reason).")


### Simulation with Human-Made Policy

In [None]:
import gymnasium as gym
from utils.policy_human import sample_action
from utils.visualize import board_to_svg, display_svgs_horizontally, visualize_policy_on_board, draw_possible_actions_on_board
from IPython.display import SVG
from utils.analyze import interpret_action

env = gym.make("Chess-v0", render_mode='rgb_array', show_possible_actions=False)
env.reset()
board = env.action_space.board

terminated = False
truncated = False
last_svg_str = None # To store the SVG from the previous step
step_count = 0
board_size = 500
deterministic = False

# Display initial board state
initial_svg = board_to_svg(board, size=board_size)
display(SVG(initial_svg))
last_svg_str = None # Store the first SVG
last_log_str = '' # Store the first log
last_policy_probs = None

print("Starting Game...")

while not terminated and not truncated:
  action, policy_probs = execute_policy(model, board, deterministic=deterministic) # Get action and policy info

  if action not in board.legal_actions:
    action_info = interpret_action(action)
    if action_info:
        last_log_str += f"Step {step_count + 1}: Action {action} (Color: {action_info['piece_color']}, Piece: {action_info['piece_type_str']}, Move: {action_info['move_type']})\n"
    else:
        last_log_str += f"Step {step_count + 1}: Action {action} (Invalid/Unknown Action)\n"
  else:
    last_log_str += f"Step {step_count + 1}: Action {action} ({board.san(action)})\n"

  current_policy_probs = visualize_policy_on_board(board, policy_probs, board_size=board_size)
  observation, reward, terminated, truncated, info = env.step(action)
  step_count += 1
  current_svg_str = board_to_svg(board, size=board_size)

  # display(draw_possible_actions_on_board(board, draw_action_ids=True))
  # current_policy_distribution = visualize_policy_distribution(policy_probs.cpu().numpy(), 0, board)

  # Display pairs of boards
  if last_svg_str: # If we have a stored SVG from the previous step
      print("Displaying boards from previous and current step:")
      display_svgs_horizontally([last_policy_probs, last_svg_str, current_policy_probs, current_svg_str])
      print(last_log_str)
      last_svg_str = None # Clear the stored SVG
      last_log_str = ''
      last_policy_probs = None
  else: # If it's an odd step number (1st, 3rd, etc.), store the current SVG
      last_svg_str = current_svg_str
      last_policy_probs = current_policy_probs
else:
  if last_svg_str:
    print("Displaying final board state:")
    display_svgs_horizontally([last_policy_probs, last_svg_str])
    print(last_log_str)

print_game_result(board, reward, terminated, truncated)
env.close()
