diff --git a/configure b/configure index de2cb07df2698f..8133b115ec702a 100755 --- a/configure +++ b/configure @@ -686,7 +686,7 @@ if [ "$TF_NEED_OPENCL" == "1" ]; then while true; do fromuser="" if [ -z "$HOST_CXX_COMPILER" ]; then - default_cxx_host_compiler=$(which clang++-3.6 || true) + default_cxx_host_compiler=$(which g++ || true) read -p "Please specify which C++ compiler should be used as the host C++ compiler. [Default is $default_cxx_host_compiler]: " HOST_CXX_COMPILER fromuser="1" if [ -z "$HOST_CXX_COMPILER" ]; then @@ -710,7 +710,7 @@ done while true; do fromuser="" if [ -z "$HOST_C_COMPILER" ]; then - default_c_host_compiler=$(which clang-3.6 || true) + default_c_host_compiler=$(which gcc || true) read -p "Please specify which C compiler should be used as the host C compiler. [Default is $default_c_host_compiler]: " HOST_C_COMPILER fromuser="1" if [ -z "$HOST_C_COMPILER" ]; then diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py index bf24347c43b28d..112dacc9abb640 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py @@ -42,7 +42,7 @@ from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging from tensorflow.python.util import nest - +from tensorflow.python.framework import test_util class Plus1RNNCell(rnn_lib.RNNCell): """RNN Cell generating (output, new_state) = (input + 1, state + 1).""" @@ -2206,9 +2206,10 @@ def testRNNOnCPUCellOnGPU(self): return # Test requires access to a GPU run_metadata = self._execute_rnn_on( - rnn_device="/cpu:0", cell_device="/gpu:0") + rnn_device="/cpu:0", cell_device=test_util.gpu_device_name()) step_stats = run_metadata.step_stats - ix = 0 if "gpu" in step_stats.dev_stats[0].device else 1 + ix = 0 if (("gpu" in step_stats.dev_stats[0].device) or + ("sycl" in step_stats.dev_stats[0].device)) else 1 gpu_stats = step_stats.dev_stats[ix].node_stats cpu_stats = step_stats.dev_stats[1 - ix].node_stats @@ -2230,9 +2231,11 @@ def testRNNOnCPUCellOnCPU(self): return # Test requires access to a GPU run_metadata = self._execute_rnn_on( - rnn_device="/cpu:0", cell_device="/cpu:0", input_device="/gpu:0") + rnn_device="/cpu:0", cell_device="/cpu:0", + input_device=test_util.gpu_device_name()) step_stats = run_metadata.step_stats - ix = 0 if "gpu" in step_stats.dev_stats[0].device else 1 + ix = 0 if (("gpu" in step_stats.dev_stats[0].device) or + ("sycl" in step_stats.dev_stats[0].device)) else 1 gpu_stats = step_stats.dev_stats[ix].node_stats cpu_stats = step_stats.dev_stats[1 - ix].node_stats @@ -2247,9 +2250,11 @@ def testInputOnGPUCellNotDeclared(self): if not test.is_gpu_available(): return # Test requires access to a GPU - run_metadata = self._execute_rnn_on(input_device="/gpu:0") + run_metadata = self._execute_rnn_on( + input_device=test_util.gpu_device_name()) step_stats = run_metadata.step_stats - ix = 0 if "gpu" in step_stats.dev_stats[0].device else 1 + ix = 0 if (("gpu" in step_stats.dev_stats[0].device) or + ("sycl" in step_stats.dev_stats[0].device)) else 1 gpu_stats = step_stats.dev_stats[ix].node_stats cpu_stats = step_stats.dev_stats[1 - ix].node_stats diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index c041a9285e8cb2..6930a72fe32601 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1836,6 +1836,7 @@ cc_library( hdrs = if_not_windows([ "common_runtime/sycl/sycl_allocator.h", "common_runtime/sycl/sycl_device.h", + "common_runtime/sycl/sycl_util.h", "common_runtime/sycl/sycl_device_context.h", ]), copts = tf_copts(), diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc index f8deaaf2229279..3d06ca0ae4da08 100644 --- a/tensorflow/core/common_runtime/direct_session_test.cc +++ b/tensorflow/core/common_runtime/direct_session_test.cc @@ -877,8 +877,6 @@ class BlockingOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("BlockingOp").Device(DEVICE_CPU), BlockingOp); REGISTER_OP("BlockingOp").Input("x: float").Output("y: float").Doc(""); -REGISTER_KERNEL_BUILDER(Name("BlockingOp").Device(DEVICE_SYCL), BlockingOp); - static void TestSessionInterOpThreadsImpl(bool use_function_lib) { FunctionDefLibrary library_graph_def; if (use_function_lib) { @@ -916,6 +914,7 @@ static void TestSessionInterOpThreadsImpl(bool use_function_lib) { ->set_opt_level(OptimizerOptions_Level_L0); (*options.config.mutable_device_count())["CPU"] = 2; (*options.config.mutable_device_count())["GPU"] = 0; + (*options.config.mutable_device_count())["SYCL"] = 0; options.config.add_session_inter_op_thread_pool(); auto* p = options.config.add_session_inter_op_thread_pool(); diff --git a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc index 6f92cd09d3b21d..0cfc28949407b6 100644 --- a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc +++ b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc @@ -155,10 +155,16 @@ static void TestHWAccelerator(bool enableHWTrace) { test::FillValues(&x_tensor, {1, 1}); Node* x = test::graph::Constant(&graph, x_tensor); x->set_assigned_device_name("/job:localhost/replica:0/task:0/gpu:0"); +#ifdef TENSORFLOW_USE_SYCL + x->set_assigned_device_name("/job:localhost/replica:0/task:0/device:SYCL:0"); +#endif // TENSORFLOW_USE_SYCL // y = A * x Node* y = test::graph::Matmul(&graph, a, x, false, false); y->set_assigned_device_name("/job:localhost/replica:0/task:0/gpu:0"); +#ifdef TENSORFLOW_USE_SYCL +y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:SYCL:0"); +#endif // TENSORFLOW_USE_SYCL Node* y_neg = test::graph::Unary(&graph, "Neg", y); y_neg->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0"); @@ -169,6 +175,9 @@ static void TestHWAccelerator(bool enableHWTrace) { SessionOptions options; (*options.config.mutable_device_count())["CPU"] = 1; (*options.config.mutable_device_count())["GPU"] = 1; +#ifdef TENSORFLOW_USE_SYCL + (*options.config.mutable_device_count())["SYCL"] = 1; +#endif // TENSORFLOW_USE_SYCL options.config.set_allow_soft_placement(true); options.config.mutable_graph_options()->set_build_cost_model(1); std::unique_ptr session(NewSession(options)); diff --git a/tensorflow/core/common_runtime/memory_types.cc b/tensorflow/core/common_runtime/memory_types.cc index db053dd2fa0724..21ed73df77da46 100644 --- a/tensorflow/core/common_runtime/memory_types.cc +++ b/tensorflow/core/common_runtime/memory_types.cc @@ -47,12 +47,12 @@ struct EndpointEq { static Status ProcessMemoryTypes( const DeviceType& device_type, const Graph* g, const std::function& fn) { - if (device_type != DEVICE_GPU) { - // On non-GPU devices, HOST_MEMORY and DEVICE_MEMORY are always + if (device_type != DEVICE_GPU && device_type != DEVICE_SYCL ) { + // On non-GPU and non-SYCL devices, HOST_MEMORY and DEVICE_MEMORY are always // compatible. return Status::OK(); } - // For GPU device, HOST_MEMORY and DEVICE_MEMORY is not + // For GPU and SYCL device, HOST_MEMORY and DEVICE_MEMORY is not // compatible. I.e., a conversion/transfer must be done. // // {node id, slot id} -> memory type. diff --git a/tensorflow/core/common_runtime/memory_types_test.cc b/tensorflow/core/common_runtime/memory_types_test.cc index 06d7daea9cdc87..55eade0566cc92 100644 --- a/tensorflow/core/common_runtime/memory_types_test.cc +++ b/tensorflow/core/common_runtime/memory_types_test.cc @@ -34,6 +34,9 @@ TEST(MemoryTypeChecker, Int32OK) { // There is a kernel for adding two int32s on host memory. TF_EXPECT_OK(ValidateMemoryTypes(DEVICE_GPU, g)); #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL + TF_EXPECT_OK(ValidateMemoryTypes(DEVICE_SYCL, g)); +#endif // TENSORFLOW_USE_SYCL delete g; } @@ -53,6 +56,15 @@ TEST(MemoryTypeChecker, Int32NotOk) { TF_EXPECT_OK(EnsureMemoryTypes(DEVICE_GPU, "/gpu:0", g)); TF_EXPECT_OK(ValidateMemoryTypes(DEVICE_GPU, g)); #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL + // There is no kernel for casting int32/host memory to float/device + // memory. + EXPECT_TRUE(errors::IsInternal(ValidateMemoryTypes(DEVICE_SYCL, g))); + + // But we can insert _HostSend/_HostRecv to ensure the invariant. + TF_EXPECT_OK(EnsureMemoryTypes(DEVICE_SYCL, "/device:SYCL:0", g)); + TF_EXPECT_OK(ValidateMemoryTypes(DEVICE_SYCL, g)); +#endif // TENSORFLOW_USE_SYCL delete g; } @@ -74,6 +86,12 @@ TEST(MemoryTypeChecker, MemoryTypeForOutput) { // int Switch's output on GPU has HOST_MEMORY constraint. EXPECT_EQ(memory_type, HOST_MEMORY); #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL + auto si = test::graph::Switch(g, test::graph::Constant(g, vi), pred); + TF_EXPECT_OK(MemoryTypeForOutput(DEVICE_SYCL, g, si, 0, &memory_type)); + // int Switch's output on GPU has HOST_MEMORY constraint. + EXPECT_EQ(memory_type, HOST_MEMORY); +#endif // TENSORFLOW_USE_SYCL delete g; } diff --git a/tensorflow/core/common_runtime/sycl/sycl_allocator.cc b/tensorflow/core/common_runtime/sycl/sycl_allocator.cc index b7ef9361e95239..485e5397e89fd6 100644 --- a/tensorflow/core/common_runtime/sycl/sycl_allocator.cc +++ b/tensorflow/core/common_runtime/sycl/sycl_allocator.cc @@ -19,29 +19,26 @@ limitations under the License. namespace tensorflow { -SYCLAllocator::~SYCLAllocator() {} +SYCLAllocator::~SYCLAllocator() { + if(sycl_device_) { + delete sycl_device_; + } +} string SYCLAllocator::Name() { return "device:SYCL"; } void *SYCLAllocator::AllocateRaw(size_t alignment, size_t num_bytes) { - assert(device_); + assert(sycl_device_); if (num_bytes == 0) { - return device_->allocate(1); + return sycl_device_->allocate(1); } - auto p = device_->allocate(num_bytes); + auto p = sycl_device_->allocate(num_bytes); return p; } void SYCLAllocator::DeallocateRaw(void *ptr) { - if (device_) { - device_->deallocate(ptr); - } -} - -void SYCLAllocator::EnterLameDuckMode() { - if (device_) { - device_->deallocate_all(); - device_ = nullptr; + if (sycl_device_) { + sycl_device_->deallocate(ptr); } } diff --git a/tensorflow/core/common_runtime/sycl/sycl_allocator.h b/tensorflow/core/common_runtime/sycl/sycl_allocator.h index 15d9ab41a461ed..8668cba06af7f8 100644 --- a/tensorflow/core/common_runtime/sycl/sycl_allocator.h +++ b/tensorflow/core/common_runtime/sycl/sycl_allocator.h @@ -28,17 +28,19 @@ namespace tensorflow { class SYCLAllocator : public Allocator { public: - SYCLAllocator(Eigen::QueueInterface *device) : device_(device) {} + SYCLAllocator(Eigen::QueueInterface *queue) : sycl_device_(new Eigen::SyclDevice(queue)) {} virtual ~SYCLAllocator() override; string Name() override; void *AllocateRaw(size_t alignment, size_t num_bytes) override; void DeallocateRaw(void *ptr) override; - void EnterLameDuckMode(); virtual bool ShouldAllocateEmptyTensors() override final { return true; } - + void Synchronize() { sycl_device_->synchronize(); } + bool Ok() { return sycl_device_->ok(); } + Eigen::SyclDevice* getSyclDevice() { return sycl_device_; } private: - Eigen::QueueInterface *device_; // not owned + Eigen::SyclDevice *sycl_device_; // owned + TF_DISALLOW_COPY_AND_ASSIGN(SYCLAllocator); }; diff --git a/tensorflow/core/common_runtime/sycl/sycl_device.cc b/tensorflow/core/common_runtime/sycl/sycl_device.cc index 2c2185b2c03c65..17f5edd572581c 100644 --- a/tensorflow/core/common_runtime/sycl/sycl_device.cc +++ b/tensorflow/core/common_runtime/sycl/sycl_device.cc @@ -22,50 +22,18 @@ limitations under the License. #include "tensorflow/core/platform/tracing.h" namespace tensorflow { - -static std::unordered_set live_devices; -static bool first_time = true; +std::mutex GSYCLInterface::mutex_; +GSYCLInterface *GSYCLInterface::s_instance = 0; void ShutdownSycl() { - for (auto device : live_devices) { - device->EnterLameDuckMode(); - } - live_devices.clear(); + GSYCLInterface::Reset(); } void SYCLDevice::RegisterDevice() { - if (first_time) { - first_time = false; atexit(ShutdownSycl); - } - live_devices.insert(this); } -SYCLDevice::~SYCLDevice() { - device_context_->Unref(); - sycl_allocator_->EnterLameDuckMode(); - if (sycl_device_) { - sycl_device_->synchronize(); - delete sycl_device_; - } - if (sycl_queue_) { - delete sycl_queue_; - } - live_devices.erase(this); -} - -void SYCLDevice::EnterLameDuckMode() { - sycl_allocator_->EnterLameDuckMode(); - if (sycl_device_) { - sycl_device_->synchronize(); - delete sycl_device_; - sycl_device_ = nullptr; - } - if (sycl_queue_) { - delete sycl_queue_; - sycl_queue_ = nullptr; - } -} +SYCLDevice::~SYCLDevice() {} void SYCLDevice::Compute(OpKernel *op_kernel, OpKernelContext *context) { assert(context); @@ -88,8 +56,12 @@ Allocator *SYCLDevice::GetAllocator(AllocatorAttributes attr) { Status SYCLDevice::MakeTensorFromProto(const TensorProto &tensor_proto, const AllocatorAttributes alloc_attrs, Tensor *tensor) { + AllocatorAttributes attr; + attr.set_on_host(true); + Allocator* host_alloc = GetAllocator(attr); + Tensor parsed(tensor_proto.dtype()); - if (!parsed.FromProto(cpu_allocator_, tensor_proto)) { + if (!parsed.FromProto(host_alloc, tensor_proto)) { return errors::InvalidArgument("Cannot parse tensor from proto: ", tensor_proto.DebugString()); } @@ -98,6 +70,14 @@ Status SYCLDevice::MakeTensorFromProto(const TensorProto &tensor_proto, *tensor = parsed; } else { Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape()); + + // If the tensor is not initialized, we likely ran out of memory. + if (!copy.IsInitialized()) { + return errors::ResourceExhausted( + "OOM when allocating tensor of shape ", parsed.shape().DebugString(), + " and type ", DataTypeString(parsed.dtype())); + } + device_context_->CopyCPUTensorToDevice( &parsed, this, ©, [&status](const Status &s) { status = s; }); *tensor = copy; @@ -119,8 +99,8 @@ Status SYCLDevice::FillContextMap(const Graph *graph, } Status SYCLDevice::Sync() { - sycl_device_->synchronize(); - if (sycl_device_->ok()) { + sycl_allocator_->Synchronize(); + if (sycl_allocator_->Ok()) { return Status::OK(); } else { return errors::Internal("Unknown error detected on device ", name()); diff --git a/tensorflow/core/common_runtime/sycl/sycl_device.h b/tensorflow/core/common_runtime/sycl/sycl_device.h index a5c7c5f0ec7304..7a6f11924863b4 100644 --- a/tensorflow/core/common_runtime/sycl/sycl_device.h +++ b/tensorflow/core/common_runtime/sycl/sycl_device.h @@ -27,31 +27,156 @@ limitations under the License. namespace tensorflow { + +class GSYCLInterface +{ + std::vector m_queue_interface_; // owned + std::vector m_cpu_allocator_; // not owned + std::vector m_sycl_allocator_; // owned + std::vector m_sycl_context_; // owned + + static std::mutex mutex_; + static GSYCLInterface* s_instance; + GSYCLInterface() { + bool found_device =false; + auto device_list = Eigen::get_sycl_supported_devices(); + // Obtain list of supported devices from Eigen + for (const auto& device : device_list) { + if(device.is_gpu()) { + // returns first found GPU + AddDevice(device); + found_device = true; + } + } + + if(!found_device) { + // Currently Intel GPU is not supported + LOG(WARNING) << "No OpenCL GPU found that is supported by ComputeCpp, trying OpenCL CPU"; + } + + for (const auto& device : device_list) { + if(device.is_cpu()) { + // returns first found CPU + AddDevice(device); + found_device = true; + } + } + + if(!found_device) { + // Currently Intel GPU is not supported + LOG(FATAL) << "No OpenCL GPU nor CPU found that is supported by ComputeCpp"; + } + } + + ~GSYCLInterface() { + m_cpu_allocator_.clear(); + + for (auto p : m_sycl_allocator_) { + p->Synchronize(); + delete p; + } + m_sycl_allocator_.clear(); + + for(auto p : m_sycl_context_) { + p->Unref(); + } + m_sycl_context_.clear(); + + for (auto p : m_queue_interface_) { + p->deallocate_all(); + delete p; + p = nullptr; + } + m_queue_interface_.clear(); + } + + void AddDevice(const cl::sycl::device & d) { + m_queue_interface_.push_back(new Eigen::QueueInterface(d)); + m_cpu_allocator_.push_back(cpu_allocator()); + m_sycl_allocator_.push_back(new SYCLAllocator(m_queue_interface_.back())); + m_sycl_context_.push_back(new SYCLDeviceContext()); + } + + public: + static GSYCLInterface *instance() + { + std::lock_guard lock(mutex_); + if (!s_instance) { + s_instance = new GSYCLInterface(); + } + return s_instance; + } + + static void Reset() + { + std::lock_guard lock(mutex_); + if(s_instance) { + delete s_instance; + s_instance = NULL; + } + } + + Eigen::QueueInterface * GetQueueInterface(size_t i = 0) { + if(!m_queue_interface_.empty()) { + return m_queue_interface_[i]; + } else { + std::cerr << "No cl::sycl::device has been added" << std::endl; + return nullptr; + } + } + + SYCLAllocator * GetSYCLAllocator(size_t i = 0) { + if(!m_sycl_allocator_.empty()) { + return m_sycl_allocator_[i]; + } else { + std::cerr << "No cl::sycl::device has been added" << std::endl; + return nullptr; + } + } + + Allocator * GetCPUAllocator(size_t i = 0) { + if(!m_cpu_allocator_.empty()) { + return m_cpu_allocator_[i]; + } else { + std::cerr << "No cl::sycl::device has been added" << std::endl; + return nullptr; + } + } + + SYCLDeviceContext * GetSYCLContext(size_t i = 0) { + if(!m_sycl_context_.empty()) { + return m_sycl_context_[i]; + } else { + std::cerr << "No cl::sycl::device has been added" << std::endl; + return nullptr; + } + } + + string GetShortDeviceDescription(int device_id = 0) { + return strings::StrCat("device: ", device_id, " ,name: SYCL"); + } +}; + + class SYCLDevice : public LocalDevice { public: - template SYCLDevice(const SessionOptions &options, const string &name, Bytes memory_limit, const DeviceLocality &locality, - const string &physical_device_desc, SYCLSelector sycl_selector, - Allocator *cpu_allocator) + const string &physical_device_desc, SYCLAllocator * sycl_allocator, + Allocator *cpu_allocator, SYCLDeviceContext* ctx) : LocalDevice( options, Device::BuildDeviceAttributes(name, DEVICE_SYCL, memory_limit, - locality, physical_device_desc), - nullptr), + locality, physical_device_desc)), cpu_allocator_(cpu_allocator), - sycl_queue_(new Eigen::QueueInterface(sycl_selector)), - sycl_device_(new Eigen::SyclDevice(sycl_queue_)), - sycl_allocator_(new SYCLAllocator(sycl_queue_)), - device_context_(new SYCLDeviceContext()) { - set_eigen_sycl_device(sycl_device_); + sycl_allocator_(sycl_allocator), + device_context_(ctx) { RegisterDevice(); + set_eigen_sycl_device(sycl_allocator->getSyclDevice()); } ~SYCLDevice() override; - void EnterLameDuckMode(); - void Compute(OpKernel *op_kernel, OpKernelContext *context) override; Allocator *GetAllocator(AllocatorAttributes attr) override; Status MakeTensorFromProto(const TensorProto &tensor_proto, @@ -62,18 +187,12 @@ class SYCLDevice : public LocalDevice { DeviceContextMap *device_context_map) override; Status Sync() override; - static string GetShortDeviceDescription(/*int device_id, - const DeviceDescription& desc*/) { - return strings::StrCat("device: 0, name SYCL, pci bus id: 0"); - } private: void RegisterDevice(); - Allocator *cpu_allocator_; // owned - Eigen::QueueInterface *sycl_queue_; // owned - Eigen::SyclDevice *sycl_device_; // owned - SYCLAllocator *sycl_allocator_; // owned + Allocator *cpu_allocator_; // not owned + SYCLAllocator *sycl_allocator_; // not owned SYCLDeviceContext *device_context_; }; diff --git a/tensorflow/core/common_runtime/sycl/sycl_device_factory.cc b/tensorflow/core/common_runtime/sycl/sycl_device_factory.cc index a643fc72580889..19c14770dcad7a 100644 --- a/tensorflow/core/common_runtime/sycl/sycl_device_factory.cc +++ b/tensorflow/core/common_runtime/sycl/sycl_device_factory.cc @@ -18,24 +18,34 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/sycl/sycl_device.h" +#include "tensorflow/core/common_runtime/sycl/sycl_util.h" + namespace tensorflow { class SYCLDeviceFactory : public DeviceFactory { public: Status CreateDevices(const SessionOptions &options, const string &name_prefix, std::vector *devices) override { - int n = 1; + + auto syclInterface = GSYCLInterface::instance(); + + size_t n = 1; auto iter = options.config.device_count().find("SYCL"); if (iter != options.config.device_count().end()) { n = iter->second; } + for (int i = 0; i < n; i++) { string name = strings::StrCat(name_prefix, "/device:SYCL:", i); devices->push_back( - new SYCLDevice(options, name, Bytes(256 << 20), DeviceLocality(), - SYCLDevice::GetShortDeviceDescription(), - cl::sycl::gpu_selector(), cpu_allocator())); + new SYCLDevice(options, name, Bytes(256 << 20), DeviceLocality() + , syclInterface->GetShortDeviceDescription(i) + , syclInterface->GetSYCLAllocator(i) + , syclInterface->GetCPUAllocator(i) + , syclInterface->GetSYCLContext(i)) + ); } + return Status::OK(); } }; diff --git a/tensorflow/core/common_runtime/sycl/sycl_util.h b/tensorflow/core/common_runtime/sycl/sycl_util.h new file mode 100644 index 00000000000000..f58614c4ff9ccf --- /dev/null +++ b/tensorflow/core/common_runtime/sycl/sycl_util.h @@ -0,0 +1,37 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if !TENSORFLOW_USE_SYCL +#error This file must only be included when building TensorFlow with SYCL support +#endif + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_UTIL_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_UTIL_H_ + +#include "tensorflow/core/common_runtime/device.h" +// For DMA helper +#include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { + inline void* GetBase(const Tensor* src) { + return const_cast(DMAHelper::base(src)); + } + + inline void* GetBase(Tensor* dst) { return DMAHelper::base(dst); } + +} + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_UTIL_H_ diff --git a/tensorflow/core/debug/debug_gateway.cc b/tensorflow/core/debug/debug_gateway.cc index 1031ea843ed787..2aaed9563a6db9 100644 --- a/tensorflow/core/debug/debug_gateway.cc +++ b/tensorflow/core/debug/debug_gateway.cc @@ -86,7 +86,7 @@ void DebugGateway::CopyTensor(const string& node_name, const int output_slot, // Determine if the tensor is on device (GPU) or host (CPU). // The second part of the check is necessary because even an OpKernel on // may have output tensors allocated on CPU. - if (device->name().find("gpu:") != string::npos && + if ((device->name().find("gpu:") != string::npos || device->name().find("SYCL:") != string::npos) && !ctx->output_alloc_attr(output_slot).on_host()) { // GPU tensors: Copy it to host (CPU). DeviceContext* device_ctxt = ctx->op_device_context(); diff --git a/tensorflow/core/debug/debug_gateway_test.cc b/tensorflow/core/debug/debug_gateway_test.cc index 2911205db2ca43..adbb1b21162eb5 100644 --- a/tensorflow/core/debug/debug_gateway_test.cc +++ b/tensorflow/core/debug/debug_gateway_test.cc @@ -46,6 +46,8 @@ class SessionDebugMinusAXTest : public ::testing::Test { #if GOOGLE_CUDA const string kDeviceName = "/job:localhost/replica:0/task:0/gpu:0"; +#elif defined(TENSORFLOW_USE_SYCL) + const string kDeviceName = "/job:localhost/replica:0/task:0/device:SYCL:0"; #else const string kDeviceName = "/job:localhost/replica:0/task:0/cpu:0"; #endif @@ -303,6 +305,8 @@ TEST_F(SessionDebugMinusAXTest, RunSimpleNetworkWithTwoDebugNodesInserted) { // through RunMetadata, given whether GPU is involved. #if GOOGLE_CUDA ASSERT_EQ(2, run_metadata.partition_graphs().size()); +#elif defined(TENSORFLOW_USE_SYCL) + ASSERT_EQ(2, run_metadata.partition_graphs().size()); #else ASSERT_EQ(1, run_metadata.partition_graphs().size()); #endif @@ -337,7 +341,7 @@ TEST_F(SessionDebugMinusAXTest, RunSimpleNetworkWithTwoDebugNodesInserted) { ASSERT_EQ(1, debug_nan_count_tensor_vals[0].scalar()()); } -#ifndef GOOGLE_CUDA +#if !defined(GOOGLE_CUDA) && !defined(TENSORFLOW_USE_SYCL) // TODO(cais): Reinstate the following test for concurrent debugged runs on // a GPU once the root cause of the ~0.5% flakiness has been addressed. // (b/34081273) @@ -500,6 +504,8 @@ class SessionDebugOutputSlotWithoutOngoingEdgeTest : public ::testing::Test { #if GOOGLE_CUDA const string kDeviceName = "/job:localhost/replica:0/task:0/gpu:0"; +#elif defined(TENSORFLOW_USE_SYCL) + const string kDeviceName = "/job:localhost/replica:0/task:0/device:SYCL:0"; #else const string kDeviceName = "/job:localhost/replica:0/task:0/cpu:0"; #endif @@ -600,6 +606,8 @@ class SessionDebugVariableTest : public ::testing::Test { #if GOOGLE_CUDA const string kDeviceName = "/job:localhost/replica:0/task:0/gpu:0"; +#elif defined(TENSORFLOW_USE_SYCL) + const string kDeviceName = "/job:localhost/replica:0/task:0/device:SYCL:0"; #else const string kDeviceName = "/job:localhost/replica:0/task:0/cpu:0"; #endif @@ -823,6 +831,8 @@ TEST_F(SessionDebugVariableTest, VariableAssignWithDebugOps) { #if GOOGLE_CUDA ASSERT_EQ(2, run_metadata.partition_graphs().size()); +#elif defined(TENSORFLOW_USE_SYCL) + ASSERT_EQ(2, run_metadata.partition_graphs().size()); #else ASSERT_EQ(1, run_metadata.partition_graphs().size()); #endif @@ -860,13 +870,17 @@ TEST_F(SessionDebugVariableTest, VariableAssignWithDebugOps) { ASSERT_EQ(2, debug_nan_count_tensor_vals[0].scalar()()); } -#if GOOGLE_CUDA +#if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_SYCL) class SessionDebugGPUSwitchTest : public ::testing::Test { public: void Initialize() { Graph graph(OpRegistry::Global()); +#ifdef GOOGLE_CUDA const string kDeviceName = "/job:localhost/replica:0/task:0/gpu:0"; +#elif TENSORFLOW_USE_SYCL + const string kDeviceName = "/job:localhost/replica:0/task:0/device:SYCL:0"; +#endif Tensor vb(DT_BOOL, TensorShape({})); vb.scalar()() = true; diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index 6c3917c6869e65..dec987e1ed9ebb 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -96,9 +96,9 @@ OpKernel::OpKernel(OpKernelConstruction* context) OP_REQUIRES_OK(context, CheckOpDeprecation(*context->op_def_, context->graph_def_version())); - // Kernels executing on GPU tie very few resources on the CPU where the + // Kernels executing on GPU/SYCL tie very few resources on the CPU where the // scheduler runs: we consider them as inexpensive. - expensive_ = context->device_type() != DeviceType(DEVICE_GPU); + expensive_ = context->device_type() != DeviceType(DEVICE_GPU) && context->device_type() != DeviceType(DEVICE_SYCL); } OpKernel::~OpKernel() {} diff --git a/tensorflow/core/graph/testlib.cc b/tensorflow/core/graph/testlib.cc index c495b2181207a0..c59c44c80edc08 100644 --- a/tensorflow/core/graph/testlib.cc +++ b/tensorflow/core/graph/testlib.cc @@ -36,6 +36,10 @@ namespace tensorflow { REGISTER_KERNEL_BUILDER(Name("HostConst").Device(DEVICE_CPU), HostConstantOp); REGISTER_KERNEL_BUILDER( Name("HostConst").Device(DEVICE_GPU).HostMemory("output"), HostConstantOp); +#ifdef TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER( + Name("HostConst").Device(DEVICE_SYCL).HostMemory("output"), HostConstantOp); +#endif // TENSORFLOW_USE_SYCL // Register the HostConst Op // Returns a constant tensor on the host. Useful for writing C++ tests diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 06ef9dc9c999f2..8540ccba318e7e 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -33,6 +33,7 @@ load( "tf_mkl_kernel_library", "cc_header_only_library", ) +load("@local_config_sycl//sycl:build_defs.bzl", "if_sycl") load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_tests") load( @@ -494,7 +495,7 @@ ARRAY_DEPS = [ "//tensorflow/core:proto_text", "//tensorflow/core:protos_all_cc", "//third_party/eigen3", -] +] + if_sycl(["//tensorflow/core:sycl_runtime"]) cc_library( name = "array_not_windows", @@ -3431,7 +3432,7 @@ STATE_DEPS = [ "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:state_ops_op_lib", -] +] + if_sycl(["//tensorflow/core:sycl_runtime"]) tf_kernel_library( name = "count_up_to_op", diff --git a/tensorflow/core/kernels/batch_matmul_op_impl.h b/tensorflow/core/kernels/batch_matmul_op_impl.h index dfc81a960eac7d..b87c98c374e364 100644 --- a/tensorflow/core/kernels/batch_matmul_op_impl.h +++ b/tensorflow/core/kernels/batch_matmul_op_impl.h @@ -39,6 +39,9 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL namespace { @@ -413,6 +416,40 @@ struct LaunchBatchMatMul { #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +template +struct ParallelMatMulKernelSYCL { + static void Run(const OpKernelContext* context, const Tensor& in_x, + const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out, + int start, int limit) { + auto Tx = in_x.tensor(); + auto Ty = in_y.tensor(); + auto Tz = out->tensor(); + Eigen::array, 1> contract_pairs; + contract_pairs[0] = ContractionDims(adj_x, adj_y); + auto d = context->eigen_sycl_device(); + for (int i = start; i < limit; ++i) { + auto x = Tx.template chip<0>(i); + auto y = Ty.template chip<0>(i); + auto z = Tz.template chip<0>(i); + z.device(d) = x.contract(y, contract_pairs); + } + } +}; + +template +struct LaunchBatchMatMul { + static void Launch(OpKernelContext* context, const Tensor& in_x, + const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out) { + + // Number of matrix multiplies i.e. size of the batch. + const int64 num_units = in_x.dim_size(0); + ParallelMatMulKernelSYCL::Run(context, in_x, in_y, adj_x, adj_y, out, + 0, num_units); + } +}; +#endif // TENSORFLOW_USE_SYCL + template class BatchMatMul : public OpKernel { public: @@ -492,4 +529,10 @@ class BatchMatMul : public OpKernel { Name("BatchMatMul").Device(DEVICE_GPU).TypeConstraint("T"), \ BatchMatMul) +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_BATCH_MATMUL_SYCL(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("BatchMatMul").Device(DEVICE_SYCL).TypeConstraint("T"), \ + BatchMatMul) +#endif // TENSORFLOW_USE_SYCL } // end namespace tensorflow diff --git a/tensorflow/core/kernels/batch_matmul_op_real.cc b/tensorflow/core/kernels/batch_matmul_op_real.cc index c719e30c4d2ae1..1900ed8e31483a 100644 --- a/tensorflow/core/kernels/batch_matmul_op_real.cc +++ b/tensorflow/core/kernels/batch_matmul_op_real.cc @@ -30,4 +30,8 @@ TF_CALL_half(REGISTER_BATCH_MATMUL_GPU); #endif #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +TF_CALL_float(REGISTER_BATCH_MATMUL_SYCL); +TF_CALL_double(REGISTER_BATCH_MATMUL_SYCL); +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/batch_norm_op.cc b/tensorflow/core/kernels/batch_norm_op.cc index 56f4e25fad096f..d3ed617f713094 100644 --- a/tensorflow/core/kernels/batch_norm_op.cc +++ b/tensorflow/core/kernels/batch_norm_op.cc @@ -28,6 +28,9 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL template class BatchNormOp : public OpKernel { @@ -201,6 +204,18 @@ TF_CALL_float(REGISTER_GPU_KERNEL); #endif // GOOGLE_CUDA +#if TENSORFLOW_USE_SYCL +#define REGISTER_KERNEL(T) \ + REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalization") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint("T"), \ + BatchNormOp); + +TF_CALL_float(REGISTER_KERNEL); +TF_CALL_double(REGISTER_KERNEL); +#undef REGISTER_KERNEL +#endif // TENSORFLOW_USE_SYCL + #define REGISTER_KERNEL(T) \ REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalizationGrad") \ .Device(DEVICE_CPU) \ @@ -248,4 +263,17 @@ TF_CALL_float(REGISTER_GPU_KERNEL); #endif // GOOGLE_CUDA +#if TENSORFLOW_USE_SYCL +#define REGISTER_KERNEL(T) \ + REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalizationGrad") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint("T"), \ + BatchNormGradOp); + +TF_CALL_float(REGISTER_KERNEL); +TF_CALL_double(REGISTER_KERNEL); +#undef REGISTER_KERNEL + +#endif // TENSORFLOW_USE_SYCL + } // namespace tensorflow diff --git a/tensorflow/core/kernels/cast_op_impl_int32.cc b/tensorflow/core/kernels/cast_op_impl_int32.cc index fca9cd60ec1e2f..b62e842cc9bace 100644 --- a/tensorflow/core/kernels/cast_op_impl_int32.cc +++ b/tensorflow/core/kernels/cast_op_impl_int32.cc @@ -44,4 +44,3 @@ GetSyclCastFromInt32(DataType dst_dtype) { #endif // TENSORFLOW_USE_SYCL } // namespace tensorflow - diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc index 15fc08675202bc..68e960d6b75904 100644 --- a/tensorflow/core/kernels/constant_op.cc +++ b/tensorflow/core/kernels/constant_op.cc @@ -30,6 +30,10 @@ limitations under the License. #include "tensorflow/core/kernels/fill_functor.h" #include "tensorflow/core/platform/macros.h" +#ifdef TENSORFLOW_USE_SYCL +#include "tensorflow/core/common_runtime/sycl/sycl_util.h" +#endif // TENSORFLOW_USE_SYCL + namespace tensorflow { ConstantOp::ConstantOp(OpKernelConstruction* ctx) @@ -52,18 +56,6 @@ ConstantOp::~ConstantOp() {} REGISTER_KERNEL_BUILDER(Name("Const").Device(DEVICE_CPU), ConstantOp); -#if TENSORFLOW_USE_SYCL -#define REGISTER_SYCL_KERNEL(TYPE) \ - REGISTER_KERNEL_BUILDER( \ - Name("Const").Device(DEVICE_SYCL).TypeConstraint("dtype"), \ - ConstantOp); -REGISTER_SYCL_KERNEL(float); -REGISTER_SYCL_KERNEL(double); -REGISTER_SYCL_KERNEL(bool); -REGISTER_SYCL_KERNEL(int64); -#undef REGISTER_SYCL_KERNEL -#endif - #if GOOGLE_CUDA #define REGISTER_KERNEL(D, TYPE) \ REGISTER_KERNEL_BUILDER( \ @@ -85,6 +77,22 @@ REGISTER_KERNEL(GPU, bool); #undef REGISTER_KERNEL #endif +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_SYCL_KERNEL(D, TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("Const").Device(DEVICE_##D).TypeConstraint("dtype"), \ + ConstantOp); +REGISTER_SYCL_KERNEL(SYCL, float); +REGISTER_SYCL_KERNEL(SYCL, double); +REGISTER_SYCL_KERNEL(SYCL, uint8); +REGISTER_SYCL_KERNEL(SYCL, int8); +REGISTER_SYCL_KERNEL(SYCL, uint16); +REGISTER_SYCL_KERNEL(SYCL, int16); +REGISTER_SYCL_KERNEL(SYCL, int64); +REGISTER_SYCL_KERNEL(SYCL, bool); +#undef REGISTER_SYCL_KERNEL +#endif + HostConstantOp::HostConstantOp(OpKernelConstruction* ctx) : OpKernel(ctx), tensor_(ctx->output_type(0)) { const TensorProto* proto = nullptr; @@ -116,9 +124,6 @@ REGISTER_KERNEL_BUILDER(Name("Const") #endif #ifdef TENSORFLOW_USE_SYCL -// A special GPU kernel for int32. -// TODO(b/25387198): Also enable int32 in device memory. This kernel -// registration requires all int32 inputs and outputs to be in host memory. REGISTER_KERNEL_BUILDER(Name("Const") .Device(DEVICE_SYCL) .HostMemory("output") @@ -143,17 +148,6 @@ struct FillFunctor { } }; -#ifdef TENSORFLOW_USE_SYCL -// Partial specialization of FillFunctor. -template -struct FillFunctor { - void operator()(const SYCLDevice& d, typename TTypes::Flat out, - typename TTypes::ConstScalar in) { - To32Bit(out).device(d) = To32Bit(out).constant(in()); - } -}; -#endif // TENSORFLOW_USE_SYCL - } // end namespace functor template @@ -184,6 +178,28 @@ class FillOp : public OpKernel { } }; +#ifdef TENSORFLOW_USE_SYCL + +namespace functor { +// Partial specialization of FillFunctor. +template +struct FillFunctor { + void operator()(const SYCLDevice& d, typename TTypes::Flat out, + typename TTypes::ConstScalar in) { +#if !defined(EIGEN_HAS_INDEX_LIST) + Eigen::array rank1{1}; +#else + Eigen::IndexList> rank1; +#endif + const int size = out.dimension(0); + Eigen::array broadcast_dims{size}; + + To32Bit(out).device(d) = in.reshape(rank1).broadcast(broadcast_dims); + } +}; +} +#endif // TENSORFLOW_USE_SYCL + #define REGISTER_KERNEL(D, TYPE) \ REGISTER_KERNEL_BUILDER(Name("Fill") \ .Device(DEVICE_##D) \ @@ -199,8 +215,14 @@ REGISTER_KERNEL(CPU, quint8); #undef REGISTER_CPU_KERNEL #ifdef TENSORFLOW_USE_SYCL -REGISTER_KERNEL(SYCL, float) -REGISTER_KERNEL(SYCL, double) +REGISTER_KERNEL(SYCL, float); +REGISTER_KERNEL(SYCL, double); +REGISTER_KERNEL(SYCL, uint8); +REGISTER_KERNEL(SYCL, int8); +REGISTER_KERNEL(SYCL, uint16); +REGISTER_KERNEL(SYCL, int16); +REGISTER_KERNEL(SYCL, int64); + REGISTER_KERNEL_BUILDER(Name("Fill") .Device(DEVICE_SYCL) .TypeConstraint("T") @@ -208,6 +230,7 @@ REGISTER_KERNEL_BUILDER(Name("Fill") .HostMemory("value") .HostMemory("output"), FillOp); +#undef REGISTER_KERNEL_SYCL #endif // TENSORFLOW_USE_SYCL #if GOOGLE_CUDA @@ -260,8 +283,10 @@ TF_CALL_POD_STRING_TYPES(REGISTER_CPU); #undef REGISTER_CPU #ifdef TENSORFLOW_USE_SYCL -REGISTER_KERNEL(float, SYCL); REGISTER_KERNEL(bool, SYCL); +REGISTER_KERNEL(float, SYCL); +REGISTER_KERNEL(double, SYCL); +REGISTER_KERNEL(int64, SYCL); REGISTER_KERNEL_BUILDER(Name("ZerosLike") .Device(DEVICE_SYCL) .TypeConstraint("T") diff --git a/tensorflow/core/kernels/debug_ops.h b/tensorflow/core/kernels/debug_ops.h index a9fa59ab018c08..ef12e2e42cb734 100644 --- a/tensorflow/core/kernels/debug_ops.h +++ b/tensorflow/core/kernels/debug_ops.h @@ -19,6 +19,9 @@ limitations under the License. #if GOOGLE_CUDA #include "tensorflow/core/common_runtime/gpu/gpu_util.h" #endif +#ifdef TENSORFLOW_USE_SYCL +#include "tensorflow/core/common_runtime/sycl/sycl_util.h" +#endif // TENSORFLOW_USE_SYCL #include "tensorflow/core/debug/debug_io_utils.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/op_kernel.h" @@ -84,6 +87,22 @@ class CopyOp : public OpKernel { // The input tensor is on the host (CPU): deep-copy from CPU to CPU. *copied_tensor = tensor::DeepCopy(src_tensor); } +#elif defined(TENSORFLOW_USE_SYCL) + Device* device = static_cast(context->device()); + // Determine if the input tensor is not on CPU (e.g., on GPU). + const bool off_host_input = device->device_type() == DEVICE_SYCL && + !context->input_alloc_attr(0).on_host(); + + if (off_host_input) { + auto size = src_tensor.NumElements() * sizeof(src_tensor.dtype()); + auto dst_ptr = GetBase(copied_tensor); + auto src_ptr = GetBase(&src_tensor); + typedef decltype(src_tensor.dtype()) ttype; + context->eigen_sycl_device().memcpy( + dst_ptr, static_cast(src_ptr), size); + } else { + *copied_tensor = tensor::DeepCopy(src_tensor); + } #else *copied_tensor = tensor::DeepCopy(src_tensor); #endif diff --git a/tensorflow/core/kernels/fill_functor.cc b/tensorflow/core/kernels/fill_functor.cc index af06e12a5e4134..8a0a558eefa4c1 100644 --- a/tensorflow/core/kernels/fill_functor.cc +++ b/tensorflow/core/kernels/fill_functor.cc @@ -56,17 +56,22 @@ DEFINE_SETZERO_CPU(complex128); template void SetZeroFunctor::operator()( const Eigen::SyclDevice& d, typename TTypes::Flat out) { - out.device(d) = out.constant(T(0)); + To32Bit(out).device(d) = To32Bit(out).constant(T(0)); } #define DEFINE_SETZERO_SYCL(T) \ template struct SetZeroFunctor; -DEFINE_SETZERO_SYCL(float); DEFINE_SETZERO_SYCL(bool); +DEFINE_SETZERO_SYCL(float); DEFINE_SETZERO_SYCL(double); +DEFINE_SETZERO_SYCL(uint8); +DEFINE_SETZERO_SYCL(int8); +DEFINE_SETZERO_SYCL(uint16); +DEFINE_SETZERO_SYCL(int16); +DEFINE_SETZERO_SYCL(int32); +DEFINE_SETZERO_SYCL(int64); #undef DEFINE_SETZERO_SYCL #endif // TENSORFLOW_USE_SYCL - template void SetOneFunctor::operator()( const Eigen::ThreadPoolDevice& d, typename TTypes::Flat out) { diff --git a/tensorflow/core/kernels/inplace_ops.cc b/tensorflow/core/kernels/inplace_ops.cc index 4433b9eea9892e..67bec7d50e9d55 100644 --- a/tensorflow/core/kernels/inplace_ops.cc +++ b/tensorflow/core/kernels/inplace_ops.cc @@ -25,11 +25,14 @@ limitations under the License. namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SyclDevice; +#endif // TENSORFLOW_USE_SYCL namespace functor { -template -Status DoParallelConcatUpdate(const CPUDevice& d, const Tensor& value, +template +Status DoParallelConcatUpdate(const Device& d, const Tensor& value, int32 loc, Tensor* output) { auto Tvalue = value.flat_outer_dims(); auto Toutput = output->flat_outer_dims(); @@ -46,7 +49,7 @@ Status DoParallelConcat(const CPUDevice& d, const Tensor& value, int32 loc, switch (value.dtype()) { #define CASE(type) \ case DataTypeToEnum::value: \ - return DoParallelConcatUpdate(d, value, loc, output); + return DoParallelConcatUpdate(d, value, loc, output); TF_CALL_NUMBER_TYPES(CASE); TF_CALL_string(CASE); #undef CASE @@ -55,6 +58,23 @@ Status DoParallelConcat(const CPUDevice& d, const Tensor& value, int32 loc, } } +#ifdef TENSORFLOW_USE_SYCL +template <> +Status DoParallelConcat(const SyclDevice& d, const Tensor& value, int32 loc, + Tensor* output) { + CHECK_EQ(value.dtype(), output->dtype()); + switch (value.dtype()) { +#define CASE(type) \ + case DataTypeToEnum::value: \ + return DoParallelConcatUpdate(d, value, loc, output); + TF_CALL_GPU_NUMBER_TYPES_NO_HALF(CASE); +#undef CASE + default: + return errors::InvalidArgument("Unsupported data type: ", value.dtype()); + } +} +#endif // TENSORFLOW_USE_SYCL + } // end namespace functor namespace { @@ -152,6 +172,42 @@ TF_CALL_POD_STRING_TYPES(REGISTER_EMPTY) TF_CALL_POD_STRING_TYPES(REGISTER_PARALLEL_CONCAT); #undef REGISTER_PARALLEL_CONCAT +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_EMPTY(type) \ + REGISTER_KERNEL_BUILDER(Name("_ParallelConcatStart") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint("dtype"), \ + ParallelConcatStart); +TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_EMPTY) +#undef REGISTER_EMPTY + +#define REGISTER_PARALLEL_CONCAT(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("ParallelConcat").Device(DEVICE_SYCL).TypeConstraint("T"), \ + FailureKernel); +TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_PARALLEL_CONCAT); +#undef REGISTER_PARALLEL_CONCAT + +#define REGISTER(type) \ + REGISTER_KERNEL_BUILDER(Name("_ParallelConcatUpdate") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint("T"), \ + ParallelConcatUpdate); +TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER) +#undef REGISTER + +// Register versions that operate on int32 data on the CPU even though the op +// has been placed on the SYCL + +REGISTER_KERNEL_BUILDER(Name("_ParallelConcatUpdate") + .Device(DEVICE_SYCL) + .HostMemory("value") + .HostMemory("update") + .HostMemory("output") + .TypeConstraint("T"), + ParallelConcatUpdate); +#endif // TENSORFLOW_USE_SYCL + #if GOOGLE_CUDA typedef Eigen::GpuDevice GPUDevice; diff --git a/tensorflow/core/kernels/matmul_op.cc b/tensorflow/core/kernels/matmul_op.cc index 199e1605923863..8003f7ff67fd71 100644 --- a/tensorflow/core/kernels/matmul_op.cc +++ b/tensorflow/core/kernels/matmul_op.cc @@ -395,6 +395,7 @@ TF_CALL_half(REGISTER_GPU); .Label("eigen"), \ MatMulOp) TF_CALL_float(REGISTER_SYCL); +TF_CALL_double(REGISTER_SYCL); #endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/random_op.cc b/tensorflow/core/kernels/random_op.cc index 80b1be8d4cae49..e78f8e26211786 100644 --- a/tensorflow/core/kernels/random_op.cc +++ b/tensorflow/core/kernels/random_op.cc @@ -48,6 +48,9 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL namespace functor { using random::PhiloxRandom; @@ -549,4 +552,193 @@ TF_CALL_int64(REGISTER_INT); #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL + +namespace functor { + +using namespace cl; + +template +struct FillPhiloxRandomKernel; + +template +struct FillPhiloxRandomKernel { + typedef typename Distribution::ResultElementType T; + using write_accessor = sycl::accessor; + + FillPhiloxRandomKernel(write_accessor& data, random::PhiloxRandom& gen, Distribution& dist) + : data_(data), + gen_(gen), + dist_(dist) { + } + + void operator()(sycl::nd_item<1> item) { + const size_t kGroupSize = Distribution::kResultElementCount; + + const size_t item_id = item.get_global(0); + const size_t total_item_count = item.get_global_range(0); + size_t offset = item_id * kGroupSize; + gen_.Skip(item_id); + + const size_t size = data_.get_size() / sizeof(T); + T* data = ConvertToActualTypeSycl(T, data_); + + while (offset + kGroupSize <= size) { + const typename Distribution::ResultType samples = dist_(&gen_); + for (size_t i = 0; i < kGroupSize; ++i) { + data[offset + i] = samples[i]; + } + + offset += (total_item_count - 1) * kGroupSize; + gen_.Skip(total_item_count - 1); + } + + const typename Distribution::ResultType samples = dist_(&gen_); + for (size_t i = 0; i < kGroupSize; ++i) { + if (offset >= size) { + return; + } + data[offset] = samples[i]; + ++offset; + } + } + + private: + write_accessor data_; + random::PhiloxRandom gen_; + Distribution dist_; +}; + + +template +struct FillPhiloxRandomKernel { + typedef typename Distribution::ResultElementType T; + using write_accessor = sycl::accessor; + + FillPhiloxRandomKernel(write_accessor& data, random::PhiloxRandom& gen, Distribution& dist) + : data_(data), + gen_(gen), + dist_(dist) { + } + + void operator()(sycl::nd_item<1> item) { + using random::PhiloxRandom; + using random::SingleSampleAdapter; + + const size_t kReservedSamplesPerOutput = 256; + const size_t kGroupSize = Distribution::kResultElementCount; + const size_t kGeneratorSkipPerOutputGroup = kGroupSize * + kReservedSamplesPerOutput / + PhiloxRandom::kResultElementCount; + + const size_t item_id = item.get_global(0); + const size_t total_item_count = item.get_global_range(0); + size_t group_index = item_id; + size_t offset = group_index * kGroupSize; + + T* data = ConvertToActualTypeSycl(T, data_); + const size_t size = data_.get_size() / sizeof(T); + + while (offset < size) { + // Since each output takes a variable number of samples, we need to + // realign the generator to the beginning for the current output group + PhiloxRandom gen = gen_; + gen.Skip(group_index * kGeneratorSkipPerOutputGroup); + SingleSampleAdapter single_samples(&gen); + + const typename Distribution::ResultType samples = dist_(&single_samples); + + for (size_t i = 0; i < kGroupSize; ++i) { + if (offset >= size) { + return; + } + data[offset] = samples[i]; + ++offset; + } + + offset += (total_item_count - 1) * kGroupSize; + group_index += total_item_count; + } + } + + private: + write_accessor data_; + random::PhiloxRandom gen_; + Distribution dist_; +}; + +template +class FillRandomKernel; +// Partial specialization for SYCL to fill the entire region with randoms +// It splits the work into several tasks and run them in parallel +template +void FillPhiloxRandom::operator()( + OpKernelContext* context, const SYCLDevice& device, random::PhiloxRandom gen, + typename Distribution::ResultElementType* data, int64 size, + Distribution dist) { + + const size_t group_size = device.maxSyclThreadsPerBlock(); + const size_t group_count = (size + group_size - 1) / group_size; + + auto buffer = device.get_sycl_buffer(data); + + device.sycl_queue().submit([&](sycl::handler& cgh) { + auto access = buffer.template get_access(cgh); + + FillPhiloxRandomKernel task(access, gen, dist); + cgh.parallel_for>( + sycl::nd_range<1>(sycl::range<1>(group_count * group_size), sycl::range<1>(group_size)), + task + ); + }); +} + +} + +#define REGISTER(TYPE) \ + template struct functor::FillPhiloxRandom< \ + SYCLDevice, random::UniformDistribution >; \ + REGISTER_KERNEL_BUILDER( \ + Name("RandomUniform") \ + .Device(DEVICE_SYCL) \ + .HostMemory("shape") \ + .TypeConstraint("dtype"), \ + PhiloxRandomOp >); \ + REGISTER_KERNEL_BUILDER( \ + Name("RandomStandardNormal") \ + .Device(DEVICE_SYCL) \ + .HostMemory("shape") \ + .TypeConstraint("dtype"), \ + PhiloxRandomOp >); \ + REGISTER_KERNEL_BUILDER( \ + Name("TruncatedNormal") \ + .Device(DEVICE_SYCL) \ + .HostMemory("shape") \ + .TypeConstraint("dtype"), \ + PhiloxRandomOp< \ + SYCLDevice, \ + random::TruncatedNormalDistribution< \ + random::SingleSampleAdapter, TYPE> >); + +#define REGISTER_INT(IntType) \ + REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \ + .Device(DEVICE_SYCL) \ + .HostMemory("shape") \ + .HostMemory("minval") \ + .HostMemory("maxval") \ + .TypeConstraint("Tout"), \ + RandomUniformIntOp); + +TF_CALL_float(REGISTER); +TF_CALL_double(REGISTER); +TF_CALL_int32(REGISTER_INT); +TF_CALL_int64(REGISTER_INT); + +#undef REGISTER +#undef REGISTER_INT + +#endif // TENSORFLOW_USE_SYCL + } // end namespace tensorflow diff --git a/tensorflow/core/kernels/random_op.h b/tensorflow/core/kernels/random_op.h index b52901c38e3ac9..97bcaf1a49a37c 100644 --- a/tensorflow/core/kernels/random_op.h +++ b/tensorflow/core/kernels/random_op.h @@ -54,6 +54,18 @@ struct FillPhiloxRandom { }; #endif // GOOGLE_CUDA +#if TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +// Declares the partially SYCL-specialized functor struct. +template +struct FillPhiloxRandom { + void operator()(OpKernelContext* ctx, const SYCLDevice& d, + random::PhiloxRandom gen, + typename Distribution::ResultElementType* data, int64 size, + Distribution dist); +}; +#endif // TENSORFLOW_USE_SYCL + } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/scatter_functor.h b/tensorflow/core/kernels/scatter_functor.h index 63add61ba7222e..c6e35fe329e1c1 100644 --- a/tensorflow/core/kernels/scatter_functor.h +++ b/tensorflow/core/kernels/scatter_functor.h @@ -75,6 +75,50 @@ struct Assign { } }; +#ifdef TENSORFLOW_USE_SYCL +template +struct AssignSYCL {}; +template <> +struct AssignSYCL { + template + static void Run(Device d, Params p, Update u) { + p.device(d) = u; + } +}; + +template <> +struct AssignSYCL { + template + static void Run(Device d, Params p, Update u) { + p.device(d) += u; + } +}; + +template <> +struct AssignSYCL { + template + static void Run(Device d, Params p, Update u) { + p.device(d) -= u; + } +}; + +template <> +struct AssignSYCL { + template + static void Run(Device d, Params p, Update u) { + p.device(d) = p * u; + } +}; + +template <> +struct AssignSYCL { + template + static void Run(Device d, Params p, Update u) { + p.device(d) = p / u; + } +}; +#endif // TENSORFLOW_USE_SYCL + } // namespace internal } // namespace scatter_op @@ -110,6 +154,31 @@ struct ScatterFunctorBase { } }; +#ifdef TENSORFLOW_USE_SYCL +template +struct ScatterFunctorBase { + Index operator()(OpKernelContext* c, const SYCLDevice& d, + typename TTypes::Matrix params, + typename TTypes::ConstMatrix updates, + typename TTypes::ConstFlat indices) { + // indices and params sizes were validated in DoCompute(). + const Index N = static_cast(indices.size()); + const Index limit = static_cast(params.dimension(0)); + for (Index i = 0; i < N; i++) { + // Grab the index and check its validity. An earlier version of the + // code checked it and then grabbed it from memory a second time, which + // was a security risk since it could have changed in between. + const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); + if (!FastBoundsCheck(index, limit)) return i; + // Copy last Ndim-1 dimensions of updates[i] to params[index] + scatter_op::internal::AssignSYCL::Run(d, params.template chip<0>(index), + updates.template chip<0>(i)); + } + return -1; + } +}; +#endif // TENSORFLOW_USE_SYCL + template struct ScatterFunctorBase { Index operator()(OpKernelContext* c, const CPUDevice& d, @@ -149,10 +218,27 @@ struct ScatterFunctorBase { template struct ScatterFunctor : ScatterFunctorBase{}; -#if TENSORFLOW_USE_SYCL -template -struct ScatterFunctor - : ScatterFunctorBase{}; + +#ifdef TENSORFLOW_USE_SYCL +template +struct ScatterFunctorSYCL { + Index operator()(OpKernelContext* c, const SYCLDevice& d, + typename TTypes::Matrix params, + typename TTypes::ConstMatrix updates, + typename TTypes::Flat indices) { + // indices and params sizes were validated in DoCompute(). + const Index N = static_cast(indices.size()); + const Index limit = static_cast(params.dimension(0)); + for (Index i = 0; i < N; i++) { + const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); + if (!FastBoundsCheck(index, limit)) return i; + // Copy last Ndim-1 dimensions of updates[i] to params[index] + scatter_op::internal::AssignSYCL::Run( + d, params.template chip<0>(index), updates.template chip<0>(i)); + } + return -1; + } +}; #endif // TENSORFLOW_USE_SYCL } // namespace functor diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc index 363de801baabf6..48565d8cb97e95 100644 --- a/tensorflow/core/kernels/scatter_nd_op.cc +++ b/tensorflow/core/kernels/scatter_nd_op.cc @@ -27,10 +27,17 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/util.h" +#ifdef TENSORFLOW_USE_SYCL +#include "tensorflow/core/common_runtime/sycl/sycl_util.h" +#endif // TENSORFLOW_USE_SYCL + namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL // Check whether updates.shape = indices.shape[:batch_dim] + // params_shape[slice_dim:] @@ -138,6 +145,40 @@ static void PrepareAndValidateInputs(OpKernelContext* c, *num_updates = indices_shape.num_elements() / safe_slice_dim; } +template +class IndexFlattener { +public: + inline typename TTypes::ConstTensor + operator()(OpKernelContext*, const Tensor& indices) { + return indices.flat_inner_dims(); + } +}; + +#ifdef TENSORFLOW_USE_SYCL +template +class IndexFlattener { +public: + IndexFlattener() { indices_host_ = nullptr; } + ~IndexFlattener() { delete[] indices_host_; } + + inline typename TTypes::ConstTensor + operator()(OpKernelContext* c, const Tensor& indices) { + size_t num_indices = indices.NumElements(); + indices_host_ = new Index[num_indices]; + auto device = c->eigen_sycl_device(); + auto size = sizeof(Index) * num_indices; + auto src_ptr = GetBase(&indices); + device.memcpyDeviceToHost(indices_host_, static_cast(src_ptr), + size); + return typename TTypes::ConstTensor(indices_host_, + indices.shape().AsEigenDSizes<2>()); + } + +private: + Index* indices_host_; +}; +#endif + template class ScatterNdOp : public OpKernel { public: @@ -166,7 +207,8 @@ class ScatterNdOp : public OpKernel { &num_updates, &slice_size); if (!c->status().ok()) return; - auto indices_flat = indices.flat_inner_dims(); + IndexFlattener index_flattener; + auto indices_flat = index_flattener(c, indices); auto updates_flat = updates.shaped({num_updates, slice_size}); Tensor* out = nullptr; @@ -262,7 +304,8 @@ class ScatterNdUpdateOp : public OpKernel { &slice_dim, &num_updates, &slice_size); if (!c->status().ok()) return; - auto indices_flat = indices.flat_inner_dims(); + IndexFlattener index_flattener; + auto indices_flat = index_flattener(c, indices); auto updates_flat = updates.shaped({num_updates, slice_size}); auto params_matrix = params.template shaped( {params_shape.num_elements() / slice_size, slice_size}); @@ -419,6 +462,19 @@ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DECLARE_GPU_SPECS); #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_SCATTER_ND_ADD_SUB_SYCL(type) \ + REGISTER_SCATTER_ND_ADD_SUB(type, SYCL); + +#define REGISTER_SCATTER_ND_UPDATE_SYCL(type) \ + REGISTER_SCATTER_ND_UPDATE(type, SYCL); + +TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_ADD_SUB_SYCL); +TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_UPDATE_SYCL); +#undef REGISTER_SCATTER_ND_ADD_SUB_SYCL +#undef REGISTER_SCATTER_ND_UPDATE_SYCL +#endif // TENSORFLOW_USE_SYCL + #undef REGISTER_SCATTER_ND_ADD #undef REGISTER_SCATTER_ND_ADD_SUB #undef REGISTER_SCATTER_ND_ADD_SUB_CPU diff --git a/tensorflow/core/kernels/scatter_nd_op_cpu_impl.h b/tensorflow/core/kernels/scatter_nd_op_cpu_impl.h index bbe2c6864ff40c..788797b668d347 100644 --- a/tensorflow/core/kernels/scatter_nd_op_cpu_impl.h +++ b/tensorflow/core/kernels/scatter_nd_op_cpu_impl.h @@ -38,6 +38,9 @@ limitations under the License. namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL class OpKernelContext; @@ -186,6 +189,92 @@ TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_MATH) #undef REGISTER_SCATTER_ND_INDEX #undef REGISTER_SCATTER_ND_FULL +#ifdef TENSORFLOW_USE_SYCL +// Implementation of update functor for SYCL. +template +struct ScatterNdFunctor { + Index operator()( + const SYCLDevice& d, const Index slice_size, + const Eigen::array output_shape_prefix, + typename TTypes::Tensor Tparams, + typename TTypes::ConstTensor Tindices, + typename TTypes::ConstTensor Tupdates, + typename TTypes::Tensor Toutput) { + // error_loc is -1 if there's no out-of-bounds index, + // otherwise it is the location of an OOB index in Tindices. + Index error_loc = -1; + + const Eigen::DenseIndex batch_size = Tindices.dimension(0); + + Index batch_strides[IXDIM]; + for (int dim = IXDIM - 1; dim >= 0; --dim) { + if (dim == IXDIM - 1) { + batch_strides[dim] = 1; + } else { + batch_strides[dim] = + batch_strides[dim + 1] * output_shape_prefix[dim + 1]; + } + } + + for (Eigen::DenseIndex loc = 0; loc < batch_size; ++loc) { + Index i = 0; + bool out_of_bounds = false; + for (int dim = 0; dim < IXDIM; ++dim) { + const Index ix_d = internal::SubtleMustCopy(Tindices(loc, dim)); + out_of_bounds |= !FastBoundsCheck(ix_d, output_shape_prefix[dim]); + i += ix_d * batch_strides[dim]; + } + if (TF_PREDICT_FALSE(out_of_bounds)) { + error_loc = loc; + break; + } else { + auto input_chip = Toutput.template chip<0>(i); + auto output_chip = input_chip.device(d); + auto update_chip = Tupdates.template chip<0>(loc); + update_executor::UpdateExecutor< + decltype(input_chip), decltype(update_chip), decltype(output_chip), + OP>::Execute(input_chip, update_chip, output_chip); + } + } + + return error_loc; + } +}; + +#define REGISTER_SCATTER_ND_FULL_SYCL(T, Index, op) \ + template Index \ + ScatterNdFunctor::operator()( \ + const SYCLDevice& d, const Index slice_size, \ + const Eigen::array \ + output_shape_prefix, \ + typename TTypes::Tensor Tparams, \ + typename TTypes::ConstTensor Tindices, \ + typename TTypes::ConstTensor Tupdates, \ + typename TTypes::Tensor Toutput) + +#define REGISTER_SCATTER_ND_INDEX_SYCL(type, op) \ + REGISTER_SCATTER_ND_FULL_SYCL(type, int32, op); \ + REGISTER_SCATTER_ND_FULL_SYCL(type, int64, op) + +#define REGISTER_SCATTER_ND_UPDATE_SYCL(type) \ + REGISTER_SCATTER_ND_INDEX_SYCL(type, scatter_nd_op::UpdateOp::ASSIGN); + +#define REGISTER_SCATTER_ND_MATH_SYCL(type) \ + REGISTER_SCATTER_ND_INDEX_SYCL(type, scatter_nd_op::UpdateOp::ADD); \ + REGISTER_SCATTER_ND_INDEX_SYCL(type, scatter_nd_op::UpdateOp::SUB); + +TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_UPDATE_SYCL) +TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_MATH_SYCL) +REGISTER_SCATTER_ND_UPDATE_SYCL(int32); +REGISTER_SCATTER_ND_MATH_SYCL(int32); + +#undef REGISTER_SCATTER_ND_MATH_SYCL +#undef REGISTER_SCATTER_ND_UPDATE_SYCL +#undef REGISTER_SCATTER_ND_INDEX_SYCL +#undef REGISTER_SCATTER_ND_FULL_SYCL + +#endif // TENSORFLOW_USE_SYCL + } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/scatter_op.cc b/tensorflow/core/kernels/scatter_op.cc index 51dad49cfec884..8607c7f95af79c 100644 --- a/tensorflow/core/kernels/scatter_op.cc +++ b/tensorflow/core/kernels/scatter_op.cc @@ -23,6 +23,10 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/util.h" +#ifdef TENSORFLOW_USE_SYCL +#include "tensorflow/core/common_runtime/sycl/sycl_util.h" +#endif // TENSORFLOW_USE_SYCL + namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; @@ -131,6 +135,79 @@ class ScatterUpdateOp : public OpKernel { } }; +#ifdef TENSORFLOW_USE_SYCL +template +class ScatterUpdateOp : public OpKernel { + public: + explicit ScatterUpdateOp(OpKernelConstruction* c) : OpKernel(c) { + OP_REQUIRES_OK(c, c->GetAttr("use_locking", &use_exclusive_lock_)); + } + + void Compute(OpKernelContext* c) override { + if (use_exclusive_lock_) { + // Hold mutex while we apply updates + mutex_lock l(*c->input_ref_mutex(0)); + DoCompute(c); + } else { + DoCompute(c); + } + } + + private: + bool use_exclusive_lock_; + + void DoCompute(OpKernelContext* c) { + Tensor params = c->mutable_input(0, use_exclusive_lock_); + const Tensor& indices = c->input(1); + const Tensor& updates = c->input(2); + DoValidationChecking(c, params, indices, updates); + if (!c->status().ok()) return; + + // Check that we have enough index space + const int64 N_big = indices.NumElements(); + OP_REQUIRES(c, N_big <= std::numeric_limits::max(), + errors::InvalidArgument( + "indices has too many elements for ", + DataTypeString(DataTypeToEnum::v()), " indexing: ", + N_big, " > ", std::numeric_limits::max())); + const Index N = static_cast(indices.NumElements()); + OP_REQUIRES( + c, params.dim_size(0) <= std::numeric_limits::max(), + errors::InvalidArgument("params.shape[0] too large for ", + DataTypeString(DataTypeToEnum::v()), + " indexing: ", params.dim_size(0), " > ", + std::numeric_limits::max())); + + // We always return the input ref. + c->forward_ref_input_to_ref_output(0, 0); + + if (N > 0) { + auto index_size = indices.NumElements() * sizeof(Index); + Tensor indices_host = Tensor(indices.dtype(), indices.shape()); + + auto src_ptr = GetBase(&indices); + auto dst_ptr = GetBase(&indices_host); + + c->eigen_sycl_device().memcpyDeviceToHost( + dst_ptr, static_cast(src_ptr), index_size); + + auto indices_flat = indices_host.flat(); + auto params_flat = params.flat_outer_dims(); + auto updates_flat = updates.shaped({N, updates.NumElements() / N}); + + functor::ScatterFunctorSYCL functor; + const Index bad_i = functor(c, c->template eigen_device(), + params_flat, updates_flat, indices_flat); + OP_REQUIRES( + c, bad_i < 0, + errors::InvalidArgument( + "indices", SliceDebugString(indices.shape(), bad_i), " = ", + indices_flat(bad_i), " is not in [0, ", params.dim_size(0), ")")); + } + } +}; +#endif // TENSORFLOW_USE_SYCL + #define REGISTER_SCATTER_KERNEL_INDEX(type, index_type, dev, name, op) \ REGISTER_KERNEL_BUILDER(Name(name) \ .Device(DEVICE_##dev) \ diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc index 8e2d6dc74eb710..f6b6194f0abb10 100644 --- a/tensorflow/core/kernels/training_ops.cc +++ b/tensorflow/core/kernels/training_ops.cc @@ -23,6 +23,10 @@ limitations under the License. #include "tensorflow/core/kernels/training_op_helpers.h" #include "tensorflow/core/kernels/variable_ops.h" +#ifdef TENSORFLOW_USE_SYCL +#include "tensorflow/core/common_runtime/sycl/sycl_util.h" +#endif // TENSORFLOW_USE_SYCL + namespace tensorflow { using CPUDevice = Eigen::ThreadPoolDevice; @@ -50,11 +54,10 @@ struct ApplyGradientDescent { #ifdef TENSORFLOW_USE_SYCL template -struct ApplyGradientDescent { +struct ApplyGradientDescentSYCL { void operator()(const SYCLDevice& d, typename TTypes::Flat var, - typename TTypes::ConstScalar lr, - typename TTypes::ConstFlat grad) { - var.device(d) -= grad * lr(); + T lr, typename TTypes::ConstFlat grad) { + var.device(d) -= grad * lr; } }; #endif @@ -276,10 +279,24 @@ struct ApplyAdamNonCuda { } }; +#ifdef TENSORFLOW_USE_SYCL template -struct ApplyAdam : ApplyAdamNonCuda {}; +struct ApplyAdamSYCL { + void operator()(const SYCLDevice& d, typename TTypes::Flat var, + typename TTypes::Flat m, typename TTypes::Flat v, + T beta1_power, T beta2_power, T lr, T beta1, T beta2, T epsilon, + typename TTypes::ConstFlat grad) { + const T alpha = lr * Eigen::numext::sqrt(T(1) - beta2_power) / + (T(1) - beta1_power); + m.device(d) += (grad - m) * (T(1) - beta1); + v.device(d) += (grad.square() - v) * (T(1) - beta2); + var.device(d) -= (m * alpha) / (v.sqrt() + epsilon); + } +}; +#endif // TENSORFLOW_USE_SYCL + template -struct ApplyAdam : ApplyAdamNonCuda {}; +struct ApplyAdam : ApplyAdamNonCuda {}; template struct ApplyRMSProp { @@ -358,6 +375,51 @@ class ApplyGradientDescentOp : public OpKernel { bool use_exclusive_lock_; }; +#ifdef TENSORFLOW_USE_SYCL +template +class ApplyGradientDescentOp < SYCLDevice, T > : public OpKernel { + public: + explicit ApplyGradientDescentOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); + } + + void Compute(OpKernelContext* ctx) override { + auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0}); + Tensor var; + OP_REQUIRES_OK(ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var)); + + OP_REQUIRES( + ctx, var.IsInitialized(), + errors::FailedPrecondition( + "Attempting to use uninitialized variables: ", def().input(0))); + const Tensor& alpha_dev = ctx->input(1); + OP_REQUIRES(ctx, IsLegacyScalar(alpha_dev.shape()), + errors::InvalidArgument("alpha is not a scalar: ", + alpha_dev.shape().DebugString())); + const Tensor& delta = ctx->input(2); + OP_REQUIRES( + ctx, var.shape().IsSameSize(delta.shape()), + errors::InvalidArgument("var and delta do not have the same shape", + var.shape().DebugString(), " ", + delta.shape().DebugString())); + + auto device = ctx->eigen_sycl_device(); + auto size = sizeof(T); + T alpha = T(0); + auto src_ptr = GetBase(&alpha_dev); + device.memcpyDeviceToHost(&alpha, static_cast(src_ptr), size); + + functor::ApplyGradientDescentSYCL()(device, var.flat(), + alpha, delta.flat()); + + MaybeForwardRefInputToRefOutput(ctx, 0, 0); + } + + private: + bool use_exclusive_lock_; +}; +#endif // TENSORFLOW_USE_SYCL + #define REGISTER_KERNELS(D, T) \ REGISTER_KERNEL_BUILDER( \ Name("ApplyGradientDescent").Device(DEVICE_##D).TypeConstraint("T"), \ @@ -373,13 +435,6 @@ TF_CALL_half(REGISTER_CPU_KERNELS); TF_CALL_float(REGISTER_CPU_KERNELS); TF_CALL_double(REGISTER_CPU_KERNELS); -#ifdef TENSORFLOW_USE_SYCL -#define REGISTER_SYCL_KERNELS(T) REGISTER_KERNELS(SYCL, T); -TF_CALL_float(REGISTER_SYCL_KERNELS); -TF_CALL_double(REGISTER_SYCL_KERNELS); -#undef REGISTER_SYCL_KERNELS -#endif - #if GOOGLE_CUDA // Forward declarations of the functor specializations for GPU. namespace functor { @@ -400,6 +455,14 @@ REGISTER_KERNELS(GPU, Eigen::half); REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); #endif + +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_SYCL_KERNELS(T) REGISTER_KERNELS(SYCL, T); +TF_CALL_float(REGISTER_SYCL_KERNELS); +TF_CALL_double(REGISTER_SYCL_KERNELS); +#undef REGISTER_SYCL_KERNELS +#endif // TENSORFLOW_USE_SYCL + #undef REGISTER_CPU_KERNELS #undef REGISTER_KERNELS @@ -2422,6 +2485,120 @@ class ApplyAdamOp : public OpKernel { bool use_nesterov_; }; +#ifdef TENSORFLOW_USE_SYCL +template +class ApplyAdamOp < SYCLDevice, T> : public OpKernel { + public: + explicit ApplyAdamOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); + } + + void Compute(OpKernelContext* ctx) override { + auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2}); + + Tensor var; + OP_REQUIRES_OK(ctx, GetInputTensorFromVariable(ctx, 0, use_exclusive_lock_, &var)); + Tensor m; + OP_REQUIRES_OK(ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, &m)); + Tensor v; + OP_REQUIRES_OK(ctx, GetInputTensorFromVariable(ctx, 2, use_exclusive_lock_, &v)); + OP_REQUIRES( + ctx, var.IsInitialized(), + errors::FailedPrecondition( + "Attempting to use uninitialized variables: ", def().input(0))); + OP_REQUIRES( + ctx, m.IsInitialized(), + errors::FailedPrecondition( + "Attempting to use uninitialized variables: ", def().input(1))); + OP_REQUIRES( + ctx, v.IsInitialized(), + errors::FailedPrecondition( + "Attempting to use uninitialized variables: ", def().input(2))); + + const Tensor& beta1_power_dev = ctx->input(3); + const Tensor& beta2_power_dev = ctx->input(4); + const Tensor& lr_dev = ctx->input(5); + const Tensor& beta1_dev = ctx->input(6); + const Tensor& beta2_dev = ctx->input(7); + const Tensor& epsilon_dev = ctx->input(8); + + T beta1_power = 0; + T beta2_power = 0; + T lr = 0; + T beta1 = 0; + T beta2 = 0; + T epsilon = 0; + + auto device = ctx->eigen_sycl_device(); + auto size = sizeof(T); + auto src_ptr = GetBase(&beta1_power_dev); + device.memcpyDeviceToHost(&beta1_power, static_cast(src_ptr), size); + + src_ptr = GetBase(&beta2_power_dev); + device.memcpyDeviceToHost(&beta2_power, static_cast(src_ptr), size); + + src_ptr = GetBase(&lr_dev); + device.memcpyDeviceToHost(&lr, static_cast(src_ptr), size); + + src_ptr = GetBase(&beta1_dev); + device.memcpyDeviceToHost(&beta1, static_cast(src_ptr), size); + + src_ptr = GetBase(&beta2_dev); + device.memcpyDeviceToHost(&beta2, static_cast(src_ptr), size); + + src_ptr = GetBase(&epsilon_dev); + device.memcpyDeviceToHost(&epsilon, static_cast(src_ptr), size); + + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_power_dev.shape()), + errors::InvalidArgument("beta1_power is not a scalar: ", + beta1_power_dev.shape().DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_power_dev.shape()), + errors::InvalidArgument("beta2_power is not a scalar: ", + beta2_power_dev.shape().DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_dev.shape()), + errors::InvalidArgument("lr is not a scalar : ", + lr_dev.shape().DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_dev.shape()), + errors::InvalidArgument("beta1 is not a scalar: ", + beta1_dev.shape().DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_dev.shape()), + errors::InvalidArgument("beta2 is not a scalar: ", + beta2_dev.shape().DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_dev.shape()), + errors::InvalidArgument("epsilon is not a scalar: ", + epsilon_dev.shape().DebugString())); + + const Tensor& grad = ctx->input(9); + + OP_REQUIRES(ctx, var.shape().IsSameSize(m.shape()), + errors::InvalidArgument("var and m do not have the same shape", + var.shape().DebugString(), " ", + m.shape().DebugString())); + OP_REQUIRES(ctx, var.shape().IsSameSize(v.shape()), + errors::InvalidArgument("var and v do not have the same shape", + var.shape().DebugString(), " ", + v.shape().DebugString())); + OP_REQUIRES( + ctx, var.shape().IsSameSize(grad.shape()), + errors::InvalidArgument("var and grad do not have the same shape", + var.shape().DebugString(), " ", + grad.shape().DebugString())); + + functor::ApplyAdamSYCL()(device, var.flat(), m.flat(), + v.flat(), beta1_power, + beta2_power, lr, + beta1, beta2, + epsilon, grad.flat()); + + MaybeForwardRefInputToRefOutput(ctx, 0, 0); + } + + private: + bool use_exclusive_lock_; +}; +#endif // TENSORFLOW_USE_SYCL + using CPUDevice = Eigen::ThreadPoolDevice; using GPUDevice = Eigen::GpuDevice; diff --git a/tensorflow/core/kernels/transpose_functor_cpu.cc b/tensorflow/core/kernels/transpose_functor_cpu.cc index 97426efab923cb..248c11976e70c4 100644 --- a/tensorflow/core/kernels/transpose_functor_cpu.cc +++ b/tensorflow/core/kernels/transpose_functor_cpu.cc @@ -144,8 +144,42 @@ Status DoTranspose(const CPUDevice& d, const Tensor& in, #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; +template +void TransposeSYCL(const Device& d, const Tensor& in, + const gtl::ArraySlice perm, Tensor* out) { + switch (in.dims()) { + case 1: + internal::TransposeUsingEigen(d, in, perm, out); + break; + case 2: + internal::TransposeUsingEigen(d, in, perm, out); + break; + case 3: + internal::TransposeUsingEigen(d, in, perm, out); + break; + case 4: + internal::TransposeUsingEigen(d, in, perm, out); + break; + case 5: + internal::TransposeUsingEigen(d, in, perm, out); + break; + case 6: + internal::TransposeUsingEigen(d, in, perm, out); + break; + case 7: + internal::TransposeUsingEigen(d, in, perm, out); + break; + case 8: + internal::TransposeUsingEigen(d, in, perm, out); + break; + default: + LOG(FATAL) << "Unsupported TransposeUsingEigen for: " << in.dims(); + break; + } +} + template -struct internal::Transpose { +struct Transpose { static void run(const SYCLDevice& d, const Tensor& in, const gtl::ArraySlice perm, Tensor* out) { // Should add a specialized implementation for SYCLDevice here. @@ -160,10 +194,36 @@ Status DoTranspose(const SYCLDevice& d, const Tensor& in, CHECK_EQ(in.dims(), perm.size()); CHECK_EQ(in.dtype(), out->dtype()); switch (in.dtype()) { + case DT_BOOL: + case DT_INT8: + case DT_QINT8: + case DT_QUINT8: + case DT_UINT8: + TransposeSYCL(d, in, perm, out); + break; + + case DT_BFLOAT16: + case DT_HALF: + case DT_INT16: + case DT_QINT16: + case DT_QUINT16: + case DT_UINT16: + TransposeSYCL(d, in, perm, out); + break; case DT_FLOAT: - case DT_DOUBLE: case DT_INT32: - internal::Transpose::run(d, in, perm, out); + case DT_QINT32: + TransposeSYCL(d, in, perm, out); + break; + + case DT_COMPLEX64: + case DT_DOUBLE: + case DT_INT64: + TransposeSYCL(d, in, perm, out); + break; + + case DT_COMPLEX128: + TransposeSYCL(d, in, perm, out); break; default: diff --git a/tensorflow/core/kernels/unique_op.cc b/tensorflow/core/kernels/unique_op.cc index d50e2060acfe64..b57e13a28c39bc 100644 --- a/tensorflow/core/kernels/unique_op.cc +++ b/tensorflow/core/kernels/unique_op.cc @@ -115,4 +115,23 @@ REGISTER_KERNEL_BUILDER(Name("Unique") .HostMemory("y") .HostMemory("idx"), UniqueOp); + +#ifdef TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER(Name("Unique") + .Device(DEVICE_SYCL) + .TypeConstraint("T") + .TypeConstraint("out_idx") + .HostMemory("x") + .HostMemory("y") + .HostMemory("idx"), + UniqueOp); +REGISTER_KERNEL_BUILDER(Name("Unique") + .Device(DEVICE_SYCL) + .TypeConstraint("T") + .TypeConstraint("out_idx") + .HostMemory("x") + .HostMemory("y") + .HostMemory("idx"), + UniqueOp); +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/lib/random/random_distributions.h b/tensorflow/core/lib/random/random_distributions.h index 03b155344ce88a..c15a6436d6ea90 100644 --- a/tensorflow/core/lib/random/random_distributions.h +++ b/tensorflow/core/lib/random/random_distributions.h @@ -27,6 +27,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/lib/random/philox_random.h" + namespace tensorflow { namespace random { @@ -373,7 +374,7 @@ class TruncatedNormalDistribution { BoxMullerFloat(x0, x1, &f[0], &f[1]); for (int i = 0; i < 2; ++i) { - if (fabs(f[i]) < kTruncateValue) { + if (Eigen::numext::abs(f[i]) < kTruncateValue) { results[index++] = Eigen::half(f[i]); if (index >= kResultElementCount) { return results; @@ -416,7 +417,7 @@ class TruncatedNormalDistribution { BoxMullerFloat(x0, x1, &f[0], &f[1]); for (int i = 0; i < 2; ++i) { - if (fabs(f[i]) < kTruncateValue) { + if (Eigen::numext::abs(f[i]) < kTruncateValue) { results[index++] = f[i]; if (index >= kResultElementCount) { return results; @@ -458,7 +459,7 @@ class TruncatedNormalDistribution { BoxMullerDouble(x0, x1, x2, x3, &d[0], &d[1]); for (int i = 0; i < 2; ++i) { - if (fabs(d[i]) < kTruncateValue) { + if (Eigen::numext::abs(d[i]) < kTruncateValue) { results[index++] = d[i]; if (index >= kResultElementCount) { return results; @@ -483,12 +484,12 @@ void BoxMullerFloat(uint32 x0, uint32 x1, float* f0, float* f1) { u1 = epsilon; } const float v1 = 2.0f * M_PI * Uint32ToFloat(x1); - const float u2 = sqrt(-2.0f * log(u1)); -#if defined(__linux__) - sincosf(v1, f0, f1); + const float u2 = Eigen::numext::sqrt(-2.0f * Eigen::numext::log(u1)); +#if defined(TENSORFLOW_USE_SYCL) || !defined(__linux__) + *f0 = Eigen::numext::sin(v1); + *f1 = Eigen::numext::cos(v1); #else - *f0 = sinf(v1); - *f1 = cosf(v1); + sincosf(v1, f0, f1); #endif *f0 *= u2; *f1 *= u2; @@ -509,12 +510,12 @@ void BoxMullerDouble(uint32 x0, uint32 x1, uint32 x2, uint32 x3, double* d0, u1 = epsilon; } const double v1 = 2 * M_PI * Uint64ToDouble(x2, x3); - const double u2 = sqrt(-2.0 * log(u1)); -#if defined(__linux__) - sincos(v1, d0, d1); + const double u2 = Eigen::numext::sqrt(-2.0 * Eigen::numext::log(u1)); +#if defined(TENSORFLOW_USE_SYCL) || !defined(__linux__) + *d0 = Eigen::numext::sin(v1); + *d1 = Eigen::numext::cos(v1); #else - *d0 = sin(v1); - *d1 = cos(v1); + sincos(v1, d0, d1); #endif *d0 *= u2; *d1 *= u2; diff --git a/tensorflow/python/debug/cli/analyzer_cli_test.py b/tensorflow/python/debug/cli/analyzer_cli_test.py index 8b191f332e8975..ce224fff208f8e 100644 --- a/tensorflow/python/debug/cli/analyzer_cli_test.py +++ b/tensorflow/python/debug/cli/analyzer_cli_test.py @@ -498,7 +498,8 @@ def setUpClass(cls): cls._is_gpu_available = test.is_gpu_available() if cls._is_gpu_available: - cls._main_device = "/job:localhost/replica:0/task:0/gpu:0" + gpu_name = test_util.gpu_device_name() + cls._main_device = "/job:localhost/replica:0/task:0" + gpu_name else: cls._main_device = "/job:localhost/replica:0/task:0/cpu:0" @@ -1461,7 +1462,8 @@ def setUpClass(cls): cls._is_gpu_available = test.is_gpu_available() if cls._is_gpu_available: - cls._main_device = "/job:localhost/replica:0/task:0/gpu:0" + gpu_name = test_util.gpu_device_name() + cls._main_device = "/job:localhost/replica:0/task:0" + gpu_name else: cls._main_device = "/job:localhost/replica:0/task:0/cpu:0" diff --git a/tensorflow/python/debug/lib/session_debug_testlib.py b/tensorflow/python/debug/lib/session_debug_testlib.py index ef1d1bb2f316b9..1b4444ef5a93e4 100644 --- a/tensorflow/python/debug/lib/session_debug_testlib.py +++ b/tensorflow/python/debug/lib/session_debug_testlib.py @@ -81,7 +81,8 @@ def setUpClass(cls): if test.is_gpu_available(): cls._expected_partition_graph_count = 2 cls._expected_num_devices = 2 - cls._main_device = "/job:localhost/replica:0/task:0/gpu:0" + gpu_name = test_util.gpu_device_name() + cls._main_device = "/job:localhost/replica:0/task:0" + gpu_name else: cls._expected_partition_graph_count = 1 cls._expected_num_devices = 1 diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index d4ab4ca7aa4ac8..98819fb8df89f6 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -1402,7 +1402,7 @@ def b1(i, x): self.assertEqual(45, rx.eval()) def _testWhileGrad_ColocateGradients(self, colocate): - gpu_dev_name = test.gpu_device_name() if test.is_gpu_available( + gpu_dev_name = test.gpu_device_name().lower() if test.is_gpu_available( ) else "/gpu:0" gpu_short_name = gpu_dev_name.split("/")[-1] diff --git a/third_party/sycl/crosstool/CROSSTOOL.tpl b/third_party/sycl/crosstool/CROSSTOOL.tpl index 19b6f3ae32cbb0..2a96cdbf95cf56 100755 --- a/third_party/sycl/crosstool/CROSSTOOL.tpl +++ b/third_party/sycl/crosstool/CROSSTOOL.tpl @@ -7,6 +7,11 @@ default_toolchain { toolchain_identifier: "local_linux" } +default_toolchain { + cpu: "arm" + toolchain_identifier: "local_arm" +} + toolchain { abi_version: "local" abi_libc_version: "local" @@ -49,6 +54,7 @@ toolchain { cxx_builtin_include_directory: "/usr/include" cxx_builtin_include_directory: "%{computecpp_toolkit_path}" + cxx_builtin_include_directory: "%{python_lib_path}" tool_path { name: "gcov" path: "/usr/bin/gcov" } @@ -101,3 +107,96 @@ toolchain { compiler_flag: "-DNDEBUG" } } + +toolchain { + abi_version: "local" + abi_libc_version: "local" + builtin_sysroot: "" + compiler: "compiler" + host_system_name: "local" + needsPic: true + supports_gold_linker: false + supports_incremental_linker: false + supports_fission: false + supports_interface_shared_objects: false + supports_normalizing_ar: false + supports_start_end_lib: false + supports_thin_archives: false + target_libc: "local" + target_cpu: "local" + target_system_name: "local" + toolchain_identifier: "local_arm" + + tool_path { name: "ar" path: "/usr/bin/ar" } + tool_path { name: "compat-ld" path: "/usr/bin/ld" } + tool_path { name: "cpp" path: "/usr/bin/cpp" } + tool_path { name: "dwp" path: "/usr/bin/dwp" } + tool_path { name: "gcc" path: "computecpp" } + # Use "-std=c++11" for nvcc. For consistency, force both the host compiler + # and the device compiler to use "-std=c++11". + cxx_flag: "-std=c++11" + linker_flag: "-Wl,-no-as-needed" + linker_flag: "-lstdc++" + linker_flag: "-B/usr/bin/" + + # TODO(bazel-team): In theory, the path here ought to exactly match the path + # used by gcc. That works because bazel currently doesn't track files at + # absolute locations and has no remote execution, yet. However, this will need + # to be fixed, maybe with auto-detection? + cxx_builtin_include_directory: "/usr/lib/gcc/" + cxx_builtin_include_directory: "/usr/lib" + cxx_builtin_include_directory: "/usr/lib64" + cxx_builtin_include_directory: "/usr/local/include" + cxx_builtin_include_directory: "/usr/include" + + cxx_builtin_include_directory: "%{computecpp_toolkit_path}" + cxx_builtin_include_directory: "%{python_lib_path}" + + tool_path { name: "gcov" path: "/usr/bin/gcov" } + + # C(++) compiles invoke the compiler (as that is the one knowing where + # to find libraries), but we provide LD so other rules can invoke the linker. + tool_path { name: "ld" path: "/usr/bin/ld" } + + tool_path { name: "nm" path: "/usr/bin/nm" } + tool_path { name: "objcopy" path: "/usr/bin/objcopy" } + objcopy_embed_flag: "-I" + objcopy_embed_flag: "binary" + tool_path { name: "objdump" path: "/usr/bin/objdump" } + tool_path { name: "strip" path: "/usr/bin/strip" } + + # Make C++ compilation deterministic. Use linkstamping instead of these + # compiler symbols. + unfiltered_cxx_flag: "-Wno-builtin-macro-redefined" + unfiltered_cxx_flag: "-D__DATE__=\"redacted\"" + unfiltered_cxx_flag: "-D__TIMESTAMP__=\"redacted\"" + unfiltered_cxx_flag: "-D__TIME__=\"redacted\"" + + # All warnings are enabled. Maybe enable -Werror as well? + compiler_flag: "-Wall" + + # Anticipated future default. + linker_flag: "-Wl,-no-as-needed" + # Stamp the binary with a unique identifier. + linker_flag: "-Wl,--build-id=md5" + linker_flag: "-Wl,--hash-style=gnu" + + linking_mode_flags { mode: DYNAMIC } + + compilation_mode_flags { + mode: FASTBUILD + compiler_flag: "-O0" + } + + compilation_mode_flags { + mode: DBG + compiler_flag: "-g" + } + + compilation_mode_flags { + mode: OPT + compiler_flag: "-g0" + compiler_flag: "-O2" + compiler_flag: "-DNDEBUG" + } +} diff --git a/third_party/sycl/crosstool/computecpp.tpl b/third_party/sycl/crosstool/computecpp.tpl index 595e7136a61ba2..94c5e6aaad0cdb 100755 --- a/third_party/sycl/crosstool/computecpp.tpl +++ b/third_party/sycl/crosstool/computecpp.tpl @@ -1,8 +1,9 @@ #!/usr/bin/env python import os -import subprocess import sys +import tempfile +from subprocess import call, Popen, PIPE CPU_CXX_COMPILER = ('%{host_cxx_compiler}') CPU_C_COMPILER = ('%{host_c_compiler}') @@ -13,76 +14,81 @@ COMPUTECPP_DRIVER= COMPUTECPP_ROOT + 'bin/compute++' COMPUTECPP_INCLUDE = COMPUTECPP_ROOT + 'include' def main(): - compiler_flags = [] - - # remove -fsamotoze-coverage from string - if CPU_CXX_COMPILER.find("g++") != -1: - compiler_flags = [flag for flag in sys.argv[1:] if not flag.startswith(('-Wl,--no-undefined', '-fsanitize-coverage', '-Wno-unused-but-set-variable', '-Wignored-attributes'))] - else: - compiler_flags = [flag for flag in sys.argv[1:] if not flag.startswith(('-Wl,--no-undefined', '-Wno-unused-but-set-variable', '-Wignored-attributes'))] + remove_flags = ('-Wl,--no-undefined', '-Wno-unused-but-set-variable', '-Wignored-attributes') + # remove -fsamotoze-coverage from string with g++ + if 'g++' in CPU_CXX_COMPILER: + remove_flags += ('-fsanitize-coverage',) + compiler_flags = [flag for flag in sys.argv[1:] if not flag.startswith(remove_flags)] output_file_index = compiler_flags.index('-o') + 1 output_file_name = compiler_flags[output_file_index] - if(output_file_index == 1): + if output_file_index == 1: # we are linking - return subprocess.call([CPU_CXX_COMPILER] + compiler_flags + ['-Wl,--no-undefined']) + return call([CPU_CXX_COMPILER] + compiler_flags + ['-Wl,--no-undefined']) # find what we compile - compiling_cpp = 0 - if('-c' in compiler_flags): - compiled_file_index = compiler_flags.index('-c') + 1 - compited_file_name = compiler_flags[compiled_file_index] - if(compited_file_name.endswith(('.cc', '.c++', '.cpp', '.CPP', '.C', '.cxx'))): - compiling_cpp = 1; - - compiler_flags = compiler_flags + ['-D_GLIBCXX_USE_CXX11_ABI=0', '-DEIGEN_USE_SYCL=1', '-DTENSORFLOW_USE_SYCL', '-DEIGEN_HAS_C99_MATH'] - - if(compiling_cpp == 1): - # create a blacklist of folders that will be skipped when compiling with ComputeCpp - _skip = ["external", "llvm", ".cu.cc"] - # if compiling external project skip computecpp - if any(_folder in _skip for _folder in output_file_name): - return subprocess.call([CPU_CXX_COMPILER] + compiler_flags) - - if(compiling_cpp == 1): - # this is an optimisation that will check if compiled file has to be compiled with ComputeCpp - - _tmp_flags = [flag for flag in compiler_flags if not flag.startswith(('-o', output_file_name))] - # create preprocessed of the file - _cmd = " ".join([CPU_CXX_COMPILER] + _tmp_flags + ["-E"]) - # check if it has parallel_for< in it - _cmd += " | grep \".parallel_for\" > /dev/null" - ps = subprocess.call(_cmd, shell=True) - # if not call CXX compiler - if(ps != 0): - return subprocess.call([CPU_CXX_COMPILER] + compiler_flags) - - if(compiling_cpp == 1): - filename, file_extension = os.path.splitext(output_file_name) - bc_out = filename + '.sycl' - - # strip asan for the device - computecpp_device_compiler_flags = ['-sycl-compress-name', '-DTENSORFLOW_USE_SYCL', '-Wno-unused-variable', '-I', COMPUTECPP_INCLUDE, '-isystem', - COMPUTECPP_INCLUDE, '-std=c++11', '-sycl', '-emit-llvm', '-no-serial-memop', '-Xclang', '-cl-denorms-are-zero', '-Xclang', '-cl-fp32-correctly-rounded-divide-sqrt'] - computecpp_device_compiler_flags += [flag for flag in compiler_flags if not flag.startswith(('-fsanitize', '-march=native', '-mavx'))] - - x = subprocess.call([COMPUTECPP_DRIVER] + computecpp_device_compiler_flags ) - if(x == 0): - # dont want that in case of compiling with computecpp first - host_compiler_flags = [flag for flag in compiler_flags - if not flag.startswith(('-MF', '-MD',)) - if not '.d' in flag - ] - - host_compiler_flags[host_compiler_flags.index('-c')] = "--include" - - host_compiler_flags = ['-xc++', '-D_GLIBCXX_USE_CXX11_ABI=0', '-DTENSORFLOW_USE_SYCL', '-Wno-unused-variable', '-I', COMPUTECPP_INCLUDE, '-c', bc_out] + host_compiler_flags - x = subprocess.call([CPU_CXX_COMPILER] + host_compiler_flags) - return x - else: + compiling_cpp = False + if '-c' in compiler_flags: + compiled_file_index = compiler_flags.index('-c') + 1 + compiled_file_name = compiler_flags[compiled_file_index] + compiling_cpp = compiled_file_name.endswith(('.cc', '.c++', '.cpp', '.CPP', '.C', '.cxx')) + + # add -D_GLIBCXX_USE_CXX11_ABI=0 to the command line if you have custom installation of GCC/Clang + compiler_flags = compiler_flags + ['-DEIGEN_USE_SYCL=1', '-DTENSORFLOW_USE_SYCL', '-DEIGEN_HAS_C99_MATH'] + + if not compiling_cpp: # compile for C - return subprocess.call([CPU_C_COMPILER] + compiler_flags) + return call([CPU_C_COMPILER] + compiler_flags) + + # create a blacklist of folders that will be skipped when compiling with ComputeCpp + skip_extensions = [".cu.cc"] + skip_folders = ["tensorflow/compiler", "tensorflow/docs_src", "tensorflow/tensorboard", "third_party", "external", "hexagon"] + skip_folders = [(folder + '/') for folder in skip_folders] + # if compiling external project skip computecpp + if any(compiled_file_name.endswith(_ext) for _ext in skip_extensions) or any(_folder in output_file_name for _folder in skip_folders): + return call([CPU_CXX_COMPILER] + compiler_flags) + + # this is an optimisation that will check if compiled file has to be compiled with ComputeCpp + flags_without_output = list(compiler_flags) + del flags_without_output[output_file_index] # remove output_file_name + del flags_without_output[output_file_index - 1] # remove '-o' + # create preprocessed of the file and store it for later use + pipe = Popen([CPU_CXX_COMPILER] + flags_without_output + ["-E"], stdout=PIPE) + preprocessed_file_str = pipe.communicate()[0] + if pipe.returncode != 0: + return pipe.returncode + + # check if it has parallel_for in it + if not '.parallel_for' in preprocessed_file_str: + # call CXX compiler like usual + with tempfile.NamedTemporaryFile(suffix=".ii") as preprocessed_file: # Force '.ii' extension so that g++ does not preprocess the file again + preprocessed_file.write(preprocessed_file_str) + preprocessed_file.flush() + compiler_flags[compiled_file_index] = preprocessed_file.name + return call([CPU_CXX_COMPILER] + compiler_flags) + del preprocessed_file_str # save some memory as this string can be quite big + + filename, file_extension = os.path.splitext(output_file_name) + bc_out = filename + '.sycl' + + # strip asan for the device + computecpp_device_compiler_flags = ['-sycl-compress-name', '-Wno-unused-variable', + '-I', COMPUTECPP_INCLUDE, '-isystem', COMPUTECPP_INCLUDE, + '-std=c++11', '-sycl', '-emit-llvm', '-no-serial-memop', + '-Xclang', '-cl-denorms-are-zero', '-Xclang', '-cl-fp32-correctly-rounded-divide-sqrt'] + # disable flags enabling SIMD instructions + computecpp_device_compiler_flags += [flag for flag in compiler_flags if \ + not any(x in flag.lower() for x in ('-fsanitize', '=native', '=core2', 'msse', 'vectorize', 'mavx', 'mmmx', 'm3dnow', 'fma'))] + + x = call([COMPUTECPP_DRIVER] + computecpp_device_compiler_flags) + if x == 0: + # dont want that in case of compiling with computecpp first + host_compiler_flags = [flag for flag in compiler_flags if (not flag.startswith(('-MF', '-MD',)) and not '.d' in flag)] + host_compiler_flags[host_compiler_flags.index('-c')] = "--include" + host_compiler_flags = ['-xc++', '-Wno-unused-variable', '-I', COMPUTECPP_INCLUDE, '-c', bc_out] + host_compiler_flags + x = call([CPU_CXX_COMPILER] + host_compiler_flags) + return x if __name__ == '__main__': sys.exit(main()) diff --git a/third_party/sycl/sycl_configure.bzl b/third_party/sycl/sycl_configure.bzl index 6ad498487fef2a..7af063178e04af 100644 --- a/third_party/sycl/sycl_configure.bzl +++ b/third_party/sycl/sycl_configure.bzl @@ -5,18 +5,20 @@ * HOST_CXX_COMPILER: The host C++ compiler * HOST_C_COMPILER: The host C compiler * COMPUTECPP_TOOLKIT_PATH: The path to the ComputeCpp toolkit. + * PYTHON_LIB_PATH: The path to the python lib """ _HOST_CXX_COMPILER = "HOST_CXX_COMPILER" _HOST_C_COMPILER= "HOST_C_COMPILER" _COMPUTECPP_TOOLKIT_PATH = "COMPUTECPP_TOOLKIT_PATH" +_PYTHON_LIB_PATH = "PYTHON_LIB_PATH" def _enable_sycl(repository_ctx): if "TF_NEED_OPENCL" in repository_ctx.os.environ: enable_sycl = repository_ctx.os.environ["TF_NEED_OPENCL"].strip() return enable_sycl == "1" return False - + def auto_configure_fail(msg): """Output failure message when auto configuration fails.""" red = "\033[0;31m" @@ -55,7 +57,14 @@ def find_computecpp_root(repository_ctx): sycl_name = repository_ctx.os.environ[_COMPUTECPP_TOOLKIT_PATH].strip() if sycl_name.startswith("/"): return sycl_name - fail( "Cannot find SYCL compiler, please correct your path") + fail("Cannot find SYCL compiler, please correct your path") + +def find_python_lib(repository_ctx): + """Returns python path.""" + if _PYTHON_LIB_PATH in repository_ctx.os.environ: + return repository_ctx.os.environ[_PYTHON_LIB_PATH].strip() + fail("Environment variable PYTHON_LIB_PATH was not specified re-run ./configure") + def _check_lib(repository_ctx, toolkit_path, lib): """Checks if lib exists under sycl_toolkit_path or fail if it doesn't. @@ -168,12 +177,13 @@ def _sycl_autoconf_imp(repository_ctx): "%{host_c_compiler}" : find_c(repository_ctx), }) - computecpp_root = find_computecpp_root(repository_ctx); + computecpp_root = find_computecpp_root(repository_ctx) _check_dir(repository_ctx, computecpp_root) _tpl(repository_ctx, "crosstool:CROSSTOOL", { "%{computecpp_toolkit_path}" : computecpp_root, + "%{python_lib_path}" : find_python_lib(repository_ctx), }) # symlink libraries