<a href="https://colab.research.google.com/github/rezabonyadi/language-models-experiments/blob/main/GRPO_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers peft python-chess accelerate bitsandbytes datasets wandb

Collecting python-chess
  Downloading python_chess-1.999-py3-none-any.whl.metadata (776 bytes)
Collecting bitsandbytes
  Downloading bitsandbytes-0.45.3-py3-none-manylinux_2_24_x86_64.whl.metadata (5.0 kB)
Collecting datasets
  Downloading datasets-3.3.2-py3-none-any.whl.metadata (19 kB)
Collecting chess<2,>=1 (from python-chess)
  Downloading chess-1.11.2.tar.gz (6.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.1/6.1 MB[0m [31m79.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.13.0->peft)
  Downloading 

# Set up the prompt

In [None]:
import re
import torch
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

# A little trick to help instruct models to do not waste too much time figuring out the required response format.
FORMAT_REMINDER = """
 Respond in the following format:
<reasoning> your reasoning process here </reasoning>
<answer> your final answer here </answer>
"""

SYSTEM_PROMPT = f"""
You are going to respond to a user query. You always first reason and then provide your answer.
You enclose your reasoning process and answer within <reasoning> </reasoning> and <answer> </answer> tags, respectively, i.e.,
{FORMAT_REMINDER}
 """

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""


# Load data

## Chess dataset and correctness reward

In [None]:
!apt-get install -y stockfish

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
Suggested packages:
  polyglot xboard | scid
The following NEW packages will be installed:
  stockfish
0 upgraded, 1 newly installed, 0 to remove and 29 not upgraded.
Need to get 24.8 MB of archives.
After this operation, 47.4 MB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/universe amd64 stockfish amd64 14.1-1 [24.8 MB]
Fetched 24.8 MB in 3s (8,508 kB/s)
Selecting previously unselected package stockfish.
(Reading database ... 124947 files and directories currently installed.)
Preparing to unpack .../stockfish_14.1-1_amd64.deb ...
Unpacking stockfish (14.1-1) ...
Setting up stockfish (14.1-1) ...
Processing triggers for man-db (2.10.2-1) ...


In [None]:
STOCKFISH_PATH = "/usr/games/stockfish"

In [None]:
# Generating Chess games. Generate games if you dont have them saved, otherwise we will load them later
import chess
import chess.engine
import math
import random
import chess.svg
from IPython.display import SVG
import numpy as np
from tqdm import tqdm
import random
from multiprocessing import Pool, cpu_count
import json

import chess

def describe_board(board: chess.Board) -> str:
    # Build descriptions for white and black pieces.
    white_descriptions = []
    black_descriptions = []
    for square, piece in board.piece_map().items():
        square_name = chess.square_name(square)
        # Get the full piece name (e.g., "pawn", "knight", etc.)
        piece_name = chess.piece_name(piece.piece_type)
        # Build the description; here we include the color if desired.
        # You can adjust this formatting if you prefer to omit the color since the header indicates it.
        description = f"a { 'white' if piece.color == chess.WHITE else 'black' } {piece_name} in {square_name}"
        if piece.color == chess.WHITE:
            white_descriptions.append(description)
        else:
            black_descriptions.append(description)

    # Build the overall pieces description.
    white_line = "White pieces placement: " + ", ".join(white_descriptions)
    black_line = "Black pieces placement: " + ", ".join(black_descriptions)

    # Determine if each king has moved by checking castling rights.
    # Note: Losing a castling right can also happen when the rook moves.
    # For this example we assume that if both castling rights remain then the king never moved.
    white_king_has_not_moved = (board.has_kingside_castling_rights(chess.WHITE) and
                                board.has_queenside_castling_rights(chess.WHITE))
    black_king_has_not_moved = (board.has_kingside_castling_rights(chess.BLACK) and
                                board.has_queenside_castling_rights(chess.BLACK))

    white_king_status = "White king has not moved in the game." if white_king_has_not_moved else "White king has moved in this game at least once."
    black_king_status = "Black king has not moved in the game." if black_king_has_not_moved else "Black king has moved in this game at least once."

    # Combine all parts.
    description = "\n".join([white_line, black_line, white_king_status, black_king_status])
    return description


def generate_random_game(min_moves=2, max_moves=10):
    board = chess.Board()
    moves = []
    num_moves = random.randint(min_moves, max_moves)
    for _ in range(num_moves):
        if board.is_game_over():
            break
        legal_moves = list(board.legal_moves)
        move = random.choice(legal_moves)
        board.push(move)
        moves.append(move.uci())
    return moves, board


def generate_realistic_game(min_moves=2, max_moves=10, engine_path=STOCKFISH_PATH):
    engine = chess.engine.SimpleEngine.popen_uci(engine_path)

    board = chess.Board()
    moves = []
    num_moves = np.random.randint(min_moves, max_moves)
    for _ in range(num_moves):
        if board.is_game_over():
            break  # Stop if the game is finished

        info = engine.analyse(board, chess.engine.Limit(time=0.1), multipv=5)
        top_moves = [(tm['pv'][0], tm['score'].relative.score()) for tm in info]

        chosen_move = random.choice(top_moves)

        board.push(chosen_move[0])
        moves.append(str(chosen_move[0]))

    engine.close()

    return moves, board

def generate_puzzle(args):
    """Function to be executed in parallel."""
    index, challenging, min_moves, max_moves, stockfish_path = args
    engine_path = stockfish_path

    moves, board = generate_realistic_game(min_moves, max_moves, engine_path) if challenging else generate_random_game(min_moves, max_moves)

    moves_text = " ".join(moves)
    turn = "white" if board.turn else "black"
    prompt_text = (f"The following chess moves have been played: {moves_text}. "
                   f"It is {turn} to move. Suggest the best feasible next move for {turn}. {FORMAT_REMINDER}")

    # board_description = describe_board(board)
    # turn = "White" if board.turn else "Black"
    # prompt_text = (f"Here is a description of a chess board with pieces placement on the board: \n{board_description}. \n\n"
    #                f"It is {turn} turn to make a move. Suggest the best feasible next move for {turn}. {FORMAT_REMINDER}")

    return {
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': prompt_text}
        ],
        'boards': board.fen()  # Store board as FEN instead of full object
    }

def get_chess_puzzles_parallel(num_examples=100, challenging=True, min_moves=2, max_moves=10):
    """Parallelized chess puzzle generator."""
    num_workers = min(cpu_count(), 8)  # Use up to 8 cores (adjust if needed)

    # Prepare arguments for parallel execution
    task_args = [(i, challenging, min_moves, max_moves, STOCKFISH_PATH) for i in range(num_examples)]

    # Use multiprocessing Pool
    with Pool(num_workers) as pool:
        puzzles = list(tqdm(pool.imap(generate_puzzle, task_args), total=num_examples))

    return puzzles

def save_puzzles(puzzles, filename):
    with open(filename, 'w') as file:
        json.dump(puzzles, file, indent=4)

def load_puzzles(filename):
    with open(filename, 'r') as file:
        puzzles = json.load(file)
    return puzzles

min_moves = 2
max_moves = 10

# Generate puzzles
dataset = get_chess_puzzles_parallel(num_examples=50000, challenging=False, min_moves=min_moves, max_moves=max_moves)

# Save them for future use
# save_puzzles(dataset, f'chess_difficult_{min_moves}_{max_moves}_moves.json')

# To load the puzzles back into a variable
# dataset = load_puzzles(f'chess_difficult_{min_moves}_{max_moves}_moves.json')


100%|██████████| 50000/50000 [00:05<00:00, 9738.96it/s] 


In [None]:
# Reward functions for learning Chess
import chess
import chess.engine
import math
import random
import chess.svg
from IPython.display import SVG
import numpy as np
from tqdm import tqdm
import random

def parse_chess_move(answer):
    answer = answer.replace('.', '')
    answer = answer.replace(' ', '')
    answer = answer.replace('-', '')
    return answer

def evaluate_move(board, move_str, turn=None, analysis_time=0.1, scaling=0.003):
    """
    Evaluate a candidate move and return a value in [0, 2]:
      - 0 indicates the worst move available for the side to move,
      - 2 indicates the best move (or mate).

    Parameters:
      board         : a chess.Board object representing the current position.
      move_str      : the candidate move in UCI format (e.g., "e2e4").
      turn          : a string "white" or "black" indicating who is making the move.
      stockfish_path: the file system path to your Stockfish executable.
      analysis_time : time (in seconds) to allocate for engine analysis.
      scaling       : scaling factor for tanh mapping of centipawn scores.

    Returns:
      A float in the range [0, 2]. For non‐mate positions, the value is computed as
      1 + tanh(scaling * effective_cp) where effective_cp is the centipawn evaluation
      from the mover’s perspective. For mate scores, if the move leads to mate for the mover,
      2 is returned; if mate for the opponent, 0 is returned.
    """
    max_score = 2.0

    valid_move_reward = 0.1
    legal_move_reward = 0.1

    total_score = 0.0
    # print(board)
    board = chess.Board(board)
    try: # Is it a valid uci move
        move = chess.Move.from_uci(move_str)
        total_score += valid_move_reward
    except:
        return total_score

    if move in board.legal_moves: # Is it a feasible legal move
        total_score += legal_move_reward
    else:
        return total_score

    # Work on a copy so as not to modify the original board.
    board_copy = board.copy()
    board_copy.push(move)

    # Start the engine.
    engine = chess.engine.SimpleEngine.popen_uci(STOCKFISH_PATH)
    try: # Score, the higher the better
        info = engine.analyse(board_copy, chess.engine.Limit(time=analysis_time))
        effective_score = -info["score"].relative
        # The engine returns a PovScore relative to the opponent (is the move in benefit for opponent).
        # We get the negative to be if it is good for the player. The larger the better

        # If the score indicates mate, return the extremes.
        if effective_score.is_mate():
            # Mate score: mate() returns the number of half-moves to mate.
            # Here, any mate in a positive number of moves for the mover is considered best.
            move_score = (max_score if effective_score.mate() > 0 else 0.0)
        else:
            # Get the centipawn score (an integer, e.g., +50 means +0.50 pawns advantage)
            cp = effective_score.score()
            # Map centipawn score to [0, 2] using a tanh function.
            # tanh returns a value in (-1, 1), so 1 + tanh(...) is in (0, 2).
            move_score = max_score*(1 + math.tanh(scaling * cp))/2.0
            # Ensure the value is within [0,2]

        # # Quantize move score to 4 levels, from 0.0 to 0.5, 0.5 to 1.0, 1.0 to 1.5, and 1.5 to 2.0
        # move_score = round(move_score * 2) / 2

        # # Quantize move score to 3 levels, from 0.0 to 0.5*max_score is 0.0, from 0.5*max_score to 0.75*max_score is 1.0,
        # # and larger than 0.75*max_score is 2.0. So, do not reward bad moves.
        move_score = 0.0 if move_score < 0.5*max_score else 1.0 if move_score < 0.75*max_score else 2.0

        total_score += move_score
        return total_score
    except:
        return total_score
    finally:
        engine.quit()

def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

# # Reward functions
def correctness_reward_func(prompts, completions, boards, **kwargs) -> list[float]:
    responses = completions
    # print(len(boards))

    extracted_responses = [extract_xml_answer(r) for r in responses]
    parsed_moves = [parse_chess_move(e) for e in extracted_responses]
    scores = [evaluate_move(board, parsed_move) for parsed_move, board in zip(parsed_moves, boards)]
    print('-'*20, f"\nQuestion:\n{prompts[0]}", f"\nResponse:\n{responses[0]}",
          f"\nExtracted:\n{extracted_responses[0]}", f"\nScore:\n{scores}")

    return scores


## gsm8k dataset and correctness reward

In [None]:
from datasets import load_dataset
# Reward functions for learning math reasoning
def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

# Reward functions
def int_reward_func(completions, **kwargs) -> list[float]:
    # responses = [completion[0]['content'] for completion in completions]
    responses = completions
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    # responses = [completion[0]['content'] for completion in completions]

    responses = completions

    extracted_responses = [extract_xml_answer(r) for r in responses]
    print('-'*20, f"Question:\n{prompts[0]}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}",
          f"\nExtracted:\n{extracted_responses[0]}")
    int_rewards = int_reward_func(completions, **kwargs)
    # print(len(prompts), len(extracted_responses), len(int_rewards), len(answer))
    # print(answer)

    scores = [i + 2.0 if r == a else i + 0.0 for r, a, i in zip(extracted_responses, answer, int_rewards)]
    # print(scores)
    return scores

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

# uncomment middle messages for 1-shot prompting
def get_gsm8k_questions(split = "train"):
    data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
    data = data.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': (x['question'] + FORMAT_REMINDER)}
        ],
        'answer': extract_hash_answer(x['answer'])
    }) # type: ignore
    return data # type: ignore

dataset = get_gsm8k_questions()

## Common formatting rewards

In [None]:
# Generic reward functions
def soft_format_reward_func(completions, **kwargs) -> list[float]:
    # print(completions)
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    # responses = [completion[0]["content"] for completion in completions]
    responses = completions
    matches = [re.match(pattern, r, re.DOTALL) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>") == 1:
        count += 0.125
    if text.count("</reasoning>") == 1:
        count += 0.125
    if text.count("<answer>") == 1:
        count += 0.125
        count -= len(text.split("</answer>")[-1])*0.001
    if text.count("</answer>") == 1:
        count += 0.125
        count -= (len(text.split("</answer>")[-1]) - 1)*0.001
    # if count>0.498:
    #     reasoning = text.split("<reasoning>")[-1].split('</reasoning>')[0]
    #     count += (len(reasoning))*0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    # contents = [completion[0]["content"] for completion in completions]
    # print(contents)
    contents = completions

    return [count_xml(c) for c in contents]

# Load model

In [None]:
# from transformers import AutoModelForCausalLM, AutoTokenizer
# model_name = "SweatyCrayfish/llama-3-8b-quantized"
# model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto",
#                                              load_in_4bit=True,
#                                              attn_implementation="flash_attention_2"
#                                              )

# tokenizer = AutoTokenizer.from_pretrained(model_name)
# tokenizer.pad_token = tokenizer.eos_token
# run_name = f"{model_name.replace('/', '-')}-GRPO-Chess-vLLM-scratch-easy-50000-1ep-out200-reward-quantized_2_smallmoves"


In [None]:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
import torch

# Create a quantization configuration to load the model in 4-bit
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,  # or torch.float16 if preferred
    bnb_4bit_use_double_quant=True,         # improves accuracy of quantization
    bnb_4bit_quant_type="nf4"               # common option for 4-bit (NF4)
)
model_name = "meta-llama/meta-Llama-3.1-8B-Instruct"
model_name = "meta-llama/Llama-3.2-3B-Instruct"
# Load the model with the quantization configuration
model = AutoModelForCausalLM.from_pretrained(
    model_name,              # e.g. "meta-llama/Llama-3B-Instruct"
    quantization_config=quantization_config,
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
run_name = f"{model_name.replace('/', '-')}-GRPO-Chess-vLLM-scratch-easy-50000-1ep-out200-reward-quantized_2_smallmoves"


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
from unsloth import FastLanguageModel, PatchFastRL, is_bfloat16_supported
PatchFastRL("GRPO", FastLanguageModel)

import torch
max_seq_length = 512 # Can increase for longer reasoning traces
model_name = "meta-llama/meta-Llama-3.1-8B-Instruct"
model_name = "unsloth/Llama-3.2-3B-Instruct-bnb-4bit"
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = model_name,
    max_seq_length = max_seq_length,
    load_in_4bit = True, # False for LoRA 16bit
    fast_inference = True, # Enable vLLM fast inference
    # max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.6, # Reduce if out of memory
)


==((====))==  Unsloth 2025.2.12: Fast Llama patching. Transformers: 4.48.3.
   \\   /|    GPU: NVIDIA L4. Max memory: 22.161 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.5.1+cu124. CUDA: 8.9. CUDA Toolkit: 12.4. Triton: 3.1.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.28.post3. FA2 = False]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: vLLM loading unsloth/llama-3.2-3b-instruct-bnb-4bit with actual GPU utilization = 59.5%
Unsloth: Your GPU has CUDA compute capability 8.9 with VRAM = 22.16 GB.
Unsloth: Using conservativeness = 1.0. Chunked prefill tokens = 512. Num Sequences = 224.
Unsloth: vLLM's KV Cache can use up to 10.63 GB. Also swap space = 6 GB.
INFO 02-19 03:52:29 config.py:542] This model supports multiple tasks: {'reward', 'generate', 'embed', 'score', 'classify'}. Defaulting to 'generate'.
Unsloth: vLLM Bitsandbytes config using kwargs = {'load_in_8bi

Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


INFO 02-19 03:52:36 model_runner.py:1115] Loading model weights took 2.1362 GB
INFO 02-19 03:52:36 punica_selector.py:18] Using PunicaWrapperGPU.
INFO 02-19 03:52:39 worker.py:267] Memory profiling takes 2.54 seconds
INFO 02-19 03:52:39 worker.py:267] the current vLLM instance can use total_gpu_memory (22.16GiB) x gpu_memory_utilization (0.59) = 13.19GiB
INFO 02-19 03:52:39 worker.py:267] model weights take 2.14GiB; non_torch_memory takes 0.04GiB; PyTorch activation peak memory takes 1.04GiB; the rest of the memory reserved for KV Cache is 9.97GiB.
INFO 02-19 03:52:39 executor_base.py:110] # CUDA blocks: 5832, # CPU blocks: 3510
INFO 02-19 03:52:39 executor_base.py:115] Maximum concurrency for 512 tokens per request: 182.25x
INFO 02-19 03:52:43 model_runner.py:1434] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error oc

Capturing CUDA graph shapes: 100%|██████████| 31/31 [00:39<00:00,  1.27s/it]

INFO 02-19 03:53:22 model_runner.py:1562] Graph capturing finished in 39 secs, took 0.59 GiB
INFO 02-19 03:53:22 llm_engine.py:431] init engine (profile, create kv cache, warmup model) took 46.63 seconds





In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model_name = "Qwen/Qwen2.5-0.5B-Instruct"
# model_name = "microsoft/Phi-3.5-mini-instruct"
# model_name = "Qwen/Qwen2.5-1.5B-Instruct"

# output_dir="outputs/Qwen-1.5B-GRPO"
# run_name="Phi-3.5-GRPO-Chess-vLLM-scratch-hard-5000-2ep-out200-r4.0-train-18-21-22-23"
run_name = f"{model_name.replace('/', '-')}-GRPO-Chess-vLLM-scratch-easy-50000-1ep-out200-reward-quantized_2_smallmoves"
run_name = f"{model_name.replace('/', '-')}-GRPO-Chess-scratch-easy-50000-1ep-out200-reward-quantized_2_smallmoves"
output_dir = f"outputs/{run_name}"

# tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')
# tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map=None,
    # attn_implementation="flash_attention_2"
).to("cuda")


# model = AutoModelForCausalLM.from_pretrained(model_name,
#                                                   device_map="auto",
#                                                   load_in_8bit=True,  # Enables 8-bit quantization
#                                                   torch_dtype="auto",  # Automatically selects the correct data type
#                                                   attn_implementation="flash_attention_2"
#                                                   )


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [None]:
run_name = f"{model_name.replace('/', '-')}-GRPO-Chess-scratch-gms8k-reward-quantized_2_smallmoves"
output_dir = f"outputs/{run_name}"


In [None]:
# from peft import LoraConfig, get_peft_model
# import torch.optim as optim

# target_modules=["q_proj", "v_proj"]
# target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
#                       "gate_proj", "up_proj", "down_proj",]
# for p in model.parameters():
#     p.requires_grad = False

# lora_config = LoraConfig(
#     r=16, lora_alpha=16, lora_dropout=0.05,
#     bias="none", task_type="CAUSAL_LM",
#     target_modules=target_modules
# )

# model.add_adapter(lora_config, adapter_name="policy")
# model.set_adapter("policy")

# # Make sure "new_policy" is trainable
# for name, param in model.named_parameters():
#     if "policy" in name:
#         param.requires_grad = True  # Unfreeze only this adapter

# trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# print(f"Trainable parameters: {trainable_params}")

# optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()))

# Reinforcement Learning

In [None]:
from os import remove
import warnings
import torch
from torch import nn
from torch.utils.data import Sampler
from typing import Any, Callable, List, Optional, Union

import transformers
from packaging import version
from transformers import (
    AutoModelForCausalLM,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    Trainer,
    GenerationConfig,
)
from datasets import Dataset


def selective_log_softmax(logits: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
    """Compute log softmax probabilities and select values for given input_ids."""
    log_probs = logits.log_softmax(dim=-1)
    return log_probs.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)


RewardFunc = Union[str, nn.Module, Callable[[List[str], List[str]], List[float]]]


class GRPOTrainer(Trainer):
    """
    GRPOTrainer for RL on language models with modular data preparation, reward computation,
    and generation selection components.
    """

    def __init__(
        self,
        model: Union[str, nn.Module],
        reward_funcs: Union[RewardFunc, List[RewardFunc]],
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Dataset] = None,
        processing_class: Optional[AutoTokenizer] = None,
        max_prompt_length: int = 256,
        max_completion_length: int = 128,
        num_generations: int = 2,
        beta: float = 0.1,
        temperature: float = 1.0,
        use_peft: bool = False,
        sample_selection: str = "best_worst",
        **kwargs,
    ):
        """
        Initialize the trainer with model, tokenizer, reward functions, and generation configuration.
        """
        self._metrics = {}

        # Initialize model.
        if isinstance(model, str):
            self.model_id = model
            self.model = AutoModelForCausalLM.from_pretrained(model)
        else:
            self.model_id = getattr(model.config, "_name_or_path", "unknown_model")
            self.model = model

        # Load tokenizer if not provided.
        if processing_class is None:
            self.processing_class = AutoTokenizer.from_pretrained(self.model_id, padding_side="left")
            if self.processing_class.pad_token is None:
                warnings.warn("No pad token found; setting EOS token as pad.")
                self.processing_class.pad_token = self.processing_class.eos_token
        else:
            self.processing_class = processing_class

        self.model.config.pad_token_id = self.processing_class.pad_token_id

        # Set up reference model.
        if use_peft:
            print('Using peft, turning off ref model...')
            self.ref_model = None
        else:
            self.ref_model = AutoModelForCausalLM.from_pretrained(self.model_id).to(self.model.device)

        # Prepare reward functions.
        self.reward_funcs = self._prepare_reward_functions(reward_funcs)

        self.max_prompt_length = max_prompt_length
        self.max_completion_length = max_completion_length
        self.num_generations = num_generations
        self.beta = beta
        self.temperature = temperature
        self.use_peft = use_peft
        self.sample_selection = sample_selection

        def data_collator(features):
            return features

        super().__init__(
            model=self.model,
            args=kwargs.get("args", None),
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            processing_class=self.processing_class,
            data_collator=data_collator,
        )

        self.generation_config = GenerationConfig(
            max_new_tokens=self.max_completion_length,
            do_sample=True,
            temperature=self.temperature,
            pad_token_id=self.processing_class.pad_token_id,
        )

    def _prepare_reward_functions(self, reward_funcs: Union[RewardFunc, List[RewardFunc]]) -> List:
        """Convert reward functions into a uniform list of callable functions or models."""
        if not isinstance(reward_funcs, list):
            reward_funcs = [reward_funcs]

        funcs = []
        for rf in reward_funcs:
            if isinstance(rf, str):
                rm = AutoModelForSequenceClassification.from_pretrained(rf, num_labels=1)
                rm.config.pad_token_id = self.processing_class.pad_token_id
                funcs.append(rm)
            elif isinstance(rf, nn.Module):
                funcs.append(rf)
            else:
                funcs.append(rf)
        return funcs

    def _process_prompts(self, raw_prompts: List[str], device: torch.device):
        """
        Combine tokenizing prompts, generating completions, and building completion masks.
        Returns:
          - prompt_ids: Tensor of tokenized prompts.
          - prompt_mask: Attention mask for the prompts.
          - prompt_completion_ids: Generated sequences (prompt+completion).
          - completion_ids: Generated completions only.
          - completion_mask: Mask for completions (zeros out tokens after EOS).
        """
        # Tokenize prompts.
        tokenized = self.processing_class(raw_prompts, return_tensors="pt", padding=True, padding_side="left",
            add_special_tokens=False, truncation=True, max_length=self.max_prompt_length).to(device)
        prompt_ids = tokenized["input_ids"]
        prompt_mask = tokenized["attention_mask"]

        # Generate completions.
        with torch.no_grad():
            prompt_completion_ids = self.model.generate(prompt_ids, attention_mask=prompt_mask,
                                                        generation_config=self.generation_config,
                                                        num_return_sequences=self.num_generations,
                                                        do_sample=True, temperature=self.temperature)

        # Separate completions from prompts.
        prompt_len = prompt_ids.size(1)
        completion_ids = prompt_completion_ids[:, prompt_len:]

        # Build a mask that zeros out tokens after the first EOS token.
        completion_mask = torch.ones_like(completion_ids, dtype=torch.long, device=device)
        eos_positions = (completion_ids == self.processing_class.eos_token_id).int().argmax(dim=1)
        for i, pos in enumerate(eos_positions):
            if pos > 0:
                completion_mask[i, pos + 1:] = 0

        return prompt_ids, prompt_mask, prompt_completion_ids, completion_ids, completion_mask

    def _compute_rewards_and_advantages(self, raw_prompts: List[str], completion_ids: List[str],
        device: torch.device, inputs: List[dict]) -> torch.Tensor:
        """
        Compute rewards from reward functions and normalize them to obtain advantages.
        Returns:
            advantages: Normalized advantages as a flat tensor.
        """
        batch_size = len(raw_prompts)

        # Decode completions and expand prompts.
        completions = [self.processing_class.decode(c, skip_special_tokens=True) for c in completion_ids]
        expanded_prompts = [p for p in raw_prompts for _ in range(self.num_generations)]

        device_rewards = torch.zeros(batch_size * self.num_generations, device=device)
        for rf in self.reward_funcs:
            if isinstance(rf, nn.Module):
                texts = [p + c for p, c in zip(expanded_prompts, completions)]
                rm_inputs = self.processing_class(texts, return_tensors="pt", padding=True, truncation=True,
                                                  max_length=512).to(device)
                with torch.no_grad():
                    rm_logits = rf(**rm_inputs).logits[:, 0]
                device_rewards += rm_logits
            elif callable(rf):
                keys = [key for key in inputs[0] if key not in ["prompt", "completion"]]
                reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
                # Now I need to expand them just the way we expaned the prompts
                for key in reward_kwargs:
                    reward_kwargs[key] = [r for r in reward_kwargs[key] for _ in range(self.num_generations)]
                rewards_list = rf(prompts=expanded_prompts, completions=completions, **reward_kwargs)
                rewards_tensor = torch.tensor(rewards_list, device=device)
                device_rewards += rewards_tensor
                self._metrics[f"rewards/{rf.__name__}"] = rewards_tensor.mean().item()

        rewards_group = device_rewards.view(batch_size, self.num_generations)
        group_mean = rewards_group.mean(dim=1, keepdim=True)
        group_std = rewards_group.std(dim=1, keepdim=True) + 1e-6
        advantages = ((rewards_group - group_mean) / group_std).view(-1)
        return advantages

    def _select_indices_by_strategy(self, advantages: torch.Tensor, batch_size: int) -> dict:
        """
        Select the best and worst completions based on computed advantages.
        Returns a dictionary with selected tensors.
        """
        if self.sample_selection == "best_worst":
            advantages_reshaped = advantages.view(batch_size, self.num_generations)
            selected_indices = []
            for i in range(batch_size):
                best_idx = torch.argmax(advantages_reshaped[i]).item()
                worst_idx = torch.argmin(advantages_reshaped[i]).item()
                selected_indices.extend([i * self.num_generations + best_idx, i * self.num_generations + worst_idx])
            selected_indices = torch.tensor(selected_indices, device=self.model.device)
        elif self.sample_selection == "all":
            selected_indices = torch.arange(batch_size * self.num_generations, device=self.model.device)
        else:
            raise ValueError(f"Unknown selection strategy: {self.selection_strategy}")
        return selected_indices

    def _select_generations(self, advantages: torch.Tensor, batch_size: int, prompt_completion_ids: torch.Tensor,
        completion_ids: torch.Tensor, completion_mask: torch.Tensor, prompt_mask: torch.Tensor
    ) -> dict:
        """
        Select the best and worst completions based on computed advantages.
        Returns a dictionary with selected tensors.
        """
        selected_indices = self._select_indices_by_strategy(advantages, batch_size)

        selected_prompt_completion_ids = prompt_completion_ids[selected_indices]
        selected_completion_ids = completion_ids[selected_indices]
        selected_completion_mask = completion_mask[selected_indices]
        selected_advantages = advantages[selected_indices]

        full_attention_mask_all = torch.cat([
            torch.repeat_interleave(prompt_mask, self.num_generations, dim=0),
            completion_mask
        ], dim=1)
        selected_full_attention_mask = full_attention_mask_all[selected_indices]

        return {
            "selected_prompt_completion_ids": selected_prompt_completion_ids,
            "selected_completion_ids": selected_completion_ids,
            "selected_completion_mask": selected_completion_mask,
            "selected_full_attention_mask": selected_full_attention_mask,
            "selected_advantages": selected_advantages,
        }

    def _prepare_inputs(self, inputs: List[dict]) -> dict:
        """
        Prepare the training inputs by:
          - Extracting prompts, processing them (tokenization, generation, and masking),
          - Decoding completions,
          - Computing rewards/advantages and selecting the best/worst completions,
          - Computing reference log-probabilities.
        """
        device = self.model.device

        # Extract the prompts from the inputs. Inputs may contain other info for calculation of reward.
        raw_prompts = [self.processing_class.apply_chat_template(d["prompt"], add_generation_prompt=True,
                continue_final_message=False, tokenize=False) for d in inputs]

        # Process prompts: tokenization, generation, and masking of prompts.
        prompt_ids, prompt_mask, prompt_completion_ids, completion_ids, completion_mask = self._process_prompts(raw_prompts, device)

        # Compute rewards and advantages.
        advantages = self._compute_rewards_and_advantages(raw_prompts, completion_ids, device, inputs)

        # Select best and worst completions per prompt.
        batch_size = len(raw_prompts)
        selection = self._select_generations(advantages, batch_size, prompt_completion_ids, completion_ids,
                                             completion_mask, prompt_mask)

        # Compute reference log-probabilities for the selected completions.
        with torch.no_grad():
            ref_per_token_logps = self._get_per_token_logps(
                self.ref_model,
                selection["selected_prompt_completion_ids"].to(device),
                selection["selected_full_attention_mask"],
                logits_to_keep=selection["selected_completion_ids"].size(1)
            )

        return {
            # "completion_ids": selection["selected_completion_ids"],
            "completion_mask": selection["selected_completion_mask"],
            "ref_per_token_logps": ref_per_token_logps,
            "advantages": selection["selected_advantages"],
            "full_input_ids": selection["selected_prompt_completion_ids"],
            "full_attention_mask": selection["selected_full_attention_mask"],
        }

    def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
        """
        Compute per-token log probabilities using either PEFT or the provided model.
        """
        if self.use_peft:
            if model is None: # We need to calculate for ref model, in peft mode
                self.model.disable_adapters()
                to_be_used_model = self.model
            else: # We need to calculate for main policy model, in peft mode
                self.model.enable_adapters()
                to_be_used_model = self.model
        else:
            to_be_used_model = model

        logits = to_be_used_model(input_ids=input_ids, attention_mask=attention_mask,
                                  logits_to_keep=logits_to_keep + 1).logits

        if self.use_peft and model is None: # It was peft and we had ref model to calculate output
            self.model.enable_adapters()

        logits = logits[:, :-1, :]
        input_ids = input_ids[:, -logits_to_keep:]
        logits = logits[:, -logits_to_keep:]
        return selective_log_softmax(logits, input_ids)

    def compute_loss(self, model, inputs, return_outputs=False, **kwrds):
        """
        Compute the GRPO loss using policy log-probs, advantages, and a KL penalty.
        """
        if return_outputs:
            raise ValueError("GRPOTrainer does not support `return_outputs=True`.")

        # completion_ids = inputs["completion_ids"]
        completion_mask = inputs["completion_mask"]
        ref_per_token_logps = inputs["ref_per_token_logps"]
        advantages = inputs["advantages"]
        full_input_ids = inputs["full_input_ids"]
        full_attention_mask = inputs["full_attention_mask"]
        seq_len = completion_mask.size(1)

        per_token_logps = self._get_per_token_logps(model, full_input_ids, full_attention_mask, logits_to_keep=seq_len)

        if self.use_peft:
            kl_term = 0.0
        else:
            kl_term = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1

        exp_term = torch.exp(per_token_logps - per_token_logps.detach())
        per_token_loss = -(exp_term * advantages.unsqueeze(1) - self.beta * kl_term)
        masked_loss = per_token_loss * completion_mask
        loss_per_sample = masked_loss.sum(dim=1) / completion_mask.sum(dim=1)
        loss = loss_per_sample.mean()

        completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean()
        mean_kl = ((kl_term * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
        self._metrics['kl'] = mean_kl.item()
        self._metrics['completion_length'] = completion_length.item()
        self._metrics['loss'] = loss.item()

        self.log(self._metrics.copy())
        self._metrics.clear()
        return loss

    def prediction_step(self, model, inputs, prediction_loss_only: bool, ignore_keys: Optional[List[str]] = None):
        """Perform a prediction step without backpropagation using the prepared inputs."""
        inputs = self._prepare_inputs(inputs)
        with torch.no_grad():
            loss = self.compute_loss(model, inputs)
            loss = loss.detach()
        return (loss, None, None)


In [None]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    # output_dir='QWEN_1.5B_GRPO_custom_sign_only',
    output_dir=run_name,
    learning_rate=5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type='cosine',
    per_device_train_batch_size=4,
    gradient_accumulation_steps=1,
    num_train_epochs=1,
    bf16=True,
    max_grad_norm=0.1,
    logging_dir='./logs',
    report_to="wandb",
    logging_steps=1,
    remove_unused_columns=False
)

trainer = GRPOTrainer(
    model=model,
    reward_funcs=[xmlcount_reward_func,
        soft_format_reward_func,
        correctness_reward_func],
    train_dataset=dataset,
    # eval_dataset=eval_ds,
    max_prompt_length=256,
    max_completion_length=200,
    # optim = "adamw_8bit",
    num_generations=8,
    temperature=1.0,
    beta=0.02,
    use_peft=False,
    sample_selection="best_worst",

    # remove_unused_columns=False
    # Provide HF TrainingArguments via `args=...`, or defaults:
    args=training_args  # or your custom `transformers.TrainingArguments`
)

trainer.train()


-------------------- 
Question:
<|im_start|>system

You are going to respond to a user query. You always first reason and then provide your answer.
You enclose your reasoning process and answer within <reasoning> </reasoning> and <answer> </answer> tags, respectively, i.e.,

 Respond in the following format:
<reasoning> your reasoning process here </reasoning>
<answer> your final answer here </answer>

 <|im_end|>
<|im_start|>user
The following chess moves have been played: b2b4 c7c5 f2f3 d8c7 g2g3 c5c4. It is white to move. Suggest the best feasible next move for white. 
 Respond in the following format:
<reasoning> your reasoning process here </reasoning>
<answer> your final answer here </answer>
<|im_end|>
<|im_start|>assistant
 
Response:
The chess strategy is to prevent the opponent from advancing their piece at c5. To do this, one can play the following moves:

1. a1: Moves c4
2. g3: Moves e7
3. a3: Moves a8
4. c4: Moves c2
5. a6: Moves a4
6. h2: Moves f3
7. h3: Moves d6
8. h5: M

Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss
0,-0.265145
1,-0.134996
2,-0.096814
3,-0.488804
4,-0.530313
5,-0.054998
6,-0.031951
7,0.055398
8,0.008462
9,-0.134996


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
<answer>
d2d4
</answer> 
Extracted:
d2d4 
Score:
[0.2, 1.2, 0.2, 1.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2]
-------------------- 
Question:
<|im_start|>system

You are going to respond to a user query. You always first reason and then provide your answer.
You enclose your reasoning process and answer within <reasoning> </reasoning> and <answer> </answer> tags, respectively, i.e.,

 Respond in the following format:
<reasoning> your reasoning process here </reasoning>
<answer> your final answer here </answer>

 <|im_end|>
<|im_start|>user
The following chess moves have been played: f2f3 b8a6 b2b4 g8f6 c2c4 a6b8 h2h4 g7g5 c1a3 a7a5. It is white to move. Suggest the best feasible next move for white. 
 Respond in the following format:
<reasoning> your reasoning process here </reasoning>
<answer> your final answer here </answe