Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions benchmarks/framework_overhead_benchmark/SimpleAddModule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import torch
from utils import NUM_PT_LOOP_ITERS

def add_tensors_loop(x, y):
z = torch.add(x, y)
for i in range(NUM_PT_LOOP_ITERS):
z = torch.add(z, x)
return z

class SimpleAddModule(torch.nn.Module):
def __init__(self, add_op):
super(SimpleAddModule, self).__init__()
self.add_op = add_op

def forward(self, x, y):
return self.add_op(x, y)
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from __future__ import absolute_import, division, print_function, unicode_literals
from utils import ms_to_us, benchmark_module, BenchmarkConfig, ModuleConfig
import argparse

from SimpleAddModule import SimpleAddModule, add_tensors_loop
from pt_wrapper_module import WrapperModule

""" Framework overhead benchmark script.
Benchmark framework overhead.
Currently supported ops: add.
As of now runs only forward pass.
Supports both graph mode and eager mode. In graph mode the module is traced via JIT tracing.
Debug option prints the traced graph is graph_mode is enabled.
Graph can be saved via save option. Saved in the directory where benchmark is run.
Example build/run:
buck run @mode/opt <path-to-framework_overhead_benchmark>:framework_overhead_benchmark --
--add_op --graph_mode --eager_mode (Runs both graph mode and eager mode)
buck run @mode/opt <path-to-framework_overhead_benchmark>:framework_overhead_benchmark --
--add_op --graph_mode (Runs only graph mode)
"""

SUPPORTED_OPS = {"add_op"}

def parse_op_args(op):
op_list = ops.split(",")

def print_results(result):
print("===================================")
for key, value in result.items():
print("{}, latency per iter (us):{}".format(key, ms_to_us(value)))
print("===================================")

def benchmark_simple_fn(args, config, module_config, module_type, result):
""" Benchmarks a PyTorch traceable function specified in the config.
Instantiates a wrapper object that wraps the object of module_type and runs the forward
method using benchmark_module.
Args:
config: contains number of warmup and benchmark iterations.
module_config: module_config which contains op, number of parameters that op takes
and wether graph mode is enabled or not.
module_type: Type of the module to be wrapped. e.g. SimpleAddModule for add op.
result: dictionary instance to be populated with the benchmark result (latency per iter).
"""
print("Benchmarking {}".format(module_type.__name__))
f_name = module_config.pt_fn.__name__ + ":Num Operands=" + str(module_config.num_params)
graph_mode_str = "Graph mode" + ":" + str(module_config.graph_mode)
result_key = ','.join((f_name, graph_mode_str))
module = WrapperModule(module_type, module_config, args.debug, args.save)
latency_per_iter_ms = benchmark_module(config, module)
result[result_key] = latency_per_iter_ms

def main():
parser = argparse.ArgumentParser()
parser.add_argument("--op", default="add_op", dest="op", type=str)
parser.add_argument("--debug", default=False, dest="debug", action="store_true")
parser.add_argument("--save", default=False, dest="save", action="store_true")
parser.add_argument("--eager_mode", default=False, dest="eager_mode", action="store_true")
parser.add_argument("--num_warmup_iters", type=int, default=100)
parser.add_argument("--num_iters", type=int, default=1000)
args = parser.parse_args()

if args.op not in SUPPORTED_OPS:
print("Op {} is not supported: Supported ops are:{}".format(args.op, SUPPORTED_OPS))
return

num_warmup_iters = args.num_warmup_iters
num_iters = args.num_iters
config = BenchmarkConfig(num_warmup_iters, num_iters)
graph_mode = True
if args.eager_mode:
graph_mode = False
result = {}
if args.op == "add_op":
num_params = 2
module_config = ModuleConfig(add_tensors_loop, num_params, graph_mode)
benchmark_simple_fn(args, config, module_config, SimpleAddModule, result)
print_results(result)

if __name__ == "__main__":
main()
44 changes: 44 additions & 0 deletions benchmarks/framework_overhead_benchmark/pt_wrapper_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import torch

class WrapperModule(object):
""" Wraps the instance of wrapped_type.
For graph_mode traces the instance of wrapped_type.
Randomaly initializes num_params tensors with single float element.
Args:
wrapped_type:
- Object type to be wrapped.
Expects the wrapped_type to:
- be constructed with pt_fn specified in module_config.
- provide forward method that takes module_config.num_params args.
module_config:
- Specified pt_fn to construct wrapped_type with, whether graph_mode
is enabled, and number of parameters wrapped_type's forward method
takes.
debug:
- Whether debug mode is enabled.
save:
- In graph mode, whether graph is to be saved.
"""
def __init__(self, wrapped_type, module_config, debug, save=False):
pt_fn = module_config.pt_fn
self.module = wrapped_type(pt_fn)
self.tensor_inputs = []
self.module_name = wrapped_type.__name__
for _ in range(module_config.num_params):
self.tensor_inputs.append(torch.randn(1))
if module_config.graph_mode:
self.module = torch.jit.trace(self.module, self.tensor_inputs)
if save:
file_name = self.module_name + "_" + pt_fn.__name__ + ".pt"
torch.jit.save(self.module, file_name)
print("Generated graph is saved in {}".format(file_name))
print("Benchmarking module {} with fn {}: Graph mode:{}".format(self.module_name, pt_fn.__name__, module_config.graph_mode))
if (debug and isinstance(self.module, torch.jit.ScriptModule)):
print(self.module.graph)
print(self.module.code)

def forward(self, niters):
with torch.no_grad():
for _ in range(niters):
self.module.forward(*self.tensor_inputs)
25 changes: 25 additions & 0 deletions benchmarks/framework_overhead_benchmark/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import time
from collections import namedtuple

NUM_PT_LOOP_ITERS = 1000
BenchmarkConfig = namedtuple('BenchmarkConfig', 'num_warmup_iters num_iters')
ModuleConfig = namedtuple('ModuleConfig', 'pt_fn num_params graph_mode')

def ms_to_us(time_ms):
return (time_ms * 1e3)

def secs_to_us(time_s):
return (time_s * 1e6)

def secs_to_ms(time_s):
return (time_s * 1e3)

def benchmark_module(config, module):
module.forward(config.num_warmup_iters)
print("Running module for {} iterations".format(config.num_iters))
start = time.time()
module.forward(config.num_iters)
end = time.time()
time_elapsed_s = (end - start)
return (secs_to_ms(time_elapsed_s) / config.num_iters / NUM_PT_LOOP_ITERS)