# TorchHook 使用示例
本示例展示了如何使用 TorchHook 捕获 PyTorch 模型的中间特征图。

In [1]:
import torch
import torch.nn as nn
from torchhook import HookManager

# 定义一个简单的模型
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3)
        self.relu = nn.ReLU()
        self.fc = nn.Linear(16 * 30 * 30, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# 初始化模型和 HookManager
model = MyModel()
hook_manager = HookManager(model)

# 使用 layer_name 注册 hooks（推荐，简单易用）
hook_manager.register_forward_hook(layer_name="conv1")

# 使用 layer 对象注册 hooks（自动命名为：类名+序号）
hook_manager.register_forward_hook(layer=model.relu)

# 使用自定义名称注册 hooks（适用于调试时区分不同 hooks）
hook_manager.register_forward_hook('CustomName', layer=model.fc)

# 运行模型
for _ in range(5):
    # 生成随机输入数据
    input_tensor = torch.randn(2, 3, 32, 32)
    output = model(input_tensor)

# 打印 HookManager 信息
print(hook_manager)
print("Current keys:", hook_manager.get_keys())  # 获取所有注册的 hooks 名称

# 获取中间结果（特征图）
print("\nconv1:", hook_manager.get_features('conv1')[0].shape)  # conv1 的特征图
print("   fc:", hook_manager.get_features('CustomName')[0].shape)  # fc 的特征图

# 获取所有特征图
all_features = hook_manager.get_all()

# 将每列的特征图 concat 起来（数据量过大时可能会内存溢出）
concatenated_features = {key: torch.cat(features, dim=0) for key, features in all_features.items()}

# 计算均值和标准差
stats = {key: (torch.mean(value), torch.std(value)) for key, value in concatenated_features.items()}

# 打印结果
print("\nMean and Std of features:")
for key, (mean, std) in stats.items():
    print(f"Layer: {key}, Mean: {mean.item():.4f}, Std: {std.item():.4f}")

# 清理 hooks 和特征图
hook_manager.clear_hooks()
hook_manager.clear_features()

Model: MyModel | Total Parameters: 144.46 K
Registered Hooks: 3 (max_size=unlimited)
--------------------------------------------------------------------------------
Captured Features Summary:
Layer Key                     Feature Count       Feature Shape                 
--------------------------------------------------------------------------------
conv1                         5                   (2, 16, 30, 30)               
ReLU_0                        5                   (2, 16, 30, 30)               
CustomName                    5                   (2, 10)                       
--------------------------------------------------------------------------------
Current keys: ['conv1', 'ReLU_0', 'CustomName']

conv1: torch.Size([2, 16, 30, 30])
   fc: torch.Size([2, 10])

Mean and Std of features:
Layer: conv1, Mean: 0.0469, Std: 0.5751
Layer: ReLU_0, Mean: 0.2531, Std: 0.3494
Layer: CustomName, Mean: 0.0771, Std: 0.2227


In [2]:
# 测试样例：统计模型参数量
from torchhook.utils import count_parameters, format_parameter_count, get_layerwise_parameter_count, model_summary
import torch.nn as nn

# 定义一个简单的模型


class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3)
        self.conv2 = nn.Conv2d(16, 32, 3)
        self.fc1 = nn.Linear(32 * 6 * 6, 120)
        self.act = nn.ReLU()
        self.fc2 = nn.Linear(120, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return x


# 初始化模型
model = SimpleModel()

# 统计总参数量
print("Total parameters:", format_parameter_count(
    count_parameters(model)))

# 统计每一层的可训练参数量
trainable_layerwise_params = get_layerwise_parameter_count(model, trainable_only=True)
print("Trainable layerwise parameters:")
for layer, count in trainable_layerwise_params.items():
    print(f"  {layer}: {count}")

# 打印模型摘要
from torchvision.models import vgg11
model_summary(vgg11(), max_depth=2)

Total parameters: 144.66 K
Trainable layerwise parameters:
  conv1: 448
  conv2: 4640
  fc1: 138360
  act: 0
  fc2: 1210
Model Summary: VGG
--------------------------------------------------------------------------------
Total Parameters: 132.86 M
Trainable Parameters: 132.86 M
Non-trainable Parameters: 0.00 
--------------------------------------------------------------------------------
Layer Name                               Total Params         Trainable Params    
features                                 0.00                 0.00                
  features.0                             1.79 K               1.79 K              
  features.1                             0.00                 0.00                
  features.2                             0.00                 0.00                
  features.3                             73.86 K              73.86 K             
  features.4                             0.00                 0.00                
  features.5               