In [6]:
from transformers import LlamaForCausalLM, AutoTokenizer
import torch

加载预训练模型

In [7]:
# 定义模型路径
model_dir = "models/story/"

# 加载分词器
tokenizer = AutoTokenizer.from_pretrained(model_dir)

model = LlamaForCausalLM.from_pretrained(model_dir)

### 创建输入数据
使用分词器将文本转化为模型的输入形式。

In [21]:
input_text = "这是一个测试句子。你可以利用Matplotlib或其他工具来可视化注意力权重等中间结果，帮助进一步理解模型的行为"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
print(input_ids.shape)
print(input_ids)

torch.Size([1, 10])
tensor([[   1,   80,    0,   38,  232,  125, 2028,  158,   54,    0]])


### 启用调试模式并逐层检查
通过torch.no_grad()确保模型在推理模式下运行，并且不会计算梯度。你可以通过注册hook来提取模型各层的输出。

In [15]:
# 存储中间层的输出
outputs = {}

def hook_fn_forward(module, input, output):
    outputs[module] = output

# 注册hook
for name, module in model.model.named_modules():
    module.register_forward_hook(hook_fn_forward)

# 运行模型
with torch.no_grad():
    model(input_ids)


In [16]:
# 打印每一层的输出
for layer, output in outputs.items():
    print(f"Layer: {layer}")
    print(f"Output shape: {output.shape}")
    print(output)

Layer: Embedding(2048, 128)
tensor([[[-0.3125, -0.3164,  0.0684,  0.0454,  0.3262, -0.4590,  0.0251,
           0.4062, -0.2852,  0.9102, -0.1973,  0.0327, -0.0247,  0.2002,
           0.2197,  0.4961, -0.1001, -0.1719,  0.1543, -0.4180,  0.1953,
           0.5273, -0.5938,  0.1768, -0.3125, -0.0417,  0.0518, -0.5000,
           0.0309, -0.2695, -0.0491,  0.0649, -0.4492,  0.1885,  0.2559,
           1.0078,  0.0527, -0.3457, -0.1562, -0.3906, -0.3262, -0.1836,
          -0.1387,  0.1904,  0.1504, -0.3594,  0.2773, -0.1484,  0.1367,
          -0.0811,  0.8906, -0.3047,  0.0679, -0.0791, -0.1934, -0.0605,
          -0.0369,  0.1748, -0.3418, -0.3047, -0.4922, -0.2129,  0.1631,
           0.0405,  0.2715, -0.2148, -0.0923,  0.4121,  0.1309,  0.0693,
           0.0610, -0.0150,  0.3242,  0.8516,  0.0942, -0.1523, -0.0605,
          -0.6797,  0.1787, -0.0046,  0.3047, -0.2988,  0.0525, -0.1982,
          -0.0099, -0.4746, -0.4844, -0.1533,  0.0062,  3.2812,  2.2500,
           0.0366, -0.2

KeyboardInterrupt: 