In [1]:
import warnings
warnings.filterwarnings("ignore")

import sys
sys.path.append('multiclass')
sys.path.append('salientnet')
sys.path.append('sceneRecog')

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
import numpy as np
from typing import List
from copy import deepcopy
import torchvision.models as models

import multiclass.model_utils as multi_model
import salientnet.model as salient_model
import multiclass.dataloader as multi_dl
import salientnet.dataloader as salient_dl
import sceneRecog.dataloader as recog_dl
from test_func import test_top1, test_top5, test_multiclass, test_salient, test_multi_result

from metamorph.compiler.compiler import MetaMorph
from metamorph.graph.abs_graph import Graph
from metamorph.graph.cmp_graph import ComputeGraph
from metamorph.metrics.testing_utils import test_accuracy, test_latency
from metamorph.compiler.policy import SimulatedAnnealingPolicy
from metamorph.data.dataloader import DatasetSampler

DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
SAMPLE_INPUT = torch.rand(1,3,224,224).to(DEVICE)
kwargs = {'num_workers': 4, 'pin_memory': True}

In [2]:
if_print_ori_model = False

if_test_scene = False
if_test_ori_accuracy = False
if_test_ori_latency = False

test_random_connet = False

In [3]:
# scene recognition + places365
rec_loader = recog_dl.load_data(kwargs=kwargs)
rec_train_indices, rec_test_indices = train_test_split(list(range(len(rec_loader.dataset.targets))), 
                                                        test_size=0.2, stratify=rec_loader.dataset.targets, random_state=10)
rec_train_data_sampler = torch.utils.data.Subset(rec_loader.dataset, rec_train_indices)
rec_test_data_sampler = torch.utils.data.Subset(rec_loader.dataset, rec_test_indices)
rec_test_loader = torch.utils.data.DataLoader(rec_test_data_sampler, batch_size=128, shuffle=False, **kwargs)
print(len(rec_train_data_sampler),len(rec_test_data_sampler))

recNet = models.__dict__['resnet18'](num_classes=365)
recNet.load_state_dict(torch.load('sceneRecog/sceneNet.model', map_location=DEVICE))
recNet = recNet.to(DEVICE).eval()

if if_print_ori_model:
    print(recNet)

if if_test_scene:
    acc_top1 = test_top1(recNet, rec_test_loader, DEVICE)
    acc_top5 = test_top5(recNet, rec_test_loader, DEVICE)

29200 7300


In [4]:
# multi-label classification + VOC2007
multi_train_loader, multi_test_loader = multi_dl.load_data(kwargs=kwargs)
print(len(multi_train_loader.dataset), len(multi_test_loader.dataset))

multiNet = multi_model.get_resnet34_model_with_custom_head()
multiNet.load_state_dict(torch.load('multiclass/objectNet.model', map_location=DEVICE))
multiNet = multiNet.to(DEVICE).eval()

if if_print_ori_model:
    print(multiNet)

if if_test_scene:
    multi_mAP = test_multiclass(multiNet, multi_test_loader, DEVICE)

5011 4952


In [5]:
# salient-object-Subitizing + SOS dataset
salient_train_loader, salient_test_loader = salient_dl.load_data(kwargs=kwargs)
print(len(salient_train_loader.dataset), len(salient_test_loader.dataset))

salientNet = salient_model.get_resnet18_model_with_custom_head()
salientNet.load_state_dict(torch.load('salientnet/salientNet.model', map_location=DEVICE))
salientNet = salientNet.to(DEVICE).eval()

if if_print_ori_model:
    print(salientNet)

if if_test_scene:
    salient_mAP = test_salient(salientNet, salient_test_loader, DEVICE)

10966 2741


In [6]:
def parse_model(model: nn.Module) -> List[nn.Module]:
    res = []
    for layer in model.children():
        if type(layer) in MetaMorph.BASIC_OPS:
            res.append(layer)
        elif isinstance(layer, nn.Sequential):
            res.extend(parse_model(layer))
        else:
            res.append(layer)
    return res

parse_models = [parse_model(recNet), parse_model(multiNet), parse_model(salientNet)]
absGraph = Graph(SAMPLE_INPUT, parse_models, DEVICE)
cmpGraph = ComputeGraph(absGraph, parse_models, DEVICE)

if if_print_ori_model:
    print(absGraph)

if if_test_ori_latency: 
    ori_latency = test_latency(cmpGraph, SAMPLE_INPUT)

if test_random_connet:
    n_trial = 30
    graph1 = deepcopy(absGraph)
    graph1.random_connect(n_trial=n_trial, verbose=True)
    graph1.build_mergeable_nodes()
    graph1.random_connect(n_trial=n_trial, verbose=True)
    graph1.build_mergeable_nodes()
    graph1.random_connect(n_trial=n_trial, verbose=True)
    graph1.build_mergeable_nodes()
    graph1.random_connect(n_trial=n_trial, verbose=True)
    graph1.build_mergeable_nodes()
    graph1.random_connect(n_trial=n_trial, verbose=True)
    graph1.build_mergeable_nodes()
    graph1.random_connect(n_trial=n_trial, verbose=True)
    print(graph1)
    cmpGraph = ComputeGraph(graph1, parse_models, DEVICE)
    cmpGraph.freeze_all_node()
    latency1 = test_latency(cmpGraph, SAMPLE_INPUT)

Encountered an error. Try to insert a nn.Flatten layer ... Success!


In [8]:
# list of models
MODELS = [recNet, multiNet, salientNet]

multi_train_data_sampler = multi_dl.load_data_sampler()
salient_train_data_sampler = salient_dl.load_data_sampler()

# dataloader
ds_samples = DatasetSampler(
        [rec_train_data_sampler, multi_train_data_sampler, salient_train_data_sampler],
        MODELS,
        DEVICE,
        [10000, 5000, 5000]
    )
samples_dataloader = torch.utils.data.DataLoader(ds_samples, batch_size=128, shuffle=True, **kwargs)
print(len(samples_dataloader.dataset))
test_loader_list = [rec_test_loader, multi_test_loader, salient_test_loader]

if if_test_ori_accuracy:
    print("Task Accuracy of original graph: ")
    test_accuracy(cmpGraph, test_multi_result, test_loader_list, DEVICE)

20
Task Accuracy of original graph: 
net1 Result: 53.479450941085815%   net2 Result: 88.34654146141654%   net3 Result: 70.10493360305297%   



In [None]:
import time

# compiler settings
optimizer = torch.optim.Adam
compiler = MetaMorph(
    models=MODELS, optimizer=optimizer, optimizer_lr=0.0001,
    input_size=SAMPLE_INPUT.shape, train_loader=samples_dataloader, test_loader=test_loader_list,
    f_accuracy=test_multi_result, fine_tune_epochs=30, max_epoch=1, device=DEVICE
)
policy = SimulatedAnnealingPolicy(
    base_graph=compiler.original_graph,
    models=compiler.models,
    f_finetune=compiler.fine_tune, f_latency=compiler.f_latency, f_accuracy=compiler.f_accuracy,
    accuracy_tolerence = 0.02,
    device=compiler.device
)

# test the compiling time
torch.cuda.synchronize()
start = time.time()
best_result = compiler.optimize(policy)
torch.cuda.synchronize()
end = time.time()

print('---------------------------- Evaluation ---------------------------------')
print("Optimal Graph: \n", best_result.graph)
print("Optimal Latency: ", best_result.latency)
print("Compiling time: ", end - start)

cmpGraph_opt = best_result.cmp_graph
print("Task Accuracy of optimized graph: ")
test_accuracy(cmpGraph_opt, test_multi_result, test_loader_list, DEVICE)

In [7]:
if_tvm_tune = False
use_tvm_tuned = False
log_file = "tvm.log"

if_test_tvm_accuracy = True
if_test_tvm_latency = True

In [11]:
# test TVM
import tvm, time
from tvm import relay, autotvm
from tvm.contrib import graph_executor
from tvm_build import get_network, tune_tasks, tune_and_evaluate
torch.manual_seed(0)

traced_module = torch.jit.trace(cmpGraph, SAMPLE_INPUT).eval()
# print(traced_module)

input_shape = (1, 3, 224, 224)
input_data = torch.randn(input_shape)
input_name = 'input0'
shape_list = [(input_name, input_shape)]

mod, params = get_network(traced_module, shape_list)
# print(mod)

# # running TVM to compile model
if if_tvm_tune:
    tune_and_evaluate(mod, params, input_shape, log_name=log_file)

target = tvm.target.cuda()
if use_tvm_tuned:
    with autotvm.apply_history_best(log_file):
        with tvm.transform.PassContext(opt_level=3):
            lib = lib = relay.build(mod, target=target, params=params)
else:
    with tvm.transform.PassContext(opt_level=3):
        lib = relay.build(mod, target=target, params=params)
dev = tvm.device(str(target), 0)
tvm_model = graph_executor.GraphModule(lib["default"](dev))  

In [12]:
# run TVM model
tvm_model.set_input(input_name, input_data)

tvm_model.run()

if if_test_tvm_accuracy:
    tvm_out_0, tvm_out_1, tvm_out_2 = tvm_model.get_output(0), tvm_model.get_output(1), tvm_model.get_output(2)
    tvm_out_0 = torch.tensor(tvm_out_0.numpy()).to(DEVICE)
    tvm_out_1 = torch.tensor(tvm_out_1.numpy()).to(DEVICE)
    tvm_out_2 = torch.tensor(tvm_out_2.numpy()).to(DEVICE)
    ori_out_0, ori_out_1, ori_out_2 = cmpGraph(input_data.to(DEVICE))
    print(tvm_out_2)
    print(ori_out_2)

if if_test_tvm_latency:
    import timeit
    timing_number = 30
    timing_repeat = 30
    optimized = (
            np.array(timeit.Timer(lambda:tvm_model.run()).repeat(repeat=timing_repeat, number=timing_number))
            * 1000
            / timing_number
        )
    optimized = {"mean": np.mean(optimized), "median": np.median(optimized), "std":np.std(optimized)}

    print("optimized: %s" % (optimized))


tensor([[-1.8443, -6.8599,  2.1578,  9.5740, 12.9538]], device='cuda:0')
tensor([[-1.8459, -6.8583,  2.1531,  9.5788, 12.9610]], device='cuda:0')
optimized: {'mean': 0.7961396691906784, 'median': 0.7933063975845774, 'std': 0.06722413711788505}
