diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..06b8c2c --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,46 @@ +cmake_minimum_required(VERSION 3.2) +project(tftvm) +set(CMAKE_VERBOSE_MAKEFILE ON) + +if ("${TVM_HOME}" STREQUAL "") + message(FATAL_ERROR "TVM_HOME is not defined") +else() + message("Use TVM_HOME=\"${TVM_HOME}\"") +endif() + + +include_directories(${TVM_HOME}/include) +include_directories(${TVM_HOME}/3rdparty/dlpack/include) +include_directories(${TVM_HOME}/3rdparty/dmlc-core/include) + + +execute_process(COMMAND python -c "import tensorflow as tf; print(' '.join(tf.sysconfig.get_compile_flags()))" + OUTPUT_VARIABLE TF_COMPILE_FLAGS_STR + RESULT_VARIABLE TF_STATUS) +if (NOT ${TF_STATUS} EQUAL 0) + message(FATAL_ERROR "Fail to get TensorFlow compile flags") +endif() + +execute_process(COMMAND python -c "import tensorflow as tf; print(' '.join(tf.sysconfig.get_link_flags()))" + OUTPUT_VARIABLE TF_LINK_FLAGS_STR + RESULT_VARIABLE TF_STATUS) +if (NOT ${TF_STATUS} EQUAL 0) + message(FATAL_ERROR "Fail to get TensorFlow link flags") +endif() + +string(REGEX REPLACE "\n" " " TF_FLAGS "${TF_COMPILE_FLAGS} ${TF_LINK_FLAGS}") +message("Use TensorFlow flags=\"${TF_FLAGS}\"") +separate_arguments(TF_COMPILE_FLAGS UNIX_COMMAND ${TF_COMPILE_FLAGS_STR}) +separate_arguments(TF_LINK_FLAGS UNIX_COMMAND ${TF_LINK_FLAGS_STR}) + + +set(OP_LIBRARY_NAME tvm_dso_op) +file(GLOB_RECURSE TFTVM_SRCS src/*.cc) +add_library(${OP_LIBRARY_NAME} SHARED ${TFTVM_SRCS}) +set_target_properties(${OP_LIBRARY_NAME} PROPERTIES PREFIX "") + +set(TFTVM_COMPILE_FLAGS -O2 -ldl -g) +set(TFTVM_LINK_FLAGS -ltvm_runtime -L${TVM_HOME}/build) +target_compile_options(${OP_LIBRARY_NAME} PUBLIC ${TFTVM_COMPILE_FLAGS} ${TF_COMPILE_FLAGS}) +target_link_options(${OP_LIBRARY_NAME} PUBLIC ${TFTVM_LINK_FLAGS} ${TF_LINK_FLAGS}) + diff --git a/build.sh b/build.sh new file mode 100644 index 0000000..e66deaf --- /dev/null +++ b/build.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +mkdir -p build +cd build +cmake .. -DTVM_HOME=${TVM_HOME} +make diff --git a/cpp/tvm_dso_op_kernels.cc b/cpp/tvm_dso_op_kernels.cc deleted file mode 100644 index 9340ca5..0000000 --- a/cpp/tvm_dso_op_kernels.cc +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include -#include -#include -#include -#include -#include -#include "tensorflow/core/framework/op_kernel.h" - -using namespace tensorflow; - -typedef Eigen::ThreadPoolDevice CPUDevice; -typedef Eigen::GpuDevice GPUDevice; - -template -class TVMDSOOpTrait; - -template <> -class TVMDSOOpTrait { - public: - static const int device_type = kDLCPU; - static int device_id(OpKernelContext* context) { - return 0; - } -}; - -template <> -class TVMDSOOpTrait { - public: - static const int device_type = kDLGPU; - static int device_id(OpKernelContext* context) { - auto device_base = context->device(); - auto gpu_device_info = device_base->tensorflow_gpu_device_info(); - return gpu_device_info->gpu_id; - } -}; - -template -class TVMDSOOp : public OpKernel { - -private: - tvm::runtime::PackedFunc tvm_func; - string lib_path; - string func_name; - string output_dtype; - string output_shape; - string device; - - void initAttributes(OpKernelConstruction* context) { - context->GetAttr("lib_path", &lib_path); - context->GetAttr("func_name", &func_name); - context->GetAttr("output_dtype", &output_dtype); - context->GetAttr("output_shpae", &output_shape); - context->GetAttr("device", &device); - } - - public: - explicit TVMDSOOp(OpKernelConstruction* context) : OpKernel(context) { - - // Get attr - initAttributes(context); - - // Load TVM function from dynamic library - tvm::runtime::Module mod_dylib = tvm::runtime::Module::LoadFromFile(lib_path); - LOG(INFO) << "Verify dynamic loading from " << lib_path << " device_type=" << TVMDSOOpTrait::device_type; - tvm_func = mod_dylib.GetFunction(func_name); - CHECK(tvm_func != nullptr); - } - - void Compute(OpKernelContext* context) override { - // Grab the input tensor - const Tensor& input_tensor = context->input(0); - auto input = input_tensor.flat(); - - DLTensor* x; - DLTensor* y; - int ndim = 1; - int dtype_code = kDLFloat; - int dtype_bits = 32; - int dtype_lanes = 1; - int device_type = TVMDSOOpTrait::device_type; - int device_id = TVMDSOOpTrait::device_id(context); - int64_t shape[1] = {10}; - TVMArrayAlloc(shape, ndim, dtype_code, dtype_bits, dtype_lanes, - device_type, device_id, &x); - TVMArrayAlloc(shape, ndim, dtype_code, dtype_bits, dtype_lanes, - device_type, device_id, &y); - - // Copy input tensor data to DLPack - x->data = const_cast(input.data()); - const int input_size = input.size(); - - tvm_func(x, y); - - // Create output tensor from DLPack - Tensor* output_tensor = NULL; - OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), - &output_tensor)); - auto output_flat = output_tensor->flat(); - - // TODO: Use zero-copy instead of memory copy - if (device_type == kDLCPU) { - memcpy(output_flat.data(), y->data, input_size*4); - } else { - cudaMemcpy(output_flat.data(), y->data, input_size*4, cudaMemcpyDeviceToDevice); - } - } -}; - - -REGISTER_KERNEL_BUILDER(Name("TvmDsoOp").Device(DEVICE_CPU), TVMDSOOp); -REGISTER_KERNEL_BUILDER(Name("TvmDsoOp").Device(DEVICE_GPU), TVMDSOOp); diff --git a/python/tf_op/module.py b/python/tf_op/module.py index fa0a35f..fd32eec 100644 --- a/python/tf_op/module.py +++ b/python/tf_op/module.py @@ -41,7 +41,7 @@ def __init__(self, lib_path, func_name, output_dtype, output_shape, device): self.tvm_dso_op = tvm_dso_op.tvm_dso_op def apply(self, *params): - return self.tvm_dso_op(params, lib_path=self.lib_path, func_name=self.func_name, output_dtype=self.output_dtype, output_shape=self.output_shape, device=self.device) + return self.tvm_dso_op(*params, lib_path=self.lib_path, func_name=self.func_name, output_dtype=self.output_dtype, output_shape=self.output_shape, device=self.device) def __call__(self, *params): - return self.apply(params) + return self.apply(*params) diff --git a/cpp/build.sh b/src/build.sh similarity index 100% rename from cpp/build.sh rename to src/build.sh diff --git a/src/tvm_dso_op_kernels.cc b/src/tvm_dso_op_kernels.cc new file mode 100644 index 0000000..b6d8957 --- /dev/null +++ b/src/tvm_dso_op_kernels.cc @@ -0,0 +1,235 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include "tensorflow/core/framework/op_kernel.h" + + +using namespace tensorflow; + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +template +class TVMDSOOpTrait; + + +class TensorAsBuf { + public: + Tensor inline_tensor; + Tensor* tensor; + + size_t size; + size_t offset; + + int device_type; + + char* origin_buf; + char* buf; + + void CopyToOrigin() { + if (buf == origin_buf) { + return; + } + if (device_type == kDLCPU) { + memcpy(origin_buf, buf, size); + } else { + cudaMemcpy(origin_buf, buf, size, cudaMemcpyDeviceToDevice); + } + } + + void CopyFromOrigin() { + if (buf == origin_buf) { + return; + } + if (device_type == kDLCPU) { + memcpy(buf, origin_buf, size); + } else { + cudaMemcpy(buf, origin_buf, size, cudaMemcpyDeviceToDevice); + } + } +}; + + +int GetDLPackDtype(const Tensor& tf_tensor, DLDataType* res) { + auto dtype = tf_tensor.dtype(); + if (dtype == DT_FLOAT) { + res->code = kDLFloat; + res->bits = 32; + res->lanes = 1; + } else { + return -1; + } + return 0; +} + + +void EnsureAlignment(OpKernelContext* ctx, const Tensor& tensor, TensorAsBuf* out) { + char* buf = (char*) tensor.tensor_data().data(); + out->origin_buf = buf; + out->size = tensor.TotalBytes(); + + int alignment = 64; + char* aligned = (char*)(((uint64_t)buf + alignment) & (~ (alignment - 1))); + if (buf == aligned) { + out->tensor = const_cast(&tensor); + out->buf = buf; + out->offset = 0; + } else { + TensorShape buf_shape; + int64 dims[1] = { (int64)(tensor.TotalBytes() + alignment) }; + TensorShapeUtils::MakeShape(dims, 1, &buf_shape); + + out->tensor = &out->inline_tensor; + ctx->allocate_temp(tensor.dtype(), buf_shape, out->tensor); + + buf = (char*)(out->tensor->tensor_data().data()); + char* buf_aligned = (char*)(((uint64_t)buf + alignment) & (~ (alignment - 1))); + out->buf = buf; + out->offset = buf_aligned - buf; + } +} + + +int MakeDLTensor(const TensorAsBuf& src, const DLContext& ctx, int64_t* tf_shape, DLTensor* out) { + DLDataType dlpack_type; + const Tensor& tensor = *src.tensor; + + int status = GetDLPackDtype(tensor, &dlpack_type); + if (status != 0) { + return status; + } + out->ctx = ctx; + out->ndim = tensor.shape().dims(); + out->shape = tf_shape; + out->strides = NULL; + out->byte_offset = 0; + out->dtype = dlpack_type; + out->data = src.buf; + return 0; +} + + +template <> +class TVMDSOOpTrait { + public: + static const int device_type = kDLCPU; + + static int device_id(OpKernelContext* context) { + return 0; + } + +}; + + +template <> +class TVMDSOOpTrait { + public: + static const int device_type = kDLGPU; + + static int device_id(OpKernelContext* context) { + auto device_base = context->device(); + auto gpu_device_info = device_base->tensorflow_gpu_device_info(); + return gpu_device_info->gpu_id; + } +}; + + +template +class TVMDSOOp : public OpKernel { + +private: + tvm::runtime::PackedFunc tvm_func; + string lib_path; + string func_name; + string output_dtype; + string output_shape; + string device; + + void initAttributes(OpKernelConstruction* context) { + context->GetAttr("lib_path", &lib_path); + context->GetAttr("func_name", &func_name); + context->GetAttr("output_dtype", &output_dtype); + context->GetAttr("output_shpae", &output_shape); + context->GetAttr("device", &device); + } + + public: + explicit TVMDSOOp(OpKernelConstruction* context) : OpKernel(context) { + + // Get attr + initAttributes(context); + + // Load TVM function from dynamic library + tvm::runtime::Module mod_dylib = tvm::runtime::Module::LoadFromFile(lib_path); + LOG(INFO) << "Verify dynamic loading from " << lib_path << " device_type=" << TVMDSOOpTrait::device_type; + tvm_func = mod_dylib.GetFunction(func_name); + CHECK(tvm_func != nullptr); + } + + void Compute(OpKernelContext* context) override { + // Grab the input tensor + auto input_tensor = context->input(0); + auto input_shape_buf = input_tensor.shape().dim_sizes(); + auto input_shape_ptr = (int64_t*) input_shape_buf.data(); + + // Allocate output tensor + Tensor* output_tensor; + OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor)); + auto output_shape_buf = output_tensor->shape().dim_sizes(); + auto output_shape_ptr = (int64_t*) output_shape_buf.data(); + + int device_id = TVMDSOOpTrait::device_id(context); + int device_type = TVMDSOOpTrait::device_type; + + DLContext dl_ctx = { DLDeviceType(device_type), device_id }; + + DLTensor dl_input; + TensorAsBuf input; + EnsureAlignment(context, input_tensor, &input); + + int status = MakeDLTensor(input, dl_ctx, input_shape_ptr, &dl_input); + OP_REQUIRES(context, status == 0, Status(error::INTERNAL, "Fail to create dlpack tensor for input")); + + DLTensor dl_output; + TensorAsBuf output; + EnsureAlignment(context, *output_tensor, &output); + + status = MakeDLTensor(output, dl_ctx, output_shape_ptr, &dl_output); + OP_REQUIRES(context, status == 0, Status(error::INTERNAL, "Fail to create dlpack tensor for output")); + + input.CopyFromOrigin(); + + tvm_func(&dl_input, &dl_output); + + output.CopyToOrigin(); + } +}; + + + + +REGISTER_KERNEL_BUILDER(Name("TvmDsoOp").Device(DEVICE_CPU), TVMDSOOp); +REGISTER_KERNEL_BUILDER(Name("TvmDsoOp").Device(DEVICE_GPU), TVMDSOOp); diff --git a/cpp/tvm_dso_ops.cc b/src/tvm_dso_ops.cc similarity index 100% rename from cpp/tvm_dso_ops.cc rename to src/tvm_dso_ops.cc