In [1]:
import torch
import torchvision.models as models
from Torch2Tensor.t2t import T2TParser



In [3]:
mlc_dict = dict(target='cuda --max_threads_per_block=1024 --max_shared_memory_per_block=49152', work_dir="./tune_tmp", 
            task_name='main', max_trials_global=64, 
            num_trials_per_iter=32, compile_tir_target='cuda')

In [2]:
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> torch.nn.Conv2d:
    """3x3 convolution with padding"""
    return torch.nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> torch.nn.Conv2d:
    """1x1 convolution"""
    return torch.nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class Demo(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = conv3x3(3, 3, 1)
        self.bn1 = torch.nn.BatchNorm2d(3)
        self.relu = torch.nn.ReLU(inplace=True)
        self.conv2 = conv3x3(3, 3)
        self.bn2 = torch.nn.BatchNorm2d(3)
        self.downsample = None
        self.stride = 1
    
    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


In [4]:
model = Demo()
x = torch.randn((1,3, 224, 224))
input_shapes = [(1,3, 224, 224)]

PR = T2TParser(model, x, input_shapes, **mlc_dict)

2023-03-09 16:59:45.146 | INFO     | Torch2Tensor.t2t_optimizer.tir_tune.mlc_tune:__init__:22 - target: cuda --max_threads_per_block=1024 --max_shared_memory_per_block=49152; compile_tir_target: cuda


In [5]:
PR.convert()
PR.print_tabular(PR.pytorch_graph)
PR.print_ir(PR.RelaxIR)

2023-03-09 16:59:56.011 | DEBUG    | Torch2Tensor.t2t:print_tabular:48 - 
opcode         name    target                   args       kwargs
-------------  ------  -----------------------  ---------  --------
placeholder    x       x                        ()         {}
call_module    conv1   conv1                    (x,)       {}
call_module    bn1     bn1                      (conv1,)   {}
call_module    relu    relu                     (bn1,)     {}
call_module    conv2   conv2                    (relu,)    {}
call_module    bn2     bn2                      (conv2,)   {}
call_function  add     <built-in function add>  (bn2, x)   {}
call_module    relu_1  relu                     (add,)     {}
output         output  output                   (relu_1,)  {}
2023-03-09 16:59:56.013 | INFO     | Torch2Tensor.t2t_relax.input_layer:_generate_input:15 - input_layer: x created
2023-03-09 16:59:56.015 | DEBUG    | Torch2Tensor.t2t_relax.module.conv_layer:generate_node:32 - {'strides': (1, 1), '

2023-03-09 16:59:56.115 | INFO     | Torch2Tensor.t2t:print_ir:59 - None


In [6]:
PR.fuse_op()
PR.print_ir(PR.RelaxIR)

2023-03-09 17:00:07.728 | INFO     | Torch2Tensor.t2t:print_ir:59 - None


In [7]:
PR.gen_TensorIR()
PR.print_ir(PR.TensorIR)
PR.print_op(PR.TensorIR)

2023-03-09 17:00:19.470 | INFO     | Torch2Tensor.t2t:print_ir:59 - None
2023-03-09 17:00:19.479 | INFO     | Torch2Tensor.t2t:print_op:68 - ['add', 'fused_cbr0', 'fused_cb0', 'relu']


In [8]:
PR.tune_tir()
PR.print_ir(PR.tuned_TensorIR)

2023-03-09 17:05:55 [INFO] [task_scheduler.cc:260] Task #0 has finished. Remaining task(s): 0


Unnamed: 0,Name,FLOP,Weight,Speed (GFLOPS),Latency (us),Weighted Latency (us),Trials,Done
0,main,150528,1,62.1716,2.4212,2.4212,6,Y


2023-03-09 17:05:55 [DEBUG] [task_scheduler.cc:318] 
 ID | Name |   FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
-----------------------------------------------------------------------------------------------------
  0 | main | 150528 |      1 |        62.1716 |       2.4212 |                2.4212 |      6 |    Y 
-----------------------------------------------------------------------------------------------------
Total trials: 6
Total latency (us): 2.42117



2023-03-09 17:05:55.727 | INFO     | Torch2Tensor.t2t:print_ir:59 - None


In [9]:
PR.check_result()
PR.infer_benchmark()

2023-03-09 17:06:02.587 | INFO     | Torch2Tensor.benchmark:check_result:73 - accuracy test passed
2023-03-09 17:06:02.631 | INFO     | Torch2Tensor.benchmark:inf:98 - tensor program inf time: 0.036164(ms)
2023-03-09 17:06:03.011 | INFO     | Torch2Tensor.benchmark:inf:119 - torch model inf time : 0.315718(ms)
