Skip to content

Commit

Permalink
Add supporting FPGA device
Browse files Browse the repository at this point in the history
    also add vect_add op for FPGA device, we can use follow code
    test it:
        import tensorflow as tf
        with tf.device("/fpga:0"):
            result = tf.user_ops.vect_add(4, 0)
            #result = tf.user_ops.vect_add([4, 3], [2, 1])
        sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))
        print(sess.run(result))

    +--------------------+
    |       python       |
    |                    |
    +--------------------+
    +--------------------+
    |                    |
    |        c/c++       |
    +--------------------+
    +--------------------+
    | opkernel(VectAdd)  |
    +--------------------+   tensorflow
    +-------------------------------------+
    +--------------------+   FPGA
    | API(fpga_mag.cc/   |
    |  calc_vector_add)  |
    +--------------------+
    +--------------------+
    |       driver       |
    |                    |
    +--------------------+
    +--------+  +--------+
    |   FPGA |  | FPGA   |
    +--------+  +--------+
  • Loading branch information
ZhaoHb committed Mar 22, 2019
1 parent 080d59b commit 7456ddb
Show file tree
Hide file tree
Showing 21 changed files with 389 additions and 3 deletions.
1 change: 1 addition & 0 deletions configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -1393,6 +1393,7 @@ def main():
set_trisycl_include_dir(environ_cp)

set_action_env_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False)
set_action_env_var(environ_cp, 'TF_NEED_FPGA', 'FPGA', True)
if (environ_cp.get('TF_NEED_CUDA') == '1' and
'TF_CUDA_CONFIG_REPO' not in environ_cp):
set_tf_cuda_version(environ_cp)
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,7 @@ tf_gen_op_libs(
# And one for all user ops
cc_library(
name = "user_ops_op_lib",
hdrs = ["user_ops/*.h"],
srcs = glob(["user_ops/**/*.cc"]),
copts = tf_copts(),
linkstatic = 1,
Expand Down Expand Up @@ -2047,6 +2048,7 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
"common_runtime/stats_publisher_interface.h",
"common_runtime/step_stats_collector.h",
"common_runtime/threadpool_device.h",
"common_runtime/threadpiscine_device.h",
"common_runtime/visitable_allocator.h",
"graph/gradients.h",
"graph/quantize_training.h",
Expand Down Expand Up @@ -2088,7 +2090,9 @@ tf_cuda_library(
"common_runtime/stats_publisher_interface.cc",
"common_runtime/step_stats_collector.cc",
"common_runtime/threadpool_device.cc",
"common_runtime/threadpiscine_device.cc",
"common_runtime/threadpool_device_factory.cc",
"common_runtime/threadpiscine_device_factory.cc",
"graph/gradients.cc",
"graph/mkl_layout_pass.cc",
"graph/mkl_tfconversion_pass.cc",
Expand Down
10 changes: 9 additions & 1 deletion tensorflow/core/common_runtime/copy_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ void CopyDeviceToHost(const Tensor* input, Allocator* cpu_allocator,
Allocator* out_allocator, StringPiece edge_name,
Device* src, Tensor* output,
DeviceContext* send_dev_context, StatusCallback done) {
printf("%s (%d) - <%s>\n",__FILE__,__LINE__,__FUNCTION__);
if (input->dtype() == DT_VARIANT) {
printf("%s (%d) - <%s>\n",__FILE__,__LINE__,__FUNCTION__);
Tensor copy(cpu_allocator, DT_VARIANT, input->shape());
auto* status_cb = new ReffedStatusCallback(std::move(done));
core::ScopedUnref status_cb_unref(status_cb);
Expand All @@ -134,6 +136,7 @@ void CopyDeviceToHost(const Tensor* input, Allocator* cpu_allocator,
if (status_cb->ok()) {
status_cb->Ref();
*to = Tensor(out_allocator, from.dtype(), from.shape());
printf("%s (%d) - <%s>\n",__FILE__,__LINE__,__FUNCTION__);
send_dev_context->CopyDeviceTensorToCPU(&from, edge_name, src, to,
wrapped_done_);
return Status::OK();
Expand All @@ -147,6 +150,7 @@ void CopyDeviceToHost(const Tensor* input, Allocator* cpu_allocator,
Variant* v_out = copy.flat<Variant>().data();
Status s_copy_init;
for (int64 i = 0; i < input->NumElements(); ++i) {
printf("%s (%d) - <%s>\n",__FILE__,__LINE__,__FUNCTION__);
s_copy_init = VariantDeviceCopy(
VariantDeviceCopyDirection::DEVICE_TO_HOST, v[i], &v_out[i], copier);
if (!s_copy_init.ok()) {
Expand All @@ -158,6 +162,7 @@ void CopyDeviceToHost(const Tensor* input, Allocator* cpu_allocator,
*output = std::move(copy);
}
} else {
printf("%s (%d) - <%s>\n",__FILE__,__LINE__,__FUNCTION__);
send_dev_context->CopyDeviceTensorToCPU(input, edge_name, src, output,
std::move(done));
}
Expand Down Expand Up @@ -244,7 +249,8 @@ void CopyTensor::ViaDMA(StringPiece edge_name, DeviceContext* send_dev_context,
src_alloc_attr.on_host() ? DEVICE_CPU : src->attributes().device_type());
const DeviceType dst_device_type(
dst_alloc_attr.on_host() ? DEVICE_CPU : dst->attributes().device_type());
const bool non_cpu_src = src_device_type != DeviceType(DEVICE_CPU);
const bool non_cpu_src = (src_device_type != DeviceType(DEVICE_CPU)) && \
(src_device_type != DeviceType(DEVICE_FPGA));
const bool non_cpu_dst = dst_device_type != DeviceType(DEVICE_CPU);

// TODO(phawkins): choose an allocator optimal for both the src and dst
Expand Down Expand Up @@ -301,6 +307,7 @@ void CopyTensor::ViaDMA(StringPiece edge_name, DeviceContext* send_dev_context,
std::move(delete_and_done_));
},
std::move(delete_and_done), std::placeholders::_1);
printf("%s (%d) - <%s>\n",__FILE__,__LINE__,__FUNCTION__);
CopyDeviceToHost(input, cpu_allocator, out_allocator, edge_name, src,
cpu_tensor, send_dev_context,
std::move(then_copy_to_other_device));
Expand All @@ -310,6 +317,7 @@ void CopyTensor::ViaDMA(StringPiece edge_name, DeviceContext* send_dev_context,
// E.g., gpu -> cpu
if (non_cpu_src && !non_cpu_dst) {
// Device to host copy.
printf("%s (%d) - <%s>\n",__FILE__,__LINE__,__FUNCTION__);
CopyDeviceToHost(input, cpu_allocator, out_allocator, edge_name, src,
output, send_dev_context, std::move(done));
return;
Expand Down
12 changes: 12 additions & 0 deletions tensorflow/core/common_runtime/device_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,18 @@ Status DeviceFactory::AddDevices(const SessionOptions& options,
return errors::NotFound("No CPU devices are available in this process");
}

//------added by zhaohb-----
#if 0
init_size = devices->size();
auto fpga_factory = GetFactory("FPGA");
if (fpga_factory) {
TF_RETURN_IF_ERROR(fpga_factory->CreateDevices(options, name_prefix, devices));
}
if (devices->size() == init_size) {
return errors::NotFound("No FPGA devices are available in this process");
}
#endif

// Then the rest (including GPU).
mutex_lock l(*get_device_factory_lock());
for (auto& p : device_factories()) {
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/common_runtime/placer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ class ColocationGraph {
std::vector<Device*>** possible_devices) {
*possible_devices = nullptr;
const int node_root = FindRoot(node->id());
printf("node_root: %d\n", node_root);
if (!members_[node_root].possible_devices.empty()) {
*possible_devices = &members_[node_root].possible_devices;
return Status::OK();
Expand Down
78 changes: 78 additions & 0 deletions tensorflow/core/common_runtime/threadpiscine_device.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/core/common_runtime/threadpiscine_device.h"

#include "tensorflow/core/common_runtime/local_device.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/allocator_registry.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.pb_text.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/types.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/tracing.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session_options.h"

namespace tensorflow {

ThreadPiscineDevice::ThreadPiscineDevice(const SessionOptions& options,
const string& name, Bytes memory_limit,
const DeviceLocality& locality,
Allocator* allocator)
: LocalDevice(options, Device::BuildDeviceAttributes(
name, DEVICE_FPGA, memory_limit, locality)),
allocator_(allocator) {}

ThreadPiscineDevice::~ThreadPiscineDevice() {}

void ThreadPiscineDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
// When TraceMe profiling is off (which is the default), the
// following TraceMe constructor is simply a conditional test of
// false value. Measurements show that its overhead is negligible.
port::Tracing::TraceMe trace_me(op_kernel->name(), op_kernel->type_string(),
op_kernel->IsExpensive());
if (port::Tracing::IsActive()) {
// TODO(pbar) We really need a useful identifier of the graph node.
const uint64 id = Hash64(op_kernel->name());
port::Tracing::ScopedActivity region(port::Tracing::EventCategory::kCompute,
id);
op_kernel->Compute(context);
} else {
op_kernel->Compute(context);
}
}

Allocator* ThreadPiscineDevice::GetAllocator(AllocatorAttributes attr) {
return allocator_;
}

Status ThreadPiscineDevice::MakeTensorFromProto(
const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs,
Tensor* tensor) {
if (tensor_proto.dtype() > 0 && tensor_proto.dtype() <= DataType_MAX) {
Tensor parsed(tensor_proto.dtype());
if (parsed.FromProto(cpu_allocator(), tensor_proto)) {
*tensor = std::move(parsed);
return Status::OK();
}
}
return errors::InvalidArgument("Cannot parse tensor from proto: ",
ProtoDebugString(tensor_proto));
}

} // namespace tensorflow
46 changes: 46 additions & 0 deletions tensorflow/core/common_runtime/threadpiscine_device.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_COMMON_RUNTIME_THREADPISCINE_DEVICE_H_
#define TENSORFLOW_COMMON_RUNTIME_THREADPISCINE_DEVICE_H_

#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/local_device.h"

namespace tensorflow {

// CPU device implementation.
class ThreadPiscineDevice : public LocalDevice {
public:
ThreadPiscineDevice(const SessionOptions& options, const string& name,
Bytes memory_limit, const DeviceLocality& locality,
Allocator* allocator);
~ThreadPiscineDevice() override;

void Compute(OpKernel* op_kernel, OpKernelContext* context) override;
Allocator* GetAllocator(AllocatorAttributes attr) override;
Status MakeTensorFromProto(const TensorProto& tensor_proto,
const AllocatorAttributes alloc_attrs,
Tensor* tensor) override;

Status Sync() override { return Status::OK(); }

private:
Allocator* allocator_; // Not owned
};

} // namespace tensorflow

#endif // TENSORFLOW_COMMON_RUNTIME_THREADPOOL_DEVICE_H_
50 changes: 50 additions & 0 deletions tensorflow/core/common_runtime/threadpiscine_device_factory.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// Register a factory that provides CPU devices.
#include "tensorflow/core/common_runtime/threadpiscine_device.h"

#include <vector>
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/public/session_options.h"

namespace tensorflow {

// TODO(zhifengc/tucker): Figure out the bytes of available RAM.
class ThreadPiscineDeviceFactory : public DeviceFactory {
public:
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<Device*>* devices) override {
// TODO(zhifengc/tucker): Figure out the number of available CPUs
// and/or NUMA configuration.
int n = 1;
auto iter = options.config.device_count().find("FPGA");
if (iter != options.config.device_count().end()) {
n = iter->second;
}
for (int i = 0; i < n; i++) {
string name = strings::StrCat(name_prefix, "/device:FPGA:", i);
devices->push_back(new ThreadPiscineDevice(
options, name, Bytes(256 << 20), DeviceLocality(), cpu_allocator()));
}

return Status::OK();
}
};

REGISTER_LOCAL_DEVICE_FACTORY("FPGA", ThreadPiscineDeviceFactory, 60);

} // namespace tensorflow
2 changes: 2 additions & 0 deletions tensorflow/core/framework/op_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1009,7 +1009,9 @@ string KernelsRegisteredForOp(StringPiece op_name) {
for (const auto& key_registration : *GlobalKernelRegistryTyped()) {
const KernelDef& kernel_def(key_registration.second.def);
if (kernel_def.op() == op_name) {
printf("----------zhaohb--kernel_def.op:%s\n", kernel_def.op().c_str());
strings::StrAppend(&ret, " device='", kernel_def.device_type(), "'");
printf("-----------zhaohb--ret:%s\n", ret.c_str());
if (!kernel_def.label().empty()) {
strings::StrAppend(&ret, "; label='", kernel_def.label(), "'");
}
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/framework/types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ std::ostream& operator<<(std::ostream& os, const DeviceType& d) {

const char* const DEVICE_CPU = "CPU";
const char* const DEVICE_GPU = "GPU";
const char* const DEVICE_FPGA = "FPGA";
const char* const DEVICE_SYCL = "SYCL";

const std::string DeviceName<Eigen::ThreadPoolDevice>::value = DEVICE_CPU;
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/framework/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ std::ostream& operator<<(std::ostream& os, const DeviceType& d);
// Convenient constants that can be passed to a DeviceType constructor
TF_EXPORT extern const char* const DEVICE_CPU; // "CPU"
TF_EXPORT extern const char* const DEVICE_GPU; // "GPU"
TF_EXPORT extern const char* const DEVICE_FPGA; // "FPGA"
TF_EXPORT extern const char* const DEVICE_SYCL; // "SYCL"

template <typename Device>
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/kernels/constant_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ void ConstantOp::Compute(OpKernelContext* ctx) {
ConstantOp::~ConstantOp() {}

REGISTER_KERNEL_BUILDER(Name("Const").Device(DEVICE_CPU), ConstantOp);
REGISTER_KERNEL_BUILDER(Name("Const").Device(DEVICE_FPGA), ConstantOp);

#if GOOGLE_CUDA
#define REGISTER_KERNEL(D, TYPE) \
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/kernels/no_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ namespace tensorflow {

REGISTER_KERNEL_BUILDER(Name("NoOp").Device(DEVICE_CPU), NoOp);
REGISTER_KERNEL_BUILDER(Name("NoOp").Device(DEVICE_GPU), NoOp);
REGISTER_KERNEL_BUILDER(Name("NoOp").Device(DEVICE_FPGA), NoOp);

#if TENSORFLOW_USE_SYCL
REGISTER_KERNEL_BUILDER(Name("NoOp").Device(DEVICE_SYCL), NoOp);
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/core/kernels/sendrecv_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ void SendOp::Compute(OpKernelContext* ctx) {

REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_CPU), SendOp);
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_GPU), SendOp);
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_FPGA), SendOp);

#ifdef TENSORFLOW_USE_SYCL
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_SYCL), SendOp);
Expand All @@ -117,6 +118,7 @@ REGISTER_KERNEL_BUILDER(
#endif // TENSORFLOW_USE_SYCL

REGISTER_KERNEL_BUILDER(Name("_HostSend").Device(DEVICE_CPU), SendOp);
REGISTER_KERNEL_BUILDER(Name("_HostSend").Device(DEVICE_FPGA), SendOp);
REGISTER_KERNEL_BUILDER(
Name("_HostSend").Device(DEVICE_GPU).HostMemory("tensor"), SendOp);

Expand Down Expand Up @@ -195,12 +197,14 @@ void RecvOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {

REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_CPU), RecvOp);
REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_GPU), RecvOp);
REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_FPGA), RecvOp);

#ifdef TENSORFLOW_USE_SYCL
REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_SYCL), RecvOp);
#endif // TENSORFLOW_USE_SYCL

REGISTER_KERNEL_BUILDER(Name("_HostRecv").Device(DEVICE_CPU), RecvOp);
REGISTER_KERNEL_BUILDER(Name("_HostRecv").Device(DEVICE_FPGA), RecvOp);
REGISTER_KERNEL_BUILDER(
Name("_HostRecv").Device(DEVICE_GPU).HostMemory("tensor"), RecvOp);

Expand Down

0 comments on commit 7456ddb

Please sign in to comment.