Skip to content

Commit

Permalink
Profiler benchmark fix (#47713)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #47713

Fix the import and also always use internal Timer

Test Plan: python benchmarks/profiler_benchmark/profiler_bench.py

Reviewed By: dzhulgakov

Differential Revision: D24873991

Pulled By: ilia-cher

fbshipit-source-id: 1c3950d7d289a4fb5bd7043ba2d842a35c263eaa
  • Loading branch information
Ilia Cherniavskii authored and facebook-github-bot committed Nov 13, 2020
1 parent 1afdcbf commit a97c7e2
Showing 1 changed file with 9 additions and 22 deletions.
31 changes: 9 additions & 22 deletions benchmarks/profiler_benchmark/profiler_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import timeit
import torch

from torch.utils._benchmark import Timer
from torch.utils.benchmark import Timer

PARALLEL_TASKS_NUM = 4
INTERNAL_ITER = None
Expand Down Expand Up @@ -37,8 +37,6 @@ def parallel_task(x):
parser.add_argument('--profiling_tensor_size', default=1, type=int)
parser.add_argument('--workload', default='loop', type=str)
parser.add_argument('--internal_iter', default=256, type=int)
parser.add_argument('--n', default=100, type=int)
parser.add_argument('--use_timer', action='store_true')
parser.add_argument('--timer_min_run_time', default=100, type=int)

args = parser.parse_args()
Expand All @@ -47,8 +45,8 @@ def parallel_task(x):
print("No CUDA available")
sys.exit()

print("Payload: {}; {} iterations, N = {}\n".format(
args.workload, args.internal_iter, args.n))
print("Payload: {}, {} iterations; timer min. runtime = {}\n".format(
args.workload, args.internal_iter, args.timer_min_run_time))
INTERNAL_ITER = args.internal_iter

for profiling_enabled in [False, True]:
Expand Down Expand Up @@ -90,20 +88,9 @@ def payload():
def payload():
return workload(input_x)

if args.use_timer:
t = Timer(
"payload()",
globals={"payload": payload},
timer=timeit.default_timer,
).blocked_autorange(min_run_time=args.timer_min_run_time)
print(t)
else:
runtimes = timeit.repeat(payload, repeat=args.n, number=1)
avg_time = statistics.mean(runtimes) * 1000.0
stddev_time = statistics.stdev(runtimes) * 1000.0
print("\tavg. time: {:.3f} ms, stddev: {:.3f} ms".format(
avg_time, stddev_time))
if args.workload == "loop":
print("\ttime per iteration: {:.3f} ms".format(
avg_time / args.internal_iter))
print()
t = Timer(
"payload()",
globals={"payload": payload},
timer=timeit.default_timer,
).blocked_autorange(min_run_time=args.timer_min_run_time)
print(t)

0 comments on commit a97c7e2

Please sign in to comment.