In [8]:
# 指定使用的GPU设备
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4,5"

In [9]:
from transformers import AutoTokenizer, AutoModelForCausalLM
# 设置模型检查点路径
checkpoint = "/data/models/huggingface/Qwen/Qwen3-32B"
# 加载分词器和模型
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
# 设置分词器的填充方向为左侧，针对 decoder-only 模型
tokenizer.padding_side = 'left'
# device_map="auto" 会自动将模型分配到可用的 GPU 上实现模型并行
# torch_dtype="auto" 会自动选择合适的精度
model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype="auto", device_map="auto")

Loading checkpoint shards: 100%|██████████| 17/17 [00:11<00:00,  1.42it/s]


In [10]:
# qwen3 模型参数是 fp32, 设置 torch_dtype="auto" 会将其转换为 bf16
model.dtype

torch.bfloat16

In [11]:
prompt = "你好，你知道什么是Qwen吗？"
# 将提示文本转换为 token_ids
# return_tensors="pt" 表示返回 PyTorch 张量格式
inputs = tokenizer(prompt, return_tensors="pt")
print(inputs)
print(type(inputs["input_ids"]))

{'input_ids': tensor([[108386,   3837, 107733, 106582,     48,  16948, 101037,  11319]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]])}
<class 'torch.Tensor'>


In [12]:
# tokenizer 也可以直接传入多个提示文本
prompts = ["你是谁？", "你会做什么？", "你能告诉我一些关于Qwen的信息吗？"]
# 一个 batch 的输入会被填充到相同的长度
# padding=True 会自动填充到最长的输入长度
# 这样可以确保所有输入的形状一致，便于批处理
inputs = tokenizer(prompts, return_tensors="pt", padding=True)
print(inputs)

{'input_ids': tensor([[151643, 151643, 151643, 151643, 151643, 151643, 105043, 100165,  11319],
        [151643, 151643, 151643, 151643, 151643, 151643, 102762, 106428,  11319],
        [107809, 106525, 101883, 101888,     48,  16948, 105427, 101037,  11319]]), 'attention_mask': tensor([[0, 0, 0, 0, 0, 0, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1]])}


In [13]:
# 生成输出
# max_new_tokens=64 表示生成的最大新 token 数量
outputs = model.generate(**inputs, max_new_tokens=32)
print(outputs)

# 将生成的 token_ids 转换回文本
# skip_special_tokens=True 表示跳过特殊 token，如 <pad>、<eos> 等
for i, output in enumerate(outputs):
    text = tokenizer.decode(output, skip_special_tokens=True)
    print(f"<{prompts[i]}>: {text}\n")



tensor([[151643, 151643, 151643, 151643, 151643, 151643, 105043, 100165,  11319,
         103929,  71109,  17992, 111558,  26850, 104198,  31935,  64559,  99320,
          56007,   3837, 107076, 100338, 111477,  31935,  64559, 104800, 111411,
           9370,  71304, 105483, 102064, 104949,   1773,  97611, 105205,  13072,
          20412,     48,  16948,   3837, 104811],
        [151643, 151643, 151643, 151643, 151643, 151643, 102762, 106428,  11319,
            220,    220,  11622, 104811, 102104,    271, 106249, 101951, 102064,
         104949,   3837,  35946, 100006, 100364,  20002,  60548, 101312,  88802,
           3837, 100630, 116509,  48443,     16,     13,   3070, 102104,  86119,
            334,   5122, 102215, 104380,  86119],
        [107809, 106525, 101883, 101888,     48,  16948, 105427, 101037,  11319,
          84897,  60894,  73670,   6313,  31935,  64559,  99320,  56007,   9909,
             48,  16948,   7552, 104625,  31935,  64559, 104800, 100013,   9370,
         