In [12]:
import numpy as np
import mxnet as mx
import tvm
from tvm import relay

In [2]:
model = mx.gluon.model_zoo.vision.resnet18_v2(pretrained=True)
len(model.features), model.output

Downloading /home/sean/.mxnet/models/resnet18_v2-a81db45f.zip8f1a84da-c783-4cc7-bc85-5126566d18cd from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/models/resnet18_v2-a81db45f.zip...


(13, Dense(512 -> 1000, linear))

In [3]:
with open('./imagenet1k_labels.txt') as f:
    labels = eval(f.read())

In [8]:
import cv2 
import numpy as np 
image = cv2.imread('./cat.png', cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (224, 224)).astype(np.float32) 

In [9]:
def image_preprocessing(image):
    image = np.array(image) - np.array([123., 117., 104.])
    image /= np.array([58.395, 57.12, 57.375])
    image = image.transpose((2, 0, 1))
    image = image[np.newaxis, :]
    return image.astype('float32')

In [10]:
x = image_preprocessing(image)
x.shape

(1, 3, 224, 224)

In [13]:
relay_mod, relay_params = relay.frontend.from_mxnet(model, {'data': x.shape})
type(relay_mod), type(relay_params)

(tvm.ir.module.IRModule, dict)

In [15]:
target = 'llvm -mcpu tigerlake'
with relay.build_config(opt_level=3):
    graph, mod, params = relay.build(relay_mod, target, params=relay_params)

  graph, mod, params = relay.build(relay_mod, target, params=relay_params)


In [16]:
type(graph), type(mod), type(params)

(str, tvm.runtime.module.Module, dict)

In [19]:
ctx = tvm.device(target)
rt = tvm.contrib.graph_executor.create(graph, mod, ctx)
rt.set_input(**params)
rt.run(data=tvm.nd.array(x))
scores = rt.get_output(0).asnumpy()[0]
scores.shape

(1000,)

In [20]:
a = np.argsort(scores)[-1:-5:-1]
labels[a[0]], labels[a[1]]

('Egyptian cat', 'tabby, tabby cat')

In [21]:
!rm -rf resnet18*

name = 'resnet18'
graph_fn, mod_fn, params_fn = [name+ext for ext in ('.json','.tar','.params')]
mod.export_library(mod_fn)
with open(graph_fn, 'w') as f:
    f.write(graph)
with open(params_fn, 'wb') as f:
    f.write(relay.save_param_dict(params))

!ls -alht resnet18*

-rw-rw-r-- 1 sean sean 45M Jun  4 15:09 resnet18.params
-rw-rw-r-- 1 sean sean 36K Jun  4 15:09 resnet18.json
-rw-rw-r-- 1 sean sean 61K Jun  4 15:09 resnet18.tar


In [22]:
loaded_graph = open(graph_fn).read()
loaded_mod = tvm.runtime.load_module(mod_fn)
loaded_params = open(params_fn, "rb").read()

In [24]:
loaded_rt = tvm.contrib.graph_executor.create(loaded_graph, loaded_mod, ctx)
loaded_rt.load_params(loaded_params)
loaded_rt.run(data=tvm.nd.array(x))
loaded_scores = loaded_rt.get_output(0).asnumpy()[0]
np.testing.assert_allclose(loaded_scores, scores)