In [1]:
import onnx
from tvm.contrib.download import download_testdata
from PIL import Image
import numpy as np
from tvm import relax
from tvm.script import tir as T
from tvm.script import relax as R   
import tvm
from tvm.contrib import graph_executor

In [2]:
import torch
import torchvision

In [3]:
from torch import fx

In [None]:
model_name = "resnet50"
model = getattr(torchvision.models, model_name)(pretrained=True)
model = model.eval()

# We grab the TorchScripted model via tracing
input_shape = [1, 3, 224, 224]
input_data = torch.randn(input_shape)
scripted_model = torch.jit.trace(model, input_data).eval()

from PIL import Image

img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"
img_path = download_testdata(img_url, "cat.png", module="data")
print(img_path)
img = Image.open(img_path).resize((224, 224))

# Preprocess the image and convert to tensor
from torchvision import transforms


my_preprocess = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)
img = my_preprocess(img)
img = np.expand_dims(img, 0)

In [5]:
from tvm.relax.frontend.torch import from_fx

In [None]:
input_info = [([1,3,244,244], "float32")]
with torch.no_grad():
    fx_module = fx.symbolic_trace(model)
    mod_from_torch = from_fx(fx_module, input_info, keep_params_as_input=True)

In [7]:
mod_from_torch, params_from_torch = relax.frontend.detach_params(mod_from_torch)

In [8]:
from tvm import meta_schedule as ms

In [9]:
mod = relax.get_pipeline("zero")(mod_from_torch)

In [10]:
l = list(mod.get_global_vars())
mod_list =[]
for i in range(len(l)):
    mod_list.append(str(mod.get_global_vars()[i]))
mod_list = list(map(lambda x : x.split('"')[-2], mod_list))
mod_list.remove("main")

In [None]:
mod_list_ins = list(map(lambda x : tvm.IRModule.from_expr(mod[x].with_attr("global_symbol", "main")) , mod_list))

In [11]:
nd_params = {k : tvm.nd.array(v.detach().numpy()) for k,v in model.named_parameters()}

In [12]:
mod2 = mod

In [None]:
database = ms.tune_tir(
        mod=mod2,
        target="llvm --num-cores=8",
        max_trials_global=600,
        num_trials_per_iter=10,
        work_dir="./tune_tmp",
        runner = ms.runner.LocalRunner(
          evaluator_config=ms.runner.EvaluatorConfig(),
          alloc_repeat=1,
        ),
        cost_model=ms.cost_model.XGBModel(  
                extractor=ms.feature_extractor.PerStoreFeature(),
                adaptive_training=True,
        ),
        strategy=ms.search_strategy.EvolutionarySearch(),
)

In [20]:
MyMod2 = relax.transform.BindParams("main", nd_params)(mod2)
for i in range(len(mod_list)):
    mod_str = mod_list[i]
    sch = ms.tir_integration.compile_tir(database, mod2[mod_str], "llvm --num-cores=8")
    if(sch == None):
        print(1)
        continue
    new_func = sch.mod["main"].with_attr("global_symbol", mod_str)
    gv = MyMod2.get_global_var(mod_str)
    MyMod2.update_func(gv, new_func)

In [23]:
data_nd = np.random.rand(1,3,244,244)
data_nd = data_nd.astype(np.float32)

data_nd = tvm.nd.array(data_nd)

In [24]:
dev = tvm.device("llvm  --num-cores=8")

In [25]:
ex = relax.build(MyMod2, target="llvm  --num-cores=8")
vm = relax.VirtualMachine(ex, dev)


In [26]:
nd_res = vm["main"](data_nd)

pred_kind = np.argmax(nd_res.numpy(), axis=1)
print("MyModuleWithParams2 Prediction:",pred_kind)

MyModuleWithParams2 Prediction: [463]


In [36]:
ftimer = vm.module.time_evaluator("main", dev, number=10)

print("MyModuleWithParams time-cost: %g ms" % (ftimer(data_nd).mean * 1000))

MyModuleWithParams time-cost: 79.6973 ms


In [33]:
MyMod1 = relax.transform.BindParams("main", nd_params)(mod)
ex2 = relax.build(MyMod1, target="llvm  --num-cores=8")
vm2 = relax.VirtualMachine(ex2, dev)

In [34]:
nd_res = vm2["main"](data_nd)

pred_kind = np.argmax(nd_res.numpy(), axis=1)
print("MyModuleWithParams2 Prediction:",pred_kind)

MyModuleWithParams2 Prediction: [463]


In [35]:
ftimer = vm2.module.time_evaluator("main", dev, number=10)

print("MyModuleWithParams time-cost: %g ms" % (ftimer(data_nd).mean * 1000))

MyModuleWithParams time-cost: 4493.13 ms
