In [1]:
import sys
sys.path.append('../')
import torch
import torch.nn as nn
from torchvision import models
from pyqcr.torchgraph.tracer import TorchTracer
from pyqcr.transformation.module_fusion import Fuser

In [18]:
def print_head(model, line_count=10):
    print('\n'.join(model.__repr__().split('\n')[:line_count]))

In [19]:
model = models.resnet18()
print_head(model)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)


In [20]:
# User tracer to getnerate graph and set model.graph attribute
# Currently graph not updated automatically when model changes. Requires careful usage.
inp = torch.rand((1,3,224,224))
with TorchTracer() as tt:
    tt.trace_model(model, inp)
    model.graph = tt.to_graph()

model.graph.to_namegraph()

{'Conv2d0': ['BatchNorm2d0'], 'Tensor0': ['Conv2d0'], 'BatchNorm2d0': ['ReLU0'], 'ReLU0': ['MaxPool2d0'], 'MaxPool2d0': ['Conv2d1', '__iadd__0'], 'Conv2d1': ['BatchNorm2d1'], '__iadd__0': ['ReLU2'], 'BatchNorm2d1': ['ReLU1'], 'ReLU1': ['Conv2d2'], 'Conv2d2': ['BatchNorm2d2'], 'BatchNorm2d2': ['__iadd__0'], 'ReLU2': ['Conv2d3', '__iadd__1'], 'Conv2d3': ['BatchNorm2d3'], '__iadd__1': ['ReLU4'], 'BatchNorm2d3': ['ReLU3'], 'ReLU3': ['Conv2d4'], 'Conv2d4': ['BatchNorm2d4'], 'BatchNorm2d4': ['__iadd__1'], 'ReLU4': ['Conv2d5', 'Conv2d7'], 'Conv2d5': ['BatchNorm2d5'], 'Conv2d7': ['BatchNorm2d7'], 'BatchNorm2d5': ['ReLU5'], 'ReLU5': ['Conv2d6'], 'Conv2d6': ['BatchNorm2d6'], 'BatchNorm2d6': ['__iadd__2'], '__iadd__2': ['ReLU6'], 'ReLU6': ['Conv2d8', '__iadd__3'], 'Conv2d8': ['BatchNorm2d8'], '__iadd__3': ['ReLU8'], 'BatchNorm2d7': ['__iadd__2'], 'BatchNorm2d11': ['__iadd__4'], 'Conv2d11': ['BatchNorm2d11'], 'BatchNorm2d8': ['ReLU7'], 'ReLU7': ['Conv2d9'], 'Conv2d9': ['BatchNorm2d9'], 'BatchNorm2

### Inspect fusion patterns

In [4]:
fuser = Fuser()

# Those are all supported fusion patterns
for j, p in enumerate(fuser.get_default_patterns()):
    s = ''
    for i, type_ in enumerate(p):
        s += "{} {} ".format(type_.__name__, '->' if i < len(p)-1 else '')
    print('{}: {}'.format(j, s))

0: Conv1d -> BatchNorm1d -> ReLU  
1: Conv2d -> BatchNorm2d -> ReLU  
2: Conv3d -> BatchNorm3d -> ReLU  
3: Conv1d -> BatchNorm1d  
4: Conv2d -> BatchNorm2d  
5: Conv3d -> BatchNorm3d  
6: BatchNorm2d -> ReLU  
7: BatchNorm3d -> ReLU  
8: Conv1d -> ReLU  
9: Conv2d -> ReLU  
10: Conv3d -> ReLU  
11: Linear -> ReLU  


In [5]:
# Find all fusable modules in the model, patterns are searched according to order above.
fuser.find_fusable_modules(model)

[((torch.nn.modules.conv.Conv2d,
   torch.nn.modules.batchnorm.BatchNorm2d,
   torch.nn.modules.activation.ReLU),
  [['conv1', 'bn1', 'relu'],
   ['layer1.0.conv1', 'layer1.0.bn1', 'layer1.0.relu'],
   ['layer1.1.conv1', 'layer1.1.bn1', 'layer1.1.relu'],
   ['layer2.0.conv1', 'layer2.0.bn1', 'layer2.0.relu'],
   ['layer2.1.conv1', 'layer2.1.bn1', 'layer2.1.relu'],
   ['layer3.0.conv1', 'layer3.0.bn1', 'layer3.0.relu'],
   ['layer3.1.conv1', 'layer3.1.bn1', 'layer3.1.relu'],
   ['layer4.1.conv1', 'layer4.1.bn1', 'layer4.1.relu'],
   ['layer4.0.conv1', 'layer4.0.bn1', 'layer4.0.relu']]),
 ((torch.nn.modules.conv.Conv2d, torch.nn.modules.batchnorm.BatchNorm2d),
  [['layer1.0.conv2', 'layer1.0.bn2'],
   ['layer1.1.conv2', 'layer1.1.bn2'],
   ['layer2.0.downsample.0', 'layer2.0.downsample.1'],
   ['layer2.0.conv2', 'layer2.0.bn2'],
   ['layer2.1.conv2', 'layer2.1.bn2'],
   ['layer3.0.downsample.0', 'layer3.0.downsample.1'],
   ['layer3.0.conv2', 'layer3.0.bn2'],
   ['layer3.1.conv2', 'layer

### Apply fusion pattens on model

In [23]:
model.eval()
fuser = Fuser()
# Apply all fusion patterns by defauld according to the predefined order
fused = fuser.fuse(model)
print_head(fused, line_count=25)

ResNet(
  (conv1): ConvReLU2d(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
    (1): ReLU(inplace=True)
  )
  (bn1): Identity()
  (relu): Identity()
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): ConvReLU2d(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
      )
      (bn1): Identity()
      (relu): Identity()
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): Identity()
    )
    (1): BasicBlock(
      (conv1): ConvReLU2d(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
      )
      (bn1): Identity()


In [24]:
# Or use specific patterns
fused = fuser.fuse(model, patterns=[[nn.Conv2d, nn.BatchNorm2d, nn.ReLU], [nn.Conv2d, nn.BatchNorm2d]])
print_head(fused, line_count=25)

ResNet(
  (conv1): ConvReLU2d(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
    (1): ReLU(inplace=True)
  )
  (bn1): Identity()
  (relu): Identity()
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): ConvReLU2d(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
      )
      (bn1): Identity()
      (relu): Identity()
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): Identity()
    )
    (1): BasicBlock(
      (conv1): ConvReLU2d(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
      )
      (bn1): Identity()


In [25]:
# use mode.train() or model.eval() to fuse for training or for evaluation
model.train()
fused = fuser.fuse(model)
print_head(fused, line_count=25)

ResNet(
  (conv1): ConvBnReLU2d(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (bn1): Identity()
  (relu): Identity()
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): ConvBnReLU2d(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
      )
      (bn1): Identity()
      (relu): Identity()
      (conv2): ConvBn2d(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (bn2): Identity()
    )
    (1): BasicBlock(
