In [1]:
import os
from pathlib import Path
SCRATCH = Path.home() / "scratch"
os.environ["HF_HOME"] = str(SCRATCH / "hf_home")

from transformers import AutoTokenizer

CHECKPOINT_OR_NAME = 'McGill-NLP/nano-aha-moment-3b'
# CHECKPOINT_OR_NAME = 'Qwen/Qwen2.5-3B'
# CHECKPOINT_OR_NAME = 'google/gemma-3-4b-pt'
# CHECKPOINT_OR_NAME = "meta-llama/Llama-3.2-3B"
CHAT_MODEL_NAME = "Qwen/Qwen2.5-3B-Instruct" # should have the tokenizer we trained the checkpoint with
# CHAT_MODEL_NAME = "google/gemma-3-4b-it"
# CHAT_MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(CHAT_MODEL_NAME)

In [None]:
import torch
from vllm import LLM, SamplingParams

inference_engine = LLM(
    model=CHECKPOINT_OR_NAME,
    gpu_memory_utilization=0.5,
    dtype=torch.bfloat16, 
    swap_space=2,
    enable_prefix_caching=True,
    max_model_len=2048,
    max_seq_len_to_capture=2048,
)

In [11]:
def format_response(query, response):
    from IPython.display import HTML

    # Escape <think> </think> <answer> </answer> and any HTML tags
    response = response.replace("<", "&lt;").replace(">", "&gt;")
    query = query.replace("<", "&lt;").replace(">", "&gt;")

    # Format the response with syntax highlighting
    formatted_html = f"""
    <div style="background-color: #f8f9fa; padding: 15px; border-radius: 5px; border: 1px solid #ddd;">
        <h3 style="color: #333; margin-top: 0;">Query:</h3>
        <pre style="background-color: #e9f7fe; padding: 10px; border-radius: 3px; overflow-x: auto; white-space: pre-wrap; word-wrap: break-word; color: #0066cc;">{query}</pre>
        <h3 style="color: #333; margin-top: 10px;">Generated Response:</h3>
        <pre style="background-color: #f5f5f5; padding: 10px; border-radius: 3px; overflow-x: auto; white-space: pre-wrap; word-wrap: break-word; color: #333333;">{response}</pre>
    </div>
    """

    return HTML(formatted_html)



def generate_chat_prompt(query, assistance_prefix="Let me think step by step\n<think>"):
    SYSTEM_MESSAGE = (
        "You are a helpful assistant. You first think about the reasoning process in the mind "
        "and then provides the user with the answer."
    )
    r1_prefix = [
        {
            "role": "system",
            "content": SYSTEM_MESSAGE,
        },
        {"role": "user", "content": f"{query}"},
    ]
    if assistance_prefix is not None:
        r1_prefix.append({"role": "assistant", "content": assistance_prefix})
    
    input_ids = tokenizer.apply_chat_template(
        r1_prefix, tokenize=True, continue_final_message=assistance_prefix is not None
    )
    prompt = tokenizer.decode(
        input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
    )
    return {"prompt": prompt, "input_ids": input_ids}

# play the countdown game
def preprocess_countdown_example(example):
    SYSTEM_MESSAGE = (
        "You are a helpful assistant. You first think about the reasoning process in the mind "
        "and then provides the user with the answer."
    )

    PROMPT_TEMPLATE = (
        "Using the numbers {numbers}, create an equation that equals {target}. "
        "You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. "
        "Show your work in <think> </think> tags. And return the final equation and answer in "
        "<answer> </answer> tags, for example <answer>(1 + 2) / (3 * 5)</answer>."
    )
    numbers = example["nums"]
    target = example["target"]
    
    chat_messages = [
        {"role": "system",  "content": SYSTEM_MESSAGE},
        {"role": "user", "content": PROMPT_TEMPLATE.format(numbers=numbers, target=target)},
        {"role": "assistant", "content": "Let me think step by step\n<think>"}
    ]
    
    input_ids = tokenizer.apply_chat_template(
        chat_messages, tokenize=True, continue_final_message=True
    )
    prompt = tokenizer.decode(input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)
    
    return {"input_ids": input_ids, "prompt": prompt}

## Countdown

In [None]:
sample = {"nums":  [7, 71, 19, 4], "target": 68}
sample.update(preprocess_countdown_example(sample))

print(f"######################## Prompt:\n`{sample['prompt']}`")

generation = inference_engine.generate(
    prompt_token_ids=sample["input_ids"], 
    sampling_params=SamplingParams(
        temperature=0.3,
        max_tokens=1024,
        top_p=1.0,
        n=1,  # Only generate one response per question
    )
)
response = tokenizer.decode(generation[0].outputs[0].token_ids)
format_response(sample["prompt"], response)


## General

In [None]:
o_1 = """A quadratic equation has roots that are also the solutions to the system:

\[
\begin{cases}
x + y = 7 \\
xy = 10
\end{cases}
\]

1. Find the quadratic equation whose roots are \( x \) and \( y \).
2. Solve the quadratic equation.
3. Verify that the roots satisfy the original system.
"""

o = """Hello, how are you?"""

o_3 = """My slurm job failed. I look at the stdout and I observe this:

### stdout ###
slurmstepd: error: container_p_join: open failed for /var/opt/slurm/localstorage/6450599/.ns: No such file or directory
slurmstepd: error: container_g_join(6450599): No such file or directory"""

o_4 = """
How many of the first 500 positive integers are divisible by 3, 4 and 5?
"""

o_5 = """
You have 1,2,3,4. Provide a math equation that equals 10. You can use each number only once. You can use basic arithmetic operations (+, -, *, /).
"""


sample = generate_chat_prompt((
    f"{o}\n"
    "Show your work in <think> </think> tags. And return the final answer in "
    "<answer> </answer> tags"
))

print(sample)

print(f"######################## Prompt:\n`{sample['prompt']}`")

generation = inference_engine.generate(
    prompt_token_ids=sample["input_ids"], 
    sampling_params=SamplingParams(
        temperature=0.6,
        max_tokens=1024,
        top_p=1.0,
        n=1,  # Only generate one response per question
    )
)
response = tokenizer.decode(generation[0].outputs[0].token_ids)
format_response(o, response)