From 60e781e6ad1f5e7927e73a044dc79e8409fe8189 Mon Sep 17 00:00:00 2001 From: Thierry Moreau Date: Thu, 15 Nov 2018 20:05:19 -0800 Subject: [PATCH] [AutoTVM] Group-Conv2D support, enabling dynamic recompile in simulator (#28) * simulator fix in autotvm build module: enabling dynamic recompile * model string should be more specific for now * avoiding errors when running simulator locally * proper handling of xlnk driver stack reset * typo fix * tunable topi group conv2d operator --- topi/python/topi/nn/pad.py | 4 +- vta/python/vta/build_module.py | 25 +++- vta/python/vta/environment.py | 29 +---- vta/python/vta/exec/rpc_server.py | 5 +- vta/python/vta/top/op.py | 4 +- vta/python/vta/top/vta_conv2d.py | 9 +- vta/python/vta/top/vta_group_conv2d.py | 169 ++++++++----------------- vta/src/pynq/pynq_driver.cc | 16 --- vta/src/ultra96/ultra96_driver.cc | 18 +-- 9 files changed, 94 insertions(+), 185 deletions(-) diff --git a/topi/python/topi/nn/pad.py b/topi/python/topi/nn/pad.py index 7ebbc566c3a34..ce4c9b1ff58d0 100644 --- a/topi/python/topi/nn/pad.py +++ b/topi/python/topi/nn/pad.py @@ -33,10 +33,10 @@ def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput"): n = len(data.shape) pad_after = pad_after if pad_after else pad_before if len(pad_before) != n: - raise ValueError("Input dimension and pad_before dismatch : %d vs %d" % ( + raise ValueError("Input dimension and pad_before mismatch : %d vs %d" % ( n, len(pad_before))) if len(pad_after) != n: - raise ValueError("Input dimension and pad_after dismatch : %d vs %d" % ( + raise ValueError("Input dimension and pad_after mismatch : %d vs %d" % ( n, len(pad_before))) out_shape = tuple( tvm.ir_pass.Simplify( diff --git a/vta/python/vta/build_module.py b/vta/python/vta/build_module.py index 0d6ee7fcf6a56..6a22551db0551 100644 --- a/vta/python/vta/build_module.py +++ b/vta/python/vta/build_module.py @@ -2,6 +2,7 @@ from __future__ import absolute_import as _abs import tvm +from tvm import rpc from . import ir_pass from . import ptr_alias from .environment import get_env @@ -124,15 +125,29 @@ def vta_autotvm_build_func(measure_input, tmp_dir, **kwargs): raise InstantiationError(config.errors) func = build(s, args, target_host=task.target_host) - func2 = build(s, args) + func_sim = build(s, args) arg_info = tuple((get_const_tuple(x.shape), x.dtype) for x in args) func.export_library(filename) - # check by local simulator - ctx = tvm.context(str(target)) - args = [tvm.nd.empty(x[0], dtype=x[1], ctx=ctx) for x in arg_info] - func2(*args) + # When targeting VTA test the schedule on simulator first + if measure_input.target.device_name == 'vta': + from vta import reconfig_runtime + # Note: if you're not running the RPC locally, you cannot benefit + # from rumtime recompilation... + local_rpc_port = int(os.environ.get("VTA_LOCAL_SIM_RPC_PORT", "0")) + if local_rpc_port: + remote = rpc.connect("localhost", local_rpc_port) + reconfig_runtime(remote) + else: + remote = rpc.LocalSession() + obj_path = os.path.join(tmp_dir, "tmp_func_%0x.tar" % getrandbits(64)) + func_sim.export_library(obj_path) + remote.upload(obj_path) + f = remote.load_module(os.path.split(obj_path)[1]) + ctx = remote.context(str(measure_input.target), 0) + args = [tvm.nd.empty(x[0], dtype=x[1], ctx=ctx) for x in arg_info] + f(*args) except Exception as e: # pylint: disable=broad-except return BuildResult(None, None, e, time.time() - tic) diff --git a/vta/python/vta/environment.py b/vta/python/vta/environment.py index 3d8959423b07d..c4bbffb00375d 100644 --- a/vta/python/vta/environment.py +++ b/vta/python/vta/environment.py @@ -151,9 +151,8 @@ def __init__(self, cfg): self._mock_env = None self._dev_ctx = None self._last_env = None - # derive bitstream name - self.BITSTREAM = "{}/{}_{}x{}x{}_a{}w{}o{}s{}_{}_{}_{}_{}_{}_{}MHz_{}ns_gii{}".format( - self.HW_VER.replace('.', '_'), + # model - autoTVM signature that identifies VTA configuration. + self.MODEL = "{}_{}x{}x{}_a{}w{}o{}s{}_{}_{}_{}_{}_{}_{}MHz_{}ns_gii{}".format( self.TARGET, self.BATCH, self.BLOCK_IN, @@ -171,27 +170,11 @@ def __init__(self, cfg): self.HW_CLK_TARGET, self.GEMM_II) if self.ALU_EN: - self.BITSTREAM += "_aii{}".format(self.TALU_II) + self.MODEL += "_aii{}".format(self.TALU_II) if self.MUL_EN and self.ALU_EN: - self.BITSTREAM += "_mul" - self.BITSTREAM += ".bit" - # model - autoTVM signature that identifies VTA configuration. - # This is WIP: knobs that could influence the efficacy of the - # schedule have been left out for now. - self.MODEL = "{}-{}x{}x{}_a{}w{}o{}_{}_{}_{}_{}_{}".format( - self.TARGET, - self.BATCH, - self.BLOCK_IN, - self.BLOCK_OUT, - self.INP_WIDTH, - self.WGT_WIDTH, - self.OUT_WIDTH, - self.LOG_BUS_WIDTH, - self.LOG_UOP_BUFF_SIZE, - self.LOG_INP_BUFF_SIZE, - self.LOG_WGT_BUFF_SIZE, - self.LOG_ACC_BUFF_SIZE) - + self.MODEL += "_mul" + # derive bitstream name + self.BITSTREAM = self.HW_VER.replace('.', '_') + "/" + self.MODEL + ".bit" def __enter__(self): self._last_env = Environment.current Environment.current = self diff --git a/vta/python/vta/exec/rpc_server.py b/vta/python/vta/exec/rpc_server.py index f5773f15959ba..9e7287dc61bbf 100644 --- a/vta/python/vta/exec/rpc_server.py +++ b/vta/python/vta/exec/rpc_server.py @@ -14,7 +14,6 @@ from tvm._ffi.base import c_str from tvm import rpc from tvm.contrib import cc -from pynq import Bitstream from ..environment import get_env from ..pkg_config import PkgConfig @@ -52,7 +51,10 @@ def ext_dev_callback(): @tvm.register_func("tvm.contrib.vta.init", override=True) def program_fpga(file_name): + from pynq import Bitstream, xlnk path = tvm.get_global_func("tvm.rpc.server.workpath")(file_name) + # Reset driver stack + xlnk.Xlnk().xlnk_reset() bitstream = Bitstream(path) bitstream.download() logging.info("Program FPGA with %s", file_name) @@ -72,6 +74,7 @@ def reconfig_runtime(cfg_json): cfg_json : str JSON string used for configurations. """ + # This ensures the tmp directory does not fill out too much if runtime_dll: raise RuntimeError("Can only reconfig in the beginning of session...") env = get_env() diff --git a/vta/python/vta/top/op.py b/vta/python/vta/top/op.py index 9838bc95ccd97..2dadfeb0ea128 100644 --- a/vta/python/vta/top/op.py +++ b/vta/python/vta/top/op.py @@ -74,7 +74,7 @@ def compute_conv2d(attrs, inputs, out): return topi.nn.conv2d(inputs[0], inputs[1], strides, padding, dilation, layout, out_dtype) else: - return packed_group_conv2d(inputs[0], inputs[1], padding, strides, groups, out_dtype) + return topi.nn.group_conv2d_nchw(inputs[0], inputs[1], strides, padding, dilation, groups, out_dtype) with tvm.target.arm_cpu(tvm.target.current_target().model): return _nn.compute_conv2d(attrs, inputs, out) @@ -93,7 +93,7 @@ def schedule_conv2d(attrs, outs, target): if groups == 1: return topi.generic.schedule_conv2d_nchw(outs) else: - return schedule_packed_group_conv2d(outs) + return topi.generic.schedule_group_conv2d_nchw(outs) elif str(target).startswith("llvm"): return tvm.create_schedule([x.op for x in outs]) else: diff --git a/vta/python/vta/top/vta_conv2d.py b/vta/python/vta/top/vta_conv2d.py index 7ef8279264013..663c4ba8c4610 100644 --- a/vta/python/vta/top/vta_conv2d.py +++ b/vta/python/vta/top/vta_conv2d.py @@ -11,7 +11,14 @@ @autotvm.register_topi_compute(topi.nn.conv2d, 'vta', 'direct') -def packed_conv2d(cfg, data, kernel, strides, padding, dilation, layout, out_dtype): +def packed_conv2d(cfg, + data, + kernel, + strides, + padding, + dilation, + layout, + out_dtype): """ Packed conv2d function.""" if not is_packed_layout(layout): raise topi.InvalidShapeError() diff --git a/vta/python/vta/top/vta_group_conv2d.py b/vta/python/vta/top/vta_group_conv2d.py index 97e9e939951b3..7404a4d034785 100644 --- a/vta/python/vta/top/vta_group_conv2d.py +++ b/vta/python/vta/top/vta_group_conv2d.py @@ -1,115 +1,38 @@ -import logging -from collections import namedtuple +"""Namespace for supporting group_conv2d of nnvm.""" import tvm +from tvm import autotvm import topi -from topi.util import get_const_int, get_const_tuple -from tvm.contrib.util import get_lower_ir +from topi.util import get_const_tuple -from ..environment import get_env - -Workload = namedtuple("GroupConv2DWorkload", - ('batch', 'height', 'width', 'in_filter', 'out_filter', 'groups', - 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride')) - -Schedule = namedtuple("GroupConv2DSchedule", - ('b_factor', 'oc_factor', 'ic_factor', 'h_factor', 'w_factor', - 'oc_nthread', 'h_nthread', 'debug_sync')) - -workloads = [ - Workload(1, 112, 112, 32, 32, 2, 3, 3, 1, 1, 1, 1), - Workload(1, 112, 112, 64, 64, 4, 3, 3, 1, 1, 2, 2), - Workload(1, 56, 56, 128, 128, 8, 3, 3, 1, 1, 1, 1), - Workload(1, 56, 56, 128, 128, 8, 3, 3, 1, 1, 2, 2), - Workload(1, 28, 28, 256, 256, 16, 3, 3, 1, 1, 1, 1), - Workload(1, 28, 28, 256, 256, 16, 3, 3, 1, 1, 2, 2), - Workload(1, 14, 14, 512, 512, 32, 3, 3, 1, 1, 1, 1), - Workload(1, 14, 14, 512, 512, 32, 3, 3, 1, 1, 2, 2), - Workload(1, 7, 7, 1024, 1024, 64, 3, 3, 1, 1, 1, 1), -] - -schedules = [ - Schedule(1, 1, 1, 28, 56, 1, 1, False), - Schedule(1, 1, 1, 14, 28, 1, 1, False), - Schedule(1, 1, 1, 28, 56, 1, 1, False), - Schedule(1, 1, 1, 14, 28, 1, 1, False), - Schedule(1, 1, 1, 28, 28, 1, 1, False), - Schedule(1, 1, 1, 14, 14, 1, 1, False), - Schedule(1, 1, 1, 14, 14, 1, 1, False), - Schedule(1, 1, 1, 7, 7, 1, 1, False), - Schedule(1, 1, 1, 7, 7, 1, 1, False), -] - -injected_schedule = None - -# load schedule - -def find_schedules(layer, vt_only=False, best_only=False): - global injected_schedule - if injected_schedule: - return [injected_schedule] - for i, wkl in enumerate(workloads): - if str(wkl) == str(layer): - return [schedules[i]] - raise RuntimeError("No schedule for " + str(layer)) - -def inject_schedule(sch): - global injected_schedule - injected_schedule = sch - -def _get_workload(data, pad_data, kernel, output): - """ Get the workload structure. - """ - o_shape = get_const_tuple(output.shape) - d_shape = get_const_tuple(data.shape) - k_shape = get_const_tuple(kernel.shape) - o_b, o_c, o_h, o_w, ob_blk, o_blk = o_shape - i_b, i_c, i_h, i_w, ib_blk, i_blk = d_shape - k_o, k_i, k_h, k_w, ko_blk, ki_blk = k_shape - # For now we need to assume that input channel blocking is the same - # as the output channel blocking - assert o_blk == i_blk - assert ob_blk == ib_blk - # Make sure that dimensions match - assert o_b == i_b - assert o_blk == ko_blk - assert i_blk == ki_blk - assert k_o == o_c - groups = i_c // k_i - assert i_c % groups == 0 - assert o_c % groups == 0 - - # Scale the channel size - i_c *= i_blk - o_c *= o_blk - if pad_data is not None: - p_shape = topi.util.get_const_tuple(pad_data.shape) - h_pad = (p_shape[2] - d_shape[2]) // 2 - w_pad = (p_shape[3] - d_shape[3]) // 2 - else: - h_pad, w_pad = 0, 0 - h_str = (i_h + h_pad*2 - k_h) // (o_h - 1) - w_str = (i_w + w_pad*2 - k_w) // (o_w - 1) - return Workload(i_b, i_h, i_w, i_c, o_c, groups, k_h, k_w, h_pad, w_pad, h_str, w_str) +import numpy as np +from ..environment import get_env -def packed_group_conv2d(data, +@autotvm.register_topi_compute(topi.nn.group_conv2d_nchw, 'vta', 'direct') +def packed_group_conv2d(cfg, + data, kernel, - padding, strides, + padding, + dilation, group, - out_dtype="int32"): - """ Packed conv2d function.""" + out_dtype): + """ Packed group conv2d nchw function.""" + assert dilation == (1, 1) + + print(padding) + print(data.shape) if padding[0]: pad_data = topi.nn.pad(data, [0, 0, padding[0], padding[1], 0, 0], name="pad_data") else: pad_data = data - assert len(data.shape) == 6 assert len(kernel.shape) == 6 assert data.dtype == "int8", data.dtype assert kernel.dtype == "int8", kernel.dtype + assert out_dtype == "int32", out_dtype N, CI, IH, IW, B_BATCH, B_CI = get_const_tuple(data.shape) CO, CI_G, KH, KW, B_CO, B_CI = get_const_tuple(kernel.shape) @@ -137,10 +60,14 @@ def packed_group_conv2d(data, kernel[co, ci_o, kh, kw, b_co, ci_i].astype(out_dtype), axis=[ci_o, kh, kw, ci_i]), name="res", tag="packed_group_conv2d") + + cfg.add_flop(2 * np.prod(topi.util.get_const_tuple(oshape)) * + KH * KW * CI * B_CI) return out -def schedule_packed_group_conv2d(outs): +@autotvm.register_topi_schedule(topi.generic.schedule_group_conv2d_nchw, 'vta', 'direct') +def schedule_packed_group_conv2d(cfg, outs): """ Schedule the packed conv2d. """ assert len(outs) == 1 @@ -167,6 +94,19 @@ def _traverse(op): _traverse(output.op) assert len(conv2d_res) == 1 conv2d_stage = conv2d_res[0].output(0) + s = tvm.create_schedule(output.op) + + ##### space definition begin ##### + b, co, h, w, bi, ci = s[conv2d_stage].op.axis + ci, kh, kw, bci = s[conv2d_stage].op.reduce_axis + cfg.define_split('tile_b', b, num_outputs=2) + cfg.define_split('tile_h', h, num_outputs=2) + cfg.define_split('tile_w', w, num_outputs=2) + cfg.define_split('tile_ci', ci, num_outputs=2) + cfg.define_split('tile_co', co, num_outputs=2) + cfg.define_knob('oc_nthread', [1, 2]) + cfg.define_knob('h_nthread', [1, 2]) + ###### space definition end ###### data, kernel = conv2d_stage.op.input_tensors if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: @@ -175,17 +115,14 @@ def _traverse(op): data = temp else: pad_data = None - wrkld = _get_workload(data, pad_data, kernel, output) - plan = find_schedules(wrkld, vt_only=True, best_only=True)[0] - env = get_env() - load_inp = load_wgt = load_out = store_out = env.dma_copy + env = get_env() + load_inp = load_wgt = load_acc = store_out = env.dma_copy alu = env.alu gemm = env.gemm - # schedule1 + # schedule oshape = topi.util.get_const_tuple(output.shape) - s = tvm.create_schedule(output.op) # setup pad if pad_data is not None: @@ -195,26 +132,23 @@ def _traverse(op): cdata = s.cache_read(data, env.inp_scope, [conv2d_stage]) ckernel = s.cache_read(kernel, env.wgt_scope, [conv2d_stage]) s[conv2d_stage].set_scope(env.acc_scope) + # cache read input cache_read_ewise = [] - for consumer, tensor in ewise_inputs: cache_read_ewise.append( s.cache_read(tensor, env.acc_scope, [consumer])) + # set ewise scope for op in ewise_ops: s[op].set_scope(env.acc_scope) s[op].pragma(s[op].op.axis[0], alu) # tile - oc_factor = (plan.oc_factor if plan.oc_factor else 1) - h_factor = (plan.h_factor if plan.h_factor else 1) - w_factor = (plan.w_factor if plan.w_factor else 1) - x_bo, x_co, x_i, x_j, x_bi, x_ci = s[output].op.axis - x_co0, x_co1 = s[output].split(x_co, factor=oc_factor) - x_i0, x_i1 = s[output].split(x_i, factor=h_factor) - x_j0, x_j1 = s[output].split(x_j, factor=w_factor) + x_co0, x_co1 = cfg['tile_co'].apply(s, output, x_co) + x_i0, x_i1 = cfg['tile_h'].apply(s, output, x_i) + x_j0, x_j1 = cfg['tile_w'].apply(s, output, x_j) s[output].reorder(x_bo, x_i0, x_co0, x_j0, x_co1, x_i1, x_j1, x_bi, x_ci) store_pt = x_j0 @@ -225,17 +159,17 @@ def _traverse(op): for tensor in cache_read_ewise: s[tensor].compute_at(s[output], store_pt) - s[tensor].pragma(s[tensor].op.axis[0], load_out) + s[tensor].pragma(s[tensor].op.axis[0], load_acc) # virtual threading along output channel axes - if plan.oc_nthread > 1: - _, v_t = s[output].split(x_co0, factor=plan.oc_nthread) + if cfg['oc_nthread'].val > 1: + _, v_t = s[output].split(x_co0, factor=cfg['oc_nthread'].val) s[output].reorder(v_t, x_bo) s[output].bind(v_t, tvm.thread_axis("cthread")) # virtual threading along spatial rows - if plan.h_nthread > 1: - _, v_t = s[output].split(x_i0, factor=plan.h_nthread) + if cfg['h_nthread'].val > 1: + _, v_t = s[output].split(x_i0, factor=cfg['h_nthread'].val) s[output].reorder(v_t, x_bo) s[output].bind(v_t, tvm.thread_axis("cthread")) @@ -243,10 +177,9 @@ def _traverse(op): k_o, d_i, d_j, k_i = s[conv2d_stage].op.reduce_axis s[conv2d_stage].reorder(x_bo, k_o, x_j, d_j, d_i, x_co, x_i, x_bi, x_ci, k_i) - if plan.ic_factor: - k_o, _ = s[conv2d_stage].split(k_o, factor=plan.ic_factor) - s[cdata].compute_at(s[conv2d_stage], k_o) - s[ckernel].compute_at(s[conv2d_stage], k_o) + k_o, _ = cfg['tile_ci'].apply(s, conv2d_stage, k_o) + s[cdata].compute_at(s[conv2d_stage], k_o) + s[ckernel].compute_at(s[conv2d_stage], k_o) # Use VTA instructions s[cdata].pragma(s[cdata].op.axis[0], load_inp) diff --git a/vta/src/pynq/pynq_driver.cc b/vta/src/pynq/pynq_driver.cc index 62530925168cf..551f000f7ea19 100644 --- a/vta/src/pynq/pynq_driver.cc +++ b/vta/src/pynq/pynq_driver.cc @@ -10,20 +10,6 @@ #define RESET_IOCTL _IOWR('X', 101, unsigned long) -void _xlnk_reset() { - /* This performs the correct ioctl but probably isn't - particularly stable as a behaviour */ - int xlnkfd = open("/dev/xlnk", O_RDWR | O_CLOEXEC); - if (xlnkfd < 0) { - printf("Reset failed - could not open device: %d\n", xlnkfd); - return; - } - if (ioctl(xlnkfd, RESET_IOCTL, 0) < 0) { - printf("Reset failed - IOCTL failed: %d\n", errno); - } - close(xlnkfd); -} - void* VTAMemAlloc(size_t size, int cached) { return cma_alloc(size, cached); } @@ -89,8 +75,6 @@ class VTADevice { VTAUnmapRegister(vta_load_handle_, VTA_RANGE); VTAUnmapRegister(vta_compute_handle_, VTA_RANGE); VTAUnmapRegister(vta_store_handle_, VTA_RANGE); - // Reset xlnk drivers to clean up leaks - _xlnk_reset(); } int Run(vta_phy_addr_t insn_phy_addr, diff --git a/vta/src/ultra96/ultra96_driver.cc b/vta/src/ultra96/ultra96_driver.cc index 8adc4497b1b44..72c70051d1b63 100644 --- a/vta/src/ultra96/ultra96_driver.cc +++ b/vta/src/ultra96/ultra96_driver.cc @@ -10,20 +10,6 @@ #define RESET_IOCTL _IOWR('X', 101, unsigned long) -void _xlnk_reset() { - /* This performs the correct ioctl but probably isn't - particularly stable as a behaviour */ - int xlnkfd = open("/dev/xlnk", O_RDWR | O_CLOEXEC); - if (xlnkfd < 0) { - printf("Reset failed - could not open device: %d\n", xlnkfd); - return; - } - if (ioctl(xlnkfd, RESET_IOCTL, 0) < 0) { - printf("Reset failed - IOCTL failed: %d\n", errno); - } - close(xlnkfd); -} - void* VTAMemAlloc(size_t size, int cached) { void* ret_val = cma_alloc(size, cached); return ret_val; @@ -90,8 +76,6 @@ class VTADevice { VTAUnmapRegister(vta_load_handle_, VTA_RANGE); VTAUnmapRegister(vta_compute_handle_, VTA_RANGE); VTAUnmapRegister(vta_store_handle_, VTA_RANGE); - // Reset xlnk drivers to clean up leaks - _xlnk_reset(); } int Run(vta_phy_addr_t insn_phy_addr, @@ -154,4 +138,4 @@ int VTADeviceRun(VTADeviceHandle handle, uint32_t wait_cycles) { return static_cast(handle)->Run( insn_phy_addr, insn_count, wait_cycles); -} \ No newline at end of file +}