Skip to content
Browse files

[driver] added spirv-llvm dispatch functions

  • Loading branch information...
ptillet committed May 2, 2019
1 parent 70f49a5 commit 208d1525de4214d5c5589474bcb96d3e9b15c26b
Showing with 45 additions and 9 deletions.
  1. +2 −1 examples/cpp/dot.cpp
  2. +22 −6 include/triton/driver/dispatch.h
  3. +20 −0 lib/driver/dispatch.cpp
  4. +1 −2 lib/jit.cpp
@@ -103,7 +103,6 @@ int main() {
stream->write(da, true, 0, ha);
stream->write(db, true, 0, hb);
stream->write(dc, true, 0, hc);
stream->write(dlocks, true, 0, hlocks);

@@ -116,6 +115,8 @@ int main() {
unsigned nthreads = info.num_threads;
unsigned GZ = jit.get_int("GZ");
std::array<size_t, 3> grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, GZ};
// init locks
stream->write(dlocks, true, 0, hlocks);
// set argument
kernel->setArg(0, da);
kernel->setArg(1, db);
@@ -38,6 +38,11 @@
#include <iostream>
#include <stdexcept>

namespace llvm {
class PassRegistry;
class Module;

namespace triton
namespace driver
@@ -85,6 +90,7 @@ class dispatch
static bool cuinit();
static bool cublasinit();
static bool cudnninit();
static bool spvllvminit();
static void release();

// OpenCL
@@ -123,10 +129,9 @@ class dispatch
static cl_program clCreateProgramWithSource(cl_context, cl_uint, const char **, const size_t *, cl_int *);
static cl_int clReleaseKernel(cl_kernel);

static CUresult cuCtxGetCurrent(CUcontext *pctx);
static CUresult cuCtxSetCurrent(CUcontext ctx);

static CUresult cuCtxDestroy_v2(CUcontext ctx);
static CUresult cuEventCreate(CUevent *phEvent, unsigned int Flags);
static CUresult cuDeviceGet(CUdevice *device, int ordinal);
@@ -139,7 +144,6 @@ class dispatch
static CUresult cuDeviceGetName(char *name, int len, CUdevice dev);
static CUresult cuDeviceGetPCIBusId(char *id, int len, CUdevice dev);
static CUresult cuModuleGetGlobal_v2(CUdeviceptr *dptr, size_t* bytes, CUmodule hmod, const char *name);

static CUresult cuMemcpyHtoDAsync_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount, CUstream hStream);
static CUresult cuModuleLoad(CUmodule *module, const char *fname);
static CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void **kernelParams, void **extra);
@@ -161,12 +165,12 @@ class dispatch
static CUresult cuPointerGetAttribute(void * data, CUpointer_attribute attribute, CUdeviceptr ptr);
static CUresult cuCtxGetDevice(CUdevice* result);
static CUresult cuMemsetD8Async(CUdeviceptr dst, unsigned char x, size_t N, CUstream stream);

static nvmlReturn_t nvmlDeviceGetHandleByPciBusId_v2( const char* pciBusId, nvmlDevice_t* device);
static nvmlReturn_t nvmlDeviceGetClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock);
static nvmlReturn_t nvmlDeviceGetMaxClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock);
static nvmlReturn_t nvmlDeviceSetApplicationsClocks(nvmlDevice_t device, unsigned int mem_clock, unsigned int sm_clock);

static cublasHandle_t cublasHandle(driver::cu_context const & ctx);
static cublasStatus_t cublasCreate_v2(cublasHandle_t* h);
static cublasStatus_t cublasGetStream_v2(cublasHandle_t h, cudaStream_t *streamId);
@@ -175,7 +179,7 @@ class dispatch
static cublasStatus_t cublasDgemm_v2 (cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, double* alpha, const double *A, int lda, const double *B, int ldb, double* beta, double *C, int ldc);
static cublasStatus_t cublasHgemm (cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, half* alpha, const half *A, int lda, const half *B, int ldb, half* beta, half *C, int ldc);
static cublasStatus_t cublasGemmEx(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const void *alpha, const void *A, cudaDataType Atype, int lda, const void *B, cudaDataType Btype, int ldb, const void *beta, void *C, cudaDataType Ctype, int ldc, cudaDataType computeType, cublasGemmAlgo_t algo);

static cudnnHandle_t cudnnHandle(driver::cu_context const & ctx);
static cudnnStatus_t cudnnCreatePoolingDescriptor(cudnnPoolingDescriptor_t *poolingDesc);
static cudnnStatus_t cudnnCreateConvolutionDescriptor(cudnnConvolutionDescriptor_t* convDesc);
@@ -196,6 +200,10 @@ class dispatch
static cudnnStatus_t cudnnSetStream(cudnnHandle_t handle, cudaStream_t streamId);
static cudnnStatus_t cudnnTransformTensor(cudnnHandle_t handle, const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, const void *beta, const cudnnTensorDescriptor_t yDesc, void *y);

// SPIR-V libraries
static int initializeLLVMToSPIRVPass(llvm::PassRegistry &);
static bool writeSpirv(llvm::Module *M, std::ostream &OS, std::string &ErrMsg);


// Libraries
@@ -204,6 +212,10 @@ class dispatch
static void* nvml_;
static void* cublas_;
static void* cudnn_;
static void* vulkan_;
static void* spvllvm_;
static void* spvcross_;
static void* opengl_;

// OpenCL functions
static void* clBuildProgram_;
@@ -310,6 +322,10 @@ class dispatch
static void* cudnnPoolingForward_;
static void* cudnnSetStream_;
static void* cudnnTransformTensor_;

static void* initializeLLVMToSPIRVPass_;
static void* writeSpirv_;

@@ -158,6 +158,12 @@ bool dispatch::cudnninit(){
return cudnn_ != nullptr;

bool dispatch::spvllvminit(){
spvllvm_ = dlopen("", RTLD_LAZY);
return spvllvm_ != nullptr;

CUDA_DEFINE1(CUresult, cuCtxDestroy_v2, CUcontext)
CUDA_DEFINE2(CUresult, cuEventCreate, CUevent *, unsigned int)
@@ -292,6 +298,15 @@ OCL_DEFINE5(cl_mem, clCreateBuffer, cl_context, cl_mem_flags, size_t, void *, cl
OCL_DEFINE5(cl_program, clCreateProgramWithSource, cl_context, cl_uint, const char **, const size_t *, cl_int *)
OCL_DEFINE1(cl_int, clReleaseKernel, cl_kernel)

int dispatch::initializeLLVMToSPIRVPass(llvm::PassRegistry &registry){
return f_impl<dispatch::spvllvminit>(spvllvm_, initializeLLVMToSPIRVPass, initializeLLVMToSPIRVPass_, "initializeLLVMToSPIRVPass", std::ref(registry));

bool dispatch::writeSpirv(llvm::Module *M, std::ostream &OS, std::string &ErrMsg){
return f_impl<dispatch::spvllvminit>(spvllvm_, writeSpirv, writeSpirv_, "writeSpirv", M, std::ref(OS), std::ref(ErrMsg));

// Release
void dispatch::release(){
@@ -313,6 +328,7 @@ void* dispatch::cuda_;
void* dispatch::nvml_;
void* dispatch::cublas_;
void* dispatch::cudnn_;
void* dispatch::spvllvm_;

void* dispatch::clBuildProgram_;
@@ -421,5 +437,9 @@ void* dispatch::cudnnPoolingForward_;
void* dispatch::cudnnSetStream_;
void* dispatch::cudnnTransformTensor_;

void* dispatch::initializeLLVMToSPIRVPass_;
void* dispatch::writeSpirv_;

@@ -71,12 +71,11 @@ std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, passes_w
llvm::Module* result = new llvm::Module(module.get_name(), llvm_context_);, *result);
// launch information
launch_information info;
launch_information& info = launch_info_map_[result->getName()];
for(unsigned i = 0; i < passes.tune.get_num_global_range(); i++)
info.num_threads = passes.tune.get_num_threads();
launch_info_map_.insert({result->getName(), info});
return std::unique_ptr<llvm::Module>(result);

0 comments on commit 208d152

Please sign in to comment.
You can’t perform that action at this time.