Skip to content
This repository has been archived by the owner on Apr 30, 2020. It is now read-only.

feat: (1) Support zero copy from tf tensor to dlpack (2) Fix pymodul… #4

Merged
merged 2 commits into from Nov 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
46 changes: 46 additions & 0 deletions CMakeLists.txt
@@ -0,0 +1,46 @@
cmake_minimum_required(VERSION 3.2)
project(tftvm)
set(CMAKE_VERBOSE_MAKEFILE ON)

if ("${TVM_HOME}" STREQUAL "")
message(FATAL_ERROR "TVM_HOME is not defined")
else()
message("Use TVM_HOME=\"${TVM_HOME}\"")
endif()


include_directories(${TVM_HOME}/include)
include_directories(${TVM_HOME}/3rdparty/dlpack/include)
include_directories(${TVM_HOME}/3rdparty/dmlc-core/include)


execute_process(COMMAND python -c "import tensorflow as tf; print(' '.join(tf.sysconfig.get_compile_flags()))"
OUTPUT_VARIABLE TF_COMPILE_FLAGS_STR
RESULT_VARIABLE TF_STATUS)
if (NOT ${TF_STATUS} EQUAL 0)
message(FATAL_ERROR "Fail to get TensorFlow compile flags")
endif()

execute_process(COMMAND python -c "import tensorflow as tf; print(' '.join(tf.sysconfig.get_link_flags()))"
OUTPUT_VARIABLE TF_LINK_FLAGS_STR
RESULT_VARIABLE TF_STATUS)
if (NOT ${TF_STATUS} EQUAL 0)
message(FATAL_ERROR "Fail to get TensorFlow link flags")
endif()

string(REGEX REPLACE "\n" " " TF_FLAGS "${TF_COMPILE_FLAGS} ${TF_LINK_FLAGS}")
message("Use TensorFlow flags=\"${TF_FLAGS}\"")
separate_arguments(TF_COMPILE_FLAGS UNIX_COMMAND ${TF_COMPILE_FLAGS_STR})
separate_arguments(TF_LINK_FLAGS UNIX_COMMAND ${TF_LINK_FLAGS_STR})


set(OP_LIBRARY_NAME tvm_dso_op)
file(GLOB_RECURSE TFTVM_SRCS src/*.cc)
add_library(${OP_LIBRARY_NAME} SHARED ${TFTVM_SRCS})
set_target_properties(${OP_LIBRARY_NAME} PROPERTIES PREFIX "")

set(TFTVM_COMPILE_FLAGS -O2 -ldl -g)
set(TFTVM_LINK_FLAGS -ltvm_runtime -L${TVM_HOME}/build)
target_compile_options(${OP_LIBRARY_NAME} PUBLIC ${TFTVM_COMPILE_FLAGS} ${TF_COMPILE_FLAGS})
target_link_options(${OP_LIBRARY_NAME} PUBLIC ${TFTVM_LINK_FLAGS} ${TF_LINK_FLAGS})

6 changes: 6 additions & 0 deletions build.sh
@@ -0,0 +1,6 @@
#!/bin/bash

mkdir -p build
cd build
cmake .. -DTVM_HOME=${TVM_HOME}
make
130 changes: 0 additions & 130 deletions cpp/tvm_dso_op_kernels.cc

This file was deleted.

4 changes: 2 additions & 2 deletions python/tf_op/module.py
Expand Up @@ -41,7 +41,7 @@ def __init__(self, lib_path, func_name, output_dtype, output_shape, device):
self.tvm_dso_op = tvm_dso_op.tvm_dso_op

def apply(self, *params):
return self.tvm_dso_op(params, lib_path=self.lib_path, func_name=self.func_name, output_dtype=self.output_dtype, output_shape=self.output_shape, device=self.device)
return self.tvm_dso_op(*params, lib_path=self.lib_path, func_name=self.func_name, output_dtype=self.output_dtype, output_shape=self.output_shape, device=self.device)

def __call__(self, *params):
return self.apply(params)
return self.apply(*params)
File renamed without changes.