In [None]:
import copy
import os
import time
import warnings
from time import time

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.ao.quantization import DeQuantStub, QuantStub
from torch.utils.data import DataLoader
from tqdm.auto import trange

warnings.filterwarnings("ignore")

In [None]:
class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # QuantStub converts tensors from floating point to quantized
        self.quant = torch.quantization.QuantStub()
        self.conv = torch.nn.Conv2d(1, 5, 3)
        self.relu = torch.nn.ReLU()
        self.flatten = torch.nn.Flatten()
        self.linear = torch.nn.Linear(4500, 100)
        # DeQuantStub converts tensors from quantized to floating point
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        # manually specify where tensors will be converted from floating
        # point to quantized in the quantized model
        x = self.quant(x)
        start = time()
        x = self.conv(x)
        x = self.relu(x)
        x = self.linear(self.flatten(x))
        # manually specify where tensors will be converted from quantized
        # to floating point in the quantized model
        x = self.dequant(x)
        return x

In [None]:
# create a model instance
model_fp32 = M()

# model must be set to eval mode for static quantization logic to work
model_fp32.eval()

M(
  (quant): QuantStub()
  (conv): Conv2d(1, 5, kernel_size=(3, 3), stride=(1, 1))
  (relu): ReLU()
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear): Linear(in_features=4500, out_features=100, bias=True)
  (dequant): DeQuantStub()
)

In [None]:
torch.quantization.fuse_modules(model_fp32, [["conv", "relu"]])

M(
  (quant): QuantStub()
  (conv): ConvReLU2d(
    (0): Conv2d(1, 5, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
  )
  (relu): Identity()
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear): Linear(in_features=4500, out_features=100, bias=True)
  (dequant): DeQuantStub()
)