In [2]:
import torch
from pytorch.src.models import ResNet

# load model
model_path = "pytorch/model/ResNet18.pth"
model = ResNet(18, 10)
model.load_state_dict(torch.load(model_path))
model.eval()

# transfer to TorchScript
example_input = torch.randn(1, 1, 28, 28)  # input shape: [batch_size, channel, height, width]
traced_model = torch.jit.trace(model, example_input)




In [3]:
traced_model

ResNet(
  original_name=ResNet
  (conv1): Sequential(
    original_name=Sequential
    (0): Conv2d(original_name=Conv2d)
    (1): BatchNorm2d(original_name=BatchNorm2d)
    (2): ReLU(original_name=ReLU)
    (3): MaxPool2d(original_name=MaxPool2d)
  )
  (stage1): Sequential(
    original_name=Sequential
    (0): ResidualBlock(
      original_name=ResidualBlock
      (CB1): ConvBN(
        original_name=ConvBN
        (0): Conv2d(original_name=Conv2d)
        (1): BatchNorm2d(original_name=BatchNorm2d)
      )
      (CB2): ConvBN(
        original_name=ConvBN
        (0): Conv2d(original_name=Conv2d)
        (1): BatchNorm2d(original_name=BatchNorm2d)
      )
    )
    (1): ResidualBlock(
      original_name=ResidualBlock
      (CB1): ConvBN(
        original_name=ConvBN
        (0): Conv2d(original_name=Conv2d)
        (1): BatchNorm2d(original_name=BatchNorm2d)
      )
      (CB2): ConvBN(
        original_name=ConvBN
        (0): Conv2d(original_name=Conv2d)
        (1): BatchNorm2d(o

In [4]:
import tvm
from tvm import relay

In [5]:
target = "llvm"

In [6]:
input_shape = example_input.shape
input_name = "input"

In [7]:
mod, params = relay.frontend.from_pytorch(traced_model, [(input_name, input_shape)])

In [8]:
print(mod)

type List[A] {
  Cons(A, List[A]),
  Nil,
}

type Option[A] {
  Some(A),
  None,
}

type Tree[A] {
  Rose(A, List[Tree[A]]),
}

type tensor_float16_t {
  tensor_nil_float16,
  tensor0_float16(float16),
  tensor1_float16(Tensor[(?), float16]),
  tensor2_float16(Tensor[(?, ?), float16]),
  tensor3_float16(Tensor[(?, ?, ?), float16]),
  tensor4_float16(Tensor[(?, ?, ?, ?), float16]),
  tensor5_float16(Tensor[(?, ?, ?, ?, ?), float16]),
  tensor6_float16(Tensor[(?, ?, ?, ?, ?, ?), float16]),
}

type tensor_float32_t {
  tensor_nil_float32,
  tensor0_float32(float32),
  tensor1_float32(Tensor[(?), float32]),
  tensor2_float32(Tensor[(?, ?), float32]),
  tensor3_float32(Tensor[(?, ?, ?), float32]),
  tensor4_float32(Tensor[(?, ?, ?, ?), float32]),
  tensor5_float32(Tensor[(?, ?, ?, ?, ?), float32]),
  tensor6_float32(Tensor[(?, ?, ?, ?, ?, ?), float32]),
}

type tensor_float64_t {
  tensor_nil_float64,
  tensor0_float64(float64),
  tensor1_float64(Tensor[(?), float64]),
  tensor2_float64(Ten

In [9]:
with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, target=target, params=params)

One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.


In [41]:
from torch.utils import data
from torchvision import datasets, transforms
class MNIST(data.DataLoader):

    def __init__(self, batch_size: int, train: bool, **kwargs):
        transform = transforms.Compose([
            # transforms.Resize((28,28)),
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])

        dataset = datasets.MNIST('./pytorch/data/mnist', train=False, transform=transform)

        super(MNIST, self).__init__(dataset=dataset, batch_size=batch_size, shuffle=train, **kwargs)

test_loader = MNIST(batch_size=1, train=False)

In [42]:
for i, (data, target) in enumerate(test_loader):
    if i == 0:
        break

In [49]:
import time

In [69]:
import numpy as np
from tvm.contrib import graph_executor

# create a runtime executor module
dev = tvm.cpu()
module = graph_executor.GraphModule(lib["default"](dev))

# input data
input_data = data.reshape(1, 1, 28, 28)
input_data = np.array(input_data.numpy(), dtype="float32")

# process model
start = time.time()
for i in range(1000):
    tvm_data = tvm.nd.array(input_data)
    module.set_input(input_name, tvm_data)
    module.run()
    output = module.get_output(0).asnumpy()
end = time.time()
print("TVM running time:", end - start)

# get output
output = module.get_output(0).asnumpy()
print("TVM output:", output)

TVM running time: 1.0734262466430664
TVM output: [[3.1008712e-23 1.2309332e-20 9.8386543e-24 1.6017456e-19 2.9745590e-22
  8.2686773e-26 9.5610522e-23 1.0000000e+00 2.6356636e-23 2.4026528e-19]]


In [66]:
model.eval()
start = time.time()
for i in range(1000):
    output = model(data)
end = time.time()
print("PyTorch runnung time:", end - start)
print("PyTorch output:", output.detach().numpy())

PyTorch runnung time: 1.7855355739593506
PyTorch output: [[3.1008592e-23 1.2309191e-20 9.8385044e-24 1.6017272e-19 2.9745476e-22
  8.2685510e-26 9.5609797e-23 1.0000000e+00 2.6356335e-23 2.4026345e-19]]
