In [1]:
import logging
import torch
from torch import nn, fx, optim
from torch.nn import functional as F
from torch.profiler import profile, record_function, ProfilerActivity
from torch.ao.quantization import observer as _observer
from torchvision.models import resnet18, ResNet18_Weights
from torch_book.data.simple_vision import load_data_cifar10
from torch_book.tools import train, try_all_gpus # try_gpu
from torch_book.tools import evaluate_accuracy
from fixed_sigmoid import linear_sigmoid as _linear_sigmoid
from fixed_sigmoid import get_nodes as _get_nodes
# from fit_sigmoid import sigmoid2linear

def set_bin(xs, bins=5):
    n_calib = len(xs)
    step = n_calib//bins
    full_index = xs.argsort(dim=0)
    index = full_index[::step]
    return torch.concat([full_index[::step],
                         full_index[-1].unsqueeze(dim=0)])
    
def get_nodes(xs, bins=5):
    with torch.no_grad():
        obs = _observer.PerChannelMinMaxObserver(ch_axis=0,
                                                 dtype=torch.qint8,
                                                 qscheme=torch.per_channel_affine)
        obs(xs)
        min_val = obs.min_val
        min_val = min_val.take(set_bin(min_val, bins=bins)).numpy()
        max_val = obs.max_val
        max_val = max_val.take(set_bin(max_val, bins=bins)).numpy()
        value_ranges = list(set(min_val) | set(max_val))
    # logging.info(f"节点信息：\n {value_ranges}")
    return _get_nodes(value_ranges)

In [2]:
logging.basicConfig(filename='draft/test.log',
                    filemode="w",
                    format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s-%(funcName)s',
                    level=logging.DEBUG)
torch.cuda.empty_cache() # 清空 GPU 缓存

batch_size = 1000
train_iter, test_iter = load_data_cifar10(batch_size=batch_size)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
for xs, _ in train_iter:
    break

In [3]:
def linear_sigmoid(x, bins=4):
    nodes = get_nodes(x, bins=bins)
    y = _linear_sigmoid(x, nodes)
    return y

In [20]:
model = resnet18(weights=ResNet18_Weights.DEFAULT)
model.conv1 = nn.Conv2d(model.conv1.in_channels, 
                        model.conv1.out_channels, 
                        3, 1, 1)
model.maxpool = nn.Identity()
model.fc = nn.Linear(model.fc.in_features, 10)

# 替换 ReLU 为 torch.sigmoid
mod = fx.symbolic_trace(model)
# 遍历 Graph 中全部节点
for node in mod.graph.nodes:
    # 如果匹配目标
    if node.op == "call_module":
        if "relu" in node.target:
            # 设置插入点，添加新节点，用新节点替换所有 `node` 的用法
            with mod.graph.inserting_after(node):
                new_node = mod.graph.call_function(torch.sigmoid, node.args, node.kwargs)
                node.replace_all_uses_with(new_node)
            # 移除 graph 中旧的节点
            mod.graph.erase_node(node)
mod.graph.lint()
# 不用忘记 recompile!
new_code = mod.recompile()
# 加载模型参数
state_dict = torch.load("models/resnet18_cifar10_sigmoid.h5")
mod.load_state_dict(state_dict)

<All keys matched successfully>

In [21]:
for node in mod.graph.nodes:
    # 如果匹配目标
    if node.op == "call_function":
        if node.target == torch.sigmoid:
            print(node.name)
            # _observer
            # mod.graph.call_function(torch.sigmoid, node.args, node.kwargs)
            break

sigmoid


In [23]:
inputs = node.all_input_nodes
s = inputs[0]

In [5]:
# valid_acc = evaluate_accuracy(mod, test_iter)
# valid_acc
# 0.8407

In [5]:
# from fixed_sigmoid import linear_sigmoid
# 遍历 Graph 中全部节点
for node in mod.graph.nodes:
    # 如果匹配目标
    if node.op == "call_function":
        # if node.name == "sigmoid":
        if "sigmoid" in node.name:
            node.target = linear_sigmoid
            print(node.name, node.target)
mod.graph.lint()
# 不用忘记 recompile!
new_code = mod.recompile()

sigmoid <function linear_sigmoid at 0x7fa21c8900d0>
sigmoid_1 <function linear_sigmoid at 0x7fa21c8900d0>
sigmoid_2 <function linear_sigmoid at 0x7fa21c8900d0>
sigmoid_3 <function linear_sigmoid at 0x7fa21c8900d0>
sigmoid_4 <function linear_sigmoid at 0x7fa21c8900d0>
sigmoid_5 <function linear_sigmoid at 0x7fa21c8900d0>
sigmoid_6 <function linear_sigmoid at 0x7fa21c8900d0>
sigmoid_7 <function linear_sigmoid at 0x7fa21c8900d0>
sigmoid_8 <function linear_sigmoid at 0x7fa21c8900d0>
sigmoid_9 <function linear_sigmoid at 0x7fa21c8900d0>
sigmoid_10 <function linear_sigmoid at 0x7fa21c8900d0>
sigmoid_11 <function linear_sigmoid at 0x7fa21c8900d0>
sigmoid_12 <function linear_sigmoid at 0x7fa21c8900d0>
sigmoid_13 <function linear_sigmoid at 0x7fa21c8900d0>
sigmoid_14 <function linear_sigmoid at 0x7fa21c8900d0>
sigmoid_15 <function linear_sigmoid at 0x7fa21c8900d0>
sigmoid_16 <function linear_sigmoid at 0x7fa21c8900d0>


In [7]:
valid_acc = evaluate_accuracy(mod, test_iter)
valid_acc

0.1