In [11]:
import warnings
from PIL import Image
import numpy as np
from tvm.contrib.download import download_testdata
import tvm
from tvm import relay
from tvm.contrib import graph_executor
from tvm.contrib import graph_runtime
import tvm.autotvm as autotvm
from tvm.autotvm.tuner import XGBTuner
from tvm import autotvm
import tvm.contrib.graph_executor as runtime


import logging
import tvm
from tvm import relay, autotvm
from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
from tvm.autotvm.graph_tuner import DPTuner, PBQPTuner


## 載入ONNX模型
import numpy as np
import onnx
import os
import glob
#import onnx_backend as backend
from onnx import backend
from onnx import numpy_helper



warnings.filterwarnings('ignore')


## Get network function

In [12]:
def get_network(onnx_model, batch_size):
    input_name = "input"
    input_shape = (batch_size, 3, 224, 224)
    output_shape = (batch_size, 1000)
    shape_dict = {input_name: input_shape}
    try:
        mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)
    except Exception as e:
        print(e)
    return mod, params, input_name, input_shape, output_shape



## tuning option

In [13]:
# 得到autotvm的設定值tuning_option
def get_tuning_option(batch_size, target, log_file):
    # 如果在cpu上執行
    if "cpu" in target.keys:
        # 設定auto_tuning的參數
        tuning_option = {
            "log_filename": log_file,
            "tuner": "xgb",
            "n_trial": 1500,
            "early_stopping": 600,
            "use_transfer_learning": True,
            "measure_option": tvm.autotvm.measure.measure_option(
                builder=tvm.autotvm.LocalBuilder(timeout=10),
                runner=tvm.autotvm.LocalRunner(number=10, repeat=1, min_repeat_ms=1000),
                ),
            }
    else:
        # 設定auto_tuning的參數
        tuning_option = {
            "log_filename": log_file,
            "tuner": "xgb",
            "n_trial": 2000,
            "early_stopping": 600,
            "use_transfer_learning": True,
            "measure_option": autotvm.measure_option(
                builder=autotvm.LocalBuilder(timeout=10),
                runner=autotvm.LocalRunner(
                    number=20, repeat=3, timeout=4, min_repeat_ms=150
                ),
            ),
        }

    return tuning_option




## tune kernels

In [14]:
# tune_kernels
def tune_kernels(
    tasks,
    measure_option,
    tuner,
    n_trial,
    early_stopping,
    log_filename,
    use_transfer_learning,
):
    for i, tsk in enumerate(reversed(tasks)):
        prefix = "[Task %2d/%2d] " % (i + 1, len(tasks))
        
        # 創建要turn的方法，有RandomTuner,XGBTuner,GATuner,GridSearchTuner
        # 如果訓練到一半，有歷史資料可以使用就使用
        # tuner_obj.load_history(autotvm.record.load_from_file(log_filename))
        # 創建tuner物件
        # create tuner
        if tuner == "random" or n_trial >= len(tsk.config_space):
            tuner_obj = RandomTuner(tsk)
        elif tuner == "xgb" or tuner == "xgb-rank":
            tuner_obj = XGBTuner(tsk, loss_type="rank")
            # use history data to pre-train the cost model
            if use_transfer_learning:
                if os.path.isfile(log_filename):
                    tuner_obj.load_history(tvm.autotvm.record.load_from_file(log_filename))
        elif tuner == "ga":
            tuner_obj = GATuner(tsk, pop_size=100)
        elif tuner == "gridsearch":
            tuner_obj = GridSearchTuner(tsk)
        else:
            raise ValueError("Invalid tuner: " + tuner)
        
        # 開始做tune
        # do tuning
        tsk_trial = min(n_trial, len(tsk.config_space))
        tuner_obj.tune(
            n_trial=tsk_trial,
            early_stopping=early_stopping,
            measure_option=measure_option,
            callbacks=[
                tvm.autotvm.callback.progress_bar(tsk_trial, prefix=prefix),
                tvm.autotvm.callback.log_to_file(log_filename),
            ],
        )



## tune graph

In [15]:
def tune_graph(graph, input_name, input_shape, target, kernel_log, graph_log, use_DP=True):
    target_op = [
        relay.op.get("nn.conv2d"),
    ]
    
    Tuner = DPTuner if use_DP else PBQPTuner
    executor = Tuner(graph, {input_name: input_shape}, kernel_log, target_op, target)
    executor.benchmark_layout_transform(min_exec_num=2000)
    executor.run()
    executor.write_opt_sch2record_file(graph_log)

## autotvm

In [22]:
def autotvm(network,model_onnx,batch_size,ctx,target,device):
    # cpu
    if device =='llvm':
        kernel_log_file = log_file = "logs/autotvm/cpu/" + network + "_kernel_" + '.log'
        graph_log_file = log_file = "logs/autotvm/cpu/" + network + "_graph_" + '.log'
        
    else:
        # gpu
        kernel_log_file = log_file = "logs/autotvm/gpu/" + network + "_kernel_" + '.log'
        graph_log_file = log_file = "logs/autotvm/gpu/" + network + "_graph_" + '.log'
    logger = logging.getLogger()
    handler = logging.FileHandler(kernel_log_file)
    handler = logging.FileHandler(graph_log_file)
    mod, params, input_name, input_shape,output_shape = get_network(onnx_model, 1)
    
    ## 得到autotvm的設定值tuning_option
    tuning_opt = get_tuning_option(batch_size, target, kernel_log_file)
    #tuning_opt["n_trial"] = 1
    #tuning_opt["early_stopping"] = 1
    
    # 決定哪些算子要進行auto_tvm
    ops = [
        relay.op.get("nn.batch_matmul"),
        relay.op.get("nn.dense"),
        relay.op.get("nn.conv2d"),
    ]
    # 對IRmodeul差成task物件
    tasks = tvm.autotvm.task.extract_from_program(
        mod["main"], target=target, params=params, ops=ops
    )
    
    # 開始做autotvm
    # run tuning tasks
    tune_kernels(tasks, **tuning_opt)
    
    # run graph
    tune_graph(mod["main"], input_name, input_shape, target, kernel_log_file, graph_log_file)
    

In [23]:
# param
target = tvm.target.Target("llvm")
ctx = tvm.cpu(0)
dtype = "float32"

In [None]:
if __name__ == "__main__":
    device = 'llvm'
    for model_name in ["vgg16","resnet50","mobilenet_v2"]:
        print(model_name)
        # 載入ONNX模型
        model_path = "./models/" + "onnx/" + model_name +".onnx"
        onnx_model = onnx.load(model_path)
        autotvm(model_name,onnx_model,1,ctx,target,device)


## benchmark

In [None]:
def benchmark_autotvm(onnx_model,batch_size,log_file,ctx,repeat):
    history_best_context = tvm.autotvm.apply_history_best(log_file)
    with history_best_context:
        with tvm.transform.PassContext(opt_level=3):
            mod, params, input_name, input_shape, output_shape = get_network(onnx_model, batch_size)
            lib = relay.build(mod, target=target, params=params)
        module = runtime.GraphModule(lib["default"](ctx))
        # Feed input data
        data = np.random.uniform(size=input_shape)
        module.set_input(input_name, data)
    # Evaluate
    ftimer = module.module.time_evaluator("run", ctx, min_repeat_ms=500, repeat=repeat)
    return np.array(ftimer().results)

    
    

In [None]:
if __name__ == "__main__":
    device = 'llvm'
    result_messages = []
    for model_name in ["vgg16","resnet50","mobilenet_v2"]:
    #for model_name in ["vgg16"]:
        print(model_name)
        # 載入ONNX模型
        model_path = "./models/" + "onnx/" + model_name +".onnx"
        onnx_model = onnx.load(model_path)
        
        
        # cpu
        if device =='llvm':
            kernel_log_file = log_file = "logs/autotvm/cpu/" + model_name + "_kernel_" + '.log'
        else:
            # gpu
            kernel_log_file = log_file = "logs/autotvm/gpu/" + model_name + "_kernel_" + '.log'
        
        res = benchmark_autotvm(onnx_model,1,kernel_log_file,ctx,3)
        
        # convert to millisecond
        res *= 1000          
        message = "%-18s %-12s %-19s (%s)" % (
                model_name,
                1,
                "%.2f ms" % np.mean(res_kernel),
                "%.2f ms" % np.std(res_kernel),
                )
        result_messages.append(message)
        
        
   # Print result
    print("cpu")
    print("-------------------------------------------------------------")
    print(
        "%-18s %-12s %-20s"
        % ("Network Name", "Batch size", "Mean Inference Time (std dev)")
    )
    print("-------------------------------------------------------------")
    for line in result_messages:
        print(line)
    print("-------------------------------------------------------------")