Skip to content
Permalink
Browse files

[triton-c] predicate in assignment statement now propagates to rhs

computations
  • Loading branch information...
ptillet committed Apr 27, 2019
1 parent 4b77b76 commit af58b8bd81f6c01581a2bd26459941e5b7c998b6
@@ -69,7 +69,7 @@ void matmul(restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C,
*pcount = countp1;
}
else {
*pc = c + (checkc ? *pc : 0);
@checkc *pc = c + *pc;
*pcount = countp1;
}
__atomic_cas(plock, 1, 0);
@@ -410,13 +410,13 @@ class statement: public block_item{
class expression_statement: public statement{
public:
expression_statement(node *expr, node *mask = nullptr)
: expr_((expression*)expr), mask_((expression*)mask){ }
: expr_((expression*)expr), pred_((expression*)mask){ }

ir::value* codegen(ir::module * mod) const;

private:
expression *expr_;
expression *mask_;
expression *pred_;
};


@@ -335,15 +335,15 @@ class mask_inst: public instruction {
};

// merge
class merge_inst: public instruction {
class psi_inst: public instruction {
private:
std::string repr_impl() const { return "merge"; }
merge_inst(ir::value *mask_true, ir::value *value_true,
psi_inst(ir::value *mask_true, ir::value *value_true,
ir::value *mask_false, ir::value *value_false,
const std::string &name, instruction *next);

public:
static merge_inst* create(ir::value *mask_true, ir::value *value_true,
static psi_inst* create(ir::value *mask_true, ir::value *value_true,
ir::value *mask_false, ir::value *value_false,
const std::string &name = "", instruction *next = nullptr);
ir::value *get_mask_true() { return get_operand(0); }
@@ -70,6 +70,7 @@ class jit {
shmem_barriers.run(module);
}
vectorize.run(module);
ir::print(module, std::cout);
}

codegen::tune tune;
@@ -320,23 +320,33 @@ ir::value* compound_statement::codegen(ir::module* mod) const{
/* expression statement */
ir::value* expression_statement::codegen(ir::module *mod) const{
ir::builder &builder = mod->get_builder();
if(mask_) {
ir::value *pred = mask_->codegen(mod);
ir::mask_inst *mask = (ir::mask_inst*)builder.create_mask(pred);
ir::value *true_value = expr_->codegen(mod);
ir::basic_block *block = builder.get_insert_block();
if(pred_) {
// check that it is an assignment
assignment_expression *assignment = dynamic_cast<assignment_expression*>(expr_);
assert(assignment);

ir::type *ty = true_value->get_type();
if(auto *itn = dynamic_cast<ir::instruction*>(true_value))
itn->set_mask_pred(mask->get_result(0));
// generate mask
ir::value *pred = pred_->codegen(mod);
ir::mask_inst *mask = (ir::mask_inst*)builder.create_mask(pred);
// generate expression
unsigned szbegin = block->get_inst_list().size();
ir::value *expr = expr_->codegen(mod);
ir::basic_block::iterator begin = block->begin();
std::advance(begin, szbegin);
// set mask
ir::type *ty = expr->get_type();
for(auto it = begin; it != builder.get_insert_point(); it++)
(*it)->set_mask_pred(mask->get_result(0));
// if(auto *itn = dynamic_cast<ir::instruction*>(expr))
// itn->set_mask_pred(mask->get_result(0));
if(ty->is_void_ty())
return true_value;
ir::merge_inst *merge = (ir::merge_inst*)builder.create_merge(mask->get_result(0), true_value,
return expr;
// merge with psi
ir::psi_inst *psi = (ir::psi_inst*)builder.create_merge(mask->get_result(0), expr,
mask->get_result(1), ir::undef_value::get(ty));
std::string name = ((named_expression*)assignment->lvalue())->id()->name();
mod->set_value(name, merge);
return merge;
mod->set_value(name, psi);
return psi;
}
return expr_->codegen(mod);
}
@@ -690,7 +690,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
});
}
// merge
else if(auto *merge = dynamic_cast<ir::merge_inst*>(ins)) {
else if(auto *merge = dynamic_cast<ir::psi_inst*>(ins)) {
distributed_tile* mask_tile_true = (distributed_tile*)tmap_.at(merge->get_mask_true());
distributed_tile *value_tile_true = (distributed_tile*)tmap_.at(merge->get_value_true());
distributed_tile* mask_tile_false = (distributed_tile*)tmap_.at(merge->get_mask_false());
@@ -951,7 +951,7 @@ void selection::run(ir::module &src, Module &dst) {
dst_builder.SetInsertPoint(parent);
for(ir::instruction *i: block->get_inst_list()){
BasicBlock *current = dst_builder.GetInsertBlock();
bool phi_inserted = (dynamic_cast<ir::phi_node*>(i) || dynamic_cast<ir::merge_inst*>(i)) && !current->empty();
bool phi_inserted = (dynamic_cast<ir::phi_node*>(i) || dynamic_cast<ir::psi_inst*>(i)) && !current->empty();
if(phi_inserted && current->getFirstNonPHI())
dst_builder.SetInsertPoint(&*current->getFirstNonPHI());
lower_instruction(i, dst_builder);
@@ -253,7 +253,7 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) {
cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { }

cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
std::cout << source << std::endl;
// std::cout << source << std::endl;
cu_context::context_switcher ctx_switch(*context);
// JIT compile source-code
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
@@ -92,7 +92,7 @@ value *builder::create_mask(value *pred, const std::string &name){
}

value *builder::create_merge(value *mask_true, value *value_true, value *mask_false, value *value_false, const std::string &name) {
return insert(merge_inst::create(mask_true, value_true, mask_false, value_false, name));
return insert(psi_inst::create(mask_true, value_true, mask_false, value_false, name));
}


@@ -334,7 +334,7 @@ mask_inst* mask_inst::create(value *pred, const std::string &name, instruction *
}

// merge_inst
merge_inst::merge_inst(value *mask_true, value *value_true,
psi_inst::psi_inst(value *mask_true, value *value_true,
value *mask_false, value *value_false,
const std::string &name, instruction *next)
: instruction(value_true->get_type(), 4, 1, name, next) {
@@ -344,10 +344,10 @@ merge_inst::merge_inst(value *mask_true, value *value_true,
set_operand(3, value_false);
}

merge_inst* merge_inst::create(value *mask_true, value *value_true,
psi_inst* psi_inst::create(value *mask_true, value *value_true,
value *mask_false, value *value_false,
const std::string &name, instruction *next) {
return new merge_inst(mask_true, value_true, mask_false, value_false, name, next);
return new psi_inst(mask_true, value_true, mask_false, value_false, name, next);
}


0 comments on commit af58b8b

Please sign in to comment.
You can’t perform that action at this time.