In [None]:
"""
使用few-shot的方式测试模型底座处理能力
"""

In [25]:
import yaml

d = yaml.safe_load(open("primary_maths.yaml"))
system_prompt = d["system_prompt"]
few_shot_examples = d["examples"][:10]

few_shot_messages = []
for example in few_shot_examples:
    few_shot_messages.extend([
        {"role": "user", "content": example["question"]},
        {"role": "assistant", "content": example["answer"]}
    ])

In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "models/Qwen2.5-7B-Instruct"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="cuda:0"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

  from .autonotebook import tqdm as notebook_tqdm
Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.
Loading checkpoint shards: 100%|██████████| 4/4 [00:03<00:00,  1.00it/s]


In [26]:
user_input = "小明家里有三只狗，他的哥哥送给他两只小狗，现在小明家里有几只狗？"
messages = [{"role": "system", "content": system_prompt}] + \
    few_shot_messages + \
    [{"role": "user", "content": user_input}]

In [None]:
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)
text

In [28]:
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

In [29]:
len(model_inputs.input_ids.tolist()[0])

14085

In [30]:
model_inputs

{'input_ids': tensor([[151644,   8948,    198,  ..., 151644,  77091,    198]],
       device='cuda:0'), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1]], device='cuda:0')}

In [31]:
generated_ids = model.generate(
    **model_inputs,
    max_new_tokens=4096
)

In [32]:
generated_ids

tensor([[151644,   8948,    198,  ...,  86119,   8997, 151645]],
       device='cuda:0')

In [33]:
len(generated_ids.tolist()[0])

14869

In [None]:
generated_ids = [
    output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(response)