<a href="https://colab.research.google.com/github/rezabonyadi/chess-language/blob/main/language_model_learn_to_chess.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install vllm

In [None]:
!pip install trl datasets peft python-chess accelerate bitsandbytes unsloth

In [None]:
!apt-get install -y stockfish

In [None]:
STOCKFISH_PATH = "/usr/games/stockfish"

In [None]:
import re
import torch
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import GRPOConfig, GRPOTrainer

# A little trick to help instruct models to do not waste too much time figuring out the required response format.
FORMAT_REMINDER = """
 Respond in the following format:
<reasoning> your reasoning process here </reasoning>
<answer> your final answer here </answer>
"""

SYSTEM_PROMPT = f"""
You are going to respond to a user query. You always first reason and then provide your answer.
You enclose your reasoning process and answer within <reasoning> </reasoning> and <answer> </answer> tags, respectively, i.e.,
{FORMAT_REMINDER}
 """

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""


In [None]:
# Generating Chess games. Generate games if you dont have them saved, otherwise we will load them later
import chess
import chess.engine
import math
import random
import chess.svg
from IPython.display import SVG
import numpy as np
from tqdm import tqdm
import random
from multiprocessing import Pool, cpu_count
import json

def generate_random_game(min_moves=2, max_moves=10):
    board = chess.Board()
    moves = []
    num_moves = random.randint(min_moves, max_moves)
    for _ in range(num_moves):
        if board.is_game_over():
            break
        legal_moves = list(board.legal_moves)
        move = random.choice(legal_moves)
        board.push(move)
        moves.append(move.uci())
    return moves, board


def generate_realistic_game(min_moves=2, max_moves=10, engine_path=STOCKFISH_PATH):
    engine = chess.engine.SimpleEngine.popen_uci(engine_path)

    board = chess.Board()
    moves = []
    num_moves = np.random.randint(min_moves, max_moves)
    for _ in range(num_moves):
        if board.is_game_over():
            break  # Stop if the game is finished

        info = engine.analyse(board, chess.engine.Limit(time=0.1), multipv=5)
        top_moves = [(tm['pv'][0], tm['score'].relative.score()) for tm in info]

        chosen_move = random.choice(top_moves)

        board.push(chosen_move[0])
        moves.append(str(chosen_move[0]))

    engine.close()

    return moves, board

def generate_puzzle(args):
    """Function to be executed in parallel."""
    index, challenging, min_moves, max_moves, stockfish_path = args
    engine_path = stockfish_path

    moves, board = generate_realistic_game(min_moves, max_moves, engine_path) if challenging else generate_random_game(min_moves, max_moves)

    moves_text = " ".join(moves)
    turn = "white" if board.turn else "black"
    prompt_text = (f"The following chess moves have been played: {moves_text}. "
                   f"It is {turn} to move. Suggest the best feasible next move for {turn}. {FORMAT_REMINDER}")

    return {
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': prompt_text}
        ],
        'boards': board.fen()  # Store board as FEN instead of full object
    }

def get_chess_puzzles_parallel(num_examples=100, challenging=True, min_moves=2, max_moves=10):
    """Parallelized chess puzzle generator."""
    num_workers = min(cpu_count(), 8)  # Use up to 8 cores (adjust if needed)

    # Prepare arguments for parallel execution
    task_args = [(i, challenging, min_moves, max_moves, STOCKFISH_PATH) for i in range(num_examples)]

    # Use multiprocessing Pool
    with Pool(num_workers) as pool:
        puzzles = list(tqdm(pool.imap(generate_puzzle, task_args), total=num_examples))

    return puzzles

def save_puzzles(puzzles, filename):
    with open(filename, 'w') as file:
        json.dump(puzzles, file, indent=4)

# Generate chess games. "challenging=True" would use stockfish to create games, which takes more time.
# False would be just random feasible games.
dataset = get_chess_puzzles_parallel(num_examples=50000, challenging=False, min_moves=2, max_moves=6)

# Save them for future use, if you want.
# save_puzzles(dataset, 'chess_difficult_5_10_moves.json')

In [None]:
# Reward functions for learning Chess
import chess
import chess.engine
import math
import random
import chess.svg
from IPython.display import SVG
import numpy as np
from tqdm import tqdm
import random


def parse_chess_move(answer):
    answer = answer.replace('.', '')
    answer = answer.replace(' ', '')
    answer = answer.replace('-', '')
    return answer

def evaluate_move(board, move_str, turn=None, analysis_time=0.1, scaling=0.003):
    max_score = 2.0

    valid_move_reward = 0.1
    legal_move_reward = 0.1

    total_score = 0.0
    board = chess.Board(board)
    try: # Is it a valid uci move
        move = chess.Move.from_uci(move_str)
        total_score += valid_move_reward
    except:
        return total_score

    if move in board.legal_moves: # Is it a feasible legal move
        total_score += legal_move_reward
    else:
        return total_score

    board_copy = board.copy()
    board_copy.push(move)

    # Start the engine.
    engine = chess.engine.SimpleEngine.popen_uci(STOCKFISH_PATH)
    try: # Score, the higher the better
        info = engine.analyse(board_copy, chess.engine.Limit(time=analysis_time))
        effective_score = -info["score"].relative
        # The engine returns a PovScore relative to the opponent (is the move in benefit for opponent).
        # We get the negative to be if it is good for the player. The larger the better

        # If the score indicates mate, return the extremes.
        if effective_score.is_mate():
            move_score = (max_score if effective_score.mate() > 0 else 0.0)
        else:
            # Get the centipawn score (an integer, e.g., +50 means +0.50 pawns advantage)
            cp = effective_score.score()
            # Map centipawn score to smaller range.
            move_score = max_score*(1 + math.tanh(scaling * cp))/2.0

        # # Quantize move score to 3 levels, from 0.0 to 0.5*max_score is 0.0, from 0.5*max_score to 0.75*max_score is 1.0,
        # # and larger than 0.75*max_score is 2.0. So, do not reward bad moves.
        move_score = 0.0 if move_score < 0.5*max_score else 1.0 if move_score < 0.75*max_score else 2.0

        total_score += move_score
        return total_score
    except:
        return total_score
    finally:
        engine.quit()

def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

# # Reward functions
def correctness_reward_func(prompts, completions, boards, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    parsed_moves = [parse_chess_move(e) for e in extracted_responses]
    scores = [evaluate_move(board, parsed_move) for board, parsed_move in zip(boards, parsed_moves)]
    print('-'*20, f"\nQuestion:\n{q}", f"\nBoard:\n{chess.Board(boards[0])}", f"\nResponse:\n{responses[0]}",
          f"\nExtracted:\n{extracted_responses[0]}", f"\nScore:\n{scores[0]}")
    return scores


In [None]:
import json

def load_puzzles(filename):
    with open(filename, 'r') as file:
        puzzles = json.load(file)
    return puzzles

# To load the puzzles back into a variable
dataset = load_puzzles('chess_difficult_5_10_moves.json')
# dataset = get_gsm8k_questions()

In [None]:
# Formatting reward functions
def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>") == 1:
        count += 0.125
    if text.count("</reasoning>") == 1:
        count += 0.125
    if text.count("<answer>") == 1:
        count += 0.125
        count -= len(text.split("</answer>")[-1])*0.001
    if text.count("</answer>") == 1:
        count += 0.125
        count -= (len(text.split("</answer>")[-1]) - 1)*0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

# Load models and train

## Using normal models

In [None]:
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
# model_name = "microsoft/Phi-3.5-mini-instruct"
# model_name = "Qwen/Qwen2.5-1.5B-Instruct"

output_dir="outputs/Qwen-0.5B-GRPO"
run_name = f"{model_name.replace('/', '-')}-Some-ID"

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map=None
).to("cuda")

# You can of course use some quantization
# model = AutoModelForCausalLM.from_pretrained(model_name,
#                                                   device_map="auto",
#                                                   load_in_8bit=True,  # Enables 8-bit quantization
#                                                   torch_dtype="auto"  # Automatically selects the correct data type
#                                                   )


We better train the model only partially, as otherwise it needs huge memory and we may change its foundation too much.
For Qwen 0.5B, you will be fine with an L4 GPU.

Two options here: LORA or select subset of layers to train


In [None]:

# LORA
# from peft import LoraConfig, get_peft_model

# # Define your target modules (for example, you might target the query and value projection layers)
# target_modules = ["q_proj", "v_proj"]  # adjust as needed for your model architecture
# target_modules = ["qkv_proj", "o_proj"]
# # Your LoRA configuration
# lora_config = LoraConfig(
#     r=8,
#     lora_alpha=16,
#     lora_dropout=0.05,
#     bias="none",
#     task_type="CAUSAL_LM",
#     target_modules=target_modules
# )

# # ------
# # Subset of layers
# # Freeze all layers first
# for param in model.parameters():
#     param.requires_grad = False

# last_layers = list(model.named_parameters())  # Adjust based on model specifics
# trainable_layers = ['31', '30', '29']
# for name, param in last_layers:
#     try:
#         if name.split('.')[2] in trainable_layers:
#             param.requires_grad = True
#             print(f"Unfreezing layer: {name}")
#     except:
#         continue


If you want to load a model from a checkpoint


In [None]:
# model_weights_path = f"outputs_1/Qwen-0.5B-GRPO/checkpoint-2500"
# model_weights_path = f"drive/MyDrive/results/llm_chess/Qwen-0.5B-GRPO-Chess/Qwen-0.5B-GRPO-Chess/checkpoint-2500"

# model = AutoModelForCausalLM.from_pretrained(
#     model_weights_path,
#     torch_dtype=torch.bfloat16,
#     device_map=None
# ).to("cuda")


In [None]:
import torch.optim as optim

optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=5e-6)

training_args = GRPOConfig(
    output_dir=output_dir,
    run_name=run_name,
    learning_rate=5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type='cosine',
    logging_steps=1,
    bf16=True,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    num_generations=8,
    max_prompt_length=256,
    max_completion_length=200,
    num_train_epochs=1,
    save_steps=1000,
    max_grad_norm=0.1,
    log_on_each_node=False,
    use_vllm=True,
    vllm_gpu_memory_utilization=0.4, #.3,
    beta=0.04,
    vllm_device="cuda:0",
    report_to="wandb" #I'm disabling Wandb.
)


In [None]:
# PEFT doesnt seem to work on multi-gpu for this library
trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    optimizers=(optimizer, None),
    reward_funcs=[
        xmlcount_reward_func,
        soft_format_reward_func,
        correctness_reward_func],
    args=training_args,
    train_dataset=dataset,
    # peft_config=lora_config,
)
trainer.train()

## Using Unsloth

IF You want to use Unsloth, then use these code below

In [None]:
from unsloth import FastLanguageModel, PatchFastRL, is_bfloat16_supported
PatchFastRL("GRPO", FastLanguageModel)

import torch
max_seq_length = 512 # Can increase for longer reasoning traces

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "meta-llama/meta-Llama-3.1-8B-Instruct",
    max_seq_length = max_seq_length,
    load_in_4bit = True, # False for LoRA 16bit
    fast_inference = True, # Enable vLLM fast inference
    # max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.6, # Reduce if out of memory
)


In [None]:
# We better train the model only partially, as otherwise it needs huge memory and we may change its foundation too much.

lora_rank = 32 # Larger rank = smarter, but slower

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ], # Remove QKVO if out of memory
    lora_alpha = lora_rank,
    use_gradient_checkpointing = "unsloth", # Enable long context finetuning
    random_state = 3407,
)



In [None]:
from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    use_vllm = True, # use vLLM for fast inference!
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "paged_adamw_8bit",
    logging_steps = 1,
    bf16 = is_bfloat16_supported(),
    fp16 = not is_bfloat16_supported(),
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 4, # Increase to 4 for smoother training
    num_generations = 6, # Decrease if out of memory
    max_prompt_length = 256,
    max_completion_length = 200,
    # num_train_epochs = 1, # Set to 1 for a full training run
    max_steps = 250,
    save_steps = 250,
    max_grad_norm = 0.1,
    report_to = "wandb", # Can use none
    output_dir = "outputs",
)

In [None]:
# use peft at your own risk; not working for me with multi-GPU training
trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
        xmlcount_reward_func,
        soft_format_reward_func,
        correctness_reward_func],
    args=training_args,
    train_dataset=dataset,
    # peft_config=lora_config
)
trainer.train()