In [1]:
import logging
from typing import Any, Callable, Dict, Optional, Tuple
import torch
from torch import nn, fx, optim
from torch.nn import functional as F
from torch.profiler import profile, record_function, ProfilerActivity
from torch.ao.quantization import observer as _observer
from torchvision.models import resnet18, ResNet18_Weights
from torch_book.data.simple_vision import load_data_cifar10
from torch_book.tools import train, try_all_gpus # try_gpu
from torch_book.tools import evaluate_accuracy
from fixed_sigmoid import linear_sigmoid as _linear_sigmoid
from fixed_sigmoid import get_nodes as _get_nodes


def create_model():
    model = resnet18()
    model.conv1 = nn.Conv2d(model.conv1.in_channels, 
                            model.conv1.out_channels, 
                            3, 1, 1)
    model.maxpool = nn.Identity()
    model.fc = nn.Linear(model.fc.in_features, 10)
    state_dict = torch.load("models/resnet18_cifar10_origin.h5")
    model.load_state_dict(state_dict)
    return model

def replace(model):
    # 替换 ReLU 为 torch.sigmoid
    mod = fx.symbolic_trace(model)
    # 遍历 Graph 中全部节点
    for node in mod.graph.nodes:
        # 如果匹配目标
        if node.op == "call_module":
            if "relu" in node.target:
                # 设置插入点，添加新节点，用新节点替换所有 `node` 的用法
                with mod.graph.inserting_after(node):
                    new_node = mod.graph.call_function(torch.sigmoid, node.args, node.kwargs)
                    node.replace_all_uses_with(new_node)
                # 移除 graph 中旧的节点
                mod.graph.erase_node(node)
    mod.graph.lint()
    # 不用忘记 recompile!
    new_code = mod.recompile()
    state_dict = torch.load("models/resnet18_cifar10_sigmoid.h5")
    mod.load_state_dict(state_dict)
    return mod

logging.basicConfig(filename='draft/test.log',
                    filemode="w",
                    format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s-%(funcName)s',
                    level=logging.DEBUG)
torch.cuda.empty_cache() # 清空 GPU 缓存

查看 sigmoid 对应的网络的精度：

In [2]:
model = create_model()
mod = replace(model)

# 加载数据
batch_size = 256
train_iter, test_iter = load_data_cifar10(batch_size=batch_size)
# 加载模型参数
valid_acc = evaluate_accuracy(mod, test_iter)
valid_acc

Files already downloaded and verified
Files already downloaded and verified


0.8407

In [10]:
mod.graph.print_tabular()

opcode         name                   target                                                      args                                   kwargs
-------------  ---------------------  ----------------------------------------------------------  -------------------------------------  --------
placeholder    x                      x                                                           ()                                     {}
call_module    conv1                  conv1                                                       (x,)                                   {}
call_module    bn1                    bn1                                                         (conv1,)                               {}
call_function  sigmoid                <built-in method sigmoid of type object at 0x7fadc34bf200>  (bn1,)                                 {}
call_module    maxpool                maxpool                                                     (sigmoid,)                             {}
call_modul

In [6]:
class Tracer(fx.Tracer):
    """
    Tracer 是 FX 跟踪器，对于每个运算，它还记录了运算起源于的模块的限定名。
    """
    
    # 正在跟踪的模块的当前限定名。
    # 顶级模块由空字符串表示。
    # 在进入 ``call_module`` 时更新，在退出 ``call_module`` 时恢复
    current_module_qualified_name : str = ''
    # 从 FX 节点到它起源模块的 qualname 的映射
    # 这在记录运算时由 `create_proxy` 记录
    node_to_originating_module : Dict[torch.fx.Node, str] = {}

    def call_module(self, m: torch.nn.Module, forward: Callable[..., Any],
                    args : Tuple[Any, ...], kwargs : Dict[str, Any]) -> Any:
        """
        1. 存储调用者的限定名称以便稍后恢复
        2. 在 `current_module_qualified_name` 中安装(install)调用者的限定名，以供 `create_proxy` 检索。
        3. 委托到正常的 Tracer.call_module 方法
        4. 将调用者的限定名恢复到 current_module_qualified_name 中
        """
        old_qualname = self.current_module_qualified_name
        try:
            self.current_module_qualified_name = self.path_of_module(m)
            return super().call_module(m, forward, args, kwargs)
        finally:
            self.current_module_qualified_name = old_qualname

    def create_proxy(self, kind: str, target: torch.fx.node.Target, args: Tuple[Any, ...],
                     kwargs: Dict[str, Any], name: Optional[str] = None, type_expr: Optional[Any] = None):
        """
        Override of `Tracer.create_proxy`. This override intercepts the recording
        of every operation and stores away the current traced module's qualified
        name in `node_to_originating_module`
        """
        proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr)
        self.node_to_originating_module[proxy.node] = self.current_module_qualified_name
        return proxy

In [9]:
# Instantiate our ModulePathTracer and use that to trace our ResNet18
tracer = Tracer()
traced_mod = tracer.trace(mod)

# Print (node, module qualified name) for every node in the Graph
for node in traced_mod.nodes:
    module_qualname = tracer.node_to_originating_module.get(node)
    print('Node', node, 'is from module', module_qualname)

Node x is from module 
Node conv1 is from module conv1
Node bn1 is from module bn1
Node sigmoid is from module 
Node maxpool is from module maxpool
Node layer1_0_conv1 is from module layer1.0.conv1
Node layer1_0_bn1 is from module layer1.0.bn1
Node sigmoid_1 is from module 
Node layer1_0_conv2 is from module layer1.0.conv2
Node layer1_0_bn2 is from module layer1.0.bn2
Node add is from module 
Node sigmoid_2 is from module 
Node layer1_1_conv1 is from module layer1.1.conv1
Node layer1_1_bn1 is from module layer1.1.bn1
Node sigmoid_3 is from module 
Node layer1_1_conv2 is from module layer1.1.conv2
Node layer1_1_bn2 is from module layer1.1.bn2
Node add_1 is from module 
Node sigmoid_4 is from module 
Node layer2_0_conv1 is from module layer2.0.conv1
Node layer2_0_bn1 is from module layer2.0.bn1
Node sigmoid_5 is from module 
Node layer2_0_conv2 is from module layer2.0.conv2
Node layer2_0_bn2 is from module layer2.0.bn2
Node layer2_0_downsample_0 is from module layer2.0.downsample.0
Node 