In [6]:
!pip install trl

Collecting trl
  Using cached trl-0.19.1-py3-none-any.whl.metadata (10 kB)
Collecting accelerate>=1.4.0 (from trl)
  Using cached accelerate-1.9.0-py3-none-any.whl.metadata (19 kB)
Using cached trl-0.19.1-py3-none-any.whl (376 kB)
Using cached accelerate-1.9.0-py3-none-any.whl (367 kB)
Installing collected packages: accelerate, trl
Successfully installed accelerate-1.9.0 trl-0.19.1


In [None]:
import torch
import pandas as pd
from datasets import load_dataset, Dataset
from transformers import TrainingArguments, AutoTokenizer, AutoModelForCausalLM
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM, SFTConfig # setting up SFT training process

In [9]:
# helper function for inference

def generate_responses(model, tokenizer, user_message, system_message=None, max_new_tokens=100):
    messages = []
    if system_message:
        messages.append({"role": "system", "content": system_message})
    
    # We assume the data are all single-turn conversation
    messages.append({"role": "user", "content": user_message})
        
    prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=False,
    )

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    # can use vLLM, sglang or TensorRT here for more efficient inference
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
    input_len = inputs["input_ids"].shape[1]
    generated_ids = outputs[0][input_len:] # generated token_ids
    response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip() # decode the token_ids

    return response

In [None]:
# helper function to test model with questions

def test_model(model, tokenizer, questions, system_message=None, title="Model output"):
    print(f"\n****** {title} ******")
    for i, question in enumerate(questions, 1): # start indexing from 1
        response = generate_responses(model, tokenizer, question, system_message)
        print(f"\nModel input {i}: \n{question} \nModel output: {response} \n")