Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LazyNVRTC #45674

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions aten/src/ATen/CMakeLists.txt
Expand Up @@ -51,6 +51,7 @@ file(GLOB cudnn_cpp "cudnn/*.cpp")

file(GLOB hip_h "hip/*.h" "hip/detail/*.h" "hip/*.cuh" "hip/detail/*.cuh" "hip/impl/*.h")
file(GLOB hip_cpp "hip/*.cpp" "hip/detail/*.cpp" "hip/impl/*.cpp")
list(REMOVE_ITEM hip_cpp "${CMAKE_CURRENT_SOURCE_DIR}/hip/detail/LazyNVRTC.cpp")
file(GLOB hip_hip "hip/*.hip" "hip/detail/*.hip" "hip/impl/*.hip")
file(GLOB hip_nvrtc_stub_h "hip/nvrtc_stub/*.h")
file(GLOB hip_nvrtc_stub_cpp "hip/nvrtc_stub/*.cpp")
Expand Down
10 changes: 9 additions & 1 deletion aten/src/ATen/cuda/detail/CUDAHooks.cpp
Expand Up @@ -28,6 +28,10 @@
#include <miopen/version.h>
#endif

#ifndef USE_ROCM
#include <ATen/cuda/detail/LazyNVRTC.h>
#endif

#include <cuda.h>

#include <sstream>
Expand Down Expand Up @@ -116,10 +120,14 @@ bool CUDAHooks::hasCuDNN() const {
return AT_CUDNN_ENABLED();
}

#ifdef USE_DIRECT_NVRTC
#if defined(USE_DIRECT_NVRTC)
static std::pair<std::unique_ptr<at::DynamicLibrary>, at::cuda::NVRTC*> load_nvrtc() {
return std::make_pair(nullptr, at::cuda::load_nvrtc());
}
#elif !defined(USE_ROCM)
static std::pair<std::unique_ptr<at::DynamicLibrary>, at::cuda::NVRTC*> load_nvrtc() {
return std::make_pair(nullptr, &at::cuda::detail::lazyNVRTC);
}
#else
static std::pair<std::unique_ptr<at::DynamicLibrary>, at::cuda::NVRTC*> load_nvrtc() {
#if defined(_WIN32)
Expand Down
171 changes: 171 additions & 0 deletions aten/src/ATen/cuda/detail/LazyNVRTC.cpp
@@ -0,0 +1,171 @@
#include <ATen/cuda/detail/LazyNVRTC.h>

#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
#include <ATen/DynamicLibrary.h>
#include <stdexcept>

namespace at {
namespace cuda {
namespace detail {
namespace _stubs {

at::DynamicLibrary& getCUDALibrary() {
#if defined(_WIN32)
static at::DynamicLibrary lib("nvcuda.dll");
#else
static at::DynamicLibrary lib("libcuda.so.1");
#endif
return lib;
}

at::DynamicLibrary& getNVRTCLibrary() {
constexpr auto major = CUDA_VERSION / 1000;
constexpr auto minor = ( CUDA_VERSION / 10 ) % 10;
#if defined(_WIN32)
auto libname = std::string("nvrtc64_") + std::to_string(major) + std::to_string(minor) + "_0.dll";
#else
static auto libname = std::string("libnvrtc.so.") + std::to_string(major) + "." + std::to_string(minor);
#endif
static at::DynamicLibrary lib(libname.c_str());
return lib;
}

#define _STUB_1(LIB, NAME, RETTYPE, ARG1) \
RETTYPE NAME(ARG1 a1) { \
auto fn = reinterpret_cast<decltype(&NAME)>(get## LIB ## Library().sym(__func__)); \
if (!fn) \
throw std::runtime_error("Can't get " C10_STRINGIZE(NAME) ); \
lazyNVRTC.NAME = fn; \
return fn(a1); \
}

#define _STUB_2(LIB, NAME, RETTYPE, ARG1, ARG2) \
RETTYPE NAME(ARG1 a1, ARG2 a2) { \
auto fn = reinterpret_cast<decltype(&NAME)>(get## LIB ## Library().sym(__func__)); \
if (!fn) \
throw std::runtime_error("Can't get " C10_STRINGIZE(NAME) ); \
lazyNVRTC.NAME = fn; \
return fn(a1, a2); \
}

#define _STUB_3(LIB, NAME, RETTYPE, ARG1, ARG2, ARG3) \
RETTYPE NAME(ARG1 a1, ARG2 a2, ARG3 a3) { \
auto fn = reinterpret_cast<decltype(&NAME)>(get## LIB ## Library().sym(__func__)); \
malfet marked this conversation as resolved.
Show resolved Hide resolved
if (!fn) \
throw std::runtime_error("Can't get " C10_STRINGIZE(NAME) ); \
lazyNVRTC.NAME = fn; \
return fn(a1, a2, a3); \
}

#define _STUB_4(LIB, NAME, RETTYPE, ARG1, ARG2, ARG3, ARG4) \
RETTYPE NAME(ARG1 a1, ARG2 a2, ARG3 a3, ARG4 a4) { \
auto fn = reinterpret_cast<decltype(&NAME)>(get## LIB ## Library().sym(__func__)); \
if (!fn) \
throw std::runtime_error("Can't get " C10_STRINGIZE(NAME) ); \
lazyNVRTC.NAME = fn; \
return fn(a1, a2, a3, a4); \
}

#define CUDA_STUB1(NAME, A1) _STUB_1(CUDA, NAME, CUresult CUDAAPI, A1)
#define CUDA_STUB2(NAME, A1, A2) _STUB_2(CUDA, NAME, CUresult CUDAAPI, A1, A2)
#define CUDA_STUB3(NAME, A1, A2, A3) _STUB_3(CUDA, NAME, CUresult CUDAAPI, A1, A2, A3)
#define CUDA_STUB4(NAME, A1, A2, A3, A4) _STUB_4(CUDA, NAME, CUresult CUDAAPI, A1, A2, A3, A4)

#define NVRTC_STUB1(NAME, A1) _STUB_1(NVRTC, NAME, nvrtcResult, A1)
#define NVRTC_STUB2(NAME, A1, A2) _STUB_2(NVRTC, NAME, nvrtcResult, A1, A2)
#define NVRTC_STUB3(NAME, A1, A2, A3) _STUB_3(NVRTC, NAME, nvrtcResult, A1, A2, A3)

NVRTC_STUB2(nvrtcVersion, int*, int*);
NVRTC_STUB2(nvrtcAddNameExpression, nvrtcProgram, const char * const);

nvrtcResult nvrtcCreateProgram(nvrtcProgram *prog,
const char *src,
const char *name,
int numHeaders,
const char * const *headers,
const char * const *includeNames) {
auto fn = reinterpret_cast<decltype(&nvrtcCreateProgram)>(getNVRTCLibrary().sym(__func__));
if (!fn)
throw std::runtime_error("Can't get nvrtcCreateProgram");
lazyNVRTC.nvrtcCreateProgram = fn;
return fn(prog, src, name, numHeaders, headers, includeNames);
}

NVRTC_STUB1(nvrtcDestroyProgram, nvrtcProgram *);
NVRTC_STUB2(nvrtcGetPTXSize, nvrtcProgram, size_t *);
NVRTC_STUB2(nvrtcGetPTX, nvrtcProgram, char *);
NVRTC_STUB3(nvrtcCompileProgram, nvrtcProgram, int, const char * const *);
_STUB_1(NVRTC, nvrtcGetErrorString, const char *, nvrtcResult);
NVRTC_STUB2(nvrtcGetProgramLogSize,nvrtcProgram, size_t*);
NVRTC_STUB2(nvrtcGetProgramLog, nvrtcProgram, char *);
NVRTC_STUB3(nvrtcGetLoweredName, nvrtcProgram, const char *, const char **);

CUDA_STUB2(cuModuleLoadData, CUmodule *, const void *);
CUDA_STUB3(cuModuleGetFunction, CUfunction *, CUmodule, const char *);
CUDA_STUB4(cuOccupancyMaxActiveBlocksPerMultiprocessor, int *, CUfunction, int, size_t);
CUDA_STUB2(cuGetErrorString, CUresult, const char **);
CUDA_STUB1(cuCtxGetCurrent, CUcontext *);
CUDA_STUB1(cuModuleUnload, CUmodule);
CUDA_STUB3(cuDevicePrimaryCtxGetState, CUdevice, unsigned int *, int *);
CUDA_STUB4(cuLinkCreate, unsigned int, CUjit_option *, void **, CUlinkState *);
CUDA_STUB3(cuLinkComplete, CUlinkState, void **, size_t *);

// Irregularly shaped functions
CUresult CUDAAPI cuLaunchKernel(CUfunction f,
unsigned int gridDimX,
unsigned int gridDimY,
unsigned int gridDimZ,
unsigned int blockDimX,
unsigned int blockDimY,
unsigned int blockDimZ,
unsigned int sharedMemBytes,
CUstream hStream,
void **kernelParams,
void **extra) {
auto fn = reinterpret_cast<decltype(&cuLaunchKernel)>(getCUDALibrary().sym(__func__));
if (!fn)
throw std::runtime_error("Can't get cuLaunchKernel");
lazyNVRTC.cuLaunchKernel = fn;
return fn(f,
gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ,
sharedMemBytes, hStream, kernelParams, extra);
}

CUresult CUDAAPI cuModuleLoadDataEx(CUmodule *module,
const void *image,
unsigned int numOptions,
CUjit_option *options,
void **optionValues) {
auto fn = reinterpret_cast<decltype(&cuModuleLoadDataEx)>(getCUDALibrary().sym(__func__));
if (!fn)
throw std::runtime_error("Can't get cuModuleLoadDataEx");
lazyNVRTC.cuModuleLoadDataEx = fn;
return fn(module, image, numOptions, options, optionValues);
}

CUresult CUDAAPI
cuLinkAddData(CUlinkState state,
CUjitInputType type,
void *data,
size_t size,
const char *name,
unsigned int numOptions,
CUjit_option *options,
void **optionValues) {
auto fn = reinterpret_cast<decltype(&cuLinkAddData)>(getCUDALibrary().sym(__func__));
if (!fn)
throw std::runtime_error("Can't get cuLinkAddData");
lazyNVRTC.cuLinkAddData = fn;
return fn(state, type, data, size, name, numOptions, options, optionValues);
}

} // namespace _stubs

NVRTC lazyNVRTC = {
#define _REFERENCE_MEMBER(name) _stubs::name,
AT_FORALL_NVRTC(_REFERENCE_MEMBER)
#undef _REFERENCE_MEMBER
};
} // namespace detail
} // namespace cuda
} // namespace at
11 changes: 11 additions & 0 deletions aten/src/ATen/cuda/detail/LazyNVRTC.h
@@ -0,0 +1,11 @@
#pragma once
#include <ATen/detail/CUDAHooksInterface.h>
namespace at { namespace cuda {
// Forward-declares at::cuda::NVRTC
struct NVRTC;

namespace detail {
extern NVRTC lazyNVRTC;
}

}} // at::cuda::detail
1 change: 1 addition & 0 deletions torch/csrc/jit/codegen/cuda/executor.cpp
Expand Up @@ -11,6 +11,7 @@
#include <ATen/core/LegacyTypeDispatch.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
#include <c10/core/DeviceGuard.h>
#include <c10/cuda/CUDAFunctions.h>
#include <c10/cuda/CUDAStream.h>
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/codegen/cuda/executor_utils.cpp
@@ -1,5 +1,6 @@
#include <ATen/CUDAGeneratorImpl.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>

#include <c10/cuda/CUDACachingAllocator.h>

Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/jit/codegen/cuda/executor_utils.h
@@ -1,11 +1,12 @@
#pragma once

#include <ATen/core/ivalue.h>
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>

#include <c10/core/DeviceType.h>
#include <c10/util/Exception.h>

#include <cuda.h>

#include <torch/csrc/jit/ir/ir.h>

#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
Expand Down