# TorchHook 使用示例
本示例展示了如何使用 TorchHook 捕获 PyTorch 模型的中间特征图。

In [1]:
import torch
import torchvision.models as models
# 导入我们的库 TorchHook
from torchhook import HookManager

# 1. 加载你的模型
model = models.resnet18()
model.eval() # 设置为评估模式

# 2. 初始化 HookManager
# max_size=1 表示每个 hook 只保留最新的特征图
hook_manager = HookManager(model, max_size=1)

# 3. 注册你感兴趣的层
# 通过层名称注册
# 层的名称可从 dict(model.named_modules()).keys() 中获取
hook_manager.register_forward_hook(layer_name='conv1')
# 也可以使用别名 add
hook_manager.add(layer_name='layer4.1.relu')
# 或者直接传入层对象
target_layer = model.fc
hook_manager.add(layer_name='fully_connected', layer=target_layer) # 建议提供名称

# 4. 执行模型的前向传播
dummy_input = torch.randn(1, 3, 224, 224)
with torch.no_grad():
    output = model(dummy_input)

# 5. 获取特征
features_conv1 = hook_manager.get_features('conv1') # 获取 'conv1' 的特征列表
features_relu = hook_manager.get_features('layer4.1.relu') # 获取 'layer4.1.relu' 的特征列表
all_features = hook_manager.get_all() # 获取包含所有捕获的特征的字典

print(f"Conv1 feature shape: {features_conv1[0].shape}")
print(f"Layer 4.1 ReLU feature shape: {features_relu[0].shape}")

# 6. 查看 Hook 状态总结
hook_manager.summary()
# 或者 print(hook_manager)

# 7. 清理 Hook（重要！）
hook_manager.clear_hooks()

Conv1 feature shape: torch.Size([1, 64, 112, 112])
Layer 4.1 ReLU feature shape: torch.Size([1, 512, 7, 7])
Model: ResNet | Total Parameters: 11.69 M
Registered Hooks: 3 (max_size=1)
--------------------------------------------------------------------------------
Captured Features Summary:
Layer Key                     Feature Count       Feature Shape                 
--------------------------------------------------------------------------------
conv1                         1                   (1, 64, 112, 112)             
layer4.1.relu                 1                   (1, 512, 7, 7)                
fully_connected               1                   (1, 1000)                     
--------------------------------------------------------------------------------


In [2]:
# 测试样例：统计模型参数量
from torchhook.utils import count_parameters, format_parameter_count, get_layerwise_parameter_count, model_summary

# 统计总参数量
print("Total parameters:", format_parameter_count(
    count_parameters(model)))

# 统计每一层的可训练参数量
trainable_layerwise_params = get_layerwise_parameter_count(model, trainable_only=True)
print("Trainable layerwise parameters:")
for layer, count in trainable_layerwise_params.items():
    print(f"  {layer}: {format_parameter_count(count)}")

Total parameters: 11.69 M
Trainable layerwise parameters:
  conv1: 9.41 K
  bn1: 128.00 
  relu: 0.00 
  maxpool: 0.00 
  layer1: 0.00 
  layer1.0: 0.00 
  layer1.0.conv1: 36.86 K
  layer1.0.bn1: 128.00 
  layer1.0.relu: 0.00 
  layer1.0.conv2: 36.86 K
  layer1.0.bn2: 128.00 
  layer1.1: 0.00 
  layer1.1.conv1: 36.86 K
  layer1.1.bn1: 128.00 
  layer1.1.relu: 0.00 
  layer1.1.conv2: 36.86 K
  layer1.1.bn2: 128.00 
  layer2: 0.00 
  layer2.0: 0.00 
  layer2.0.conv1: 73.73 K
  layer2.0.bn1: 256.00 
  layer2.0.relu: 0.00 
  layer2.0.conv2: 147.46 K
  layer2.0.bn2: 256.00 
  layer2.0.downsample: 0.00 
  layer2.0.downsample.0: 8.19 K
  layer2.0.downsample.1: 256.00 
  layer2.1: 0.00 
  layer2.1.conv1: 147.46 K
  layer2.1.bn1: 256.00 
  layer2.1.relu: 0.00 
  layer2.1.conv2: 147.46 K
  layer2.1.bn2: 256.00 
  layer3: 0.00 
  layer3.0: 0.00 
  layer3.0.conv1: 294.91 K
  layer3.0.bn1: 512.00 
  layer3.0.relu: 0.00 
  layer3.0.conv2: 589.82 K
  layer3.0.bn2: 512.00 
  layer3.0.downsample: 0.00 

In [4]:
# 打印模型摘要
model_summary(model, max_depth=2)

Model Summary: ResNet
--------------------------------------------------------------------------------
Total Parameters: 11.69 M
Trainable Parameters: 11.69 M
Non-trainable Parameters: 0.00 
--------------------------------------------------------------------------------
Layer Name                               Total Params         Trainable Params    
conv1                                    9.41 K               9.41 K              
bn1                                      128.00               128.00              
relu                                     0.00                 0.00                
maxpool                                  0.00                 0.00                
layer1                                   0.00                 0.00                
  layer1.0                               0.00                 0.00                
    layer1.0.conv1                       36.86 K              36.86 K             
    layer1.0.bn1                         128.00               12