In [None]:
import json
file_path = "/home/xhl/eval/my_eval/data/math/train.jsonl"

problems = []
answers = []

with open(file_path, "r", encoding="utf-8") as f:
    for line in f:
        item = json.loads(line)
        problems.append(item["problem"])
        answers.append(item["answer"])
examples = ["Please reason step by step, and put your final answer within \\boxed{}.\nUser: " + prompt + "\nAssistant: <think>" for prompt in problems]

In [None]:
import torch
import transformers
import easysteer.reft.pyreft as pyreft
import os
import time
# Set GPU device
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = "cuda"

model_name_or_path = "/data/zju-46/shenyl/hf/model/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B/"
reft_save_dir = "./speed_test" 
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name_or_path, torch_dtype=torch.bfloat16, device_map=device
)

tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_name_or_path, padding_side="left", use_fast=False
)
tokenizer.pad_token = tokenizer.eos_token


reft_model = pyreft.ReftModel.load(reft_save_dir, model)
reft_model.set_device(device)
reft_model.eval()

# 关键：改为左填充，让所有样本的提示末位对齐到同一列
tokenizer.padding_side = "left"


# 1) 左填充的批量编码
batch_size = 256
batch = tokenizer(
    examples[:batch_size],
    return_tensors="pt",
    padding=True,
).to(device)

input_dict = {
    "input_ids": batch["input_ids"],
    "attention_mask": batch["attention_mask"],
}

# 2) 干预所有提示位置：用整个序列长度的索引范围
seq_len = input_dict["input_ids"].shape[1]
all_pos = list(range(seq_len))            # [0, 1, 2, ..., seq_len-1]
unit_locations = {"base": all_pos}        # 让 ReFT 自动广播到 batch 与所有 interventions

# 3) 生成
time1 = time.time()
_, generated = reft_model.generate(
    input_dict,
    # unit_locations=unit_locations,
    intervene_on_prompt=False,
    max_new_tokens=2048,
    do_sample=False,
    eos_token_id=tokenizer.eos_token_id,
    early_stopping=True,
)
time2 = time.time()

In [None]:
# 计算总生成 token 数（逐样本：输出有效长度 - 输入长度）
input_lens = input_dict["attention_mask"].sum(dim=1).to(generated.device)
if tokenizer.pad_token_id is not None:
    out_valid_lens = (generated != tokenizer.pad_token_id).sum(dim=1) 
else:
    out_valid_lens = torch.full((generated.size(0),), generated.size(1), device=generated.device)
new_tokens_per_seq = (out_valid_lens - input_lens).clamp(min=0)
total_new_tokens = int(new_tokens_per_seq.sum().item())
tok_per_s = total_new_tokens / (time2 - time1)
req_per_s = (time2 - time1) / batch_size
print(f"{(time2 - time1)/batch_size:.8f}s/req")
print(f"throughput: {tok_per_s:.2f} tok/s")