AnyCapture是一个Python工具库,专门用于捕获函数执行过程中的局部变量。该库主要致力于解决深度学习模型中间结果提取的技术难题,特别适用于深度学习模型中Attention Map的可视化分析。
- 🚀 多变量捕获:支持通过装饰器同时捕获多个局部变量
- 📦 字典缓存:变量以结构化字典形式存储,便于管理和访问
- 🧹 缓存管理:提供clear()方法进行缓存清理
- 🔄 队列功能:支持限制缓存大小,自动管理内存使用
在深度学习模型可视化过程中,开发者经常遇到以下技术挑战:
传统解决方案的局限性:
- 返回值传递法:需要修改模型结构,将嵌套在模型深处的Attention Map逐层返回,在训练时又需要还原代码
- 全局变量法:使用全局变量直接记录Attention Map,容易在训练时遗忘修改导致内存溢出
这些问题在实际开发中普遍存在,严重影响了开发效率。
PyTorch Hook机制的技术限制:
虽然PyTorch提供了hook机制来获取中间结果:
handle = net.conv2.register_forward_hook(hook)但在实际应用中存在以下技术障碍:
以Vision Transformer为例,其典型结构如下:
class VisionTransformer(nn.Module):
def __init__(self, *args, **kwargs):
...
self.blocks = nn.Sequential(*[Block(...) for i in range(depth)])
...每个Block中包含Attention模块:
class Block(nn.Module):
def __init__(self, *args, **kwargs):
...
self.attn = Attention(...)
...Hook机制的技术挑战:
- 模块路径复杂:深度嵌套的模块结构导致准确定位目标模块困难
- 批量注册繁琐:Transformer中每层都包含attention map,逐个注册hook效率低下
AnyCapture的技术优势:
基于上述技术分析,AnyCapture提供了一种更为简洁高效的解决方案,具备以下核心特性:
- 🎯 精准定位:支持按变量名精确捕获模型中间结果
- ⚡ 多变量支持:装饰器支持同时捕获多个目标变量
- 🚀 高效便捷:可批量获取Transformer模型中所有层的attention map
- 🔄 非侵入式设计:无需修改现有函数代码
- 🎯 开发友好:可视化分析完成后无需修改训练代码
使用pip安装AnyCapture:
pip install AnyCapture安装完成后,通过get_local装饰器可以便捷地捕获函数内部的局部变量。
以捕获attention_map变量为例:
步骤1:在模型文件中添加装饰器
from anycapture import get_local
@get_local('attention_map')
def your_attention_function(*args, **kwargs):
...
attention_map = ...
...
return ...步骤2:在分析代码中激活装饰器并获取结果
from anycapture import get_local
get_local.activate() # 激活装饰器
from ... import model # 注意:模型导入必须在装饰器激活之后
# 加载模型和数据
...
output = model(data)
# 获取捕获的变量
cache = get_local.cache # 输出格式:{'your_attention_function.attention_map': [attention_map]}捕获结果以字典形式存储在get_local.cache中,键值格式为函数名.变量名,对应值为变量值列表。
# 查看缓存内容
print(get_local.cache)
# 清空缓存
get_local.clear()
# 激活/取消激活
get_local.activate() # 激活捕获
get_local.deactivate() # 取消激活,提高性能
# 队列功能:限制缓存大小
get_local.activate(max_size=10) # 只保留最近10次结果
get_local.set_size(5) # 动态调整为5个元素详细文档请参考:DOC.md | demo.ipynb | 更新日志
以下展示了使用AnyCapture对Vision Transformer小型模型(vit_small)进行可视化分析的部分结果。完整案例请参考 demo.ipynb。
由于标准Vision Transformer的所有Attention Map均在Attention.forward方法中计算,仅需对该方法添加装饰器,即可批量提取模型12层Transformer的全部Attention Map数据。
单个Attention Head可视化结果:
单层全部Attention Heads可视化结果:
网格级别Attention Map可视化:
原始作者: luo3300612
原始项目: Visualizer
当前维护者: zzaiyan
本项目基于luo3300612的Visualizer项目进行重构和功能扩展。为避免与PyPI现有软件包的命名冲突,项目重命名为AnyCapture。特此对原作者的卓越贡献表示诚挚感谢。


