In [1]:
import torch
from pytorch.src.models import ResNet

device = "cpu"

# load model
model_path = "pytorch/model/ResNet18.pth"
model = ResNet(18, 10)
model.load_state_dict(torch.load(model_path))
model.to(device)
model.eval()

# transfer to TorchScript
example_input = torch.randn(1, 1, 28, 28).to(device)  # input shape: [batch_size, channel, height, width]
traced_model = torch.jit.trace(model, example_input)




In [2]:
import tvm
from tvm import relay

In [3]:
tvm.__version__

'0.20.dev0'

In [4]:
target = tvm.target.Target("llvm")

In [5]:
target.keys

["cpu"]

In [6]:
target = tvm.target.Target("llvm -mtriple=armv7l-linux-gnueabihf")

In [9]:
target.keys

["arm_cpu", "cpu"]

In [10]:
target.kind.name

'llvm'

In [37]:
input_shape = example_input.shape
input_name = "input"

In [38]:
mod, params = relay.frontend.from_pytorch(traced_model, [(input_name, input_shape)])

In [None]:
# target = tvm.target.Target("llvm -mtriple=armv7l-linux-gnueabihf")

# # Compile with the ARM target
# with tvm.transform.PassContext(opt_level=3):
#     lib = relay.build(mod, target=target, params=params)

# # Export the compiled model
# lib.export_library("tvm_rpi3.tar")

In [39]:
with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, target=target, params=params)

In [40]:
from torch.utils import data
from torchvision import datasets, transforms
class MNIST(data.DataLoader):

    def __init__(self, batch_size: int, train: bool, **kwargs):
        transform = transforms.Compose([
            # transforms.Resize((28,28)),
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])

        dataset = datasets.MNIST('./pytorch/data/mnist', train=False, transform=transform)

        super(MNIST, self).__init__(dataset=dataset, batch_size=batch_size, shuffle=train, **kwargs)

test_loader = MNIST(batch_size=1, train=False)

In [41]:
for i, (data, target) in enumerate(test_loader):
    if i == 0:
        break

In [42]:
import time

In [43]:
import numpy as np
from tvm.contrib import graph_executor

# create a runtime executor module
dev = tvm.cpu()
module = graph_executor.GraphModule(lib["default"](dev))

# input data
input_data = data.reshape(1, 1, 28, 28)
input_data = np.array(input_data.numpy(), dtype="float32")

# process model
start = time.time()
for i in range(1000):
    tvm_data = tvm.nd.array(input_data)
    module.set_input(input_name, tvm_data)
    module.run()
    output = module.get_output(0).asnumpy()
end = time.time()
print("TVM running time:", end - start)

# get output
output = module.get_output(0).asnumpy()
print("TVM output:", output)

TVM running time: 1.0580708980560303
TVM output: [[3.1008712e-23 1.2309332e-20 9.8386543e-24 1.6017456e-19 2.9745590e-22
  8.2686773e-26 9.5610522e-23 1.0000000e+00 2.6356636e-23 2.4026528e-19]]


In [44]:
model.eval()
start = time.time()
for i in range(1000):
    output = model(data)
end = time.time()
print("PyTorch runnung time:", end - start)
print("PyTorch output:", output.detach().numpy())

PyTorch runnung time: 1.8217601776123047
PyTorch output: [[3.1008592e-23 1.2309191e-20 9.8385044e-24 1.6017272e-19 2.9745476e-22
  8.2685510e-26 9.5609797e-23 1.0000000e+00 2.6356335e-23 2.4026345e-19]]
