In [None]:
import json
file_path = "/home/xhl/eval/my_eval/data/math/train.jsonl"

problems = []
answers = []

with open(file_path, "r", encoding="utf-8") as f:
    for line in f:
        item = json.loads(line)
        problems.append(item["problem"])
        answers.append(item["answer"])
examples = ["Please reason step by step, and put your final answer within \\boxed{}.\nUser: " + prompt + "\nAssistant: <think>" for prompt in problems]

In [None]:
import os
import torch
import transformers
import easysteer.reft.pyreft as pyreft
import time
# Set GPU device
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
device = "cuda"

model_name_or_path = "/data/zju-46/shenyl/hf/model/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B/"
reft_save_dir = "./speed_test"

model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name_or_path, torch_dtype=torch.bfloat16, device_map=device
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_name_or_path, padding_side="left", use_fast=False
)
tokenizer.pad_token = tokenizer.eos_token

reft_model = pyreft.ReftModel.load(reft_save_dir, model)
reft_model.set_device(device)
reft_model.eval()

# 单样本文本
input_texts = []
for i in range(10):
    text = examples[i] 
    inputs = tokenizer(text, return_tensors="pt").to(device)
    input_dict = {
        "input_ids": inputs["input_ids"],
        "attention_mask": inputs["attention_mask"],
    }
    input_texts.append(input_dict)
tokens = 0
time1 = time.time()
for i in range(10):
    # 干预整段提示的所有位置
    # seq_len = input_texts[i]["input_ids"].shape[1]
    # unit_locations = {"base": list(range(seq_len))}
    # 生成（仅在提示阶段干预）
    _, generated = reft_model.generate(
        input_texts[i],
        # unit_locations=unit_locations,
        intervene_on_prompt=False,
        max_new_tokens=2048,
        do_sample=False,
        eos_token_id=tokenizer.eos_token_id,
        early_stopping=True,
    )
    # 计算总生成 token 数（逐样本：输出有效长度 - 输入长度）
    input_lens = input_texts[i]["attention_mask"].sum(dim=1).to(generated.device)
    out_valid_lens = (generated != tokenizer.pad_token_id).sum(dim=1) 
    new_tokens_per_seq = (out_valid_lens - input_lens).clamp(min=0)
    tokens+=new_tokens_per_seq
time2 = time.time()

In [None]:
print(tokens)
print(f"{(time2 - time1)/10:.8f}s/req")
toks_per_sec = tokens[0] / (time2 - time1) if time2 - time1 > 0 else 0
print(f"output toks/s: {toks_per_sec:.2f}")