 **추론 모델 훈련 방법은 [Reasoning blog](https://unsloth.ai/blog/r1-reasoning) 참고**


### Installation

In [1]:
# ! pip install unsloth vllm
# # Gemma3를 사용할 경우
# ! pip install --no-deps git+https://github.com/huggingface/transformers@v4.49.0-Gemma-3

In [2]:
import torch

max_seq_length = 1024
model_name = "Yeongi/gemma-3-4b-it-bnb-4bit-lora"

### Unsloth

In [3]:
from unsloth import FastLanguageModel

model, _ = FastLanguageModel.from_pretrained(
    model_name=model_name,
    max_seq_length=max_seq_length,
    load_in_4bit=False,
    dtype=torch.bfloat16,
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
INFO 04-12 04:08:56 [__init__.py:239] Automatically detected platform cuda.
==((====))==  Unsloth 2025.3.19: Fast Gemma3 patching. Transformers: 4.51.0.dev0. vLLM: 0.8.2.
   \\   /|    NVIDIA GeForce RTX 4070 Ti SUPER. Num GPUs = 1. Max memory: 15.992 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.9. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post2. FA2 = True]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA.


### 화나는 tokenizer 버그
https://github.com/unslothai/unsloth/issues/2214

In [4]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

### LoRA

In [5]:
model = FastLanguageModel.get_peft_model(
    model,
    finetune_vision_layers=False,  # Turn off for just text!
    finetune_language_layers=True,  # Should leave on!
    finetune_attention_modules=True,  # Attention good for GRPO
    finetune_mlp_modules=True,  # SHould leave on always!
    r=8,  # Larger = higher accuracy, but might overfit
    lora_alpha=8,  # Recommended alpha == r at least
    lora_dropout=0.05,
    bias="none",
    random_state=3407,
    use_gradient_checkpointing="unsloth",
)

Unsloth: Making `base_model.model.vision_tower.vision_model` require gradients


일부 LoRA 레이어가 float16이라 float32로 변환하지 않으면 forward에서 오류 발생, Unsloth가 모델을 float32로 캐스팅하므로 model.to(torch.float32) 필요

In [6]:
model = model.to(torch.float32)

### 데이터 준비
[GSM8K](https://huggingface.co/datasets/openai/gsm8k) 데이터셋 사용

In [7]:
from datasets import load_dataset

dataset = load_dataset("openai/gsm8k", "main", split="train")

모든 answer에 ####가 있음으로 이를 기준으로 최종 정답 추출

In [8]:
def extract_hash_answer(text):
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

사용자 정의가 가능한 시스템 프롬프트를 생성. 작업 또는 사고/추론 섹션을 위한 4개의 추가 기호와 최종 답변을 추가

In [None]:
reasoning_start = "<|begin_of_thought|>"
reasoning_end = "<|end_of_thought|>"
solution_start = "<|begin_of_solution|>"
solution_end = "<|end_of_solution|>"
boxed_answer = r"\\boxed{.*?}"  # 정규식 패턴

system_prompt = f"""You are a math problem solver. Follow these steps strictly:

1. {reasoning_start} to {reasoning_end}: 
- Analyze given problem
- Explore multiple approaches
- Detail step-by-step reasoning

2. {solution_start} to {solution_end}:
- State final answer in `$\\boxed{{ANSWER}}$`
- Provide concise explanation

**Rules:**
• Use exact tags: {reasoning_start}, {reasoning_end}, {solution_start}, {solution_end}
• Never use markdown
"""

In [10]:
dataset = dataset.map(
    lambda x: {
        "prompt": [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": x["question"]},
        ],
        "answer": extract_hash_answer(x["answer"]),
    }
)

### 형식 보상 함수

In [11]:
import re
from ahocorasick import Automaton


def match_format_exactly(completions, **kwargs):
    """
    순차적 검사를 통과할 때마다 점수를 추가하는 보상 함수
    - 필수 태그 존재 여부 검사 통과: 1.5 점
    - 중복 태그 검사 통과: 0.75 점 (누적 2.25)
    - 태그 순서 검사 통과: 0.75 점 (총 3.0)
    """

    scores = []
    for completion in completions:
        score = 0.0
        response = completion[0]["content"]
        required_tags = [
            "<|begin_of_thought|>",
            "<|end_of_thought|>",
            "<|begin_of_solution|>",
            r"\\boxed{.*?}",
            "<|end_of_solution|>",
        ]

        # 1. 필수 태그 존재 여부 검사
        all_tags_present = True
        for token in required_tags:
            if token == r"\\boxed{.*?}":
                # 정규식 패턴은 re.search로 확인
                if not re.search(token, response):
                    all_tags_present = False
                    break
            elif token not in response:
                # 일반 문자열은 in 연산자로 확인
                all_tags_present = False
                break

        if not all_tags_present:
            scores.append(score)
            continue  # 첫 단계 통과 실패 시 0점 반환

        score += 1.5  # 필수 태그 검사 통과

        # 2. 중복 태그 검사
        no_duplicates = True
        for tag in required_tags:
            if tag == r"\\boxed{.*?}":
                # 정규식 패턴은 re.findall로 개수 확인
                matches = re.findall(tag, response)
                if len(matches) > 1:
                    no_duplicates = False
                    break
            else:
                # 일반 문자열은 count로 개수 확인
                if response.count(tag) > 1:
                    no_duplicates = False
                    break

        if not no_duplicates:
            # 중복 검사 실패 시 현재까지 점수 반환
            scores.append(score)
            continue

        score += 0.75  # 중복 검사 통과

        # 3. 태그 순서 검사
        found_matches = []  # (original_tag, start_index, end_index) 튜플 저장 리스트

        # Aho-Corasick을 위한 Automaton 준비 (고정 문자열용)
        automaton = Automaton()
        regex_patterns = []  # (original_regex_pattern, compiled_regex) 튜플 저장 리스트

        # 태그를 고정 문자열과 정규식 패턴으로 분리
        for tag in required_tags:
            # 여기서는 \\boxed{.*?} 만 정규식으로 간주 (필요시 로직 확장)
            if tag == r"\\boxed{.*?}":
                try:
                    compiled_re = re.compile(tag)
                    regex_patterns.append((tag, compiled_re))
                except re.error as e:
                    print(f"정규식 컴파일 오류: {tag} - {e}")  # 오류 처리
                    # 정규식이 잘못되면 순서 검사를 제대로 할 수 없으므로 현재 보상 반환
                    scores.append(score)
                    continue
            else:
                # Automaton에 고정 문자열 추가 (key와 value 모두 원본 태그 사용)
                automaton.add_word(tag, tag)

        automaton.make_automaton()

        # 1) 고정 문자열 찾기 (Aho-Corasick 사용)
        for end_index, original_tag in automaton.iter(response):
            # Aho-Corasick은 매칭된 단어(original_tag)를 반환하므로 바로 사용
            start_index = end_index - len(original_tag) + 1
            found_matches.append((original_tag, start_index, end_index))

        # 2) 정규식 패턴 찾기
        for original_pattern, compiled_re in regex_patterns:
            for match in compiled_re.finditer(response):
                # 찾은 위치 정보와 함께 '원본 정규식 패턴'을 저장
                found_matches.append((original_pattern, match.start(), match.end() - 1))

        # 3) 찾은 모든 태그를 시작 위치(start_index) 기준으로 정렬
        found_matches.sort(key=lambda x: x[1])

        # 4) 정렬된 결과에서 원본 태그(패턴)만 순서대로 추출
        found_sequence = [match[0] for match in found_matches]

        # 5) 추출된 순서가 required_tags와 일치하는지 비교
        if found_sequence == required_tags:
            score += 0.75  # 순서 검사 통과

        scores.append(score)

    # 최종 보상 반환 (순서 검사 통과 시 3.0, 실패 시 2.25)
    return scores

### 부분 형식 보상 함수
실패할 경우, 부분적으로 형식을 따르는 모델에 보상을 준다.

In [12]:
def match_format_approximately(completions, **kwargs):
    scores = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]

        # re.findall()을 사용하여 패턴과 일치하는 모든 부분을 찾습니다.
        boxed_matches = re.findall(boxed_answer, response)

        # 고정 문자열 태그 개수 확인
        score += 0.2 if response.count(reasoning_start) == 1 else -0.2
        score += 0.2 if response.count(reasoning_end) == 1 else -0.2
        score += 0.2 if response.count(solution_start) == 1 else -0.2
        score += 0.2 if response.count(solution_end) == 1 else -0.2
        score += 0.2 if len(boxed_matches) == 1 else -0.2
        scores.append(score)
    return scores

### 정답 보상 함수

In [13]:
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify


def check_answer(prompts, completions, answer, **kwargs):
    """정확도 기반 보상 함수로, 모델의 답변이 정답과 일치하는지 확인합니다.

    수학 문제에 특화되어 있으며, LaTeX 형식으로 된 답변을 파싱해 정답과 비교합니다.
    정답과 일치하면 3.0, 그렇지 않으면 -1.5의 보상을 반환합니다.

    Args:
        completions: 모델이 생성한 답변 목록
        answer: 정답 목록

    Returns:
        float: 각 답변에 대한 정확도 보상 점수 목록
    """
    question = prompts[0][-1]["content"]
    contents = [completion[0]["content"] for completion in completions]
    print(
        "*" * 20,
        f"\nQuestion:\n{question}",
        f"\nAnswer:\n{answer[0]}",
        f"\nResponse:\n{contents[0]}",
    )
    scores = []
    for content, ans in zip(contents, answer):
        if len(ans) != 0:
            # 답변이 올바른 LaTeX 형식으로 제공되어야 함 (형식이 잘못된 연산자 없이)
            answer_parsed = parse(
                content,
                extraction_config=[
                    LatexExtractionConfig(
                        normalization_config=NormalizationConfig(
                            nits=False,
                            malformed_operators=False,
                            basic_latex=True,
                            equations=False,
                            boxed="all",
                            units=True,
                        ),
                        # boxed가 먼저 시도되도록 함
                        boxed_match_priority=1,
                        try_extract_without_anchor=True,
                    )
                ],
                extraction_mode="first_match",
            )
            # 내용이 정답과 같으면 1, 아니면 0의 보상
            try:
                is_correct = verify(answer_parsed, ans)
                score = 3.0 if is_correct else -1.5  # 정답일 때 3점, 오답일 때 -1.5점
            except Exception as e:
                print(f"검증 실패: {e}, 답변: {answer_parsed}, 정답: {ans}")
                score = 0.0
            scores.append(score)  # 모든 경우에 한 번만 추가
        else:
            score = 0.0
            print(f"정답 솔루션 파싱 실패: {ans}")
            scores.append(score)
    return scores

In [14]:
import os
from unsloth import is_bfloat16_supported

output_dir = os.path.join("outputs", model_name)
max_prompt_length = 256

from trl import GRPOConfig, GRPOTrainer

trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    train_dataset=dataset,
    reward_funcs=[
        match_format_exactly,
        match_format_approximately,
        check_answer,
    ],
    args=GRPOConfig(
        bf16=is_bfloat16_supported,
        learning_rate=5e-6,
        adam_beta1=0.9,
        adam_beta2=0.99,
        weight_decay=0.1,
        warmup_ratio=0.1,
        lr_scheduler_type="constant_with_warmup",
        optim="adamw_torch_fused",
        logging_steps=1,
        per_device_train_batch_size=4,  # 메모리 사용량이 많다면 줄이세요.
        gradient_accumulation_steps=4,
        # gradient_checkpointing=True,
        max_grad_norm=0.2,
        num_generations=4,  # 중요! 몇 개의 답변을 생성할 것인지 결정. 메모리 사용량이 많다면 줄이세요.
        max_prompt_length=max_prompt_length,
        max_completion_length=max_seq_length - max_prompt_length,
        num_train_epochs=1,  # 전체 학습을 실행하려면 1이상으로 설정하세요.
        # max_steps=100,  # 테스트 시에만 사용
        save_steps=5,
        save_total_limit=20,
        report_to="wandb",  # Weights & Biases, 사용하지 않을 시 "none"
        run_name="gemma3_4b_lora_grpo",  # Wanb 사용시에만
        output_dir=output_dir,  # checkpoint 저장 경로
        # seed=42,
    ),
)

In [None]:
trainer_stat = trainer.train(
    resume_from_checkpoint="./outputs/Yeongi/gemma-3-4b-it-bnb-4bit-lora/checkpoint-565"
)

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 7,473 | Num Epochs = 1 | Total steps = 1,868
O^O/ \_/ \    Batch size per device = 4 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (4 x 4 x 1) = 16
 "-____-"     Trainable parameters = 16,394,240/4,000,000,000 (0.41% trained)
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mhiyo2044[0m ([33mhiyo2044-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


`generation_config` default values have been modified to match model-specific defaults: {'top_k': 64, 'top_p': 0.95, 'bos_token_id': 2, 'eos_token_id': [1, 106]}. If this is not desired, please set these values explicitly.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


******************** 
Question:
Terry's mom brought home 4 different colored dogs from the shelter. The brown dog weighs 4 pounds. The black dog weighs 1 pound more than the brown dog. The white dog weighs twice as much as the brown dog. The grey dog weighs 2 pounds less than the black dog. What's the average weight of all the dogs? 
Answer:
5 
Response:
<|begin_of_thought|>
Let's analyze the given information. We have four dogs with different colors and weights. We need to find the average weight of all the dogs.

First, we are given the weight of the brown dog: 4 pounds.
The black dog weighs 1 pound more than the brown dog, so its weight is 4 + 1 = 5 pounds.
The white dog weighs twice as much as the brown dog, so its weight is 2 * 4 = 8 pounds.
The grey dog weighs 2 pounds less than the black dog, so its weight is 5 - 2 = 3 pounds.

Now we have the weights of all four dogs:
Brown dog: 4 pounds
Black dog: 5 pounds
White dog: 8 pounds
Grey dog: 3 pounds

To find the average weight, we 

Step,Training Loss,reward,reward_std,completion_length,kl,rewards / match_format_exactly,rewards / match_format_approximately,rewards / check_answer
566,0.0013,6.43125,1.1375,434.5625,0.03324,2.8125,0.9,2.71875
567,0.0011,5.015625,0.643561,502.625,0.027445,2.390625,0.75,1.875
568,0.0017,6.4375,0.649519,380.9375,0.041666,3.0,1.0,2.4375
569,0.0008,4.775,0.1,448.8125,0.019785,2.25,0.65,1.875
570,0.0012,7.0,0.0,335.9375,0.030165,3.0,1.0,3.0
571,0.0006,2.85625,0.8125,634.9375,0.014914,1.5,0.325,1.03125
572,0.0011,6.43125,1.1375,373.4375,0.02751,2.8125,0.9,2.71875
573,0.0013,6.4375,0.649519,367.3125,0.031675,3.0,1.0,2.4375
574,0.0015,4.453125,1.1286,505.5,0.037092,2.390625,0.75,1.3125
575,0.0011,6.21875,1.5625,423.5,0.028007,2.625,0.875,2.71875


******************** 
Question:
Tom's cat needs an expensive surgery.  He has had pet insurance for 24 months that cost $20 per month.  The procedure cost $5000 but the insurance covers all but 20% of this.  How much money did he save by having insurance? 
Answer:
3520 
Response:
<|begin_of_thought|>
The problem describes a scenario where Tom's cat needs surgery, and he has pet insurance. The insurance covers all but 20% of the surgery cost. We need to determine how much money Tom saved by having insurance.

First, let's calculate the total cost of the insurance over the 24 months.
Monthly cost = $20
Number of months = 24
Total insurance cost = $20 * 24 = $480

Next, let's calculate the amount covered by the insurance.
Surgery cost = $5000
Insurance covers 100% - 20% = 80% of the cost.
Insurance coverage = 80% of $5000 = 0.80 * $5000 = $4000

Now, we need to find out how much Tom has to pay for the surgery.
Tom's cost = Surgery cost - Insurance coverage
Tom's cost = $5000 - $4000 = $10