Skip to content
Permalink
Browse files

[driver] now passing std::unique_ptr<> instead of cloning LLVM module

when compiling it
  • Loading branch information...
ptillet committed Sep 5, 2019
1 parent 0a6329e commit 18848cbb71e620805fac0dfb741de0951c30f6c4
Showing with 28 additions and 50 deletions.
  1. +4 −12 CMakeLists.txt
  2. +0 −1 include/triton/driver/backend.h
  3. +6 −6 include/triton/driver/module.h
  4. +0 −8 lib/driver/backend.cc
  5. +16 −15 lib/driver/module.cc
  6. +2 −8 lib/runtime/function.cc
@@ -8,10 +8,9 @@ option(BUILD_TESTS "Build C++ Triton tests" ON)
option(BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF)

# LLVM
find_package(LLVM REQUIRED CONFIG)
find_package(LLVM REQUIRED)
include_directories(${LLVM_INCLUDE_DIRS})
add_definitions(${LLVM_DEFINITIONS})
llvm_map_components_to_libnames(llvm_libs all)

# Default build type
if(NOT CMAKE_BUILD_TYPE)
@@ -21,7 +20,7 @@ endif()

# Compiler flags
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${LLVM_CXXFLAGS} -std=c++11")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")

# Tests
if(BUILD_TESTS)
@@ -53,13 +52,6 @@ endif()
# Triton
file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
target_link_libraries(triton LLVM)

# Warning level
#if(MSVC)
# target_compile_options(triton PRIVATE /W4)
#else()
# target_compile_options(triton PRIVATE -Wno-unused-parameter -Wall -Wextra -pedantic)
#endif()

link_directories(${LLVM_LIBRARY_DIRS})
target_link_libraries(triton ${LLVM_LIBRARIES})

@@ -66,7 +66,6 @@ struct backend

public:
static void release();
static driver::module* get(driver::stream* stream, std::string const & name, llvm::Module *src);

private:
static std::map<std::tuple<driver::stream*, std::string>, driver::module*> cache_;
@@ -38,9 +38,9 @@ class module: public polymorphic_resource<CUmodule, cl_program, host_module_t> {
module(driver::context* ctx, CUmodule mod, bool has_ownership);
module(driver::context* ctx, cl_program mod, bool has_ownership);
module(driver::context* ctx, host_module_t mod, bool has_ownership);
static module* create(driver::context* ctx, llvm::Module *src);
static module* create(driver::context* ctx, std::unique_ptr<llvm::Module> src);
driver::context* context() const;
void compile_llvm_module(llvm::Module* module, const std::string& triple,
void compile_llvm_module(std::unique_ptr<llvm::Module> module, const std::string& triple,
const std::string &proc, std::string layout,
llvm::SmallVectorImpl<char> &buffer,
const std::string &features,
@@ -53,22 +53,22 @@ class module: public polymorphic_resource<CUmodule, cl_program, host_module_t> {
// CPU
class host_module: public module{
public:
host_module(driver::context* context, llvm::Module *module);
host_module(driver::context* context, std::unique_ptr<llvm::Module> module);
};

// OpenCL
class ocl_module: public module{

public:
ocl_module(driver::context* context, llvm::Module *module);
ocl_module(driver::context* context, std::unique_ptr<llvm::Module> module);
};

// CUDA
class cu_module: public module {
std::string compile_llvm_module(llvm::Module* module, driver::device* device);
std::string compile_llvm_module(std::unique_ptr<llvm::Module> module, driver::device* device);

public:
cu_module(driver::context* context, llvm::Module *module);
cu_module(driver::context* context, std::unique_ptr<llvm::Module> module);
cu_module(driver::context* context, const std::string& source);
cu_buffer* symbol(const char * name) const;

@@ -103,14 +103,6 @@ void backend::modules::release(){
cache_.clear();
}

driver::module* backend::modules::get(driver::stream* stream, std::string const & name, llvm::Module* src){
std::tuple<driver::stream*, std::string> key(stream, name);
if(cache_.find(key)==cache_.end()){
return &*cache_.insert({key, driver::module::create(stream->context(), src)}).first->second;
}
return &*cache_.at(key);
}

std::map<std::tuple<driver::stream*, std::string>, driver::module*> backend::modules::cache_;

/*-----------------------------------*/
@@ -76,16 +76,16 @@ driver::context* module::context() const {
return ctx_;
}

module* module::create(driver::context* ctx, llvm::Module *src) {
module* module::create(driver::context* ctx, std::unique_ptr<llvm::Module> src) {
switch(ctx->backend()){
case CUDA: return new cu_module(ctx, src);
case OpenCL: return new ocl_module(ctx, src);
case Host: return new host_module(ctx, src);
case CUDA: return new cu_module(ctx, std::move(src));
case OpenCL: return new ocl_module(ctx, std::move(src));
case Host: return new host_module(ctx, std::move(src));
default: throw std::runtime_error("unknown backend");
}
}

void module::compile_llvm_module(llvm::Module* module, const std::string& triple,
void module::compile_llvm_module(std::unique_ptr<llvm::Module> module, const std::string& triple,
const std::string &proc, std::string layout,
llvm::SmallVectorImpl<char> &buffer,
const std::string& features,
@@ -133,7 +133,7 @@ void module::compile_llvm_module(llvm::Module* module, const std::string& triple
// Host //
/* ------------------------ */

host_module::host_module(driver::context * context, llvm::Module* src): module(context, host_module_t(), true) {
host_module::host_module(driver::context * context, std::unique_ptr<llvm::Module> src): module(context, host_module_t(), true) {
init_llvm();
// host info
// std::string triple = llvm::sys::getDefaultTargetTriple();
@@ -147,7 +147,7 @@ host_module::host_module(driver::context * context, llvm::Module* src): module(c
llvm::Type *args_ty = llvm::Type::getInt8PtrTy(ctx)->getPointerTo();
llvm::Type *int32_ty = llvm::Type::getInt32Ty(ctx);
llvm::FunctionType *main_ty = llvm::FunctionType::get(void_ty, {args_ty, int32_ty, int32_ty, int32_ty}, false);
llvm::Function* main = llvm::Function::Create(main_ty, llvm::Function::ExternalLinkage, "main", src);
llvm::Function* main = llvm::Function::Create(main_ty, llvm::Function::ExternalLinkage, "main", &*src);
llvm::Function* fn = src->getFunction("matmul");
llvm::FunctionType *fn_ty = fn->getFunctionType();
std::vector<llvm::Value*> fn_args(fn_ty->getNumParams());
@@ -169,10 +169,9 @@ host_module::host_module(driver::context * context, llvm::Module* src): module(c


// create execution engine
auto cloned = llvm::CloneModule(*src);
for(llvm::Function& fn: cloned->functions())
for(llvm::Function& fn: src->functions())
hst_->functions[fn.getName()] = &fn;
llvm::EngineBuilder builder(std::move(cloned));
llvm::EngineBuilder builder(std::move(src));
builder.setErrorStr(&hst_->error);
builder.setMCJITMemoryManager(llvm::make_unique<llvm::SectionMemoryManager>());
builder.setOptLevel(llvm::CodeGenOpt::Aggressive);
@@ -185,7 +184,7 @@ host_module::host_module(driver::context * context, llvm::Module* src): module(c
// OpenCL //
/* ------------------------ */

ocl_module::ocl_module(driver::context * context, llvm::Module* src): module(context, cl_program(), true) {
ocl_module::ocl_module(driver::context * context, std::unique_ptr<llvm::Module> src): module(context, cl_program(), true) {
throw std::runtime_error("not supported");
// init_llvm();
// llvm::SmallVector<char, 0> buffer;
@@ -217,18 +216,20 @@ ocl_module::ocl_module(driver::context * context, llvm::Module* src): module(con
// CUDA //
/* ------------------------ */

std::string cu_module::compile_llvm_module(llvm::Module* module, driver::device* device) {
std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module, driver::device* device) {
// options
auto options = llvm::cl::getRegisteredOptions();
// for(auto& opt: options)
// std::cout << opt.getKey().str() << std::endl;
static_cast<llvm::cl::opt<bool>*>(options["nvptx-short-ptr"])->setValue(true);
auto* short_ptr = static_cast<llvm::cl::opt<bool>*>(options["nvptx-short-ptr"]);
assert(short_ptr);
short_ptr->setValue(true);
// compute capability
auto cc = ((driver::cu_device*)device)->compute_capability();
std::string sm = "sm_" + std::to_string(cc.first) + std::to_string(cc.second);
// create
llvm::SmallVector<char, 0> buffer;
module::compile_llvm_module(module, "nvptx64-nvidia-cuda", sm, "", buffer, "ptx63", Assembly);
module::compile_llvm_module(std::move(module), "nvptx64-nvidia-cuda", sm, "", buffer, "ptx63", Assembly);
std::string result(buffer.begin(), buffer.end());
size_t start_replace = result.find(".version");
size_t end_replace = result.find('\n', start_replace);
@@ -237,7 +238,7 @@ std::string cu_module::compile_llvm_module(llvm::Module* module, driver::device*
return result;
}

cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module, context->device())) { }
cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { }

cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
// std::cout << source << std::endl;
@@ -188,6 +188,7 @@ function::caller function::autotune(driver::stream* stream, const grid_fn_ty& gr

std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::context *context, const options_t& opt) {
std::unique_ptr<codegen::target> target = context->device()->make_target();

// create passes
codegen::analysis::grids grids(opt.num_warps);
codegen::analysis::meminfo shmem_info;
@@ -201,20 +202,13 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
codegen::transform::reassociate reassociate(&alignment_info, &grids);
codegen::selection selection(&shmem_allocation, &grids, &shmem_info, &alignment_info, target.get());



// run passes
peephole.run(module);
dce.run(module);
// ir::print(module, std::cout);
alignment_info.run(module);
grids.run(module);
// ir::print(module, std::cout);

reassociate.run(module);
dce.run(module);
// ir::print(module, std::cout);

peephole.run(module);

if(target->is_gpu()){
@@ -233,7 +227,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx));
selection.run(module, *llvm);
// return binary
std::unique_ptr<driver::module> res(driver::module::create(context, llvm.get()));
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
return res;
}

0 comments on commit 18848cb

Please sign in to comment.
You can’t perform that action at this time.