In [1]:
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
)
import torch

model_name = "rmeireles/SmolLM2-FT-ORPO"

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)

# Model to fine-tune
model = AutoModelForCausalLM.from_pretrained(
    pretrained_model_name_or_path=model_name,
).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [22]:
def create_chat_function(model, tokenizer, device, max_history=10):
    conversation_history = []
    def chat(message):
        nonlocal conversation_history
        
        # Add user message to history
        conversation_history.append({"role": "user", "content": message})
        
        prompt = tokenizer.apply_chat_template(conversation_history, tokenize=False, add_generation_prompt=True)

        
        # Tokenize and generate
        inputs = tokenizer(prompt, return_tensors="pt").to(device)
        outputs = model.generate(
            **inputs,
            max_new_tokens=2048,
        )
        
        # Decode response (skip prompt)
        response = tokenizer.decode(outputs[0][len(inputs.input_ids[0]):], 
                                   skip_special_tokens=True)
        
        # Add assistant response to history
        conversation_history.append({"role": "assistant", "content": response})
        
        # Maintain rolling history
        if len(conversation_history) > max_history:
            conversation_history = conversation_history[-max_history:]
            
        return response

    def clear():
        nonlocal conversation_history
        conversation_history = []
        
    return chat, clear

# Initialize chat function
chat_fn, clear = create_chat_function(model, tokenizer, device)

# Example usage:
# response = chat_fn("Hello!")
# print(response)
# response = chat_fn("How are you?")
# print(response)