Skip to content

Commit

Permalink
[RPC][RUNTIME] Support dynamic reload of runtime API according to con…
Browse files Browse the repository at this point in the history
…fig (apache#19)
  • Loading branch information
tqchen committed Jul 12, 2018
1 parent 092460c commit 5d15b16
Show file tree
Hide file tree
Showing 17 changed files with 433 additions and 211 deletions.
33 changes: 14 additions & 19 deletions vta/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -40,36 +40,31 @@ ifneq ($(ADD_LDFLAGS), NONE)
LDFLAGS += $(ADD_LDFLAGS)
endif

ifeq ($(UNAME_S), Darwin)
SHARED_LIBRARY_SUFFIX := dylib
WHOLE_ARCH= -all_load
NO_WHOLE_ARCH= -noall_load
LDFLAGS += -undefined dynamic_lookup
else
SHARED_LIBRARY_SUFFIX := so
WHOLE_ARCH= --whole-archive
NO_WHOLE_ARCH= --no-whole-archive
endif


all: lib/libvta.$(SHARED_LIBRARY_SUFFIX)
all: lib/libvta.so lib/libvta_runtime.so

VTA_LIB_SRC = $(wildcard src/*.cc src/tvm/*.cc)

ifeq ($(TARGET), VTA_PYNQ_TARGET)
VTA_LIB_SRC += $(wildcard src/pynq/*.cc)
LDFLAGS += -L/usr/lib -lsds_lib
LDFLAGS += -L/opt/python3.6/lib/python3.6/site-packages/pynq/drivers/ -l:libdma.so
LDFLAGS += -L/opt/python3.6/lib/python3.6/site-packages/pynq/drivers/
LDFLAGS += -L/opt/python3.6/lib/python3.6/site-packages/pynq/lib/
LDFLAGS += -l:libdma.so
endif
VTA_LIB_OBJ = $(patsubst %.cc, build/%.o, $(VTA_LIB_SRC))

test: $(TEST)
VTA_LIB_OBJ = $(patsubst src/%.cc, build/%.o, $(VTA_LIB_SRC))

build/src/%.o: src/%.cc
build/%.o: src/%.cc
@mkdir -p $(@D)
$(CXX) $(CFLAGS) -MM -MT build/src/$*.o $< >build/src/$*.d
$(CXX) $(CFLAGS) -MM -MT build/src/$*.o $< >build/$*.d
$(CXX) -c $(CFLAGS) -c $< -o $@

lib/libvta.$(SHARED_LIBRARY_SUFFIX): $(VTA_LIB_OBJ)
lib/libvta.so: $(filter-out build/runtime.o, $(VTA_LIB_OBJ))
@mkdir -p $(@D)
$(CXX) $(CFLAGS) -shared -o $@ $(filter %.o, $^) $(LDFLAGS)

lib/libvta_runtime.so: build/runtime.o
@mkdir -p $(@D)
$(CXX) $(CFLAGS) -shared -o $@ $(filter %.o, $^) $(LDFLAGS)

Expand All @@ -79,7 +74,7 @@ cpplint:
python nnvm/dmlc-core/scripts/lint.py vta cpp include src hardware tests

pylint:
pylint python/tvm_vta --rcfile=$(ROOTDIR)/tests/lint/pylintrc
pylint python/vta --rcfile=$(ROOTDIR)/tests/lint/pylintrc

doc:
doxygen docs/Doxyfile
Expand Down
4 changes: 2 additions & 2 deletions vta/apps/pynq_rpc/start_rpc_server.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/bin/bash
export PYTHONPATH=${PYTHONPATH}:/home/xilinx/tvm/python
export PYTHONPATH=${PYTHONPATH}:/home/xilinx/tvm/python:/home/xilinx/vta/python
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/opt/python3.6/lib/python3.6/site-packages/pynq/drivers/
python -m tvm.exec.rpc_server --load-library /home/xilinx/vta/lib/libvta.so
python -m vta.exec.rpc_server
4 changes: 1 addition & 3 deletions vta/examples/resnet18/pynq/imagenet_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@
# Program the FPGA remotely
assert tvm.module.enabled("rpc")
remote = rpc.connect(host, port)
remote.upload(BITSTREAM_FILE, BITSTREAM_FILE)
fprogram = remote.get_function("tvm.contrib.vta.init")
fprogram(BITSTREAM_FILE)
vta.program_fpga(remote, BITSTREAM_FILE)

if verbose:
logging.basicConfig(level=logging.INFO)
Expand Down
66 changes: 34 additions & 32 deletions vta/include/vta/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,40 +23,20 @@ extern "C" {
#define VTA_DEBUG_SKIP_WRITE_BARRIER (1 << 4)
#define VTA_DEBUG_FORCE_SERIAL (1 << 5)

/*! \brief VTA command handle */
typedef void * VTACommandHandle;

/*! \brief Shutdown hook of VTA to cleanup resources */
void VTARuntimeShutdown();

/*!
* \brief Get thread local command handle.
* \return A thread local command handle.
*/
VTACommandHandle VTATLSCommandHandle();

/*!
* \brief Allocate data buffer.
* \param cmd The VTA command handle.
* \param size Buffer size.
* \return A pointer to the allocated buffer.
*/
void* VTABufferAlloc(VTACommandHandle cmd, size_t size);
void* VTABufferAlloc(size_t size);

/*!
* \brief Free data buffer.
* \param cmd The VTA command handle.
* \param buffer The data buffer to be freed.
*/
void VTABufferFree(VTACommandHandle cmd, void* buffer);

/*!
* \brief Get the buffer access pointer on CPU.
* \param cmd The VTA command handle.
* \param buffer The data buffer.
* \return The pointer that can be accessed by the CPU.
*/
void* VTABufferCPUPtr(VTACommandHandle cmd, void* buffer);
void VTABufferFree(void* buffer);

/*!
* \brief Copy data buffer from one location to another.
Expand All @@ -68,20 +48,32 @@ void* VTABufferCPUPtr(VTACommandHandle cmd, void* buffer);
* \param size Size of copy.
* \param kind_mask The memory copy kind.
*/
void VTABufferCopy(VTACommandHandle cmd,
const void* from,
void VTABufferCopy(const void* from,
size_t from_offset,
void* to,
size_t to_offset,
size_t size,
int kind_mask);

/*! \brief VTA command handle */
typedef void* VTACommandHandle;

/*! \brief Shutdown hook of VTA to cleanup resources */
void VTARuntimeShutdown();

/*!
* \brief Set debug mode on the command handle.
* \brief Get thread local command handle.
* \return A thread local command handle.
*/
VTACommandHandle VTATLSCommandHandle();

/*!
* \brief Get the buffer access pointer on CPU.
* \param cmd The VTA command handle.
* \param debug_flag The debug flag.
* \param buffer The data buffer.
* \return The pointer that can be accessed by the CPU.
*/
void VTASetDebugMode(VTACommandHandle cmd, int debug_flag);
void* VTABufferCPUPtr(VTACommandHandle cmd, void* buffer);

/*!
* \brief Perform a write barrier to make a memory region visible to the CPU.
Expand All @@ -92,9 +84,10 @@ void VTASetDebugMode(VTACommandHandle cmd, int debug_flag);
* \param extent The end of the region (in elements).
*/
void VTAWriteBarrier(VTACommandHandle cmd,
void* buffer, uint32_t elem_bits,
uint32_t start, uint32_t extent);

void* buffer,
uint32_t elem_bits,
uint32_t start,
uint32_t extent);
/*!
* \brief Perform a read barrier to a memory region visible to VTA.
* \param cmd The VTA command handle.
Expand All @@ -104,8 +97,17 @@ void VTAWriteBarrier(VTACommandHandle cmd,
* \param extent The end of the region (in elements).
*/
void VTAReadBarrier(VTACommandHandle cmd,
void* buffer, uint32_t elem_bits,
uint32_t start, uint32_t extent);
void* buffer,
uint32_t elem_bits,
uint32_t start,
uint32_t extent);

/*!
* \brief Set debug mode on the command handle.
* \param cmd The VTA command handle.
* \param debug_flag The debug flag.
*/
void VTASetDebugMode(VTACommandHandle cmd, int debug_flag);

/*!
* \brief Perform a 2D data load from DRAM.
Expand Down
1 change: 1 addition & 0 deletions vta/make/config.mk
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ VTA_LOG_WGT_BUFF_SIZE = 15
# Log of acc buffer size in Bytes
VTA_LOG_ACC_BUFF_SIZE = 17


#---------------------
# Derived VTA hardware parameters
#--------------------
Expand Down
22 changes: 15 additions & 7 deletions vta/python/vta/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
"""TVM VTA runtime"""
"""TVM-based VTA Compiler Toolchain"""
from __future__ import absolute_import as _abs

from .hw_spec import *

from .runtime import SCOPE_INP, SCOPE_OUT, SCOPE_WGT, DMA_COPY, ALU
from .intrin import GEVM, GEMM
from .build import debug_mode
try:
from .runtime import SCOPE_INP, SCOPE_OUT, SCOPE_WGT, DMA_COPY, ALU
from .intrin import GEVM, GEMM
from .build import debug_mode
from . import mock, ir_pass
from . import arm_conv2d, vta_conv2d
except AttributeError:
pass

from . import mock, ir_pass
from . import arm_conv2d, vta_conv2d
from . import graph
from .rpc_client import reconfig_runtime, program_fpga

try:
from . import graph
except ImportError:
pass
1 change: 1 addition & 0 deletions vta/python/vta/exec/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""VTA Command line utils."""
104 changes: 104 additions & 0 deletions vta/python/vta/exec/rpc_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""VTA customized TVM RPC Server
Provides additional runtime function and library loading.
"""
from __future__ import absolute_import

import logging
import argparse
import os
import ctypes
import tvm
from tvm.contrib import rpc, util, cc


@tvm.register_func("tvm.contrib.rpc.server.start", override=True)
def server_start():
curr_path = os.path.dirname(
os.path.abspath(os.path.expanduser(__file__)))
dll_path = os.path.abspath(
os.path.join(curr_path, "../../../lib/libvta_runtime.so"))
runtime_dll = []
_load_module = tvm.get_global_func("tvm.contrib.rpc.server.load_module")

@tvm.register_func("tvm.contrib.rpc.server.load_module", override=True)
def load_module(file_name):
if not runtime_dll:
runtime_dll.append(ctypes.CDLL(dll_path, ctypes.RTLD_GLOBAL))
return _load_module(file_name)

@tvm.register_func("tvm.contrib.rpc.server.shutdown", override=True)
def server_shutdown():
if runtime_dll:
runtime_dll[0].VTARuntimeShutdown()
runtime_dll.pop()

@tvm.register_func("tvm.contrib.vta.reconfig_runtime", override=True)
def reconfig_runtime(cflags):
"""Rebuild and reload runtime with new configuration.
Parameters
----------
cfg_json : str
JSON string used for configurations.
"""
if runtime_dll:
raise RuntimeError("Can only reconfig in the beginning of session...")
cflags = cflags.split()
cflags += ["-O2", "-std=c++11"]
lib_name = dll_path
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
proj_root = os.path.abspath(os.path.join(curr_path, "../../../"))
runtime_source = os.path.join(proj_root, "src/runtime.cc")
cflags += ["-I%s/include" % proj_root]
cflags += ["-I%s/nnvm/tvm/include" % proj_root]
cflags += ["-I%s/nnvm/tvm/dlpack/include" % proj_root]
cflags += ["-I%s/nnvm/dmlc-core/include" % proj_root]
logging.info("Rebuild runtime dll with %s", str(cflags))
cc.create_shared(lib_name, [runtime_source], cflags)


def main():
"""Main funciton"""
parser = argparse.ArgumentParser()
parser.add_argument('--host', type=str, default="0.0.0.0",
help='the hostname of the server')
parser.add_argument('--port', type=int, default=9090,
help='The port of the PRC')
parser.add_argument('--port-end', type=int, default=9199,
help='The end search port of the PRC')
parser.add_argument('--key', type=str, default="",
help="RPC key used to identify the connection type.")
parser.add_argument('--tracker', type=str, default="",
help="Report to RPC tracker")
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
proj_root = os.path.abspath(os.path.join(curr_path, "../../../"))
lib_path = os.path.abspath(os.path.join(proj_root, "lib/libvta.so"))

libs = []
for file_name in [lib_path]:
libs.append(ctypes.CDLL(file_name, ctypes.RTLD_GLOBAL))
logging.info("Load additional library %s", file_name)

if args.tracker:
url, port = args.tracker.split(":")
port = int(port)
tracker_addr = (url, port)
if not args.key:
raise RuntimeError(
"Need key to present type of resource when tracker is available")
else:
tracker_addr = None

server = rpc.Server(args.host,
args.port,
args.port_end,
key=args.key,
tracker_addr=tracker_addr)
server.libs += libs
server.proc.join()

if __name__ == "__main__":
main()
22 changes: 21 additions & 1 deletion vta/python/vta/hw_spec.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,31 @@
"""VTA configuration constants (should match hw_spec.h"""
from __future__ import absolute_import as _abs

# Log of input/activation width in bits (default 3 -> 8 bits)
VTA_LOG_INP_WIDTH = 3
# Log of kernel weight width in bits (default 3 -> 8 bits)
VTA_LOG_WGT_WIDTH = 3
# Log of accum width in bits (default 5 -> 32 bits)
VTA_LOG_ACC_WIDTH = 5
# Log of tensor batch size (A in (A,B)x(B,C) matrix multiplication)
VTA_LOG_BATCH = 0
# Log of tensor inner block size (B in (A,B)x(B,C) matrix multiplication)
VTA_LOG_BLOCK_IN = 4
# Log of tensor outer block size (C in (A,B)x(B,C) matrix multiplication)
VTA_LOG_BLOCK_OUT = 4
VTA_LOG_OUT_WIDTH = VTA_LOG_INP_WIDTH
# Log of uop buffer size in Bytes
VTA_LOG_UOP_BUFF_SIZE = 15
# Log of acc buffer size in Bytes
VTA_LOG_ACC_BUFF_SIZE = 17

# The Constants
VTA_WGT_WIDTH = 8
VTA_INP_WIDTH = VTA_WGT_WIDTH
VTA_OUT_WIDTH = 32

VTA_TARGET = "VTA_PYNQ_TARGET"

# Dimensions of the GEMM unit
# (BATCH,BLOCK_IN) x (BLOCK_IN,BLOCK_OUT)
VTA_BATCH = 1
Expand Down Expand Up @@ -67,4 +87,4 @@
DEBUG_DUMP_INSN = (1 << 1)
DEBUG_DUMP_UOP = (1 << 2)
DEBUG_SKIP_READ_BARRIER = (1 << 3)
DEBUG_SKIP_WRITE_BARRIER = (1 << 4)
DEBUG_SKIP_WRITE_BARRIER = (1 << 4)
Loading

0 comments on commit 5d15b16

Please sign in to comment.