Skip to content
Permalink
Browse files

Merge branch 'c-reduction'

  • Loading branch information...
ptillet committed Sep 12, 2019
2 parents d7be0ed + 7f2bc5b commit 981ffb6d85b9f264c4986fbfe10a623d6c3cfe2a
@@ -136,7 +136,7 @@ class builder{
value *create_dot(value *A, value *B, value *C, const std::string &name = "");
value *create_trans(value *A, const std::vector<constant_int *> &perm = {}, const std::string &name = "");
value *create_sqrt(value *A, const std::string &name = "");
value *create_reduce(value *A, unsigned axis, const std::string &name = "");
value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis, const std::string &name = "");
value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = "");
// Intrinsics
value *create_copy_to_shared(value *arg, const std::string &name = "");
@@ -611,19 +611,28 @@ class sqrt_inst: public builtin_inst {
};

class reduce_inst: public builtin_inst {
public:
enum op_t{
ADD, SUB, MAX, MIN,
FADD, FSUB, FMAX, FMIN
};

private:
static type* get_res_type(value *arg, unsigned axis);
static std::string to_str(op_t op);

private:
reduce_inst(value* arg, unsigned axis, const std::string& name, instruction* next);
std::string repr_impl() const { return "reduce"; }
reduce_inst(value* arg, op_t op, unsigned axis, const std::string& name, instruction* next);
std::string repr_impl() const { return "red<" + std::to_string(axis_) + ">"; }

public:
static instruction* create(value *arg, unsigned axis, const std::string &name = "", instruction *next = nullptr);
static instruction* create(value *arg, op_t op, unsigned axis, const std::string &name = "", instruction *next = nullptr);
unsigned get_axis() const { return axis_; }
op_t get_op() const { return op_; }

private:
unsigned axis_;
op_t op_;
};

class select_inst: public builtin_inst {
@@ -418,29 +418,33 @@ class UnaryOp : public Expr {
friend class LValAssigner;

public:
static UnaryOp* New(int op, Expr* operand, QualType type=nullptr);
static UnaryOp* New(int op, Expr* operand, QualType type=nullptr, int info=0);
virtual ~UnaryOp() {}
virtual void Accept(Visitor* v);
virtual bool IsLVal();
::Type *Convert();
static int encodeRed(int ax, int tag);
static void decodeRed(int info, int& ax, int& tag);
void TypeChecking();
void IncDecOpTypeChecking();
void AddrOpTypeChecking();
void DerefOpTypeChecking();
void ReduceOpTypeChecking();
void TransOpTypeChecking();
void UnaryArithmOpTypeChecking();
void CastOpTypeChecking();

protected:
UnaryOp(int op, Expr* operand, QualType type=nullptr)
: Expr(operand->Tok(), type), op_(op) {
UnaryOp(int op, Expr* operand, QualType type=nullptr, int info=0)
: Expr(operand->Tok(), type), op_(op), info_(info) {
operand_ = operand;
if (op_ != Token::CAST && op_ != Token::ADDR) {
operand_ = MayCast(operand);
}
}

int op_;
int info_;
Expr* operand_;
};

@@ -131,6 +131,8 @@ class Token {

// TILE ARITHMETICS BEGIN
NEWAXIS,
MAX,
MIN,
// TILE ARITHMETICS END

ALIGNAS, // _Alignas
@@ -180,6 +182,7 @@ class Token {
PLUS,
MINUS,
CAST,
REDUCE,

// For preprocessor
PP_IF,
@@ -70,7 +70,7 @@ class function {
struct options_space_t {
typedef std::pair<std::string, std::vector<std::string>> define_t;
std::vector<define_t> defines;
std::vector<size_t> num_warps;
std::vector<int> num_warps;
};

struct options_t {
@@ -59,16 +59,7 @@ void grids::init_c_graph(ir::instruction *v) {
shapes = atom->get_operand(0)->get_type()->get_tile_shapes();
else if(dynamic_cast<ir::downcast_inst*>(v))
return;
else if(auto *reduce = dynamic_cast<ir::reduce_inst*>(v)) {
unsigned axis = reduce->get_axis();
ir::value *arg = reduce->get_operand(0);
auto in_shapes = arg->get_type()->get_tile_shapes();
unsigned current = 0;
for(unsigned i = 0; i < in_shapes.size(); i++){
if(i == axis)
continue;
add_constraint({reduce, current++}, {arg, i});
}
else if(dynamic_cast<ir::reduce_inst*>(v)) {
return;
}
else
@@ -244,7 +235,6 @@ void grids::run(ir::module &mod) {
unsigned size = i->get_type()->get_tile_num_elements();
/* HMMA parameters*/
if(fragments_.at({i, 0}) == HMMA_FRAGMENT_C){

/* fragments per warp */
// try to make things as square as possible to maximize data re-use
std::vector<unsigned> fpw = {1, 1, 1};
@@ -285,7 +275,6 @@ void grids::run(ir::module &mod) {

if(num_warps_ != effective_num_warps)
throw std::runtime_error("cannot create a kernel with this amount of warps");

}

/* Scan-line */
@@ -923,78 +923,96 @@ void selection::lower_downcast(ir::downcast_inst *x, LLVMContext &ctx, Function
}

void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
ir::instruction *ins = (ir::instruction*)x;
Module *module = fn->getParent();
std::map<indices_t, Value*> partial;
ir::value *op = x->get_operand(0);
distributed_tile* op_tile = (distributed_tile*)tmap_.at(op);
ir::value *arg = x->get_operand(0);
distributed_tile* arg_tile = (distributed_tile*)tmap_.at(arg);
ir::reduce_inst::op_t op = x->get_op();
auto accumulate = [&](Value* x, Value *y) -> Value* {
switch(op) {
case ir::reduce_inst::ADD: return builder.CreateAdd(x, y);
case ir::reduce_inst::SUB: return builder.CreateSub(x, y);
case ir::reduce_inst::MAX: return builder.CreateMaximum(x, y);
case ir::reduce_inst::MIN: return builder.CreateMinimum(x, y);
case ir::reduce_inst::FADD: return builder.CreateFAdd(x, y);
case ir::reduce_inst::FSUB: return builder.CreateFSub(x, y);
case ir::reduce_inst::FMAX: return builder.CreateSelect(builder.CreateFCmpOGT(x, y), x, y);
case ir::reduce_inst::FMIN: return builder.CreateSelect(builder.CreateFCmpOLT(x, y), x, y);
default: break;
}
assert(false);
return nullptr;
};

unsigned axis = x->get_axis();

// reduce within thread
op_tile->for_each([&](indices_t idx) {
arg_tile->for_each([&](indices_t idx) {
indices_t pidx = idx;
pidx.erase(pidx.begin() + axis);
Value *current = op_tile->get_value(idx);
pidx[axis] = builder.getInt32(0);
Value *current = arg_tile->get_value(idx);
// current partial result is not initialized -- create
if(partial.find(pidx) == partial.end())
partial[pidx] = current;
// current partial result is initialized -- accumulate
else
partial[pidx] = builder.CreateFAdd(partial[pidx], current);
partial[pidx] = accumulate(partial[pidx], current);
});

// depth
unsigned shape_ax = arg->get_type()->get_tile_shapes()[axis];
unsigned per_thread = arg_tile->axis(axis).values.size();
unsigned depth = shape_ax / per_thread;

// shapes
auto shared_shapes = arg_tile->get_shapes();
shared_shapes[axis] = depth;

// reduce within blocks
unsigned addr_space = sh_mem_ptr_->getType()->getPointerAddressSpace();
Type *res_ty = builder.getFloatTy();
Value *base_ptr = builder.CreateBitCast(sh_mem_ptr_, PointerType::get(res_ty, addr_space));
for(auto& x: partial) {
// current element being computed
Value *lane = axes_.at(params_->get_param_group(op, axis)).thread_id;
Value *lane = axes_.at(params_->get_param_group(arg, axis)).thread_id;
Value *&result = x.second;
indices_t write_idx = x.first;
write_idx.insert(write_idx.begin() + axis, lane);

write_idx[axis] = lane;
// shared memory write pointer
Value *write_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), write_idx);
Value *write_offset = shared_tile::shared_offset(builder, shared_shapes, write_idx);
Value *write_ptr = builder.CreateGEP(base_ptr, write_offset);

// initialize shared memory
tgt_->add_barrier(module, builder);
builder.CreateStore(result, write_ptr);
// build result
unsigned depth = params_->get_param(op, "wpt.d" + std::to_string(axis))->get_value();
for(unsigned i = depth/2; i > 0; i >>= 1){
// current indices
indices_t current(write_idx.size(), builder.getInt32(0));
current[axis] = builder.getInt32(i);
// shared memory offset
Value *read_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), current);
Value *read_offset = shared_tile::shared_offset(builder, shared_shapes, current);
Value *is_active = builder.CreateICmpULT(lane, builder.getInt32(i));
read_offset = builder.CreateSelect(is_active, read_offset, builder.getInt32(0));
// shared memory read pointer
Value *read_ptr = builder.CreateGEP(write_ptr, read_offset);
tgt_->add_barrier(module, builder);
Value *next = builder.CreateLoad(read_ptr);
// accumulate
result = builder.CreateFAdd(result, next);
result = accumulate(result, next);
// write back
builder.CreateStore(result, write_ptr);
}
}
tgt_->add_barrier(module, builder);

// result is on the first lane of shared memory
indices_t final = write_idx;
final[axis] = builder.getInt32(0);
Value *read_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), final);
distributed_tile* x_tile = (distributed_tile*)tmap_.at(x);
x_tile->for_each([&](indices_t idx) {
indices_t red_idx = idx;
red_idx.insert(red_idx.begin() + axis, builder.getInt32(0));
Value *read_offset = shared_tile::shared_offset(builder, shared_shapes, red_idx);
Value *read_ptr = builder.CreateGEP(base_ptr, read_offset);
tgt_->add_barrier(module, builder);
result = builder.CreateLoad(read_ptr);
if(tmap_.find(ins) == tmap_.end())
vmap_[ins] = result;
else{
distributed_tile *ti = (distributed_tile*)tmap_[ins];
ti->set_value(x.first, result);
}
}
x_tile->set_value(idx, builder.CreateLoad(read_ptr));
});
}

void selection::lower_dynamic_program_idx(ir::nv_dynamic_program_idx_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
@@ -323,8 +323,8 @@ value *builder::create_sqrt(value *A, const std::string &name) {
return insert(sqrt_inst::create(A, name));
}

value *builder::create_reduce(value *A, unsigned axis, const std::string &name) {
return insert(reduce_inst::create(A, axis, name));
value *builder::create_reduce(value *A, reduce_inst::op_t op, unsigned axis, const std::string &name) {
return insert(reduce_inst::create(A, op, axis, name));
}

value *builder::create_select(value *pred, value *if_value, value *else_value, const std::string &name){
@@ -615,6 +615,23 @@ instruction* sqrt_inst::create(value *arg, const std::string &name, instruction
//===----------------------------------------------------------------------===//
// reduce instructions
//===----------------------------------------------------------------------===//

std::string reduce_inst::to_str(op_t op) {
switch (op) {
case ADD: return "+";
case SUB: return "-";
case MAX: return "imax";
case MIN: return "imin";
case FADD: return "+";
case FSUB: return "-";
case FMAX: return "fmax";
case FMIN: return "fmin";
default: break;
}
assert(false);
return "";
}

type* reduce_inst::get_res_type(value *arg, unsigned axis) {
ir::tile_type::tile_shapes_t shapes = arg->get_type()->get_tile_shapes();
shapes.erase(shapes.begin() + axis);
@@ -625,14 +642,15 @@ type* reduce_inst::get_res_type(value *arg, unsigned axis) {
return tile_type::get(scalar_ty, shapes);
}

reduce_inst::reduce_inst(value *arg, unsigned axis, const std::string &name, instruction *next)
reduce_inst::reduce_inst(value *arg, op_t op, unsigned axis, const std::string &name, instruction *next)
: builtin_inst(get_res_type(arg, axis), 1, 1, name, next),
op_(op),
axis_(axis){
set_operand(0, arg);
}

instruction* reduce_inst::create(value *arg, unsigned axis, const std::string &name, instruction *next) {
return new reduce_inst(arg, axis, name, next);
instruction* reduce_inst::create(value *arg, op_t op, unsigned axis, const std::string &name, instruction *next) {
return new reduce_inst(arg, op, axis, name, next);
}


@@ -448,6 +448,8 @@ void BinaryOp::RangeOpTypeChecking() {
}

void BinaryOp::MaskedDerefOpTypeChecking() {
// auto lhsTileType = lhs_->Type()->ToTile();
// auto rhsTileType = rhs_->Type()->ToTile();
::Type* lhsScalType = TryExtractScalarType(this, lhs_);
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
auto lhsType = lhsScalType->ToArithm();
@@ -572,15 +574,27 @@ void BinaryOp::AssignOpTypeChecking() {
* Unary Operators
*/

UnaryOp* UnaryOp::New(int op, Expr* operand, QualType type) {
auto ret = new (unaryOpPool.Alloc()) UnaryOp(op, operand, type);
UnaryOp* UnaryOp::New(int op, Expr* operand, QualType type, int info) {
auto ret = new (unaryOpPool.Alloc()) UnaryOp(op, operand, type, info);
ret->pool_ = &unaryOpPool;

ret->TypeChecking();
return ret;
}


int UnaryOp::encodeRed(int ax, int tag) {
int result = 0;
result |= ax;
result |= tag << 16;
return result;
}

void UnaryOp::decodeRed(int info, int& ax, int& tag) {
ax = info & 0x0000FFFF;
tag = (info & 0xFFFF0000) >> 16;
}

bool UnaryOp::IsLVal() {
// Only deref('*') could be lvalue;
return op_ == Token::DEREF;
@@ -626,6 +640,9 @@ void UnaryOp::TypeChecking() {
case '^':
return TransOpTypeChecking();

case Token::REDUCE:
return ReduceOpTypeChecking();

default:
assert(false);
}
@@ -663,6 +680,16 @@ void UnaryOp::DerefOpTypeChecking() {
type_ = ScalarOrLikeTile(operand_, pointerType->Derived().GetPtr());
}

void UnaryOp::ReduceOpTypeChecking() {
int ax, tag;
decodeRed(info_, ax, tag);
auto tileType = operand_->Type()->ToTile();
if(!tileType)
Error(this, "array expected for reduction operation");
auto shape = tileType->Shape();
shape.erase(shape.begin() + ax);
type_ = TileType::New(shape, tileType->Derived());
}

void UnaryOp::TransOpTypeChecking() {
auto tileType = operand_->Type()->ToTile();

0 comments on commit 981ffb6

Please sign in to comment.
You can’t perform that action at this time.