# Installation

In [None]:
! pip install unsloth

# Load model

In [None]:
from unsloth import FastLanguageModel

max_seq_length = 2048
lora_rank = 32
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "Qwen/Qwen3-4B",
    max_seq_length = max_seq_length,   # Context length - can be longer, but uses more memory
    fast_inference=True,  # Enable vLLM fast inference
    max_lora_rank=lora_rank,
    gpu_memory_utilization=0.85,
)

model = FastLanguageModel.get_peft_model(
    model,
    r=lora_rank,  # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],  # Remove QKVO if out of memory
    lora_alpha=lora_rank,
    use_gradient_checkpointing="unsloth",  # Enable long context finetuning
    random_state=3407,
)

# Data Prepare

In [None]:
import datasets

SYSTEM_PROMPT="""You are a professional code assistant with expertise in code analysis and version control best practices.
1. Your task is to analyze code changes (diffs), summarize their theme and core content.
2. Generate a structured commit message inside <commit_message> and </commit_message> tags. Your message must strictly follow the rules provided by the user."""

USER_PROMPT="""Task: generate a structured commit message based on the given code changes.

Rules:
- Commits MUST be prefixed with a type, which consists of a noun, `feat`, `fix`, etc., followed by an OPTIONAL scope, and a REQUIRED terminal colon and space.
- The type `feat` MUST be used when a commit adds a new feature to your application or library.
- The type `fix` MUST be used when a commit represents a bug fix for your application.
- A scope MAY be provided after a type. A scope MUST consist of a noun describing a section of the codebase surrounded by parenthesis, e.g., `fix(parser)`:
- A description MUST immediately follow the space after the type/scope prefix. The description is a short summary of the code changes, e.g., fix: array parsing issue when multiple spaces were contained in string.
- A longer commit body MAY be provided after the short description, providing additional contextual information about the code changes. The body MUST begin one blank line after the description.
- A footer of one or more lines MAY be provided one blank line after the body. The footer MUST contain meta-information about the commit, e.g., related pull-requests, reviewers, breaking changes, with one piece of meta-information per-line.
- Breaking changes MUST be indicated at the very beginning of the body section, or at the beginning of a line in the footer section. A breaking change MUST consist of the uppercase text BREAKING CHANGE, followed by a colon and a space.
- A description MUST be provided after the `BREAKING CHANGE`: , describing what has changed about the API, e.g., BREAKING CHANGE: environment variables now take precedence over config files.
- Types other than `feat` and `fix` MAY be used in your commit messages.
- The units of information that make up conventional commits MUST NOT be treated as case sensitive by implementors, with the exception of BREAKING CHANGE which MUST be uppercase.
- A `!` MAY be appended prior to the `:` in the type/scope prefix, to further draw attention to breaking changes. `BREAKING CHANGE: description` MUST also be included in the body or footer, along with the `!` in the prefix.

Examples:
## Commit message with description and breaking change in body
```
feat: allow provided config object to extend other configs

BREAKING CHANGE: `extends` key in config file is now used for extending other config files
```
## Commit message with optional `!` to draw attention to breaking change
```
chore!: drop Node 6 from testing matrix

BREAKING CHANGE: dropping Node 6 which hits end of life in April
```
## Commit message with no body
```
docs: correct spelling of CHANGELOG
```
## Commit message with scope
```
feat(lang): add polish language
```
## Commit message for a fix using an (optional) issue number.
```
fix: correct minor typos in code

see the issue for details on the typos fixed

closes issue #12
```

You must use this format:

<commit_message>
...
<\commit_message>
"""

# ds = datasets.load_dataset("circle33/conventional_commits")
ds = datasets.load_from_disk("/home/circle/code/AImmit_ai/dataset_generation")

ds = ds.map(
    lambda x: {
        "prompt": [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": USER_PROMPT},
        ]
    }
)

ds[0]

# Reward functions

1. Format reward: ensure the output is in the correct format. (10 points)
2. Check the output for similarity to known commit messages. The higher the similarity, the higher the reward. (90 points)

In [None]:
import re
import numpy as np
from sentence_transformers.cross_encoder import CrossEncoder
from rouge_score import rouge_scorer

pattern = r"<commit_message>.*</commit_message>"
regex = re.compile(pattern, re.DOTALL)

def format_reward(prompts, completions, **kwargs):
    responses = [completion[0]["content"] for completion in completions]

    return [
        0.0 if not regex.match(response) else 10.0 for response in responses
    ]

CONVENTIONAL_COMMIT_REGEX = re.compile(
    r"^(?P<type>[a-zA-Z_]+)"  # type (e.g., feat, fix, chore)
    r"(?:\((?P<scope>[^\)\n]+)\))?"  # Optional scope
    r"(?P<breaking_indicator>!)?"  # Optional breaking change indicator
    r":\s*(?P<description>[^\n]+)"  # Description
    r"(?:(?:\r\n|\r|\n){2}(?P<body>.*))?", # Optional body
    re.DOTALL | re.MULTILINE
)

def parse_commit(commit_message: str):
    """
    Parses a commit message string into its Conventional Commit components.
    """
    match = CONVENTIONAL_COMMIT_REGEX.match(commit_message.strip())
    if not match:
        return {
            "type": None,
            "scope": None,
            "is_breaking_change_header": False,
            "description": commit_message.strip(), # Fallback to full message as description
            "body": None,
            "breaking_change_footer": None,
            "is_breaking_change": False,
            "raw": commit_message,
            "parse_success": False
        }

    components = match.groupdict()
    parsed_type = components.get("type")
    parsed_scope = components.get("scope") # Might be None
    breaking_indicator = components.get("breaking_indicator") == "!"
    parsed_description = components.get("description", "").strip()

    full_body_and_footers = components.get("body")
    parsed_body = None
    parsed_breaking_change_footer_desc = None

    if full_body_and_footers:
        # Simple split for body and footers (might need refinement for multiple footers)
        parts = full_body_and_footers.split('\n\n', 1)
        parsed_body = parts[0].strip()
        if len(parts) > 1: # check if breaking change was not caught by regex directly
            footers_part = parts[1]
            bc_match = re.search(r"BREAKING CHANGE:\s*([^\n]+)", footers_part, re.IGNORECASE)
            if bc_match:
                parsed_breaking_change_footer_desc = bc_match.group(1).strip()
        elif parsed_body and "BREAKING CHANGE:" in parsed_body:
            # Handle case where BREAKING CHANGE might be in the first part of body
            bc_match_in_body = re.search(r"BREAKING CHANGE:\s*([^\n]+)", parsed_body, re.IGNORECASE)
            if bc_match_in_body:
                parsed_breaking_change_footer_desc = bc_match_in_body.group(1).strip()
                # Potentially remove this from body if it's distinctly a footer
                parsed_body = parsed_body.split("BREAKING CHANGE:")[0].strip()

    is_breaking = breaking_indicator or (parsed_breaking_change_footer_desc is not None)

    return {
        "type": parsed_type,
        "scope": parsed_scope,
        "is_breaking_change_header": breaking_indicator,
        "description": parsed_description,
        "body": full_body_and_footers, # Ensure empty string if None
        "breaking_change_footer": parsed_breaking_change_footer_desc,
        "is_breaking_change": is_breaking,
        "raw": commit_message,
        "parse_success": True
    }

CROSS_MODEL = CrossEncoder("cross-encoder/stsb-distilroberta-base")
ROUGE_SCORER = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)

def cross_compare(model_commit, input_commit, scale=1.0):
    """
    Compares two commit messages using a cross-encoder model and returns a similarity score.
    The score is normalized to a 0-scale.
    """
    try:
        score = CROSS_MODEL.predict([model_commit, input_commit])
        return score * scale
    except Exception as e:
        print(f"Error during semantic similarity calculation: {e}")
        return 0.0

def rouge_compare(model_commit, input_commit, scale=1.0):
    """
    Compares two commit messages using ROUGE-L and returns a similarity score.
    The score is normalized to a 0-scale.
    """
    if not model_commit and not input_commit:
        return scale # Both are empty, return scale as score
    try:
        scores = ROUGE_SCORER.score(model_commit, input_commit)
        return scores['rougeL'].fmeasure * scale
    except Exception as e:
        print(f"Error during ROUGE calculation: {e}")
        return 0.0

def score_reward(prompts, completions, commit_message, diff, **kwargs):
    """
    Parse success score: 5.0 if all completions parse successfully, else 0.0
    Breaking change score: 5.0 if breaking change is same as the given Breaking change commit, else 0.0
    Commit type score: 5.0 if type is same as the given type else 0.0
    Commit scope score: 5.0 if scope is same as the given scope else 0.0
    Commit description score: 35.0 as the highest score for similarity to the given description
    Commit body score: 35.0 as the highest score for rougeL to the given body
    """
    scores = []
    responses = [completion[0]["content"] for completion in completions]

    for content, input_commit, input_diff in zip(responses, commit_message, diff):
        # Parse the commit message
        model_commit = parse_commit(content)
        raw_commit = parse_commit(input_commit)

        # Check if parsing was successful
        if not model_commit["parse_success"]:
            scores.append(0.0)
            continue

        scope_match = (model_commit["scope"] == raw_commit["scope"]) if raw_commit["scope"] else True

        # Calculate scores
        parse_success_score = 5.0 if model_commit["parse_success"] else 0.0
        breaking_change_score = 5.0 if (model_commit["is_breaking_change"] == raw_commit["is_breaking_change"]) else 0.0
        type_score = 5.0 if model_commit["type"] == raw_commit["type"] else 0.0
        scope_score = 5.0 if scope_match else 0.0

        # Body score based on Rouge similarity (not implemented here, placeholder)
        description_score = cross_compare(model_commit=model_commit["description"], input_commit=raw_commit["description"], scale=35.0)
        body_score = rouge_compare(model_commit=model_commit["body"] or "", input_commit=raw_commit["body"] or "", scale=35.0)

        scores.append(
            parse_success_score +
            breaking_change_score +
            type_score +
            scope_score +
            description_score +
            body_score
        )

    return scores

# Training

In [None]:
from trl import GRPOConfig, GRPOTrainer

tokenized_prompts = [
    tokenizer.apply_chat_template(prompt, tokenize=True, add_generation_prompt=True)
    for prompt in ds['prompt']
]
exact_max_prompt_length = max(
    [len(tokenized_prompt) for tokenized_prompt in tokenized_prompts]
)

In [None]:

max_prompt_length = 448  # manually adjusted
new_model_id="circle33/qwen-commit-7b-grpo"

training_args = GRPOConfig(
    learning_rate = 8e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.01,
    lr_scheduler_type = "cosine",
    optim = "paged_adamw_8bit",
    logging_steps = 1,
    per_device_train_batch_size = 8,
    gradient_accumulation_steps = 1,
    num_generations = 8, # Decrease if out of memory
    max_prompt_length = max_prompt_length,
    max_completion_length = max_seq_length - max_prompt_length,
    max_grad_norm = 0.1,
    output_dir = "outputs",
    overwrite_output_dir = True,
    push_to_hub = True,
    hub_model_id=new_model_id,
    hub_strategy="every_save",
    save_strategy="steps",
    save_steps=50,
    save_total_limit=1,
    num_train_epochs=3,
)

In [None]:
import wandb

wandb.init(project="GRPO-reboost")

In [None]:

trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs=[
        format_reward,
        score_reward,
    ],
    args = training_args,
    train_dataset = ds,
)
trainer.train()
