In [None]:
import os
import sys

from sarashina_grpo.config import PROJECT_ROOT

# Add the `src` directory to the Python module search path
sys.path.append(os.path.join(PROJECT_ROOT, 'src'))

###################
# Config
###################

# fmt: off
MODEL_NAME = "sbintuitions/sarashina2.2-3b-instruct-v0.1"  # https://huggingface.co/sbintuitions
# MODEL_NAME = f"{PROJECT_ROOT}/artifact/outputs/checkpoint-500"  # Path to the checkpoint directory
LORA_RANK = 32  # Larger rank = smarter, but slower

# MAX_SEQ_LENGTH ≧ MAX_PROMPT_LENGTH + MAX_COMPLETION_LENGTH
MAX_SEQ_LENGTH = 4096  # Can increase for longer reasoning traces
MAX_PROMPT_LENGTH = 1024  # default 512 - Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.
MAX_COMPLETION_LENGTH = 512  # default 256 -  Maximum length of the generated completion.

USE_VLLM = True  # Enable vLLM for fast inference

GRPO_OUTPUT_DIR = f"{PROJECT_ROOT}/artifact/grpo"  # Contains the checkpoints
# fmt: on

###################
# Load the model
###################

from unsloth import FastLanguageModel, PatchFastRL

PatchFastRL("GRPO", FastLanguageModel)

# Load the model
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=MODEL_NAME,
    max_seq_length=MAX_SEQ_LENGTH,
    load_in_4bit=True,  # False for LoRA 16bit
    fast_inference=USE_VLLM,  # Enable vLLM fast inference
    max_lora_rank=LORA_RANK,
    gpu_memory_utilization=0.6,  # Reduce if out of memory
)

In [None]:
from transformers import TextStreamer, set_seed
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from transformers.models.llama.tokenization_llama_fast import LlamaTokenizerFast

from sarashina_grpo.config import SYSTEM_PROMPT

# # Check the types
# print(type(model))      # => <class 'transformers.models.llama.modeling_llama.LlamaForCausalLM'>
# print(type(tokenizer))  # => <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>

# Type hints
model: LlamaForCausalLM
tokenizer: LlamaTokenizerFast


###################
# Inference
###################

# https://docs.unsloth.ai/basics/tutorial-how-to-finetune-llama-3-and-use-in-ollama#id-11.-inference-running-the-model
FastLanguageModel.for_inference(model)

# set_seed(123)

streamer = TextStreamer(tokenizer, skip_prompt=True)


def remobe_s_tag(text: str, stream_end: bool = False):
    """Override the on_finalized_text method of TextStreamer to remove the </s> tag."""
    text = text.replace("</s>", "")
    print(text, flush=True, end="" if not stream_end else None)


streamer.on_finalized_text = remobe_s_tag


def print_and_return_response(streamer, tokenizer, model, prompts):

    print(f"user: {prompts[-1]['content']}")
    print("ai:")

    formatted_text = tokenizer.apply_chat_template(
        prompts,
        tokenize=False,
        add_generation_prompt=True,
    )

    inputs = tokenizer(
        formatted_text,
        padding=True,
        truncation=True,
        return_tensors="pt",
    ).to("cuda")

    outputs = model.generate(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        streamer=streamer,
        max_new_tokens=1024,
        use_cache=True,
        # Use randomness when choosing words
        do_sample=True,  # If False, the model always picks the most likely next word. Default False
        temperature=1.0,  # Higher temperature = more randomness. Default 1.0
        top_k=50,  # Limits the selection of next words. Default 50
        top_p=1.0,  # top_p=1.0 means no limit; top_p=0.9 would restrict to the top words that make up 90% of the probability. Default 1.0
    )

    def extract_response(text: str) -> str:
        """Extract response from AI output."""
        text = text.split("<|assistant|>")[-1]
        text = text.split("</s>")[0]
        text = text.strip()
        return text

    output = tokenizer.decode(outputs[0], skip_special_tokens=False)
    return extract_response(output)


prompts = [
    {"role": "system", "content": SYSTEM_PROMPT},
    {"role": "user", "content": "今日は疲れたので休みたい。"},
]
response = print_and_return_response(streamer, tokenizer, model, prompts)


prompts.extend(
    [
        {"role": "assistant", "content": response},
        {"role": "user", "content": "あなたも同じですか？"},
    ]
)
response = print_and_return_response(streamer, tokenizer, model, prompts)