Skip to content

Commit

Permalink
[AutoTVM] Group-Conv2D support, enabling dynamic recompile in simulat…
Browse files Browse the repository at this point in the history
…or (apache#28)

* simulator fix in autotvm build module: enabling dynamic recompile

* model string should be more specific for now

* avoiding errors when running simulator locally

* proper handling of xlnk driver stack reset

* typo fix

* tunable topi group conv2d operator
  • Loading branch information
tmoreau89 committed Jan 2, 2019
1 parent bb6b024 commit 60e781e
Show file tree
Hide file tree
Showing 9 changed files with 94 additions and 185 deletions.
4 changes: 2 additions & 2 deletions topi/python/topi/nn/pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput"):
n = len(data.shape)
pad_after = pad_after if pad_after else pad_before
if len(pad_before) != n:
raise ValueError("Input dimension and pad_before dismatch : %d vs %d" % (
raise ValueError("Input dimension and pad_before mismatch : %d vs %d" % (
n, len(pad_before)))
if len(pad_after) != n:
raise ValueError("Input dimension and pad_after dismatch : %d vs %d" % (
raise ValueError("Input dimension and pad_after mismatch : %d vs %d" % (
n, len(pad_before)))
out_shape = tuple(
tvm.ir_pass.Simplify(
Expand Down
25 changes: 20 additions & 5 deletions vta/python/vta/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import absolute_import as _abs

import tvm
from tvm import rpc
from . import ir_pass
from . import ptr_alias
from .environment import get_env
Expand Down Expand Up @@ -124,15 +125,29 @@ def vta_autotvm_build_func(measure_input, tmp_dir, **kwargs):
raise InstantiationError(config.errors)

func = build(s, args, target_host=task.target_host)
func2 = build(s, args)
func_sim = build(s, args)

arg_info = tuple((get_const_tuple(x.shape), x.dtype) for x in args)
func.export_library(filename)

# check by local simulator
ctx = tvm.context(str(target))
args = [tvm.nd.empty(x[0], dtype=x[1], ctx=ctx) for x in arg_info]
func2(*args)
# When targeting VTA test the schedule on simulator first
if measure_input.target.device_name == 'vta':
from vta import reconfig_runtime
# Note: if you're not running the RPC locally, you cannot benefit
# from rumtime recompilation...
local_rpc_port = int(os.environ.get("VTA_LOCAL_SIM_RPC_PORT", "0"))
if local_rpc_port:
remote = rpc.connect("localhost", local_rpc_port)
reconfig_runtime(remote)
else:
remote = rpc.LocalSession()
obj_path = os.path.join(tmp_dir, "tmp_func_%0x.tar" % getrandbits(64))
func_sim.export_library(obj_path)
remote.upload(obj_path)
f = remote.load_module(os.path.split(obj_path)[1])
ctx = remote.context(str(measure_input.target), 0)
args = [tvm.nd.empty(x[0], dtype=x[1], ctx=ctx) for x in arg_info]
f(*args)

except Exception as e: # pylint: disable=broad-except
return BuildResult(None, None, e, time.time() - tic)
Expand Down
29 changes: 6 additions & 23 deletions vta/python/vta/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,8 @@ def __init__(self, cfg):
self._mock_env = None
self._dev_ctx = None
self._last_env = None
# derive bitstream name
self.BITSTREAM = "{}/{}_{}x{}x{}_a{}w{}o{}s{}_{}_{}_{}_{}_{}_{}MHz_{}ns_gii{}".format(
self.HW_VER.replace('.', '_'),
# model - autoTVM signature that identifies VTA configuration.
self.MODEL = "{}_{}x{}x{}_a{}w{}o{}s{}_{}_{}_{}_{}_{}_{}MHz_{}ns_gii{}".format(
self.TARGET,
self.BATCH,
self.BLOCK_IN,
Expand All @@ -171,27 +170,11 @@ def __init__(self, cfg):
self.HW_CLK_TARGET,
self.GEMM_II)
if self.ALU_EN:
self.BITSTREAM += "_aii{}".format(self.TALU_II)
self.MODEL += "_aii{}".format(self.TALU_II)
if self.MUL_EN and self.ALU_EN:
self.BITSTREAM += "_mul"
self.BITSTREAM += ".bit"
# model - autoTVM signature that identifies VTA configuration.
# This is WIP: knobs that could influence the efficacy of the
# schedule have been left out for now.
self.MODEL = "{}-{}x{}x{}_a{}w{}o{}_{}_{}_{}_{}_{}".format(
self.TARGET,
self.BATCH,
self.BLOCK_IN,
self.BLOCK_OUT,
self.INP_WIDTH,
self.WGT_WIDTH,
self.OUT_WIDTH,
self.LOG_BUS_WIDTH,
self.LOG_UOP_BUFF_SIZE,
self.LOG_INP_BUFF_SIZE,
self.LOG_WGT_BUFF_SIZE,
self.LOG_ACC_BUFF_SIZE)

self.MODEL += "_mul"
# derive bitstream name
self.BITSTREAM = self.HW_VER.replace('.', '_') + "/" + self.MODEL + ".bit"
def __enter__(self):
self._last_env = Environment.current
Environment.current = self
Expand Down
5 changes: 4 additions & 1 deletion vta/python/vta/exec/rpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from tvm._ffi.base import c_str
from tvm import rpc
from tvm.contrib import cc
from pynq import Bitstream

from ..environment import get_env
from ..pkg_config import PkgConfig
Expand Down Expand Up @@ -52,7 +51,10 @@ def ext_dev_callback():

@tvm.register_func("tvm.contrib.vta.init", override=True)
def program_fpga(file_name):
from pynq import Bitstream, xlnk
path = tvm.get_global_func("tvm.rpc.server.workpath")(file_name)
# Reset driver stack
xlnk.Xlnk().xlnk_reset()
bitstream = Bitstream(path)
bitstream.download()
logging.info("Program FPGA with %s", file_name)
Expand All @@ -72,6 +74,7 @@ def reconfig_runtime(cfg_json):
cfg_json : str
JSON string used for configurations.
"""
# This ensures the tmp directory does not fill out too much
if runtime_dll:
raise RuntimeError("Can only reconfig in the beginning of session...")
env = get_env()
Expand Down
4 changes: 2 additions & 2 deletions vta/python/vta/top/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def compute_conv2d(attrs, inputs, out):

return topi.nn.conv2d(inputs[0], inputs[1], strides, padding, dilation, layout, out_dtype)
else:
return packed_group_conv2d(inputs[0], inputs[1], padding, strides, groups, out_dtype)
return topi.nn.group_conv2d_nchw(inputs[0], inputs[1], strides, padding, dilation, groups, out_dtype)

with tvm.target.arm_cpu(tvm.target.current_target().model):
return _nn.compute_conv2d(attrs, inputs, out)
Expand All @@ -93,7 +93,7 @@ def schedule_conv2d(attrs, outs, target):
if groups == 1:
return topi.generic.schedule_conv2d_nchw(outs)
else:
return schedule_packed_group_conv2d(outs)
return topi.generic.schedule_group_conv2d_nchw(outs)
elif str(target).startswith("llvm"):
return tvm.create_schedule([x.op for x in outs])
else:
Expand Down
9 changes: 8 additions & 1 deletion vta/python/vta/top/vta_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,14 @@


@autotvm.register_topi_compute(topi.nn.conv2d, 'vta', 'direct')
def packed_conv2d(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
def packed_conv2d(cfg,
data,
kernel,
strides,
padding,
dilation,
layout,
out_dtype):
""" Packed conv2d function."""
if not is_packed_layout(layout):
raise topi.InvalidShapeError()
Expand Down
Loading

0 comments on commit 60e781e

Please sign in to comment.