In [16]:
import os
os.environ["HF_HOME"] = "/network/scratch/a/aghajohm/hf_home" # set before transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import sglang
import sys
# Add the parent directory to the path so we can import from aha.py
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

from aha import initialize_model, generate_r1_prompt

In [None]:
def format_response(response):
    from IPython.display import HTML

    # Format the response with syntax highlighting
    formatted_html = f"""
    <div style="background-color: #f8f9fa; padding: 15px; border-radius: 5px; border: 1px solid #ddd;">
        <h3 style="color: #333; margin-top: 0;">Generated Response:</h3>
        <pre style="background-color: #f5f5f5; padding: 10px; border-radius: 3px; overflow-x: auto; white-space: pre-wrap; word-wrap: break-word;">{response}</pre>
    </div>
    """

    return HTML(formatted_html)

In [None]:
CHECKPOINT_OR_NAME = '/network/scratch/a/aghajohm/aha_models/test_checkpoint'
CHAT_MODEL_NAME = "Qwen/Qwen2.5-3B-Instruct" # should have the tokenizer we trained the checkpoint with
tokenizer = AutoTokenizer.from_pretrained(CHAT_MODEL_NAME)
# CHECKPOINT_OR_NAME = "Qwen/Qwen2.5-3B"

In [None]:
sglang_engine = sglang.Engine(
        model_path=CHECKPOINT_OR_NAME,
        enable_memory_saver=True,
        skip_tokenizer_init=True,
        mem_fraction_static=0.20,
        schedule_policy="fcfs",
        schedule_conservativeness=0.001,
        max_running_requests=10000,
    )

In [None]:
# play the countdown game
numbers = [1, 2, 3, 4]
target = 10
prompt = generate_r1_prompt(numbers, target, tokenizer)

eval_sampling_params = {
        "temperature": 0.3,
        "max_new_tokens": 1024,
        "top_p": 1.0,
        "n": 1,  # Only generate one response per question
    }
    
generation = sglang_engine.generate(input_ids=prompt["input_ids"], sampling_params=eval_sampling_params)
response = tokenizer.decode(generation["token_ids"])
format_response(response)


In [29]:
# general chat with the model
def generate_chat_prompt(query, tokenizer):
    r1_prefix = [{
        "role": "system",
        "content": "You are a helpful assistant. You first thinks about the reasoning process in the mind and then provides the user with the answer."
      },
      { 
        "role": "user",
        "content": f"{query}"
      }]
    input_ids = tokenizer.apply_chat_template(r1_prefix, tokenize=True, continue_final_message=False)
    prompt = tokenizer.decode(input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)
    return {"prompt": prompt, "input_ids": input_ids}

In [None]:
user_query = "Hello the mirror on the wall, who is the best one of all?"

chat_prompt = generate_chat_prompt(user_query, tokenizer)

eval_sampling_params = {
        "temperature": 0.7,
        "max_new_tokens": 1024,
        "top_p": 1.0,
        "n": 1,  # Only generate one response per question
    }

chat_response = sglang_engine.generate(input_ids=chat_prompt["input_ids"], sampling_params=eval_sampling_params)
chat_response = tokenizer.decode(chat_response["token_ids"])
format_response(chat_response)


In [None]:
# model = initialize_model(CHECKPOINT_OR_NAME)