Skip to content

Commit

Permalink
[OpenCL] Implementation improvements (#9117)
Browse files Browse the repository at this point in the history
* OpenCL Improvements

* Registers Scatter and ScatterNd Ops for SYCL

* Registers Stack op for SYCL

* Fixes No sycl buffer found error for debug ops

* Registers MatMul and Transpose Ops to SYCL device for double

* Extends analyzer_cli_test.py test to cover SYCL

* Fixes Transpose Op for double when on SYCL

* Bumps Eigen version to fix double precision issue on SYCL

* Extends SessionDebugTestBase to cover SYCL

* Register SYCL implementations for random ops

* Avoid functions that might not be defined on SYCL device (#51)

* Avoid functions that might not be defined on SYCL device

* Simplify by using Eigen math functions

* OpenCL improvements

 - Bumps Eigen Version
 - Refactors Ops registration
 - Introduces workaround for Const Op related to the difference between
   CUDA which uses pointers and OpenCL that uses buffers/accessors
 - Extends memory types to cover DEVICE_SYCL as well
 - Introduces  GetSYCLDevice() method that returns list of supported devices
   with GPU device having the highest priority ( doesn't include blacklisted devices )
 - ::internal::Transpose -> tensorflow::internal::Transpose in order to
   avoid compilation reported error
 - re-introduces fix for bugged string replacement causing a lot of compilation
   warnings -c -> --include
 - Adds sycl_runtime to bazels ARRAY_DEPS
 - Replicates TF_CALL_GPU_PROXY_TYPES for SYCL

* [OpenCL] Fixes an issue caused by switch to aligned allocator for sycl buffer (#53)

* [Build] Use gcc/g++ as a host compiler to avoid #8394 (#54)

* [OpenCL] Fixes Scatter Op

* Fix testSimple and testConst in stack_op_test (#3)

* Fix testSimple and testConst in stack_op_test

* Create a specialisation of DoParallelConcatUpdate for SyclDevice and
register it

* Guard all code in TENSORFLOW_USE_SYCL

* Do not use sycl device for int32

* Registration of the Sycl version is now looking like the one for the GPU

* Remove added empty line

* Register batch normalization kernels for OpenCL (#61)

* [OpenCL] RandomGamma has no GPU friendly implementation (#57)

* [OpenCL] Compatibility fixes for TensorFlow 1.1.0-rc1

* [OpenCL] Implements BatchMatmul Op for SYCL

* Lowercase the device name when GPU or SYCL returned

* [OpenCL] kernel_estimator_test.py assertEqual-> assertAlmostEqual due to floating point representation on the device

* [Eigen] Version bump

* GPU device name string manipulation is not needed anymore

* [OpenCL] Adds SYCL to device backwards compatibility

* [OpenCL] Extends core_rnn_test.py to run for SYCL device

* [OpenCL] Minor optimizations for build script

* [OpenCL] Enables skip folder list in build script

* [OpenCL] Fixes ApplyAdamOp for Sycl device

* [OpenCL] SYCL device improvements

* [OpenCL] Fixes debug_ops's SEGFAULT for SYCL device

* [Build] Adds hexagon to skipped folders list

* [OpenCL] Removes EnterLameDuckMode from SYCL device and allocator

* [OpenCL] Registers Unique Op for SYCL device

* [OpenCL][Temporary] Disables tests for SYCL target due to features not being implemented yet

  Tests affected:
    - tensorflow/contrib/memory_stats/python/kernel_tests/memory_stats_ops_test.py
    - tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
    - tensorflow/python/kernel_tests/conv_ops_test.py
    - tensorflow/python/kernel_tests/depthwise_conv_op_test.py
    - tensorflow/python/kernel_tests/pooling_ops_3d_test.py
    - tensorflow/python/kernel_tests/pooling_ops_test.py
    - tensorflow/python/kernel_tests/scatter_nd_ops_test.py
    - tensorflow/python/training/adam_test.py
    - tensorflow/python/training/localhost_cluster_performance_test.py
    - tensorflow/python/training/training_ops_test.py

* [OpenCL][Temporary] Disables failing tests for SYCL in order to establish regression baseline

  Tests affected:
    - tensorflow/python/debug/cli/analyzer_cli_test.py
    - tensorflow/python/debug/lib/session_debug_testlib.py
    - tensorflow/python/debug/lib/stepper_test.py
    - tensorflow/python/kernel_tests/unstack_op_test.py
    - tensorflow/python/ops/image_ops_test.py

* [OpenCL] Take options.config.device_count() into consideration

* [OpenCL] Fixes compilation warning

* [OpenCL] device:SYCL:0 -> sycl:0

* [OpenCL] Removes unwanted flags in building script

Removes flags given to computecpp that enable SIMD instructions
Removes duplicate flags

* bool -> const bool

* [OpenCL] sycl in test_util.gpu_device_name() -> is_sycl_enabled()

* [OpenCL][Temporary] Disables failing tests for SYCL in order to establish regression baseline

  Test affected:
    - tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py

* Imports test_util from tensorflow.python.framework

* [OpenCL] Fixes formatting in Python code

* [OpenCL] Extends session_test.py to cover SYCL device

* [OpenCL] Cleans singleton class

* [OpenCL] Keeping CUDA happy

* [OpenCL][Temporary] Disables failing tests for SYCL in order to establish regression baseline

  Test affected:
   - tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
   - tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py

* Added support for building with SYCL on ARM.

* Acts on the review feedback from:
 - #9117 (comment)
 - #9117 (comment)

* [OpenCL] Fixes scatter_nd_op_test

* Fixes auto-merge mistake

* [OpenCL] struct SyclDevice -> class SyclDevice

* Revert "[OpenCL] struct SyclDevice -> class SyclDevice"

This reverts commit addd433.

* [OpenCL] Reverting refactoring commit.

  As requested in the review #9117 (comment)
  This change set will be re-introduced in smaller chunks.

* Revert "[OpenCL] device:SYCL:0 -> sycl:0"

This reverts commit cf16e60.

* Revert "[OpenCL] Adds SYCL to device backwards compatibility"

This reverts commit b8401b5.

* Acts on the feedback from #9117 (comment)

* control_flow_ops_py_test.py expects device name to be lower cased

* Acts on the feedback from #9117 (comment)

* Removes debug print

* Removes not needed partial specialisation

* [OpenCL] Registers ScatterNdFunctor for SYCL device

* [OpenCL] Make it compile

* [OpenCL] Follow gpu_device changes

* [OpenCL] Adds cxx_builtin_include_directory for python lib

  Fixes bazels missing undeclared inclusions that appeared after
  merge with TensorFlow upstream

* [OpenCL] Fixes Constant Op

* [OpenCL] gXX-4.8 -> gXX

* [OpenCL] Removes -D_GLIBCXX_USE_CXX11_ABI=0 as it breaks default compiler setup for Ubuntu 16.04

* Revert "[OpenCL] kernel_estimator_test.py assertEqual-> assertAlmostEqual due to floating point representation on the device"

This reverts commit 06c50c0.

* [OpenCL] CPU allocator is a singleton we should not delete it
  • Loading branch information
Luke Iwanski authored and benoitsteiner committed May 31, 2017
1 parent a365082 commit fe589d9
Show file tree
Hide file tree
Showing 43 changed files with 1,505 additions and 242 deletions.
4 changes: 2 additions & 2 deletions configure
Expand Up @@ -686,7 +686,7 @@ if [ "$TF_NEED_OPENCL" == "1" ]; then
while true; do
fromuser=""
if [ -z "$HOST_CXX_COMPILER" ]; then
default_cxx_host_compiler=$(which clang++-3.6 || true)
default_cxx_host_compiler=$(which g++ || true)
read -p "Please specify which C++ compiler should be used as the host C++ compiler. [Default is $default_cxx_host_compiler]: " HOST_CXX_COMPILER
fromuser="1"
if [ -z "$HOST_CXX_COMPILER" ]; then
Expand All @@ -710,7 +710,7 @@ done
while true; do
fromuser=""
if [ -z "$HOST_C_COMPILER" ]; then
default_c_host_compiler=$(which clang-3.6 || true)
default_c_host_compiler=$(which gcc || true)
read -p "Please specify which C compiler should be used as the host C compiler. [Default is $default_c_host_compiler]: " HOST_C_COMPILER
fromuser="1"
if [ -z "$HOST_C_COMPILER" ]; then
Expand Down
19 changes: 12 additions & 7 deletions tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
Expand Up @@ -42,7 +42,7 @@
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
from tensorflow.python.util import nest

from tensorflow.python.framework import test_util

class Plus1RNNCell(rnn_lib.RNNCell):
"""RNN Cell generating (output, new_state) = (input + 1, state + 1)."""
Expand Down Expand Up @@ -2206,9 +2206,10 @@ def testRNNOnCPUCellOnGPU(self):
return # Test requires access to a GPU

run_metadata = self._execute_rnn_on(
rnn_device="/cpu:0", cell_device="/gpu:0")
rnn_device="/cpu:0", cell_device=test_util.gpu_device_name())
step_stats = run_metadata.step_stats
ix = 0 if "gpu" in step_stats.dev_stats[0].device else 1
ix = 0 if (("gpu" in step_stats.dev_stats[0].device) or
("sycl" in step_stats.dev_stats[0].device)) else 1
gpu_stats = step_stats.dev_stats[ix].node_stats
cpu_stats = step_stats.dev_stats[1 - ix].node_stats

Expand All @@ -2230,9 +2231,11 @@ def testRNNOnCPUCellOnCPU(self):
return # Test requires access to a GPU

run_metadata = self._execute_rnn_on(
rnn_device="/cpu:0", cell_device="/cpu:0", input_device="/gpu:0")
rnn_device="/cpu:0", cell_device="/cpu:0",
input_device=test_util.gpu_device_name())
step_stats = run_metadata.step_stats
ix = 0 if "gpu" in step_stats.dev_stats[0].device else 1
ix = 0 if (("gpu" in step_stats.dev_stats[0].device) or
("sycl" in step_stats.dev_stats[0].device)) else 1
gpu_stats = step_stats.dev_stats[ix].node_stats
cpu_stats = step_stats.dev_stats[1 - ix].node_stats

Expand All @@ -2247,9 +2250,11 @@ def testInputOnGPUCellNotDeclared(self):
if not test.is_gpu_available():
return # Test requires access to a GPU

run_metadata = self._execute_rnn_on(input_device="/gpu:0")
run_metadata = self._execute_rnn_on(
input_device=test_util.gpu_device_name())
step_stats = run_metadata.step_stats
ix = 0 if "gpu" in step_stats.dev_stats[0].device else 1
ix = 0 if (("gpu" in step_stats.dev_stats[0].device) or
("sycl" in step_stats.dev_stats[0].device)) else 1
gpu_stats = step_stats.dev_stats[ix].node_stats
cpu_stats = step_stats.dev_stats[1 - ix].node_stats

Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/BUILD
Expand Up @@ -1836,6 +1836,7 @@ cc_library(
hdrs = if_not_windows([
"common_runtime/sycl/sycl_allocator.h",
"common_runtime/sycl/sycl_device.h",
"common_runtime/sycl/sycl_util.h",
"common_runtime/sycl/sycl_device_context.h",
]),
copts = tf_copts(),
Expand Down
3 changes: 1 addition & 2 deletions tensorflow/core/common_runtime/direct_session_test.cc
Expand Up @@ -877,8 +877,6 @@ class BlockingOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("BlockingOp").Device(DEVICE_CPU), BlockingOp);
REGISTER_OP("BlockingOp").Input("x: float").Output("y: float").Doc("");

REGISTER_KERNEL_BUILDER(Name("BlockingOp").Device(DEVICE_SYCL), BlockingOp);

static void TestSessionInterOpThreadsImpl(bool use_function_lib) {
FunctionDefLibrary library_graph_def;
if (use_function_lib) {
Expand Down Expand Up @@ -916,6 +914,7 @@ static void TestSessionInterOpThreadsImpl(bool use_function_lib) {
->set_opt_level(OptimizerOptions_Level_L0);
(*options.config.mutable_device_count())["CPU"] = 2;
(*options.config.mutable_device_count())["GPU"] = 0;
(*options.config.mutable_device_count())["SYCL"] = 0;

options.config.add_session_inter_op_thread_pool();
auto* p = options.config.add_session_inter_op_thread_pool();
Expand Down
Expand Up @@ -155,10 +155,16 @@ static void TestHWAccelerator(bool enableHWTrace) {
test::FillValues<float>(&x_tensor, {1, 1});
Node* x = test::graph::Constant(&graph, x_tensor);
x->set_assigned_device_name("/job:localhost/replica:0/task:0/gpu:0");
#ifdef TENSORFLOW_USE_SYCL
x->set_assigned_device_name("/job:localhost/replica:0/task:0/device:SYCL:0");
#endif // TENSORFLOW_USE_SYCL

// y = A * x
Node* y = test::graph::Matmul(&graph, a, x, false, false);
y->set_assigned_device_name("/job:localhost/replica:0/task:0/gpu:0");
#ifdef TENSORFLOW_USE_SYCL
y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:SYCL:0");
#endif // TENSORFLOW_USE_SYCL

Node* y_neg = test::graph::Unary(&graph, "Neg", y);
y_neg->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
Expand All @@ -169,6 +175,9 @@ static void TestHWAccelerator(bool enableHWTrace) {
SessionOptions options;
(*options.config.mutable_device_count())["CPU"] = 1;
(*options.config.mutable_device_count())["GPU"] = 1;
#ifdef TENSORFLOW_USE_SYCL
(*options.config.mutable_device_count())["SYCL"] = 1;
#endif // TENSORFLOW_USE_SYCL
options.config.set_allow_soft_placement(true);
options.config.mutable_graph_options()->set_build_cost_model(1);
std::unique_ptr<Session> session(NewSession(options));
Expand Down
6 changes: 3 additions & 3 deletions tensorflow/core/common_runtime/memory_types.cc
Expand Up @@ -47,12 +47,12 @@ struct EndpointEq {
static Status ProcessMemoryTypes(
const DeviceType& device_type, const Graph* g,
const std::function<Status(const Edge*, MemoryType, MemoryType)>& fn) {
if (device_type != DEVICE_GPU) {
// On non-GPU devices, HOST_MEMORY and DEVICE_MEMORY are always
if (device_type != DEVICE_GPU && device_type != DEVICE_SYCL ) {
// On non-GPU and non-SYCL devices, HOST_MEMORY and DEVICE_MEMORY are always
// compatible.
return Status::OK();
}
// For GPU device, HOST_MEMORY and DEVICE_MEMORY is not
// For GPU and SYCL device, HOST_MEMORY and DEVICE_MEMORY is not
// compatible. I.e., a conversion/transfer must be done.
//
// {node id, slot id} -> memory type.
Expand Down
18 changes: 18 additions & 0 deletions tensorflow/core/common_runtime/memory_types_test.cc
Expand Up @@ -34,6 +34,9 @@ TEST(MemoryTypeChecker, Int32OK) {
// There is a kernel for adding two int32s on host memory.
TF_EXPECT_OK(ValidateMemoryTypes(DEVICE_GPU, g));
#endif // GOOGLE_CUDA
#ifdef TENSORFLOW_USE_SYCL
TF_EXPECT_OK(ValidateMemoryTypes(DEVICE_SYCL, g));
#endif // TENSORFLOW_USE_SYCL
delete g;
}

Expand All @@ -53,6 +56,15 @@ TEST(MemoryTypeChecker, Int32NotOk) {
TF_EXPECT_OK(EnsureMemoryTypes(DEVICE_GPU, "/gpu:0", g));
TF_EXPECT_OK(ValidateMemoryTypes(DEVICE_GPU, g));
#endif // GOOGLE_CUDA
#ifdef TENSORFLOW_USE_SYCL
// There is no kernel for casting int32/host memory to float/device
// memory.
EXPECT_TRUE(errors::IsInternal(ValidateMemoryTypes(DEVICE_SYCL, g)));

// But we can insert _HostSend/_HostRecv to ensure the invariant.
TF_EXPECT_OK(EnsureMemoryTypes(DEVICE_SYCL, "/device:SYCL:0", g));
TF_EXPECT_OK(ValidateMemoryTypes(DEVICE_SYCL, g));
#endif // TENSORFLOW_USE_SYCL
delete g;
}

Expand All @@ -74,6 +86,12 @@ TEST(MemoryTypeChecker, MemoryTypeForOutput) {
// int Switch's output on GPU has HOST_MEMORY constraint.
EXPECT_EQ(memory_type, HOST_MEMORY);
#endif // GOOGLE_CUDA
#ifdef TENSORFLOW_USE_SYCL
auto si = test::graph::Switch(g, test::graph::Constant(g, vi), pred);
TF_EXPECT_OK(MemoryTypeForOutput(DEVICE_SYCL, g, si, 0, &memory_type));
// int Switch's output on GPU has HOST_MEMORY constraint.
EXPECT_EQ(memory_type, HOST_MEMORY);
#endif // TENSORFLOW_USE_SYCL
delete g;
}

Expand Down
23 changes: 10 additions & 13 deletions tensorflow/core/common_runtime/sycl/sycl_allocator.cc
Expand Up @@ -19,29 +19,26 @@ limitations under the License.

namespace tensorflow {

SYCLAllocator::~SYCLAllocator() {}
SYCLAllocator::~SYCLAllocator() {
if(sycl_device_) {
delete sycl_device_;
}
}

string SYCLAllocator::Name() { return "device:SYCL"; }

void *SYCLAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
assert(device_);
assert(sycl_device_);
if (num_bytes == 0) {
return device_->allocate(1);
return sycl_device_->allocate(1);
}
auto p = device_->allocate(num_bytes);
auto p = sycl_device_->allocate(num_bytes);
return p;
}

void SYCLAllocator::DeallocateRaw(void *ptr) {
if (device_) {
device_->deallocate(ptr);
}
}

void SYCLAllocator::EnterLameDuckMode() {
if (device_) {
device_->deallocate_all();
device_ = nullptr;
if (sycl_device_) {
sycl_device_->deallocate(ptr);
}
}

Expand Down
10 changes: 6 additions & 4 deletions tensorflow/core/common_runtime/sycl/sycl_allocator.h
Expand Up @@ -28,17 +28,19 @@ namespace tensorflow {

class SYCLAllocator : public Allocator {
public:
SYCLAllocator(Eigen::QueueInterface *device) : device_(device) {}
SYCLAllocator(Eigen::QueueInterface *queue) : sycl_device_(new Eigen::SyclDevice(queue)) {}
virtual ~SYCLAllocator() override;
string Name() override;
void *AllocateRaw(size_t alignment, size_t num_bytes) override;
void DeallocateRaw(void *ptr) override;

void EnterLameDuckMode();
virtual bool ShouldAllocateEmptyTensors() override final { return true; }

void Synchronize() { sycl_device_->synchronize(); }
bool Ok() { return sycl_device_->ok(); }
Eigen::SyclDevice* getSyclDevice() { return sycl_device_; }
private:
Eigen::QueueInterface *device_; // not owned
Eigen::SyclDevice *sycl_device_; // owned

TF_DISALLOW_COPY_AND_ASSIGN(SYCLAllocator);
};

Expand Down
58 changes: 19 additions & 39 deletions tensorflow/core/common_runtime/sycl/sycl_device.cc
Expand Up @@ -22,50 +22,18 @@ limitations under the License.
#include "tensorflow/core/platform/tracing.h"

namespace tensorflow {

static std::unordered_set<SYCLDevice *> live_devices;
static bool first_time = true;
std::mutex GSYCLInterface::mutex_;
GSYCLInterface *GSYCLInterface::s_instance = 0;

void ShutdownSycl() {
for (auto device : live_devices) {
device->EnterLameDuckMode();
}
live_devices.clear();
GSYCLInterface::Reset();
}

void SYCLDevice::RegisterDevice() {
if (first_time) {
first_time = false;
atexit(ShutdownSycl);
}
live_devices.insert(this);
}

SYCLDevice::~SYCLDevice() {
device_context_->Unref();
sycl_allocator_->EnterLameDuckMode();
if (sycl_device_) {
sycl_device_->synchronize();
delete sycl_device_;
}
if (sycl_queue_) {
delete sycl_queue_;
}
live_devices.erase(this);
}

void SYCLDevice::EnterLameDuckMode() {
sycl_allocator_->EnterLameDuckMode();
if (sycl_device_) {
sycl_device_->synchronize();
delete sycl_device_;
sycl_device_ = nullptr;
}
if (sycl_queue_) {
delete sycl_queue_;
sycl_queue_ = nullptr;
}
}
SYCLDevice::~SYCLDevice() {}

void SYCLDevice::Compute(OpKernel *op_kernel, OpKernelContext *context) {
assert(context);
Expand All @@ -88,8 +56,12 @@ Allocator *SYCLDevice::GetAllocator(AllocatorAttributes attr) {
Status SYCLDevice::MakeTensorFromProto(const TensorProto &tensor_proto,
const AllocatorAttributes alloc_attrs,
Tensor *tensor) {
AllocatorAttributes attr;
attr.set_on_host(true);
Allocator* host_alloc = GetAllocator(attr);

Tensor parsed(tensor_proto.dtype());
if (!parsed.FromProto(cpu_allocator_, tensor_proto)) {
if (!parsed.FromProto(host_alloc, tensor_proto)) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
tensor_proto.DebugString());
}
Expand All @@ -98,6 +70,14 @@ Status SYCLDevice::MakeTensorFromProto(const TensorProto &tensor_proto,
*tensor = parsed;
} else {
Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape());

// If the tensor is not initialized, we likely ran out of memory.
if (!copy.IsInitialized()) {
return errors::ResourceExhausted(
"OOM when allocating tensor of shape ", parsed.shape().DebugString(),
" and type ", DataTypeString(parsed.dtype()));
}

device_context_->CopyCPUTensorToDevice(
&parsed, this, &copy, [&status](const Status &s) { status = s; });
*tensor = copy;
Expand All @@ -119,8 +99,8 @@ Status SYCLDevice::FillContextMap(const Graph *graph,
}

Status SYCLDevice::Sync() {
sycl_device_->synchronize();
if (sycl_device_->ok()) {
sycl_allocator_->Synchronize();
if (sycl_allocator_->Ok()) {
return Status::OK();
} else {
return errors::Internal("Unknown error detected on device ", name());
Expand Down

0 comments on commit fe589d9

Please sign in to comment.