In [None]:
import torch
import torchvision
from torch.utils.bundled_inputs import (
    augment_model_with_bundled_inputs)
from torch.utils.mobile_optimizer import optimize_for_mobile

import json
from device_profiling import check_device, run_on_device, parse_profiler_output, DEFAULT_PROF_CONFIG

## Make a model in the form of TorchScript

In [None]:
model = torchvision.models.resnet18(pretrained=True)
model.eval()
example = torch.zeros(1, 3, 224, 224)
script_module = torch.jit.trace(model, example)
script_module_optimized = optimize_for_mobile(script_module)
augment_model_with_bundled_inputs(script_module_optimized, [(example,)])

torch.jit.save(script_module_optimized, "./resnet18.pt")

## Set up profiling config and run profiler

In [None]:
profiling_config=DEFAULT_PROF_CONFIG
profiling_config['vulkan'] = False
profiling_config['caffe2_threadpool_android_cap'] = num_threads
profiling_config['caffe2_threadpool_force_inline'] = True
profiling_config['iter'] = 100
profiling_config['warmup'] = 30
profiling_config['use_bundled_input'] = 0
model_filename = './resnet18.pt'

raw_out = run_on_device(
        model_filename,
        prof_config=profiling_config,
        verbose=True)
        
res = parse_profiler_output(raw_out, is_file=False)

In [None]:
# inspect result
print(json.dumps(res, indent=2))