Skip to content

Commit

Permalink
Enable stream-per-thread
Browse files Browse the repository at this point in the history
  • Loading branch information
wweic committed Feb 22, 2021
1 parent 84359a9 commit 8db2517
Show file tree
Hide file tree
Showing 6 changed files with 8 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/contrib/tf_op/tvm_dso_op_kernels.cc
Expand Up @@ -18,6 +18,7 @@
*/

#ifdef TF_TVMDSOOP_ENABLE_GPU
#define CUDA_API_PER_THREAD_DEFAULT_STREAM
#include <cuda_runtime.h>
#endif
#include <dlpack/dlpack.h>
Expand Down
1 change: 1 addition & 0 deletions src/runtime/contrib/cublas/cublas_utils.h
Expand Up @@ -24,6 +24,7 @@
#ifndef TVM_RUNTIME_CONTRIB_CUBLAS_CUBLAS_UTILS_H_
#define TVM_RUNTIME_CONTRIB_CUBLAS_CUBLAS_UTILS_H_

#define CUDA_API_PER_THREAD_DEFAULT_STREAM
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
Expand Down
1 change: 1 addition & 0 deletions src/runtime/cuda/cuda_common.h
Expand Up @@ -24,6 +24,7 @@
#ifndef TVM_RUNTIME_CUDA_CUDA_COMMON_H_
#define TVM_RUNTIME_CUDA_CUDA_COMMON_H_

#define CUDA_API_PER_THREAD_DEFAULT_STREAM
#include <cuda_runtime.h>
#include <tvm/runtime/packed_func.h>

Expand Down
1 change: 1 addition & 0 deletions src/runtime/cuda/cuda_device_api.cc
Expand Up @@ -21,6 +21,7 @@
* \file cuda_device_api.cc
* \brief GPU specific API
*/
#define CUDA_API_PER_THREAD_DEFAULT_STREAM
#include <cuda.h>
#include <cuda_runtime.h>
#include <dmlc/thread_local.h>
Expand Down
1 change: 1 addition & 0 deletions src/runtime/cuda/cuda_module.cc
Expand Up @@ -22,6 +22,7 @@
*/
#include "cuda_module.h"

#define CUDA_API_PER_THREAD_DEFAULT_STREAM
#include <cuda.h>
#include <cuda_runtime.h>
#include <tvm/runtime/registry.h>
Expand Down
3 changes: 3 additions & 0 deletions src/target/opt/build_cuda_on.cc
Expand Up @@ -26,6 +26,7 @@
#if defined(__linux__)
#include <sys/stat.h>
#endif
#define CUDA_API_PER_THREAD_DEFAULT_STREAM
#include <cuda_runtime.h>
#include <nvrtc.h>

Expand Down Expand Up @@ -143,7 +144,9 @@ runtime::Module BuildCUDA(IRModule mod, Target target) {
}
std::string fmt = "ptx";
std::string ptx;
code = "#define CUDA_API_PER_THREAD_DEFAULT_STREAM \n\n" + code;
if (const auto* f = Registry::Get("tvm_callback_cuda_compile")) {
std::cout << "code is \n" << code << std::endl;
ptx = (*f)(code).operator std::string();
// Dirty matching to check PTX vs cubin.
// TODO(tqchen) more reliable checks
Expand Down

0 comments on commit 8db2517

Please sign in to comment.