Skip to content

Commit

Permalink
[TVM][AutoTVM] cast filepath arguments to string (apache#3968)
Browse files Browse the repository at this point in the history
  • Loading branch information
cchung100m authored and wweic committed Sep 30, 2019
1 parent db0cc4a commit 7ded8e0
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 12 deletions.
8 changes: 8 additions & 0 deletions python/tvm/autotvm/task/dispatcher.py
Expand Up @@ -41,6 +41,7 @@

logger = logging.getLogger('autotvm')


class DispatchContext(object):
"""
Base class of dispatch context.
Expand Down Expand Up @@ -281,8 +282,12 @@ def load(self, records):
Each row of this file is an encoded record pair.
Otherwise, it is an iterator.
"""
from pathlib import Path
from ..record import load_from_file

if isinstance(records, Path):
records = str(records)

if isinstance(records, str):
records = load_from_file(records)
if not records:
Expand Down Expand Up @@ -404,8 +409,10 @@ def update(self, target, workload, cfg):
key = (str(target), workload)
self.memory[key] = cfg


DispatchContext.current = FallbackContext()


def clear_fallback_cache(target, workload):
"""Clear fallback cache. Pass the same argument as _query_inside to this function
to clean the cache.
Expand All @@ -426,6 +433,7 @@ def clear_fallback_cache(target, workload):
context = context._old_ctx
context.clear_cache(target, workload)


class ApplyGraphBest(DispatchContext):
"""Load the graph level tuning optimal schedules.
Expand Down
6 changes: 6 additions & 0 deletions python/tvm/autotvm/tuner/callback.py
Expand Up @@ -26,6 +26,7 @@

logger = logging.getLogger('autotvm')


def log_to_file(file_out, protocol='json'):
"""Log the tuning records into file.
The rows of the log are stored in the format of autotvm.record.encode.
Expand All @@ -51,6 +52,11 @@ def _callback(_, inputs, results):
else:
for inp, result in zip(inputs, results):
file_out.write(record.encode(inp, result, protocol) + "\n")

from pathlib import Path
if isinstance(file_out, Path):
file_out = str(file_out)

return _callback


Expand Down
15 changes: 3 additions & 12 deletions python/tvm/module.py
Expand Up @@ -112,18 +112,9 @@ def export_library(self,
kwargs : dict, optional
Additional arguments passed to fcompile
"""
if self.is_empty():
logging.info("The lib generated by the NNVM compiler does not contain optimized "
"functions for any operators. This usually happens when an external "
"accelerator, e.g. TensorRT, is employed along with TVM to compile "
"the model, and all the operators in the model are supported by the "
"external accelerator at runtime. Therefore, "
"the NNVM compiler skipped optimizing them at the compile time.")
if os.path.isfile(file_name):
logging.warning("Lib file %s exists, and will be overwritten by the newly created"
" lib with the same name.", file_name)
open(file_name, 'w').close()
return
from pathlib import Path
if isinstance(file_name, Path):
file_name = str(file_name)

if self.type_key == "stackvm":
if not file_name.endswith(".stackvm"):
Expand Down

0 comments on commit 7ded8e0

Please sign in to comment.