In [1]:
!git clone https://github.com/ZichengXu/Decoding-Tree-Sketching.git
!pip install transformers==4.54.0
%cd Decoding-Tree-Sketching

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer, set_seed

import sys, os
sys.path.append(os.path.abspath('..'))

from decoding_tree_sketching.kvbatch_decoder import KVBatchEGDT
from decoding_tree_sketching.utils.eval_utils import extract_answer_llm, extract_answer_qwq

fatal: destination path 'Decoding-Tree-Sketching' already exists and is not an empty directory.
/content/Decoding-Tree-Sketching


In [23]:
# print(torch.cuda.is_available())
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"

# Decoding hyperparameters
DECODE_CONFIG = {
    "entropy_threshold": 2.5,
    "branch_top_k": 3,
    "max_active_hyps": 12,
    "max_new_tokens": 10000,
    "temperature": 0.6,
}
tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        trust_remote_code=True
    )
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="cuda" if torch.cuda.is_available() else torch.device("cpu"),
    torch_dtype="auto",
    trust_remote_code=True
)
streamer = TextStreamer(tokenizer)

In [20]:
examples = [
    # "Find the sum of all integer bases $b>9$ for which $17_b$ is a divisor of $97_b.$",
    "Let $\\mathcal{B}$ be the set of rectangular boxes with surface area $54$ and volume $23$. Let $r$ be the radius of the smallest sphere that can contain each of the rectangular boxes that are elements of $\\mathcal{B}$. The value of $r^2$ can be written as $\\\frac{p}{q}$, where $p$ and $q$ are relatively prime positive integers. Find $p+q$.",
]
groundtruths = [
    # "70",
    "721",
]
reasoning_tail = r" Please reason step by step, and put your final answer within \boxed{}."
seed = 0

set_seed(seed)
ques_idx = 0
example = examples[ques_idx]
groundtruth = groundtruths[ques_idx]
full_prompt = example + reasoning_tail
text = tokenizer.apply_chat_template(
    [{"role": "user", "content": full_prompt}],
    tokenize=False,
    add_generation_prompt=True
)

In [8]:
# Regular output

inputs = tokenizer(text, return_tensors="pt").to(model.device)
out = model.generate(
    **inputs,
    max_new_tokens=DECODE_CONFIG["max_new_tokens"],
    do_sample=True,
    temperature=DECODE_CONFIG["temperature"],
    streamer=streamer,
)

Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


<｜begin▁of▁sentence｜><｜begin▁of▁sentence｜><｜User｜>Find the sum of all integer bases $b>9$ for which $17_b$ is a divisor of $97_b.$ Please reason step by step, and put your final answer within \boxed{}.<｜Assistant｜><think>
Alright, so I have this problem: I need to find the sum of all integer bases \( b > 9 \) for which \( 17_b \) is a divisor of \( 97_b \). Hmm, okay. Let me try to break this down step by step.

First off, I remember that when a number is written in base \( b \), each digit represents a power of \( b \). So, for example, \( 17_b \) would be \( 1 \times b + 7 \) in decimal, right? Similarly, \( 97_b \) would be \( 9 \times b + 7 \) in decimal.

So, if I write that out, \( 17_b = b + 7 \) and \( 97_b = 9b + 7 \). The problem is saying that \( 17_b \) divides \( 97_b \), which in decimal terms means that \( b + 7 \) divides \( 9b + 7 \). Got it.

Now, if \( b + 7 \) divides \( 9b + 7 \), that implies that \( 9b + 7 \) is a multiple of \( b + 7 \). So, mathematically, we c

In [11]:
num_new_tokens = out[0].shape[0] - inputs["input_ids"].shape[1]
stat = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
# print(gen)
ans = extract_answer_qwq(stat)
print(ans)

70


In [None]:
kvegdt = KVBatchEGDT(model, tokenizer, seed=seed)
dts_out = kvegdt.generate(
        text,
        entropy_threshold=DECODE_CONFIG["entropy_threshold"],
        branch_top_k=DECODE_CONFIG["branch_top_k"],
        max_active_hyps=DECODE_CONFIG["max_active_hyps"],
        max_new_tokens=DECODE_CONFIG["max_new_tokens"],
        temperature=DECODE_CONFIG["temperature"],
    )

Decoding:  54%|█████▍    | 5424/10000 [09:56<17:47,  4.29it/s]

In [18]:
# print(f"*** MODEL OUTPUT ***\n{dts_out['text']}")
# Print generation statistics such as steps, branch events, and sequence length
print(f"\n*** GENERATION STATS ***\n{dts_out['stats']}")
dts_ans = extract_answer_qwq(dts_out['text'])
print(f"Groundtruth = {groundtruth}, Regular decoding output = {ans}, DTS output = {dts_ans}")


*** GENERATION STATS ***
{'num_steps': 4096, 'num_finished': 0, 'generated_len': 4096, 'mean_entropy_best': 0.20606438029871388, 'branch_events': 0, 'total_branches_created': 0, 'max_active_batch': 1, 'unfinished_left': 1}
Groundtruth = 70, Regular decoding output = 70, DTS output = [invalid]
