Skip to content

Commit

Permalink
[triton-c] added implicit conversion to bool in while/for loops
Browse files Browse the repository at this point in the history
  • Loading branch information
ptillet committed Apr 28, 2019
1 parent af58b8b commit 93f5350
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 10 deletions.
8 changes: 4 additions & 4 deletions examples/cpp/dot.cpp
Expand Up @@ -10,7 +10,7 @@ R"(
const tunable int32 TM = {16, 32, 64, 128};
const tunable int32 TN = {16, 32, 64, 128};
const tunable int32 TK = {8};
const tunable int32 GZ = {2};
const tunable int32 GZ = {1};
void matmul(restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C,
int32 M, int32 N, int32 K,
Expand Down Expand Up @@ -57,7 +57,7 @@ void matmul(restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C,
int32 ridy = get_range_id(1);
fp32* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
int32 *plock = locks + ridx + ridy*grid0;
while(__atomic_cas(plock, 0, 1) == 1);
while(__atomic_cas(plock, 0, 1));
int32 *pcount = plock + grid0*grid1;
int32 count = *pcount;
int32 countp1 = select(count == GZ - 1, 0, count + 1);
Expand All @@ -82,7 +82,7 @@ int main() {
triton::jit jit(context);

// matrix multiplication parameters
int32_t M = 256, N = 256, K = 2048;
int32_t M = 512, N = 512, K = 512;
std::vector<float> hc(M*N);
std::vector<float> rc(M*N);
std::vector<float> ha(M*K);
Expand Down Expand Up @@ -144,7 +144,7 @@ int main() {

// just-in-time compile source-code
std::vector<unsigned> params = {
16, 2, 64, 16, 2, 64, 16, 8, 2, 2, 8, 8, 8, 4
16, 2, 64, 16, 2, 64, 16, 8, 2, 2, 8, 8, 8, 1
};
// jit.autotune("matmul",src, benchmark);
jit.add_module("matmul", src, params);
Expand Down
2 changes: 1 addition & 1 deletion include/triton/jit.h
Expand Up @@ -70,7 +70,7 @@ class jit {
shmem_barriers.run(module);
}
vectorize.run(module);
ir::print(module, std::cout);
// ir::print(module, std::cout);
}

codegen::tune tune;
Expand Down
10 changes: 5 additions & 5 deletions lib/ast/lowering.cpp
Expand Up @@ -351,7 +351,7 @@ ir::value* expression_statement::codegen(ir::module *mod) const{
return expr_->codegen(mod);
}

/* Iteration statement */
/* For statement */
ir::value* iteration_statement::codegen(ir::module *mod) const{
ir::builder &builder = mod->get_builder();
ir::context &ctx = mod->get_context();
Expand All @@ -362,11 +362,11 @@ ir::value* iteration_statement::codegen(ir::module *mod) const{
mod->set_continue_fn([&](){
if(exec_)
exec_->codegen(mod);
ir::value *cond = stop_->codegen(mod);
ir::value *cond = explicit_cast(builder, stop_->codegen(mod), ir::type::get_int1_ty(ctx));
return builder.create_cond_br(cond, loop_bb, next_bb);
});
init_->codegen(mod);
ir::value *cond = stop_->codegen(mod);
ir::value *cond = explicit_cast(builder, stop_->codegen(mod), ir::type::get_int1_ty(ctx));
builder.create_cond_br(cond, loop_bb, next_bb);
// builder.create_br(loop_bb);
builder.set_insert_point(loop_bb);
Expand All @@ -390,10 +390,10 @@ ir::value* while_statement::codegen(ir::module* mod) const{
ir::basic_block *loop_bb = ir::basic_block::create(ctx, "loop", fn);
ir::basic_block *next_bb = ir::basic_block::create(ctx, "postloop", fn);
mod->set_continue_fn([&](){
ir::value *cond = cond_->codegen(mod);
ir::value *cond = explicit_cast(builder, cond_->codegen(mod), ir::type::get_int1_ty(ctx));
return builder.create_cond_br(cond, loop_bb, next_bb);
});
ir::value *cond = cond_->codegen(mod);
ir::value *cond = explicit_cast(builder, cond_->codegen(mod), ir::type::get_int1_ty(ctx));
builder.create_cond_br(cond, loop_bb, next_bb);
builder.set_insert_point(loop_bb);
if(!is_terminator(statements_->codegen(mod)))
Expand Down

0 comments on commit 93f5350

Please sign in to comment.