Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,31 @@ def test_timer(self):
median = timer.blocked_autorange(min_run_time=0.1).median
self.assertIsInstance(median, float)

def test_adaptive_timer(self):
# Validate both on different sizes validate against blocked_autorange
# This looks for relative differences btetween orders of magnitude to
# provide a stable/portable test which is somewhat informative.
timer = benchmark_utils.Timer(
stmt="torch.sum(torch.ones((10,10)))",
)
small = timer.adaptive_autorange(min_run_time=0.1)
self.assertFalse(small.has_warnings())
timer = benchmark_utils.Timer(
stmt="torch.sum(torch.ones((500,100)))",
)
medium = timer.adaptive_autorange(min_run_time=0.1)
self.assertFalse(medium.has_warnings())
blocked_medium = timer.blocked_autorange(min_run_time=0.1)
self.assertLess(small, medium)
self.assertLess(small, blocked_medium)
timer = benchmark_utils.Timer(
stmt="torch.sum(torch.ones((1000,1000)))",
)
large = timer.adaptive_autorange(min_run_time=0.1).median
self.assertFalse(large.has_warnings())
self.assertLess(medium, large)
self.assertLess(blocked_medium, large)

def test_compare(self):
compare = benchmark_utils.Compare([
benchmark_utils.Timer(
Expand Down
5 changes: 4 additions & 1 deletion torch/utils/_benchmark/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def __getstate__(self):
def __setstate__(self, state: Dict[str, Any]):
self.__init__(**state) # type: ignore

def meets_confidence(self, threshold=_IQR_WARN_THRESHOLD):
return self._iqr / self._median < threshold

def _populate_warnings(self):
warnings, rel_iqr = [], self._iqr / self._median * 100

Expand All @@ -87,7 +90,7 @@ def add_warning(msg):

if self._iqr / self._median > _IQR_GROSS_WARN_THRESHOLD:
add_warning("This suggests significant environmental influence.")
elif self._iqr / self._median > _IQR_WARN_THRESHOLD:
elif not self.meets_confidence():
add_warning("This could indicate system fluctuation.")
return warnings

Expand Down
52 changes: 42 additions & 10 deletions torch/utils/_benchmark/utils/timer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Timer class based on the timeit.Timer class, but torch aware."""

import time
import timeit
from typing import List, Optional

Expand Down Expand Up @@ -75,7 +76,37 @@ def repeat(self, repeat=-1, number=-1):
def autorange(self, callback=None):
raise NotImplementedError("See `Timer.blocked_autorange.`")

def blocked_autorange(self, callback=None, min_run_time=0.2):
def _threaded_measurement_loop(self, time_hook, stop_hook, min_run_time: float, max_run_time: float, callback=None):
total_time = 0.0
can_stop = False
times = []
with common.set_torch_threads(self._num_threads):
while (total_time < min_run_time) or (not can_stop):
time_spent = time_hook()
times.append(time_spent)
total_time += time_spent
can_stop = stop_hook(times)
if callback:
callback(number, time_taken)
if max_run_time and total_time > max_run_time:
break
return times

def adaptive_autorange(self, threshold=0.1, max_run_time=10, callback=None, min_run_time=0.01):
number = self._estimate_block_size(min_run_time=0.05)

def time_hook():
return self._timer.timeit(number)
def stop_hook(times):
if len(times) > 3:
measure = self._construct_measurement(number, times)
return measure.meets_confidence(threshold=threshold)
return False
times = self._threaded_measurement_loop(time_hook, stop_hook, min_run_time, max_run_time, callback=callback)
measure = self._construct_measurement(number, times)
return measure

def _estimate_block_size(self, min_run_time):
with common.set_torch_threads(self._num_threads):
# Estimate the block size needed for measurement to be negligible
# compared to the inner loop. This also serves as a warmup.
Expand All @@ -89,15 +120,16 @@ def blocked_autorange(self, callback=None, min_run_time=0.2):
if time_taken > min_run_time:
break
number *= 10
return number

total_time = 0.0
times = []
def blocked_autorange(self, callback=None, min_run_time=0.2):
number = self._estimate_block_size(min_run_time)
def time_hook():
return self._timer.timeit(number)
def stop_hook():
return True

while total_time < min_run_time:
time_taken = self._timer.timeit(number)
total_time += time_taken
if callback:
callback(number, time_taken)
times.append(time_taken)
times = self._threaded_measurement_loop(loop, min_run_time=min_run_time, max_run_time=None, callback=callback)

return self._construct_measurement(number_per_run=number, times=times)

return self._construct_measurement(number_per_run=number, times=times)