# Example: Improving LLM Chess Ability with Best/Mixture-of-N Sampling


## Setup

In [None]:
import asyncio
import random
from copy import deepcopy
from typing import Dict, List, Optional, Tuple
from uuid import UUID

import altair as alt
import chess
import chess.engine
import chess.pgn
import pandas as pd
from tensorzero import AsyncTensorZeroGateway
from tqdm import tqdm, trange

We're going to evaluate the following variants in this notebook. Feel free to create your own variants and evaluate them!

In [None]:
VARIANTS = [
    "baseline",
    "best_of_5",
    "mixture_of_5",
]

Let's load the training set but truncate it to 1000 examples.

For a quick trial, you can use `NUM_EXAMPLES = 10`.


In [None]:
NUM_EXAMPLES = 1000

puzzle_df = pd.read_csv("data/lichess_easy_puzzles_train.csv")
puzzle_df = puzzle_df.head(NUM_EXAMPLES)

We'll try to solve the puzzles concurrently to speed up the evaluation.
Reduce the value below if you're getting rate-limited by the API providers.

In [None]:
MAX_CONCURRENT_T0_REQUESTS = 100

semaphore = asyncio.Semaphore(MAX_CONCURRENT_T0_REQUESTS)

Let's initialize the client for the TensorZero Gateway.

In [None]:
t0 = await AsyncTensorZeroGateway.build_http(
    gateway_url="http://localhost:3000", timeout=60
)

## Helper Functions: Solving Chess Puzzles

Here we define a helper function to predict the next move for a given variant.

In [None]:
async def predict_next_move(
    board: chess.Board, variant_name: str, episode_id: Optional[UUID] = None
) -> Tuple[str, Optional[UUID]]:
    """
    Predicts the next chess move using the TensorZero Gateway.

    This function sends the current board state to the TensorZero Gateway and requests
    a move prediction using the specified variant. It handles error cases by falling back
    to a random legal move when necessary.

    Args:
        board (chess.Board): The current chess board state.
        variant_name (str): The name of the variant to use for prediction (e.g., "baseline")
        episode_id (Optional[UUID], optional): The episode ID for tracking the conversation
                                              across multiple calls. Defaults to None.

    Returns:
        Tuple[str, Optional[UUID]]:
            - str: The predicted move in standard algebraic notation (SAN).
            - UUID: Optional episode ID for tracking the puzzle attempt (for feedback)
    """
    # Compute the legal moves in standard algebraic notation (SAN)
    legal_moves_san = [board.san(move) for move in board.legal_moves]

    # Call the TensorZero Gateway to predict the next move
    try:
        response = await t0.inference(
            function_name="play_chess_board",
            input={
                "messages": [
                    {
                        "role": "user",
                        "content": {
                            "board": str(board),
                            "color": "white" if board.turn else "black",
                            "legal_moves_san": legal_moves_san,
                        },
                    }
                ]
            },
            variant_name=variant_name,
            episode_id=episode_id,
        )
    except Exception as e:
        print(f"Error occurred: {type(e).__name__}: {e}")
        return random.choice(legal_moves_san), episode_id

    episode_id = response.episode_id

    # Try to parse the generated move

    if response.output.parsed is None:
        print("Error: TensorZero returned no parsed output.")
        return random.choice(legal_moves_san), episode_id

    move = response.output.parsed.get("move")

    if move is None:
        print("Error: TensorZero returned no move.")
        return random.choice(legal_moves_san), episode_id

    return move, response.episode_id

Next we define a helper function to solve a single puzzle.

In [None]:
async def solve_puzzle(puzzle_data: Dict) -> Tuple[bool, Optional[UUID]]:
    """
    Runs a chess puzzle and checks if the player solves it correctly.

    This function simulates a chess puzzle by applying the first move from the puzzle data,
    then alternating between player and opponent moves according to the expected sequence.
    The player's moves are generated using the predict_next_move function.

    Args:
        puzzle_data (Dict): A dictionary containing puzzle details including:
            - "FEN": The FEN notation of the starting position
            - "Moves": A string of space-separated moves in the expected solution

    Returns:
        Tuple[bool, Optional[UUID]]:
            - bool: True if the player solves the puzzle correctly, False otherwise
            - UUID: Optional episode ID for tracking the puzzle attempt (for feedback)
    """

    # Extract puzzle details from puzzle_data
    fen = puzzle_data.get("FEN")
    expected_moves = puzzle_data.get("Moves", "").split()
    board = chess.Board(fen)
    move_index = 0
    total_moves = len(expected_moves)

    # Apply the first move before starting the puzzle (as expected by the benchmark)
    first_move = expected_moves[move_index]
    first_move_obj = board.parse_san(first_move)
    board.push(first_move_obj)
    move_index = 1

    # Determine player's color based on the updated position
    player_color = board.turn  # True for White, False for Black
    episode_id = None

    while move_index < total_moves and not board.is_game_over():
        if board.turn == player_color:  # Player's move
            async with semaphore:
                player_move_san, episode_id = await predict_next_move(
                    deepcopy(board), variant_name, episode_id
                )

            expected_move = expected_moves[move_index]

            try:
                player_move_obj = board.parse_san(player_move_san)
            except ValueError:
                return False, episode_id

            try:
                expected_move_obj = board.parse_san(expected_move)
            except ValueError:
                expected_move_obj = chess.Move.from_uci(expected_move)

            if board.is_checkmate():
                return True, episode_id

            if player_move_obj != expected_move_obj:
                return False, episode_id

            board.push(player_move_obj)
        else:  # Opponent's move
            expected_move = expected_moves[move_index]
            opponent_move_obj = board.parse_san(expected_move)

            board.push(opponent_move_obj)

        move_index += 1

    if move_index == total_moves:
        return True, episode_id
    else:
        return False, episode_id

Finally, we define a function to solve all the puzzles in the dataframe.

In [None]:
async def solve_puzzles(
    puzzle_df: pd.DataFrame,
    variant_name: str,
) -> List[bool]:
    """
    Solves a batch of chess puzzles concurrently and optionally sends feedback.

    This function processes a dataframe of chess puzzles, attempting to solve each one
    using the specified variant. It runs the puzzles concurrently to improve throughput.

    Args:
        puzzle_df (pd.DataFrame): DataFrame containing chess puzzles with FEN positions and expected moves.
        variant_name (str): The name of the variant to use for prediction (e.g., "baseline", "best_of_5").

    Returns:
        List[bool]: A list of boolean values indicating success (True) or failure (False) for each puzzle.
    """

    successes = []
    episode_ids = []
    num_successes = 0
    total_puzzles = len(puzzle_df)
    progress_bar = trange(
        total_puzzles,
        desc=f"[Inference] {variant_name}",
    )

    tasks = [
        asyncio.create_task(solve_puzzle(puzzle_df.iloc[i].to_dict()))
        for i in range(total_puzzles)
    ]

    for task in asyncio.as_completed(tasks):
        success, episode_id = await task
        successes.append(success)
        episode_ids.append(episode_id)
        if success:
            num_successes += 1
        current = len(successes)
        progress_bar.update(1)
        progress_bar.set_postfix(
            {"Success": f"{num_successes}/{current}"},
            refresh=True,
        )
    progress_bar.close()

    for success, episode_id in tqdm(
        zip(successes, episode_ids),
        total=len(successes),
        desc=f"[Feedback] {variant_name}",
    ):
        if episode_id:
            async with semaphore:
                await t0.feedback(
                    episode_id=episode_id,
                    metric_name="puzzle_success",
                    value=success,
                )

    return successes

## Evaluate the Variants

In [None]:
results = {}

for variant_name in VARIANTS:
    results[variant_name] = await solve_puzzles(
        puzzle_df,
        variant_name,
    )

## Plot Results

In [None]:
# Format the results for plotting
results_df = pd.DataFrame(results).stack().reset_index()
results_df.columns = ["Puzzle", "Variant", "Success Rate"]
results_df["Mean Success Rate by Variant"] = results_df.groupby("Variant")[
    "Success Rate"
].transform("mean")
results_df["Variant"] = results_df.apply(
    lambda row: f"{row['Variant']} ({row['Mean Success Rate by Variant'] * 100:.1f}%)",
    axis=1,
)

# Plot the results
chart = (
    alt.Chart(results_df)
    .encode(
        x=alt.X(
            "mean(Success Rate):Q",
            axis=alt.Axis(format="%"),
            scale=alt.Scale(domain=[0, 1]),
        ),
        y=alt.Y("Variant:N", sort=None),
    )
    .mark_bar()
)

error_bars = (
    alt.Chart(results_df)
    .mark_errorbar(extent="ci")
    .encode(
        x=alt.X("Success Rate:Q"),
        y=alt.Y("Variant:N"),
    )
)

(chart + error_bars).properties(title="Success Rate by Variant")