diff --git a/python/tvm/autotvm/task/dispatcher.py b/python/tvm/autotvm/task/dispatcher.py index 8cd87b058801..f6ab10064b4a 100644 --- a/python/tvm/autotvm/task/dispatcher.py +++ b/python/tvm/autotvm/task/dispatcher.py @@ -41,6 +41,7 @@ logger = logging.getLogger('autotvm') + class DispatchContext(object): """ Base class of dispatch context. @@ -281,8 +282,12 @@ def load(self, records): Each row of this file is an encoded record pair. Otherwise, it is an iterator. """ + from pathlib import Path from ..record import load_from_file + if isinstance(records, Path): + records = str(records) + if isinstance(records, str): records = load_from_file(records) if not records: @@ -404,8 +409,10 @@ def update(self, target, workload, cfg): key = (str(target), workload) self.memory[key] = cfg + DispatchContext.current = FallbackContext() + def clear_fallback_cache(target, workload): """Clear fallback cache. Pass the same argument as _query_inside to this function to clean the cache. @@ -426,6 +433,7 @@ def clear_fallback_cache(target, workload): context = context._old_ctx context.clear_cache(target, workload) + class ApplyGraphBest(DispatchContext): """Load the graph level tuning optimal schedules. diff --git a/python/tvm/autotvm/tuner/callback.py b/python/tvm/autotvm/tuner/callback.py index a2a2f32bded6..154406b9b2ed 100644 --- a/python/tvm/autotvm/tuner/callback.py +++ b/python/tvm/autotvm/tuner/callback.py @@ -26,6 +26,7 @@ logger = logging.getLogger('autotvm') + def log_to_file(file_out, protocol='json'): """Log the tuning records into file. The rows of the log are stored in the format of autotvm.record.encode. @@ -51,6 +52,11 @@ def _callback(_, inputs, results): else: for inp, result in zip(inputs, results): file_out.write(record.encode(inp, result, protocol) + "\n") + + from pathlib import Path + if isinstance(file_out, Path): + file_out = str(file_out) + return _callback diff --git a/python/tvm/module.py b/python/tvm/module.py index fea5a850d9b3..0177bce66a33 100644 --- a/python/tvm/module.py +++ b/python/tvm/module.py @@ -112,18 +112,9 @@ def export_library(self, kwargs : dict, optional Additional arguments passed to fcompile """ - if self.is_empty(): - logging.info("The lib generated by the NNVM compiler does not contain optimized " - "functions for any operators. This usually happens when an external " - "accelerator, e.g. TensorRT, is employed along with TVM to compile " - "the model, and all the operators in the model are supported by the " - "external accelerator at runtime. Therefore, " - "the NNVM compiler skipped optimizing them at the compile time.") - if os.path.isfile(file_name): - logging.warning("Lib file %s exists, and will be overwritten by the newly created" - " lib with the same name.", file_name) - open(file_name, 'w').close() - return + from pathlib import Path + if isinstance(file_name, Path): + file_name = str(file_name) if self.type_key == "stackvm": if not file_name.endswith(".stackvm"):