Skip to content

Commit

Permalink
[AutoTVM] Re-enable ref_input (apache#8113)
Browse files Browse the repository at this point in the history
* [AutoTVM] Re-enable ref_input

* add ref_input on measure_option

* add ref_input unittest

* fix: test reformat

* [autotvm] [ref-input] refine test and description

* [autotvm] [ref-input] revert arg on measure_option
  • Loading branch information
Tantalus13A98B5F authored and ylc committed Jan 13, 2022
1 parent cbdf253 commit d92bdb6
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 12 deletions.
53 changes: 41 additions & 12 deletions python/tvm/autotvm/measure/measure_methods.py
Expand Up @@ -32,6 +32,7 @@
import typing
from collections import namedtuple
from random import getrandbits
import warnings

import tvm._ffi
import tvm.ir.transform
Expand Down Expand Up @@ -235,13 +236,33 @@ def __init__(
self.number = number
self.repeat = repeat
self.min_repeat_ms = min_repeat_ms
self._ref_input = None

self.enable_cpu_cache_flush = enable_cpu_cache_flush
self.cooldown_interval = cooldown_interval
self.module_loader = module_loader

self.executor = LocalExecutor(timeout=timeout * (self.n_parallel + 1))

@property
def ref_input(self):
"""
Fixed input for tuning special operators, e.g., sparse operators
requiring indices as input.
"""
return self._ref_input

@ref_input.setter
def ref_input(self, val):
warnings.warn(
"You are specifying fixed input for tuning the operator. "
"Be sure your input always fits the operator. Some "
"operators may conduct layout transformation during tuning, "
"thus can lead to unexpected behaviors. ",
RuntimeWarning,
)
self._ref_input = val

def set_task(self, task):
self.task = task

Expand Down Expand Up @@ -308,6 +329,7 @@ def run(self, measure_inputs, build_results):
self.min_repeat_ms,
self.cooldown_interval,
remote_kwargs,
self.ref_input,
self.enable_cpu_cache_flush,
module_loader,
)
Expand Down Expand Up @@ -508,6 +530,7 @@ def run_through_rpc(
min_repeat_ms,
cooldown_interval,
remote_kwargs,
ref_input,
enable_cpu_cache_flush=False,
module_loader=None,
):
Expand Down Expand Up @@ -539,6 +562,8 @@ def run_through_rpc(
The cool down interval between two measurements
remote_kwargs: dict
Passed to module_loader(). Ultimately, keyword args to request_remote().
ref_input: List of np.ndarray
The reference input used for tuning. Empty for randomly filled input.
enable_cpu_cache_flush: bool
Whether to flush cache on CPU between repeated measurements.
Flushing cache can make the measured latency of one operator closer to
Expand Down Expand Up @@ -573,18 +598,22 @@ def run_through_rpc(
f_preproc=f_prepare,
)

try:
random_fill = remote.get_function("tvm.contrib.random.random_fill")
except AttributeError:
raise AttributeError(
"Please make sure USE_RANDOM is ON in the config.cmake " "on the remote devices"
)
args = [nd.empty(x[0], x[1], dev) for x in build_result.arg_info]
if "scatter" not in measure_input.task.name:
# the index tensor of scatter op cannot be randomly initialized
for arg in args:
random_fill(arg)
dev.sync()
if ref_input:
args = [nd.array(x, device=dev) for x in ref_input]
else:
try:
random_fill = remote.get_function("tvm.contrib.random.random_fill")
except AttributeError:
raise AttributeError(
"Please make sure USE_RANDOM is ON in the config.cmake "
"on the remote devices"
)
args = [nd.empty(x[0], x[1], dev) for x in build_result.arg_info]
if "scatter" not in measure_input.task.name:
# the index tensor of scatter op cannot be randomly initialized
for arg in args:
random_fill(arg)
dev.sync()

costs = time_f(*args).results

Expand Down
24 changes: 24 additions & 0 deletions tests/python/unittest/test_autotvm_measure.py
Expand Up @@ -26,6 +26,8 @@
from test_autotvm_common import DummyRunner, bad_matmul, get_sample_task
from tvm import autotvm
from tvm.autotvm.measure.measure import MeasureErrorNo, MeasureResult
from tvm.autotvm import measure
from inspect import Signature


def test_task_tuner_without_measurement():
Expand Down Expand Up @@ -60,8 +62,30 @@ def test_task_tuner_without_measurement_spawn():
p.join()


def test_task_runner_with_ref_input():
"""test runner ref_input without measurement"""
refinp = [np.random.rand(128, 128) for i in range(3)]
runner = measure.LocalRunner()
runner.ref_input = refinp

class DummyExecutor(measure.executor.Executor):
def __init__(self):
self.ran_dummy_executor = False

def submit(self, func, *args, **kwargs):
self.ran_dummy_executor = True
sig = Signature.from_callable(func)
assert sig.bind(*args, **kwargs).arguments["ref_input"] == refinp
return measure.local_executor.LocalFutureNoFork(None)

runner.executor = DummyExecutor()
runner.run([None], [None])
assert runner.executor.ran_dummy_executor


if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)

test_task_tuner_without_measurement()
test_task_tuner_without_measurement_spawn()
test_task_runner_with_ref_input()

0 comments on commit d92bdb6

Please sign in to comment.