In [1]:
import io
import os
import random
import time
from collections import deque

import chess
import chess.engine
import chess.pgn
import torch
from datasets import Dataset, load_dataset
from dotenv import load_dotenv
from peft import LoraConfig  # type: ignore
from pydantic import BaseModel, Field
from tqdm import tqdm

import wandb
import xverify as xv
from xverify import GuidedSchema

load_dotenv()

# Can I push to HF
if os.environ.get("HF_TOKEN") is None:
    raise ValueError("HF_TOKEN not found! Please set")

ENGINE = chess.engine.SimpleEngine.popen_uci("stockfish")

INFO 03-25 13:06:35 [__init__.py:256] Automatically detected platform cuda.


In [2]:
# ======== CONFIGURATION PARAMETERS ========
# fmt: off
NUM_SAMPLES = 1000

# Reward function weights
UCI_FORMAT_WEIGHT = 0.25         # Valid UCI notation
LEGAL_MOVE_WEIGHT = 0.5          # Legal move
MOVE_QUALITY_WEIGHT = 2.         # Good move quality

# deduct points for repeating moves
REPEAT_MOVE_PENALTY = 0.05
MOVE_HISTORY = deque(maxlen=50)

# Engine settings
ENGINE_ANALYSIS_TIME = 1.5  # Time limit for engine analysis in seconds

# fmt: on
# =========================================


In [3]:
class Chess_Reason_and_Act(BaseModel):
    scratchpad: str = Field(
        ...,
        description="Information from the Observation useful to answer the question",
    )
    reasoning: str = Field(
        ...,
        description="It describes your thoughts about the question you have been asked",
    )
    best_move: str = Field(
        ...,
        description="The best move to make in the current position, in UCI notation (e.g. b7b3)",
    )


guided_schema = GuidedSchema(Chess_Reason_and_Act)


SYSTEM_PROMPT = f"""
Given a chess position in FEN notation, analyze it and suggest the best move in UCI notation.

Respond in the following format:

{guided_schema.doc}
"""

print(SYSTEM_PROMPT)


Given a chess position in FEN notation, analyze it and suggest the best move in UCI notation.

Respond in the following format:

Output Model: Chess_Reason_and_Act
  Output Fields:
    scratchpad (str):
        Description: Information from the Observation useful to answer the question
    reasoning (str):
        Description: It describes your thoughts about the question you have been asked
    best_move (str):
        Description: The best move to make in the current position, in UCI notation (e.g. b7b3)



In [4]:
def extract_answer(trajectory: list[dict[str, str]]) -> str:
    """Extract the last answer from a trajectory."""
    last_message = trajectory[-1]
    assert last_message["role"] == "assistant", "should be assistant"
    parsed: Chess_Reason_and_Act | None = guided_schema.parse(last_message["content"])  # type: ignore
    return parsed.best_move if parsed else ""

def extract_completions(completions) -> list[str]:
    """Extract the last answer from a trajectory."""
    responses = [completion[0]["content"] for completion in completions]
    return [extract_answer(r) for r in responses]

def is_valid_uci_format(move_str: str) -> bool:
    """Check if a string is in valid UCI move format (e.g., e2e4)"""
    try:
        chess.Move.from_uci(move_str)
        return True
    except:  # noqa: E722
        return False


def is_legal_move(move_str: str, board: chess.Board) -> bool:
    """Check if a move string is valid for the given board position"""
    try:
        move = chess.Move.from_uci(move_str)
        return move in board.legal_moves
    except:  # noqa: E722
        return False



def valid_uci_reward(completions, fen, **kwargs) -> list[float]:
    """Reward function that checks if the move is a valid UCI format"""
    extracted_moves = extract_completions(completions)

    rewards = []
    for i, move in enumerate(extracted_moves):
        move = move.strip()
        valid_uci = is_valid_uci_format(move)
        rewards.append(UCI_FORMAT_WEIGHT if valid_uci else 0.0)

    return rewards


def legal_move_reward(completions, fen, **kwargs) -> list[float]:
    """Reward function that checks if the move is legal"""
    extracted_moves = extract_completions(completions)

    rewards = []
    for i, move in enumerate(extracted_moves):
        move = move.strip()
        board = chess.Board(fen[i])

        legal = is_legal_move(move, board)
        if legal:
            reward = LEGAL_MOVE_WEIGHT
            frequency = MOVE_HISTORY.count(move)
            frequency_penalty = frequency * REPEAT_MOVE_PENALTY
            reward = max(0.0, reward * (1.0 - frequency_penalty))
        else:
            reward = 0.0

        rewards.append(reward)

    return rewards


def engine_analysis_reward(completions, fen, **kwargs) -> list[float]:
    """
    Reward based on how good the suggested move is according to the engine.
    Uses centipawn loss to evaluate move quality.
    This is the final reward function, so it's responsible for calling the logging function.
    """
    engine_time = 0.0
    centipawn_losses = []

    extracted_moves = extract_completions(completions)

    move_rewards = []

    for i, move in enumerate(extracted_moves):
        move = move.strip()
        board = chess.Board(fen[i])

        # Skip evaluation for invalid moves
        if not move or not is_valid_uci_format(move) or not is_legal_move(move, board):
            move_rewards.append(0.0)
            centipawn_losses.append(None)
            continue

        start_time = time.perf_counter()

        # Engine analysis of current position
        initial_eval = ENGINE.analyse(
            board, chess.engine.Limit(time=ENGINE_ANALYSIS_TIME)
        )
        best_move = initial_eval["pv"][0]
        initial_score = initial_eval["score"].relative.score(mate_score=10000)

        # Make player's move and get new evaluation
        player_move = chess.Move.from_uci(move)
        board.push(player_move)
        player_eval = ENGINE.analyse(
            board, chess.engine.Limit(time=ENGINE_ANALYSIS_TIME)
        )

        # Negate because it's from opponent's perspective
        after_move_score = -player_eval["score"].relative.score(mate_score=10000)

        # Calculate centipawn loss
        centipawn_loss = initial_score - after_move_score

        centipawn_losses.append(centipawn_loss)

        # Reward scaling
        # - Less than 300 (bishop / rook blunder) is 0.0
        # - Best move is 1.0
        reward = 0.0
        if centipawn_loss <= 0:
            reward = 1.0
        elif centipawn_loss >= 300:
            reward = 0.0
        else:
            reward = 1.0 - (centipawn_loss / 300.0)

        move_rewards.append(reward * MOVE_QUALITY_WEIGHT)

        engine_time += time.perf_counter() - start_time

    wandb.log({"train/engine_time": engine_time})
    wandb.log({"train/centipawn_losses": centipawn_losses})

    return move_rewards


In [None]:

def get_random_position(row) -> str:
    """Extract a random position from a chess game"""
    pgn = io.StringIO(row["text"])
    game = chess.pgn.read_game(pgn)
    if not game:
        return chess.STARTING_FEN

    board = game.board()
    mainline_moves = list(game.mainline_moves())
    if not mainline_moves:
        return chess.STARTING_FEN

    # Choose a random point in the game (not too early, not too late)
    min_move = min(5, len(mainline_moves) // 5)
    max_move = max(min_move + 1, len(mainline_moves) - 5)
    if max_move <= min_move:
        max_move = min(len(mainline_moves), min_move + 10)

    # Apply moves up to the random point
    move_count = random.randint(min_move, max_move)
    for move in mainline_moves[:move_count]:
        board.push(move)

    return board.fen()


def format_dataset(row):
    """Format dataset for GRPO training"""
    return {
        "prompt": [
            {"role": "system", "content": SYSTEM_PROMPT},
            {
                "role": "user",
                "content": f"Analyze this chess position and give the best move: {row['fen']}",
            },
        ],
        "fen": row["fen"],
    }

dataset = load_dataset( # type: ignore
    "Icannos/lichess_games",
    streaming=True,
    trust_remote_code=True,
)
positions = []
for _, row in tqdm(
    zip(range(NUM_SAMPLES), dataset["train"]),
    desc="Loading chess positions",
    total=NUM_SAMPLES,
):
    positions.append(get_random_position(row))

dataset: Dataset = Dataset.from_dict({"fen": positions})
dataset = dataset.map(format_dataset)


In [None]:
# model_name = "google/gemma-3-4b-it"
model_name = "Qwen/Qwen2.5-3B-Instruct"
lora_rank = 16
lora_alpha = 64

gpu_memory_utilization = 0.85

model, tokenizer = xv.get_model_and_tokenizer(
    model_name,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    use_cache=False,
)

max_seq_length = 1280
max_prompt_length = 256
max_completion_length = max_seq_length - max_prompt_length

peft_config = LoraConfig(
    r=lora_rank,
    lora_alpha=lora_alpha,
    lora_dropout=0.05,
    target_modules="all-linear",
)


Using Liger kernel
Applied Liger kernels to Qwen2


`low_cpu_mem_usage` was None, now default to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
checkpoint_path = None  # No checkpoint: None

training_args = xv.get_default_grpo_config(
    "chess-reasoner-training",
    learning_rate=5e-6,
    weight_decay=0.1,
    optim="adamw_8bit",
    per_device_train_batch_size=6,
    gradient_accumulation_steps=1,
    num_generations=6,
    max_prompt_length=max_seq_length,
    max_completion_length=max_completion_length,
    max_steps=NUM_SAMPLES,
    save_steps=1000,
    save_total_limit=3,
    max_grad_norm=0.1,
    report_to="wandb",
    output_dir="outputs",
    vllm_gpu_memory_utilization=gpu_memory_utilization,
)


trainer = xv.GRPOGuidedTrainer(
    guided_schema=guided_schema,
    model=model,
    args=training_args,
    train_dataset=dataset,
    peft_config=peft_config,
    reward_funcs=[
        legal_move_reward,
        valid_uci_reward,
        engine_analysis_reward,
    ],
)

trainer.train(resume_from_checkpoint=checkpoint_path)

model.save_lora("chess_reasoner_llama_8b_lora")
print("Model saved to chess_reasoner_llama_8b_lora")

ENGINE.quit()

No label_names provided for model class `PeftModel`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [None]:
def push_to_hub(model, tokenizer, repo_id):
    """Upload the trained model to Hugging Face Hub"""
    tqdm.write(f"Uploading model to Hugging Face Hub: {repo_id}")
    model.push_to_hub(repo_id)
    tokenizer.push_to_hub(repo_id)
    tqdm.write(f"Successfully uploaded model to: https://huggingface.co/{repo_id}")

# push_to_hub(model, tokenizer, "tommyp111/chess-reasoner")