## Installation

In [1]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1" huggingface_hub hf_transfer
    !pip install --no-deps unsloth

In [None]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm
else:
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    !pip install --no-deps unsloth vllm==0.8.5.post1

In [None]:
%%capture
#@title Colab Extra Install { display-mode: "form" }
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm
else:
    !pip install --no-deps unsloth vllm==0.8.5.post1
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    # Skip restarting message in Colab
    import sys, re, requests; modules = list(sys.modules.keys())
    for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1,<4.0.0" huggingface_hub hf_transfer
    
    # vLLM requirements - vLLM breaks Colab due to reinstalling numpy
    f = requests.get("https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt").content
    with open("vllm_requirements.txt", "wb") as file:
        file.write(re.sub(rb"(transformers|numpy|xformers)[^\n]{1,}\n", b"", f))
    !pip install -r vllm_requirements.txt

In [None]:
!pip install langid -qq

## Unsloth

In [None]:
from unsloth import FastModel
import torch
max_seq_length = 4096

model, tokenizer = FastModel.from_pretrained(
    model_name = "lmq1909/Qwen3-8B-LQA-14e-full",
    max_seq_length = max_seq_length, # Choose any for long context!
    load_in_4bit = True,  # 4 bit quantization to reduce memory
    load_in_8bit = False, # [NEW!] A bit more accurate, uses 2x memory
    full_finetuning = False, # [NEW!] We have full finetuning now!
    # token = "hf_...", # use one if using gated models
)

In [None]:
model = FastModel.get_peft_model(
    model,
    finetune_vision_layers     = False, # Turn off for just text!
    finetune_language_layers   = True,  # Should leave on!
    finetune_attention_modules = True,  # Attention good for GRPO
    finetune_mlp_modules       = True,  # SHould leave on always!

    r = 8,           # Larger = higher accuracy, but might overfit
    lora_alpha = 16,  # Recommended alpha == r at least
    lora_dropout = 0,
    bias = "none",
    random_state = 3407,
)

In [None]:
reasoning_start = None
reasoning_end = None
user_token = None
assistant_token = None

for token in tokenizer.get_added_vocab().keys():
    if "think" in token and "/" in token:
        reasoning_end = token
    elif "think" in token:
        reasoning_start = token
    elif "user" in token:
        user_token = token
    elif "assistant" in token:
        assistant_token = token

## data ALQAC 2025

In [2]:
!gdown 13WRuCCpUFPSqwQ3csIe7-UcoSfVcHjpU

Downloading...
From: https://drive.google.com/uc?id=13WRuCCpUFPSqwQ3csIe7-UcoSfVcHjpU
To: /kaggle/working/data_train_legal_qa_new.json
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5.83M/5.83M [00:00<00:00, 215MB/s]


In [3]:
import json
import random

random.seed(42)

# Gi·∫£ ƒë·ªãnh ƒë∆∞·ªùng d·∫´n t·ªáp JSON
file_path = 'data_train_legal_qa_new.json'

# B∆∞·ªõc 1: T·∫£i d·ªØ li·ªáu
try:
    with open(file_path, 'r', encoding='utf-8') as f:
        questions = json.load(f)
except FileNotFoundError:
    print(f"L·ªói: Kh√¥ng t√¨m th·∫•y t·ªáp {file_path}. Vui l√≤ng ƒë·∫£m b·∫£o t·ªáp t·ªìn t·∫°i v√† ƒë√∫ng ƒë∆∞·ªùng d·∫´n.")
    exit()
except json.JSONDecodeError:
    print(f"L·ªói: Kh√¥ng th·ªÉ gi·∫£i m√£ JSON t·ª´ t·ªáp {file_path}. ƒê·∫£m b·∫£o t·ªáp c√≥ ƒë·ªãnh d·∫°ng JSON h·ª£p l·ªá.")
    exit()

# B∆∞·ªõc 2: Ph√¢n lo·∫°i c√¢u h·ªèi
true_false_questions = [q for q in questions if q['question_type'] == 'ƒê√∫ng/Sai']
multiple_choice_questions = [q for q in questions if q['question_type'] == 'Tr·∫Øc nghi·ªám']
essay_questions = [q for q in questions if q['question_type'] == 'T·ª± lu·∫≠n']

# S·ªë l∆∞·ª£ng ban ƒë·∫ßu c·ªßa t·ª´ng lo·∫°i
num_true_false_total = len(true_false_questions)
num_multiple_choice_total = len(multiple_choice_questions)
num_essay_total = len(essay_questions)
total_samples = len(questions)

print(f"T·ªïng s·ªë m·∫´u ban ƒë·∫ßu: {total_samples}")
print(f" - ƒê√∫ng/Sai: {num_true_false_total}")
print(f" - Tr·∫Øc nghi·ªám: {num_multiple_choice_total}")
print(f" - T·ª± lu·∫≠n: {num_essay_total}\n")

# B∆∞·ªõc 3: Chu·∫©n b·ªã chia t·∫≠p d·ªØ li·ªáu
validation_set_target_size = 100
num_essay_in_val_requested = 5

validation_questions = []
train_questions = []

# --- X·ª≠ l√Ω c√¢u h·ªèi 'T·ª± lu·∫≠n' cho t·∫≠p validation ---
if num_essay_total < num_essay_in_val_requested:
    print(f"C·∫£nh b√°o: S·ªë l∆∞·ª£ng c√¢u h·ªèi 'T·ª± lu·∫≠n' ({num_essay_total}) √≠t h∆°n s·ªë l∆∞·ª£ng y√™u c·∫ßu cho t·∫≠p validation ({num_essay_in_val_requested}). S·∫Ω l·∫•y t·∫•t c·∫£ {num_essay_total} c√¢u h·ªèi 'T·ª± lu·∫≠n' cho t·∫≠p validation.")
    num_essay_in_val_actual = num_essay_total
else:
    num_essay_in_val_actual = num_essay_in_val_requested

random.shuffle(essay_questions)
val_essay_q = essay_questions[:num_essay_in_val_actual]
train_essay_q = essay_questions[num_essay_in_val_actual:]

validation_questions.extend(val_essay_q)
train_questions.extend(train_essay_q)

# --- X·ª≠ l√Ω c√¢u h·ªèi 'ƒê√∫ng/Sai' v√† 'Tr·∫Øc nghi·ªám' cho ph·∫ßn c√≤n l·∫°i c·ªßa t·∫≠p validation ---
remaining_val_slots = validation_set_target_size - len(val_essay_q)

if remaining_val_slots > 0:
    total_tf_mc_available = num_true_false_total + num_multiple_choice_total

    if total_tf_mc_available == 0:
        print("Kh√¥ng c√≥ c√¢u h·ªèi 'ƒê√∫ng/Sai' ho·∫∑c 'Tr·∫Øc nghi·ªám' n√†o trong d·ªØ li·ªáu g·ªëc ƒë·ªÉ chia cho t·∫≠p validation.")
    else:
        # T√≠nh to√°n t·ª∑ l·ªá d·ª±a tr√™n t·ªïng s·ªë c√¢u h·ªèi ƒê√∫ng/Sai v√† Tr·∫Øc nghi·ªám ban ƒë·∫ßu
        prop_tf = num_true_false_total / total_tf_mc_available
        prop_mc = num_multiple_choice_total / total_tf_mc_available

        # T√≠nh s·ªë l∆∞·ª£ng mong mu·ªën cho t·∫≠p validation d·ª±a tr√™n t·ª∑ l·ªá v√† s·ªë slot c√≤n l·∫°i
        num_tf_val_desired = round(remaining_val_slots * prop_tf)
        num_mc_val_desired = round(remaining_val_slots * prop_mc)
        
        # ƒêi·ªÅu ch·ªânh l√†m tr√≤n ƒë·ªÉ t·ªïng s·ªë l∆∞·ª£ng ch√≠nh x√°c v·ªõi remaining_val_slots
        current_sum_desired = num_tf_val_desired + num_mc_val_desired
        diff = remaining_val_slots - current_sum_desired
        if diff != 0:
            if diff > 0: # C·∫ßn th√™m m·∫´u
                # ∆Øu ti√™n th√™m v√†o lo·∫°i c√≥ t·ª∑ l·ªá l·ªõn h∆°n ƒë·ªÉ duy tr√¨ t√≠nh ƒë·∫°i di·ªán
                if prop_tf >= prop_mc:
                    num_tf_val_desired += diff
                else:
                    num_mc_val_desired += diff
            else: # C·∫ßn b·ªõt m·∫´u (diff l√† s·ªë √¢m)
                # ∆Øu ti√™n b·ªõt t·ª´ lo·∫°i c√≥ t·ª∑ l·ªá l·ªõn h∆°n
                if prop_tf >= prop_mc:
                    num_tf_val_desired += diff # diff l√† s·ªë √¢m
                else:
                    num_mc_val_desired += diff # diff l√† s·ªë √¢m
        
        # ƒê·∫£m b·∫£o kh√¥ng l·∫•y qu√° s·ªë l∆∞·ª£ng c√≥ s·∫µn trong d·ªØ li·ªáu g·ªëc
        num_tf_val_final = min(num_tf_val_desired, num_true_false_total)
        num_mc_val_final = min(num_mc_val_desired, num_multiple_choice_total)

        # Ki·ªÉm tra xem t·ªïng s·ªë l∆∞·ª£ng sau khi capping c√≥ ƒë·ªß ƒë·ªÉ ƒëi·ªÅn v√†o remaining_val_slots kh√¥ng
        if num_tf_val_final + num_mc_val_final < remaining_val_slots:
            print(f"C·∫£nh b√°o: Kh√¥ng ƒë·ªß c√¢u h·ªèi ƒê√∫ng/Sai v√† Tr·∫Øc nghi·ªám ƒë·ªÉ ƒëi·ªÅn ƒë·∫ßy ƒë·ªß {remaining_val_slots} v·ªã tr√≠ c√≤n l·∫°i trong t·∫≠p validation. S·∫Ω l·∫•y t·∫•t c·∫£ {num_tf_val_final + num_mc_val_final} c√¢u h·ªèi ƒê√∫ng/Sai v√† Tr·∫Øc nghi·ªám c√≥ th·ªÉ.")
        
        random.shuffle(true_false_questions)
        random.shuffle(multiple_choice_questions)

        val_tf_q = true_false_questions[:num_tf_val_final]
        val_mc_q = multiple_choice_questions[:num_mc_val_final]

        train_tf_q = true_false_questions[num_tf_val_final:]
        train_mc_q = multiple_choice_questions[num_mc_val_final:]

        validation_questions.extend(val_tf_q)
        validation_questions.extend(val_mc_q)

        train_questions.extend(train_tf_q)
        train_questions.extend(train_mc_q)
else:
    print("Kh√¥ng c√≤n ch·ªó tr·ªëng trong t·∫≠p validation ƒë·ªÉ th√™m c√¢u h·ªèi ƒê√∫ng/Sai ho·∫∑c Tr·∫Øc nghi·ªám (ƒë√£ ƒë·ªß 100 m·∫´u t·ª´ 'T·ª± lu·∫≠n').")


# B∆∞·ªõc 4: Ki·ªÉm tra v√† in k·∫øt qu·∫£
num_train_true_false = sum(1 for q in train_questions if q['question_type'] == 'ƒê√∫ng/Sai')
num_train_multiple_choice = sum(1 for q in train_questions if q['question_type'] == 'Tr·∫Øc nghi·ªám')
num_train_essay = sum(1 for q in train_questions if q['question_type'] == 'T·ª± lu·∫≠n')
total_train = len(train_questions)

num_val_true_false = sum(1 for q in validation_questions if q['question_type'] == 'ƒê√∫ng/Sai')
num_val_multiple_choice = sum(1 for q in validation_questions if q['question_type'] == 'Tr·∫Øc nghi·ªám')
num_val_essay = sum(1 for q in validation_questions if q['question_type'] == 'T·ª± lu·∫≠n')

print("\n--- K·∫øt qu·∫£ chia t·∫≠p d·ªØ li·ªáu ---")
print(f"T·ªïng s·ªë m·∫´u t·∫≠p hu·∫•n luy·ªán (train): {total_train}")
print(f" - ƒê√∫ng/Sai: {num_train_true_false}")
print(f" - Tr·∫Øc nghi·ªám: {num_train_multiple_choice}")
print(f" - T·ª± lu·∫≠n: {num_train_essay}\n")

print(f"T·ªïng s·ªë m·∫´u t·∫≠p ki·ªÉm ƒë·ªãnh (valid): {len(validation_questions)}")
print(f" - ƒê√∫ng/Sai: {num_val_true_false}")
print(f" - Tr·∫Øc nghi·ªám: {num_val_multiple_choice}")
print(f" - T·ª± lu·∫≠n: {num_val_essay}\n")

print(f"T·ªïng c·ªông: {total_train + len(validation_questions)} m·∫´u (kh·ªõp v·ªõi {total_samples} ban ƒë·∫ßu).")

T·ªïng s·ªë m·∫´u ban ƒë·∫ßu: 728
 - ƒê√∫ng/Sai: 387
 - Tr·∫Øc nghi·ªám: 285
 - T·ª± lu·∫≠n: 56


--- K·∫øt qu·∫£ chia t·∫≠p d·ªØ li·ªáu ---
T·ªïng s·ªë m·∫´u t·∫≠p hu·∫•n luy·ªán (train): 628
 - ƒê√∫ng/Sai: 332
 - Tr·∫Øc nghi·ªám: 245
 - T·ª± lu·∫≠n: 51

T·ªïng s·ªë m·∫´u t·∫≠p ki·ªÉm ƒë·ªãnh (valid): 100
 - ƒê√∫ng/Sai: 55
 - Tr·∫Øc nghi·ªám: 40
 - T·ª± lu·∫≠n: 5

T·ªïng c·ªông: 728 m·∫´u (kh·ªõp v·ªõi 728 ban ƒë·∫ßu).


In [4]:
from datasets import Dataset
train_dataset = Dataset.from_list(train_questions)
valid_dataset = Dataset.from_list(validation_questions)

In [5]:
# Format l·∫°i ƒë·ªÉ ph√π h·ª£p v·ªõi GRPO training
def format_to_chat_template(example):
    return {
        "prompt": [
            {"role": "system", "content": example["system_prompt"]},
            {"role": "user", "content": example["prompt"]},
        ],
        "answer": example["answer"],  ############ ho·∫∑c extract_hash_answer(...) n·∫øu c·∫ßn
    }

# # Lo·∫°i b·ªè c√°c c·ªôt kh√¥ng c·∫ßn thi·∫øt
# columns_to_remove = [
#     'question_id', 'text', 'relevant_articles', 'system_prompt',
#     'prompt', 'answer', 'answer_think', 'question_type'
# ]

# √Åp d·ª•ng l√™n train v√† valid datasets
train_dataset = train_dataset.map(format_to_chat_template)
valid_dataset = valid_dataset.map(format_to_chat_template)

Map:   0%|          | 0/628 [00:00<?, ? examples/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

In [6]:
dataset = train_dataset

In [7]:
dataset[100]

{'question_id': 'train_alqac25_14',
 'question_type': 'ƒê√∫ng/Sai',
 'text': 'Ch·ªìng c√≥ ngu·ªìn thu nh·∫≠p nhi·ªÅu h∆°n v·ª£ th√¨ ch·ªìng c√≥ quy·ªÅn chi·∫øm h·ªØu, s·ª≠ d·ª•ng t√†i s·∫£n chung nhi·ªÅu h∆°n so v·ªõi v·ª£, ƒë√∫ng hay sai?',
 'relevant_articles': ['Nguy√™n t·∫Øc chung v·ªÅ ch·∫ø ƒë·ªô t√†i s·∫£n c·ªßa v·ª£ ch·ªìng\n1. V·ª£, ch·ªìng b√¨nh ƒë·∫≥ng v·ªõi nhau v·ªÅ quy·ªÅn, nghƒ©a v·ª• trong vi·ªác t·∫°o l·∫≠p, chi·∫øm h·ªØu, s·ª≠ d·ª•ng, ƒë·ªãnh ƒëo·∫°t t√†i s·∫£n chung; kh√¥ng ph√¢n bi·ªát gi·ªØa lao ƒë·ªông trong gia ƒë√¨nh v√† lao ƒë·ªông c√≥ thu nh·∫≠p.\n\n2. V·ª£, ch·ªìng c√≥ nghƒ©a v·ª• b·∫£o ƒë·∫£m ƒëi·ªÅu ki·ªán ƒë·ªÉ ƒë√°p ·ª©ng nhu c·∫ßu thi·∫øt y·∫øu c·ªßa gia ƒë√¨nh.\n\n3. Vi·ªác th·ª±c hi·ªán quy·ªÅn, nghƒ©a v·ª• v·ªÅ t√†i s·∫£n c·ªßa v·ª£ ch·ªìng m√† x√¢m ph·∫°m ƒë·∫øn quy·ªÅn, l·ª£i √≠ch h·ª£p ph√°p c·ªßa v·ª£, ch·ªìng, gia ƒë√¨nh v√† c·ªßa ng∆∞·ªùi kh√°c th√¨ ph·∫£i b·ªìi th∆∞·ªùng.'],
 'answer': 'Sai',
 'system_prompt': 'B·∫°n l√† m·ªôt tr·ª£ l√Ω ph√°p l√

## reward funciton

In [None]:
import re

# Add optional EOS token matching
solution_end_regex = rf"{reasoning_end}(.*)"

match_format = re.compile(solution_end_regex, re.DOTALL)
match_format

In [None]:
# check xem c√≥ </think> hay kh√¥ng
def match_format_exactly(completions, **kwargs):
    scores = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        # Match if format is seen exactly!
        
        if match_format.search(response) is not None: score += 1.0
        scores.append(score)
    return scores

completions = [
    [{"content": "<think>\n\n**1. Ph√¢n t√≠ch c√¢u h·ªèi:**\nC√¢u h·ªèi y√™u c·∫ßu ph√¢n t√≠ch v√† gi·∫£i th√≠ch cho k·∫øt lu·∫≠n v·ªÅ c√°c quy·ªÅn c·ªßa H·ªôi ƒë·ªìng x√©t x·ª≠ ph√∫c th·∫©m khi xem x√©t quy·∫øt ƒë·ªãnh s∆° th·∫©m b·ªã kh√°ng c√°o, kh√°ng ngh·ªã. K·∫øt lu·∫≠n ƒë∆∞a ra l√† H·ªôi ƒë·ªìng x√©t x·ª≠ c√≥ ba quy·ªÅn: Gi·ªØ nguy√™n, S·ª≠a, H·ªßy v√† chuy·ªÉn h·ªì s∆°.\n\n**2. D·∫´n ch·ª©ng t·ª´ b·ªëi c·∫£nh:**\n- H·ªôi ƒë·ªìng x√©t x·ª≠ ph√∫c th·∫©m c√≥ quy·ªÅn **gi·ªØ nguy√™n** quy·∫øt ƒë·ªãnh c·ªßa T√≤a √°n c·∫•p s∆° th·∫©m (ƒêi·ªÉm a, M·ª•c 5).\n- H·ªôi ƒë·ªìng x√©t x·ª≠ ph√∫c th·∫©m c√≥ quy·ªÅn **s·ª≠a** quy·∫øt ƒë·ªãnh c·ªßa T√≤a √°n c·∫•p s∆° th·∫©m (ƒêi·ªÉm b, M·ª•c 5).\n- H·ªôi ƒë·ªìng x√©t x·ª≠ ph√∫c th·∫©m c√≥ quy·ªÅn **h·ªßy** quy·∫øt ƒë·ªãnh c·ªßa T√≤a √°n c·∫•p s∆° th·∫©m v√† chuy·ªÉn h·ªì s∆° v·ª• √°n cho T√≤a √°n c·∫•p s∆° th·∫©m ƒë·ªÉ ti·∫øp t·ª•c gi·∫£i quy·∫øt v·ª• √°n (ƒêi·ªÉm c, M·ª•c 5).\n\n**3. Suy lu·∫≠n step-by-step:**\na) CƒÉn c·ª© v√†o quy ƒë·ªãnh t·∫°i M·ª•c 5, H·ªôi ƒë·ªìng x√©t x·ª≠ ph√∫c th·∫©m c√≥ ba quy·ªÅn khi gi·∫£i quy·∫øt kh√°ng c√°o, kh√°ng ngh·ªã.\nb) Quy·ªÅn Gi·ªØ nguy√™n ƒë∆∞·ª£c quy ƒë·ªãnh t·∫°i ƒêi·ªÉm a, M·ª•c 5.\nc) Quy·ªÅn S·ª≠a ƒë∆∞·ª£c quy ƒë·ªãnh t·∫°i ƒêi·ªÉm b, M·ª•c 5.\nd) Quy·ªÅn H·ªßy v√† chuy·ªÉn h·ªì s∆° ƒë∆∞·ª£c quy ƒë·ªãnh t·∫°i ƒêi·ªÉm c, M·ª•c 5.\n\n**4. K·∫øt lu·∫≠n:**\nGi·ªØ nguy√™n quy·∫øt ƒë·ªãnh c·ªßa T√≤a √°n c·∫•p s∆° th·∫©m; s·ª≠a quy·∫øt ƒë·ªãnh c·ªßa T√≤a √°n c·∫•p s∆° th·∫©m; h·ªßy quy·∫øt ƒë·ªãnh c·ªßa T√≤a √°n c·∫•p s∆° th·∫©m v√† chuy·ªÉn h·ªì s∆° v·ª• √°n cho T√≤a √°n c·∫•p s∆° th·∫©m ƒë·ªÉ ti·∫øp t·ª•c gi·∫£i quy·∫øt v·ª• √°n.\n</think>\nGi·ªØ nguy√™n quy·∫øt ƒë·ªãnh c·ªßa T√≤a √°n c·∫•p s∆° th·∫©m; s·ª≠a quy·∫øt ƒë·ªãnh c·ªßa T√≤a √°n c·∫•p s∆° th·∫©m; h·ªßy quy·∫øt ƒë·ªãnh c·ªßa T√≤a √°n c·∫•p s∆° th·∫©m v√† chuy·ªÉn h·ªì s∆° v·ª• √°n cho T√≤a √°n c·∫•p s∆° th·∫©m ƒë·ªÉ ti·∫øp t·ª•c gi·∫£i quy·∫øt v·ª• √°n."}],
    [{"content": "**1. Ph√¢n t√≠ch c√¢u h·ªèi:**\nC√¢u h·ªèi y√™u c·∫ßu ph√¢n t√≠ch v√† gi·∫£i th√≠ch cho k·∫øt lu·∫≠n v·ªÅ c√°c quy·ªÅn c·ªßa H·ªôi ƒë·ªìng x√©t x·ª≠ ph√∫c th·∫©m khi xem x√©t quy·∫øt ƒë·ªãnh s∆° th·∫©m b·ªã kh√°ng c√°o, kh√°ng ngh·ªã. K·∫øt lu·∫≠n ƒë∆∞a ra l√† H·ªôi ƒë·ªìng x√©t x·ª≠ c√≥ ba quy·ªÅn: Gi·ªØ nguy√™n, S·ª≠a, H·ªßy v√† chuy·ªÉn h·ªì s∆°.\n\n**2. D·∫´n ch·ª©ng t·ª´ b·ªëi c·∫£nh:**\n- H·ªôi ƒë·ªìng x√©t x·ª≠ ph√∫c th·∫©m c√≥ quy·ªÅn **gi·ªØ nguy√™n** quy·∫øt ƒë·ªãnh c·ªßa T√≤a √°n c·∫•p s∆° th·∫©m (ƒêi·ªÉm a, M·ª•c 5).\n- H·ªôi ƒë·ªìng x√©t x·ª≠ ph√∫c th·∫©m c√≥ quy·ªÅn **s·ª≠a** quy·∫øt ƒë·ªãnh c·ªßa T√≤a √°n c·∫•p s∆° th·∫©m (ƒêi·ªÉm b, M·ª•c 5).\n- H·ªôi ƒë·ªìng x√©t x·ª≠ ph√∫c th·∫©m c√≥ quy·ªÅn **h·ªßy** quy·∫øt ƒë·ªãnh c·ªßa T√≤a √°n c·∫•p s∆° th·∫©m v√† chuy·ªÉn h·ªì s∆° v·ª• √°n cho T√≤a √°n c·∫•p s∆° th·∫©m ƒë·ªÉ ti·∫øp t·ª•c gi·∫£i quy·∫øt v·ª• √°n (ƒêi·ªÉm c, M·ª•c 5).\n\n**3. Suy lu·∫≠n step-by-step:**\na) CƒÉn c·ª© v√†o quy ƒë·ªãnh t·∫°i M·ª•c 5, H·ªôi ƒë·ªìng x√©t x·ª≠ ph√∫c th·∫©m c√≥ ba quy·ªÅn khi gi·∫£i quy·∫øt kh√°ng c√°o, kh√°ng ngh·ªã.\nb) Quy·ªÅn Gi·ªØ nguy√™n ƒë∆∞·ª£c quy ƒë·ªãnh t·∫°i ƒêi·ªÉm a, M·ª•c 5.\nc) Quy·ªÅn S·ª≠a ƒë∆∞·ª£c quy ƒë·ªãnh t·∫°i ƒêi·ªÉm b, M·ª•c 5.\nd) Quy·ªÅn H·ªßy v√† chuy·ªÉn h·ªì s∆° ƒë∆∞·ª£c quy ƒë·ªãnh t·∫°i ƒêi·ªÉm c, M·ª•c 5.\n\n**4. K·∫øt lu·∫≠n:**\nGi·ªØ nguy√™n quy·∫øt ƒë·ªãnh c·ªßa T√≤a √°n c·∫•p s∆° th·∫©m; s·ª≠a quy·∫øt ƒë·ªãnh c·ªßa T√≤a √°n c·∫•p s∆° th·∫©m; h·ªßy quy·∫øt ƒë·ªãnh c·ªßa T√≤a √°n c·∫•p s∆° th·∫©m v√† chuy·ªÉn h·ªì s∆° v·ª• √°n cho T√≤a √°n c·∫•p s∆° th·∫©m ƒë·ªÉ ti·∫øp t·ª•c gi·∫£i quy·∫øt v·ª• √°n.\n</think>\nGi·ªØ nguy√™n quy·∫øt ƒë·ªãnh c·ªßa T√≤a √°n c·∫•p s∆° th·∫©m; s·ª≠a quy·∫øt ƒë·ªãnh c·ªßa T√≤a √°n c·∫•p s∆° th·∫©m; h·ªßy quy·∫øt ƒë·ªãnh c·ªßa T√≤a √°n c·∫•p s∆° th·∫©m v√† chuy·ªÉn h·ªì s∆° v·ª• √°n cho T√≤a √°n c·∫•p s∆° th·∫©m ƒë·ªÉ ti·∫øp t·ª•c gi·∫£i quy·∫øt v·ª• √°n."}],
    [{"content": "ƒê√°p √°n l√† 42."}],
]
scores = match_format_exactly(completions)
print(scores)  

In [None]:
# ƒê√∫ng ƒë·ªãnh d·∫°ng (<think> 1 l·∫ßn, </think> 1 l·∫ßn)
def match_format_approximately(completions, **kwargs):
    scores = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        # Count how many keywords are seen - we penalize if too many!
        # If we see 1, then plus some points!

        # No need to reward <think> since we always prepend it!
        score += 0.5 if response.count(reasoning_start) == 1 else -2.0
        score += 0.5 if response.count(reasoning_end)   == 1 else -2.0
        scores.append(score)
    return scores

completions = [
    # ‚úÖ 1. ƒê√∫ng ƒë·ªãnh d·∫°ng
    [{"content": "<think> reasoning here </think>\nThe answer is 42."}],
    # ‚ùå 2. Thi·∫øu <think>
    [{"content": "reasoning here </think>\nThe answer is 42."}],
    # ‚ùå 3. Thi·∫øu </think>
    [{"content": "<think> reasoning here \nThe answer is 42."}],
    # ‚ùå 4. Kh√¥ng c√≥ g√¨
    [{"content": "The answer is 42."}],
    # ‚ùå 5. Qu√° nhi·ªÅu <think>
    [{"content": "<think> step 1 </think>\n<think> step 2 </think>\nAnswer: 10"}],
    # ‚ùå 6. L·∫∑p </think>
    [{"content": "<think> a </think> </think>\nAnswer"}],
]
scores = match_format_approximately(completions)
for i, score in enumerate(scores):
    print(f"Test case {i+1} - Score: {score}")

In [None]:
def check_answer(prompts, completions, answer, **kwargs):
    question = prompts[0][-1]["content"]

    responses = [completion[0]["content"] for completion in completions]

    extracted_responses = [
        guess.group(1)
        if (guess := match_format.search(r)) is not None else None \
        for r in responses
    ]

    scores = []
    for guess, true_answer in zip(extracted_responses, answer):
        score = 0
        if guess is None:
            scores.append(-2.0)
            continue

        guess = str(guess)
        true_answer = str(true_answer)  # l·∫•y t·ª´ list, kh√¥ng ph·∫£i list nguy√™n

        if true_answer not in ['ƒê√∫ng','Sai','A','B','C','D']:
            score = 0
        elif guess == true_answer:
            score += 5
        elif guess.strip() == true_answer.strip():
            score += 3
        else:
            try:
                ratio = float(guess) / float(true_answer)
                if   0.9 <= ratio <= 1.1:
                    score += 2.0
                elif 0.8 <= ratio <= 1.2:
                    score += 1
                else:
                    score -= 5
            except:
                score -= 20
        scores.append(score)
    return scores

prompts = [[{"role": "user", "content": "T√≠nh 7 x 6"}]]

completions = [
    [{"content": "<think>\n\n1. Ph√¢n t√≠ch c√¢u h·ªèi: C√¢u h·ªèi ƒë·∫∑t ra v·∫•n ƒë·ªÅ v·ªÅ vi·ªác, n·∫øu ch·ªìng c√≥ ngu·ªìn thu nh·∫≠p nhi·ªÅu h∆°n v·ª£, th√¨ li·ªáu ch·ªìng c√≥ ƒë∆∞·ª£c ph√©p chi·∫øm h·ªØu, s·ª≠ d·ª•ng t√†i s·∫£n chung nhi·ªÅu h∆°n so v·ªõi v·ª£ hay kh√¥ng. ƒêi·ªÅu n√†y li√™n quan ƒë·∫øn nguy√™n t·∫Øc b√¨nh ƒë·∫≥ng gi·ªØa v·ª£ v√† ch·ªìng trong vi·ªác qu·∫£n l√Ω t√†i s·∫£n chung.\n\n2. D·∫´n ch·ª©ng t·ª´ b·ªëi c·∫£nh:\n   - 'V·ª£, ch·ªìng b√¨nh ƒë·∫≥ng v·ªõi nhau v·ªÅ quy·ªÅn, nghƒ©a v·ª• trong vi·ªác t·∫°o l·∫≠p, chi·∫øm h·ªØu, s·ª≠ d·ª•ng, ƒë·ªãnh ƒëo·∫°t t√†i s·∫£n chung; kh√¥ng ph√¢n bi·ªát gi·ªØa lao ƒë·ªông trong gia ƒë√¨nh v√† lao ƒë·ªông c√≥ thu nh·∫≠p.' (ƒêi·ªÉm 1)\n   - 'V·ª£, ch·ªìng c√≥ nghƒ©a v·ª• b·∫£o ƒë·∫£m ƒëi·ªÅu ki·ªán ƒë·ªÉ ƒë√°p ·ª©ng nhu c·∫ßu thi·∫øt y·∫øu c·ªßa gia ƒë√¨nh.' (ƒêi·ªÉm 2)\n   - 'Vi·ªác th·ª±c hi·ªán quy·ªÅn, nghƒ©a v·ª• v·ªÅ t√†i s·∫£n c·ªßa v·ª£ ch·ªìng m√† x√¢m ph·∫°m ƒë·∫øn quy·ªÅn, l·ª£i √≠ch h·ª£p ph√°p c·ªßa v·ª£, ch·ªìng, gia ƒë√¨nh v√† c·ªßa ng∆∞·ªùi kh√°c th√¨ ph·∫£i b·ªìi th∆∞·ªùng.' (ƒêi·ªÉm 3)\n\n3. Suy lu·∫≠n step-by-step:\n   a) CƒÉn c·ª© v√†o ƒëi·ªÉm 1 c·ªßa b·ªëi c·∫£nh, v·ª£ v√† ch·ªìng b√¨nh ƒë·∫≥ng v·ªÅ m·ªçi quy·ªÅn v√† nghƒ©a v·ª• trong t√†i s·∫£n chung, kh√¥ng ph√¢n bi·ªát thu nh·∫≠p hay lao ƒë·ªông c√≥ thu nh·∫≠p hay kh√¥ng.\n   b) Thu nh·∫≠p nhi·ªÅu h∆°n kh√¥ng t·∫°o ra quy·ªÅn ∆∞u ti√™n h∆°n trong vi·ªác chi·∫øm h·ªØu, s·ª≠ d·ª•ng t√†i s·∫£n chung.\n   c) ƒê·∫∑t bi·ªát, ƒëi·ªÉm 3 n√™u r√µ n·∫øu c√≥ h√†nh vi x√¢m ph·∫°m quy·ªÅn l·ª£i th√¨ ph·∫£i b·ªìi th∆∞·ªùng, ƒëi·ªÅu n√†y c≈©ng x√°c nh·∫≠n r·∫±ng kh√¥ng c√≥ s·ª± ∆∞u ti√™n n√†o ƒë∆∞·ª£c ph√©p x√¢m ph·∫°m quy·ªÅn l·ª£i c·ªßa v·ª£/ch·ªìng.\n\n4. K·∫øt lu·∫≠n: Sai. Quy·ªÅn chi·∫øm h·ªØu t√†i s·∫£n chung ph·∫£i b√¨nh ƒë·∫≥ng gi·ªØa v·ª£ v√† ch·ªìng, kh√¥ng ph·ª• thu·ªôc v√†o thu nh·∫≠p c√° nh√¢n.\n</think>\nSai"}],  
    [{"content": "<think> 7 x 6 = 42 </think>Sai"}],           # ƒê√öng k·∫øt qu·∫£
    [{"content": "<think> 7 x 6 = 42 </think>ƒê√∫ng"}],          # Sai k·∫øt qu·∫£
]

answer = ["Sai"] * len(completions)
scores = check_answer(prompts, completions, answer)
for i, score in enumerate(scores, 1):
    print(f"Test case {i}: Score = {score}")

In [None]:
import langid

def get_lang(text: str) -> str:
    if not text:
        return "und"
    lang, _ = langid.classify(text)
    return lang


print(get_lang("Hello, How are you")) # This should return en
print(get_lang("Aku berpikir kalau aku adalah kamu")) # This should return id
print(get_lang("Xin ch√†o m·ªçi ng∆∞·ªùi")) # This should return vi

In [None]:
import re

# think v·ªõi ti·∫øng vi·ªát
def format_and_language_reward_func(completions, **kwargs):
    scores = []

    for completion_item in completions:
        if not completion_item or not isinstance(completion_item[0], dict) or "content" not in completion_item[0]:
            scores.append(-5.0)
            print(f"Warning: Malformed completion item, assigning default low score: {completion_item}")
            continue

        content = completion_item[0]["content"]

        lang = get_lang(content)

        if lang == 'vi':
            score = 0.5
        elif lang == 'en':
            score = -7.0
        elif lang == 'zh':
            score = -7.0
        else:
            score = -10.0

        scores.append(score)

    return scores

completions = [
    [{"role": "assistant", "content": "<think>The sum of 1 and 2 is 3, which we multiply by 4 to get 12.</think><answer>(1 + 2) * 4 = 12</answer>"}],
    [{"role": "assistant", "content": "<think>The sum of 1 and 2 is 3, which we multiply by 4 to get 12.</think>\nƒê√∫ng"}],
    [{"role": "assistant", "content": "<think>Ph√¢n t√≠ch</think>\nƒê√∫ng"}],
]
format_and_language_reward_func(prompts=prompts, completions=completions)

In [None]:
# tr·∫Øc nghi·ªám, ƒë√∫ng/ sai ph·∫£i ƒë√∫ng format
def format_answer(prompts, completions, answer, **kwargs):
    question = prompts[0][-1]["content"]

    responses = [completion[0]["content"] for completion in completions]

    extracted_responses = [
        guess.group(1)
        if (guess := match_format.search(r)) is not None else None \
        for r in responses
    ]

    scores = []
    for guess, true_answer in zip(extracted_responses, answer):
        score = 0
        if guess is None:
            scores.append(-2.0)
            continue

        guess = str(guess).strip()
        true_answer = str(true_answer).strip()  # l·∫•y t·ª´ list, kh√¥ng ph·∫£i list nguy√™n
        # print(guess, true_answer)

        if true_answer in ['ƒê√∫ng','Sai']:
            if guess in ['ƒê√∫ng','Sai']:
                score += 2
            else:
                score -= 20
        elif true_answer in ['A','B', 'C', 'D']:
            if guess in ['A','B', 'C', 'D']:
                score += 2
            else:
                score -= 20
        else:
            pass

        scores.append(score)
    return scores

prompts = [[{"role": "user", "content": "T√≠nh 7 x 6"}]]

completions = [
    [{"content": "<think>\n\n1. Ph√¢n t√≠ch c√¢u h·ªèi: C√¢u h·ªèi ƒë·∫∑t ra v·∫•n ƒë·ªÅ v·ªÅ vi·ªác, n·∫øu ch·ªìng c√≥ ngu·ªìn thu nh·∫≠p nhi·ªÅu h∆°n v·ª£, th√¨ li·ªáu ch·ªìng c√≥ ƒë∆∞·ª£c ph√©p chi·∫øm h·ªØu, s·ª≠ d·ª•ng t√†i s·∫£n chung nhi·ªÅu h∆°n so v·ªõi v·ª£ hay kh√¥ng. ƒêi·ªÅu n√†y li√™n quan ƒë·∫øn nguy√™n t·∫Øc b√¨nh ƒë·∫≥ng gi·ªØa v·ª£ v√† ch·ªìng trong vi·ªác qu·∫£n l√Ω t√†i s·∫£n chung.\n\n2. D·∫´n ch·ª©ng t·ª´ b·ªëi c·∫£nh:\n   - 'V·ª£, ch·ªìng b√¨nh ƒë·∫≥ng v·ªõi nhau v·ªÅ quy·ªÅn, nghƒ©a v·ª• trong vi·ªác t·∫°o l·∫≠p, chi·∫øm h·ªØu, s·ª≠ d·ª•ng, ƒë·ªãnh ƒëo·∫°t t√†i s·∫£n chung; kh√¥ng ph√¢n bi·ªát gi·ªØa lao ƒë·ªông trong gia ƒë√¨nh v√† lao ƒë·ªông c√≥ thu nh·∫≠p.' (ƒêi·ªÉm 1)\n   - 'V·ª£, ch·ªìng c√≥ nghƒ©a v·ª• b·∫£o ƒë·∫£m ƒëi·ªÅu ki·ªán ƒë·ªÉ ƒë√°p ·ª©ng nhu c·∫ßu thi·∫øt y·∫øu c·ªßa gia ƒë√¨nh.' (ƒêi·ªÉm 2)\n   - 'Vi·ªác th·ª±c hi·ªán quy·ªÅn, nghƒ©a v·ª• v·ªÅ t√†i s·∫£n c·ªßa v·ª£ ch·ªìng m√† x√¢m ph·∫°m ƒë·∫øn quy·ªÅn, l·ª£i √≠ch h·ª£p ph√°p c·ªßa v·ª£, ch·ªìng, gia ƒë√¨nh v√† c·ªßa ng∆∞·ªùi kh√°c th√¨ ph·∫£i b·ªìi th∆∞·ªùng.' (ƒêi·ªÉm 3)\n\n3. Suy lu·∫≠n step-by-step:\n   a) CƒÉn c·ª© v√†o ƒëi·ªÉm 1 c·ªßa b·ªëi c·∫£nh, v·ª£ v√† ch·ªìng b√¨nh ƒë·∫≥ng v·ªÅ m·ªçi quy·ªÅn v√† nghƒ©a v·ª• trong t√†i s·∫£n chung, kh√¥ng ph√¢n bi·ªát thu nh·∫≠p hay lao ƒë·ªông c√≥ thu nh·∫≠p hay kh√¥ng.\n   b) Thu nh·∫≠p nhi·ªÅu h∆°n kh√¥ng t·∫°o ra quy·ªÅn ∆∞u ti√™n h∆°n trong vi·ªác chi·∫øm h·ªØu, s·ª≠ d·ª•ng t√†i s·∫£n chung.\n   c) ƒê·∫∑t bi·ªát, ƒëi·ªÉm 3 n√™u r√µ n·∫øu c√≥ h√†nh vi x√¢m ph·∫°m quy·ªÅn l·ª£i th√¨ ph·∫£i b·ªìi th∆∞·ªùng, ƒëi·ªÅu n√†y c≈©ng x√°c nh·∫≠n r·∫±ng kh√¥ng c√≥ s·ª± ∆∞u ti√™n n√†o ƒë∆∞·ª£c ph√©p x√¢m ph·∫°m quy·ªÅn l·ª£i c·ªßa v·ª£/ch·ªìng.\n\n4. K·∫øt lu·∫≠n: Sai. Quy·ªÅn chi·∫øm h·ªØu t√†i s·∫£n chung ph·∫£i b√¨nh ƒë·∫≥ng gi·ªØa v·ª£ v√† ch·ªìng, kh√¥ng ph·ª• thu·ªôc v√†o thu nh·∫≠p c√° nh√¢n.\n</think>\nSai"}],           # ‚úÖ Ch√≠nh x√°c
    [{"content": "<think> 7 x 6 = 42 </think>K·∫øt lu·∫≠n: Sai"}],           # üî¢ Kh√¥ng g·∫ßn ƒë√∫ng
    [{"content": "<think> 7 x 6 = 42 </think>ƒê√∫ng"}],          # ‚ùå Kh√¥ng ph·∫£i s·ªë
]

answer = ["Sai"] * len(completions)
scores = format_answer(prompts, completions, answer)
for i, score in enumerate(scores, 1):
    print(f"Test case {i}: Score = {score}")

<a name="Train"></a>
## Train the model

Now set up GRPO Trainer and all configurations!

In [None]:
from datasets import Dataset
from collections import Counter

# T√≠nh ƒë·ªô d√†i prompt sau khi tokenize
def compute_prompt_length(example):
    tokens = tokenizer.apply_chat_template(example["prompt"], tokenize=True)
    example["prompt_length"] = len(tokens)
    return example

# Th√™m c·ªôt prompt_length
dataset_with_length = dataset.map(compute_prompt_length)

# ƒê·∫øm lo·∫°i c√¢u h·ªèi
type_counts = Counter(dataset_with_length["question_type"])
print("T·ªïng s·ªë c√¢u h·ªèi theo lo·∫°i:", type_counts)

# S·ªë m·∫´u c·∫ßn l·∫•y
target_counts = {
    "T·ª± lu·∫≠n": 60,
    "Tr·∫Øc nghi·ªám": 160,
    "ƒê√∫ng/Sai": 60,
}

# L·∫•y sample ng·∫Øn nh·∫•t theo t·ª´ng lo·∫°i
selected_samples = []
for q_type, n_samples in target_counts.items():
    filtered = dataset_with_length.filter(lambda x: x["question_type"] == q_type)
    sorted_samples = filtered.sort("prompt_length")
    selected = sorted_samples.select(range(min(n_samples, len(sorted_samples))))
    selected_samples.extend(selected)

# T·∫°o l·∫°i Dataset
final_dataset = Dataset.from_list(selected_samples)

# In k·∫øt qu·∫£
final_counts = Counter(final_dataset["question_type"])
print("S·ªë l∆∞·ª£ng m·∫´u sau khi ch·ªçn:")
for k, v in final_counts.items():
    print(f"{k}: {v} m·∫´u")

print(f"T·ªïng c·ªông: {len(final_dataset)} m·∫´u")

In [None]:
tokenized = final_dataset.map(
    lambda x: {"tokens" : tokenizer.apply_chat_template(x["prompt"], add_generation_prompt = True, tokenize = True)},
    batched = True,
)
# print(tokenizer.decode(tokenized[0]["tokens"]))
tokenized = tokenized.map(lambda x: {"L" : len(x["tokens"])})

import numpy as np
maximum_length = int(np.quantile(tokenized["L"], 1))
print("Max Length = ", maximum_length)

# Filter only samples smaller than 90% max length
final_dataset = final_dataset.select(np.where(np.array(tokenized["L"]) <= maximum_length)[0])
print(final_dataset)
del tokenized

In [None]:
max_prompt_length = maximum_length + 1 # + 1 just in case!
max_completion_length = max_seq_length - max_prompt_length

from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "adamw_torch_fused",
    logging_steps = 1,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 1, # Increase to 4 for smoother training
    num_generations = 4, # Decrease if out of memory
    max_prompt_length = max_prompt_length,
    max_completion_length = max_seq_length - max_prompt_length,
    # num_train_epochs = 1, # Set to 1 for a full training run
    max_steps = 100,
    save_steps = 10,
    max_grad_norm = 0.1,
    report_to = "none", # Can use Weights & Biases
    output_dir = "outputs",
)

trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        match_format_exactly,
        match_format_approximately,
        check_answer,
        # check_numbers,
        format_and_language_reward_func,
        format_answer,
    ],
    args = training_args,
    train_dataset = final_dataset,
)
trainer.train()

## save model

In [21]:
hf_token = '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'
username = 'lmq1909'
your_model_repo = 'Qwen3-8B-LQA-14e-GRPO-100s-save3-full'
# model.save_pretrained_merged("model", tokenizer, save_method = "merged_16bit")
model.push_to_hub_merged(f"{username}/{your_model_repo}", tokenizer, save_method = "merged_16bit", token = hf_token)

No files have been modified since last commit. Skipping to prevent empty commit.
No files have been modified since last commit. Skipping to prevent empty commit.


Found HuggingFace hub cache directory: /root/.cache/huggingface/hub
Checking cache directory for required files...
Successfully copied all 4 files from cache to /tmp/tmp6nhk8xjy.
Downloading safetensors index for lmq1909/Qwen3-8B-LQA-14e-full...


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

No files have been modified since last commit. Skipping to prevent empty commit.
Unsloth: Merging weights into 16bit:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.90G [00:00<?, ?B/s]

Unsloth: Merging weights into 16bit:  25%|‚ñà‚ñà‚ñå       | 1/4 [02:50<08:30, 170.19s/it]

model-00002-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

Unsloth: Merging weights into 16bit:  50%|‚ñà‚ñà‚ñà‚ñà‚ñà     | 2/4 [05:45<05:46, 173.23s/it]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

Unsloth: Merging weights into 16bit:  75%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå  | 3/4 [08:28<02:48, 168.50s/it]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.58G [00:00<?, ?B/s]

Unsloth: Merging weights into 16bit: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [09:17<00:00, 139.48s/it]


<a name="Inference"></a>
## Inference
Now let's try the model we just trained!

In [8]:
import os
from unsloth import FastLanguageModel
import torch

# T√¨m checkpoint l·ªõn nh·∫•t trong ./results
def get_latest_checkpoint(results_dir="./results"):
    checkpoints = []
    for name in os.listdir(results_dir):
        if name.startswith("checkpoint-"):
            try:
                step = int(name.split("-")[1])
                checkpoints.append((step, name))
            except ValueError:
                continue
    if not checkpoints:
        raise ValueError("‚ùå Kh√¥ng t√¨m th·∫•y checkpoint trong th∆∞ m·ª•c.")
    latest = max(checkpoints)[1]
    return os.path.join(results_dir, latest)

# L·∫•y checkpoint m·ªõi nh·∫•t
results_dir="./outputs"
results_dir="/kaggle/input/alqac-qwen-grpo-save3-full/outputs"
checkpoint_path = get_latest_checkpoint(results_dir)
checkpoint_path

ü¶• Unsloth: Will patch your computer to enable 2x faster free finetuning.


2025-07-22 06:57:34.230819: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1753167454.613568      36 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1753167454.710766      36 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


ü¶• Unsloth Zoo will now patch everything to make training faster!


'/kaggle/input/alqac-qwen-grpo-save3-full/outputs/checkpoint-100'

In [9]:
from unsloth import FastLanguageModel
import torch
max_seq_length = 4096 # Can increase for longer reasoning traces
lora_rank = 8 # Larger rank = smarter, but slower

model, tokenizer = FastLanguageModel.from_pretrained(
    # model_name = "/kaggle/input/alqac-qwen-grpo/outputs/checkpoint-100",
    model_name = checkpoint_path,
    max_seq_length = max_seq_length,
    load_in_4bit = True, # False for LoRA 16bit
    fast_inference = False, # Enable vLLM fast inference
    # fast_inference = False, 
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.7, # Reduce if out of memory
)

Unrecognized keys in `rope_scaling` for 'rope_type'='yarn': {'attn_factor'}


==((====))==  Unsloth 2025.7.6: Fast Qwen3 patching. Transformers: 4.52.4.
   \\   /|    Tesla T4. Num GPUs = 2. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Unrecognized keys in `rope_scaling` for 'rope_type'='yarn': {'attn_factor'}
Unrecognized keys in `rope_scaling` for 'rope_type'='yarn': {'attn_factor'}


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.58G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.90G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/171 [00:00<?, ?B/s]

Unsloth 2025.7.6 patched 36 layers with 36 QKV layers, 36 O layers and 36 MLP layers.


In [10]:
valid_dataset[-1]['prompt']

[{'content': 'B·∫°n l√† m·ªôt tr·ª£ l√Ω ph√°p l√Ω Ti·∫øng Vi·ªát, c√≥ nhi·ªám v·ª• tr·∫£ l·ªùi c√°c c√¢u h·ªèi ng·∫Øn g·ªçn, ch√≠nh x√°c, d·ª±a tr√™n n·ªôi dung ƒëi·ªÅu lu·∫≠t ƒë∆∞·ª£c cung c·∫•p. Ch·ªâ tr·∫£ l·ªùi b·∫±ng Ti·∫øng Vi·ªát.\n\nB·∫°n h√£y tr·∫£ l·ªùi theo ƒë·ªãnh d·∫°ng sau:\n<think>\n[Suy nghƒ©, ph√¢n t√≠ch c·ªßa b·∫°n]\n</think>\n[C√¢u tr·∫£ l·ªùi c·ªßa b·∫°n]\n',
  'role': 'system'},
 {'content': "D·ª±a v√†o b·ªëi c·∫£nh b√™n d∆∞·ªõi, h√£y ph√¢n t√≠ch k·ªπ tr∆∞·ªõc khi tr·∫£ l·ªùi c√¢u h·ªèi.\n\nLo·∫°i c√¢u h·ªèi: Tr·∫Øc nghi·ªám (format c·ªßa K·∫øt lu·∫≠n cu·ªëi c√πng sau khi suy lu·∫≠n l√† 1 trong 4 k·∫øt lu·∫≠n: 'A', 'B', 'C', 'D'. Kh√¥ng ƒë∆∞·ª£c gi·∫£i th√≠ch g√¨ th√™m.)\n\nC√¢u h·ªèi: Ch·ªìng kh√¥ng c√≥ quy·ªÅn y√™u c·∫ßu ly h√¥n \ntrong tr∆∞·ªùng h·ª£p n√†o ?\n\nB·ªëi c·∫£nh: \nQuy·ªÅn y√™u c·∫ßu gi·∫£i quy·∫øt ly h√¥n\n1. V·ª£, ch·ªìng ho·∫∑c c·∫£ hai ng∆∞·ªùi c√≥ quy·ªÅn y√™u c·∫ßu T√≤a √°n gi·∫£i quy·∫øt ly h√¥n.\n\n2. Cha, m·∫π, ng∆∞·ªùi th√¢n th√≠ch kh√°c c√

In [None]:
import time
start = time.time()
# prompt = 'D·ª±a v√†o b·ªëi c·∫£nh b√™n d∆∞·ªõi, h√£y ph√¢n t√≠ch k·ªπ tr∆∞·ªõc khi tr·∫£ l·ªùi c√¢u h·ªèi.\n\nC√¢u h·ªèi: Vi·ªác s·ª≠ d·ª•ng t√†i s·∫£n c√¥ng c·ªßa c∆° quan ƒë·ªÉ quy√™n g√≥p t·ª´ thi·ªán c√≥ tr√°i quy ƒë·ªãnh c·ªßa ph√°p lu·∫≠t kh√¥ng?\n\nB·ªëi c·∫£nh: 1. C∆° quan, t·ªï ch·ª©c, ƒë∆°n v·ªã, ng∆∞·ªùi c√≥ ch·ª©c v·ª•, quy·ªÅn h·∫°n ch·ªâ ƒë∆∞·ª£c s·ª≠ d·ª•ng t√†i ch√≠nh c√¥ng, t√†i s·∫£n c√¥ng ƒë·ªÉ l√†m qu√† t·∫∑ng v√¨ m·ª•c ƒë√≠ch t·ª´ thi·ªán, ƒë·ªëi ngo·∫°i v√† th·ª±c hi·ªán ch·∫ø ƒë·ªô, ch√≠nh s√°ch theo quy ƒë·ªãnh c·ªßa ph√°p lu·∫≠t.\n2. Vi·ªác t·∫∑ng qu√† ph·∫£i th·ª±c hi·ªán ƒë√∫ng ch·∫ø ƒë·ªô, ƒë·ªãnh m·ª©c, ti√™u chu·∫©n, ƒë·ªëi t∆∞·ª£ng theo quy ƒë·ªãnh c·ªßa ph√°p lu·∫≠t; c∆° quan, ƒë∆°n v·ªã t·∫∑ng qu√† ph·∫£i h·∫°ch to√°n k·∫ø to√°n v√† th·ª±c hi·ªán c√¥ng khai trong c∆° quan, ƒë∆°n v·ªã m√¨nh theo ƒë√∫ng quy ƒë·ªãnh c·ªßa ph√°p lu·∫≠t.\n\nH√£y sinh ph·∫ßn suy lu·∫≠n chi ti·∫øt theo m·∫´u b√™n d∆∞·ªõi trong th·∫ª <think>...</think>. B·∫°n c·∫ßn vi·∫øt ƒë·∫ßy ƒë·ªß c√°c b∆∞·ªõc ph√¢n t√≠ch, d·∫´n ch·ª©ng v√† suy lu·∫≠n tr∆∞·ªõc khi ƒë∆∞a ra c√¢u tr·∫£ l·ªùi ng·∫Øn g·ªçn.\nKh√¥ng ƒë∆∞·ª£c tr·∫£ l·ªùi ngay m√† ph·∫£i suy lu·∫≠n ƒë·∫ßy ƒë·ªß tr∆∞·ªõc trong <think>.\n\n<think>\n1. Ph√¢n t√≠ch c√¢u h·ªèi: [tr√¨nh b√†y ng·∫Øn g·ªçn n·ªôi dung v√† √Ω ƒë·ªãnh c·ªßa c√¢u h·ªèi]\n2. D·∫´n ch·ª©ng t·ª´ b·ªëi c·∫£nh:\n   - H√£y t√°ch t·ª´ng ƒëo·∫°n d√†i trong b·ªëi c·∫£nh th√†nh nhi·ªÅu √Ω nh·ªè r√µ r√†ng.\n   - M·ªói √Ω n√™n n√™u r√µ n·ªôi dung ph√°p l√Ω, vi·∫øt ng·∫Øn g·ªçn d·ªÖ hi·ªÉu.\n   - V√≠ d·ª•:\n       - [√Ω 1 t·ª´ ƒëo·∫°n lu·∫≠t A]\n       - [√Ω 2 t·ª´ ƒëo·∫°n lu·∫≠t A]\n       - [√Ω 3 t·ª´ ƒëo·∫°n lu·∫≠t B]\n   - Ghi r√µ ƒëo·∫°n n√†o c√≥ li√™n quan ƒë·∫øn c√¢u h·ªèi.\n3. Suy lu·∫≠n step-by-step:\n   a) [b∆∞·ªõc suy lu·∫≠n 1 d·ª±a tr√™n d·∫´n ch·ª©ng ·ªü tr√™n]\n   b) [b∆∞·ªõc suy lu·∫≠n 2 ti·∫øp theo]\n   ‚Ä¶\n4. K·∫øt lu·∫≠n: [t√≥m t·∫Øt c√¢u tr·∫£ l·ªùi cu·ªëi c√πng d·ª±a tr√™n suy lu·∫≠n]\n</think>\n\n[ƒê√°p √°n cu·ªëi c√πng sau khi suy lu·∫≠n]'
prompt = [msg['content'] for msg in valid_dataset[-1]['prompt'] if msg['role'] == 'user'][0]
# prompt += "\n<think>\n"
messages = [{'role': 'system',
  'content': 'B·∫°n l√† m·ªôt tr·ª£ l√Ω ph√°p l√Ω, c√≥ nhi·ªám v·ª• tr·∫£ l·ªùi c√°c c√¢u h·ªèi ng·∫Øn g·ªçn, ch√≠nh x√°c, d·ª±a tr√™n n·ªôi dung ƒëi·ªÅu lu·∫≠t ƒë∆∞·ª£c cung c·∫•p.'},
  {'role': 'user',
  'content': prompt}
          ]
# print(messages)

text = tokenizer.apply_chat_template(
    messages,
    tokenize = False,
    add_generation_prompt = True, # Must add for generation
    enable_thinking = True, # kh√¥ng c·∫ßn trong m√¥ h√¨nh n√†y
)
max_input_tokens = len(tokenizer(prompt)["input_ids"])
max_new_tokens = max_seq_length - max_input_tokens
max_new_tokens = 1024
print(max_new_tokens)

# from transformers import TextStreamer
# _ = model.generate(
#     **tokenizer(text, return_tensors = "pt").to("cuda"),
#     max_new_tokens = max_new_tokens,
#     temperature = 0.6, top_p = 0.95, top_k = 20, # For thinking
#     streamer = TextStreamer(tokenizer, skip_prompt = True),
# )

generated_output = model.generate(
    **tokenizer(text, return_tensors = "pt").to("cuda"),
    max_new_tokens = max_new_tokens,
    temperature = 0.6, top_p = 0.95, top_k = 20, # For thinking
    # temperature = 1.0,
    # top_k = 50,
)

# Gi·∫£i m√£ token ID th√†nh chu·ªói vƒÉn b·∫£n
decoded_output = tokenizer.decode(generated_output[0], skip_special_tokens=True)
output = decoded_output.split('<ÔΩúAssistantÔΩú>')[-1]
print(output)

end = time.time()
print(end-start)

In [None]:
output.find('<think>'), output.find('</think>')

In [None]:
#10 token ~ 1s