# Installation

In [None]:
! pip install unsloth
! pip install unsloth vllm

# Load model

In [None]:
from unsloth import FastLanguageModel

max_seq_length = 4500
lora_rank = 32
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "Qwen/Qwen3-4B-Base",
    max_seq_length = max_seq_length,   # Context length - can be longer, but uses more memory
    fast_inference=True,  # Enable vLLM fast inference
    max_lora_rank=lora_rank,
    gpu_memory_utilization=0.85,
)

model = FastLanguageModel.get_peft_model(
    model,
    r=lora_rank,  # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],  # Remove QKVO if out of memory
    lora_alpha=lora_rank,
    use_gradient_checkpointing="unsloth",  # Enable long context finetuning
    random_state=3407,
)

# Pre Train
预训练使模型遵循交互式提交的格式

In [None]:
import datasets

sys_prompt="""You are a professional code assistant with expertise in code analysis and version control best practices.
1. Your task is to analyze code changes between <DIFF></DIFF>, summarize their theme and core content.
2. Generate a structured commit message between <COMMIT></COMMIT>. Your message must strictly follow the conventional commits specification.
"""

ds = datasets.load_from_disk("/home/circle/code/AImmit_ai/dataset_generation")
ds = ds.to_pandas()[
    ["diff", "commit_message"]
]
ds = ds[:50]

def format_dataset(x):
    diff = f"<DIFF>{x['diff']}</DIFF>"
    commit_message = x["commit_message"]
    final_prompt = f"<COMMIT>{commit_message}</COMMIT>"
    return [
        {"role": "system", "content": sys_prompt},
        {"role": "user", "content": diff},
        {"role": "assistant", "content": final_prompt},
    ]

ds["prompt"] = ds.apply(format_dataset, axis = 1)

In [None]:
chat_template = \
    "{% if messages[0]['role'] == 'system' %}"\
        "{{ messages[0]['content'] + eos_token }}"\
        "{% set loop_messages = messages[1:] %}"\
    "{% else %}"\
        "{{ '{system_prompt}' + eos_token }}"\
        "{% set loop_messages = messages %}"\
    "{% endif %}"\
    "{% for message in loop_messages %}"\
        "{% if message['role'] == 'user' %}"\
            "{{ message['content'] }}"\
        "{% elif message['role'] == 'assistant' %}"\
            "{{ message['content'] + eos_token }}"\
        "{% endif %}"\
    "{% endfor %}"\
    "{% if add_generation_prompt %}{{ '<COMMIT>' }}"\
    "{% endif %}"
chat_template = chat_template\
    .replace("'{system_prompt}'",   f"'{sys_prompt}'")
tokenizer.chat_template = chat_template

tokenizer.apply_chat_template(ds["prompt"][0], tokenize = False)

In [None]:
from datasets import Dataset

ds["text"] = tokenizer.apply_chat_template(ds["prompt"].values.tolist(), tokenize = False)
dataset = Dataset.from_pandas(ds)
dataset

In [None]:
from trl import SFTTrainer, SFTConfig
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    args = SFTConfig(
        dataset_text_field = "text",
        per_device_train_batch_size = 1,
        gradient_accumulation_steps = 1, # Use GA to mimic batch size!
        warmup_steps = 5,
        num_train_epochs = 2, # Set this for 1 full training run.
        learning_rate = 2e-4, # Reduce to 2e-5 for long training runs
        logging_steps = 5,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        report_to = "none", # Use this for WandB etc
    ),
)
trainer.train()

In [None]:
text = tokenizer.apply_chat_template(
    dataset[0]["prompt"][:2],
    tokenize = False,
    add_generation_prompt = True, # Must add for generation
)

from transformers import TextStreamer
_ = model.generate(
    **tokenizer(text, return_tensors = "pt").to("cuda"),
    temperature = 0,
    max_new_tokens = 1024,
    streamer = TextStreamer(tokenizer, skip_prompt = False),
)

In [None]:
import gc
import torch
del dataset
del ds
torch.cuda.empty_cache()
gc.collect()

# Train

In [None]:
import datasets

sys_prompt=f"""{sys_prompt}
3. Ensure the commit message is concise yet informative.
4. Focus on the most significant changes if the diff is extensive."""

# ds = datasets.load_dataset("circle33/conventional_commits")
ds = datasets.load_from_disk("/home/circle/code/AImmit_ai/dataset_generation")

ds = ds.map(
    lambda x: {
        "prompt": [
            {"role": "system", "content": sys_prompt},
            {"role": "user", "content": x["diff"]},
        ],
        "commit_message": x["commit_message"],
    }
)

ds[0]

# Reward functions

1. Format reward: ensure the output is in the correct format. (10 points)
2. Check the output for similarity to known commit messages. The higher the similarity, the higher the reward. (90 points)

In [None]:
import re

commit_end_regex = r"</COMMIT>[\s]{0,}" + \
    "(?:" + re.escape(tokenizer.eos_token) + ")?"

match_format = re.compile(
    rf"<COMMIT>(.+?){commit_end_regex}"\
    rf"[\s]{{0,}}$",
    flags = re.MULTILINE | re.DOTALL
)

def format_exactaly_reward(prompts, completions, **kwargs):
    responses = [completion[0]["content"] for completion in completions]

    return [
        3.0 if match_format.search(response) is not None else 0 for response in responses
    ]

def format_approximately_reward(prompts, completions, **kwargs):
    scores = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]

        score += 0.5 if response.count("</COMMIT>") == 1 else -1.0
        scores.append(score)

    return scores

In [None]:
from sentence_transformers.cross_encoder import CrossEncoder

CROSS_MODEL = CrossEncoder("cross-encoder/stsb-distilroberta-base")

def cross_compare(model_commit, input_commit):
    """
    Compares two commit messages using a cross-encoder model and returns a similarity score.
    The score is normalized to a 0-scale.
    """
    try:
        score = CROSS_MODEL.predict([model_commit, input_commit])
        if score > 0.9:
            return 2.0
        elif score > 0.7:
            return 1.5
        elif score > 0.5:
            return 1.0
        else:
            return 0.0
    except Exception as e:
        print(f"Error during semantic similarity calculation: {e}")
        return 0.0

match_commit = re.compile(
    r"<COMMIT>(.*?)</COMMIT>",
    flags = re.MULTILINE | re.DOTALL
)

global PRINTED_TIMES
global PRINT_EVERY_STEPS
PRINTED_TIMES = 0
PRINT_EVERY_STEPS = 5

def score_reward(prompts, completions, commit_message, diff, **kwargs):
    diff_input = prompts[0][-1]["content"]
    responses = [completion[0]["content"] for completion in completions]
    extracted_responses = [
        guess.group(1)
        if (guess := match_commit.search(r)) is not None else None \
        for r in responses
    ]
    scores = []

    global PRINTED_TIMES
    global PRINT_EVERY_STEPS
    if PRINTED_TIMES % PRINT_EVERY_STEPS == 0:
        print(
            '='*50 + f"DIFF_INPUT:\n{diff_input}", f"\nEXPECTED_COMMIT:\n{commit_message[0]}", f"\nMODEL_RESPONSE:\n{responses[0]}"
        )
    PRINTED_TIMES += 1

    for guess, input_commit in zip(extracted_responses, commit_message):
        if guess is None:
            scores.append(-2.5)
            continue

        try:
            scores.append(cross_compare(model_commit=guess, input_commit=input_commit))
        except:
            scores.append(0.0)
            continue

    return scores

# Tokenize

In [None]:
tokenized = ds.map(
    lambda x: {"tokens" : tokenizer.apply_chat_template(x["prompt"], add_generation_prompt = True, tokenize = True)},
    batched = True,
)
print(tokenizer.decode(tokenized[0]["tokens"]))
tokenized = tokenized.map(lambda x: {"L" : len(x["tokens"])})

import numpy as np
maximun_length = int(np.quantile(tokenized["L"], 0.9))
print("Max Length = ", maximun_length)

# ds = ds.select(np.where(np.array(tokenized["L"]) <= maximun_length)[0])
del tokenized

# Training

In [None]:

max_prompt_length = maximun_length + 1
max_completion_length = max_seq_length - max_prompt_length
new_model_id="circle33/qwen-commit-7b-grpo"

from vllm import SamplingParams
vllm_sampling_params = SamplingParams(
    min_p = 0.1,
    top_p = 1.0,
    top_k = -1,
    seed = 3407,
    stop = [tokenizer.eos_token],
    include_stop_str_in_output = True,
)

from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    vllm_sampling_params = vllm_sampling_params,
    learning_rate = 8e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.01,
    lr_scheduler_type = "cosine",
    optim = "paged_adamw_8bit",
    logging_steps = 1,
    per_device_train_batch_size = 4,
    gradient_accumulation_steps = 1,
    num_generations = 2, # Decrease if out of memory
    max_prompt_length = max_prompt_length,
    max_completion_length = max_completion_length,
    max_grad_norm = 0.1,
    output_dir = "outputs",
    overwrite_output_dir = True,
    push_to_hub = True,
    hub_model_id=new_model_id,
    hub_strategy="every_save",
    save_strategy="steps",
    save_steps=50,
    save_total_limit=1,
    num_train_epochs=3,
    unsloth_num_chunks=4,
)

In [None]:
import wandb

wandb.init(project="GRPO-reboost")

In [None]:
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs=[
        format_exactaly_reward,
        format_approximately_reward,
        score_reward,
    ],
    args = training_args,
    train_dataset = ds,
)

trainer.train()
