Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions include/taco/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ enum class IRNodeType {
BlankLine,
Print,
GetProperty,
Break,
Sort
Continue,
Sort,
Break
};

enum class TensorProperty {
Expand Down Expand Up @@ -719,7 +720,14 @@ struct BlankLine : public StmtNode<BlankLine> {
static const IRNodeType _type_info = IRNodeType::BlankLine;
};

/** Breaks current loop */
/** Continues past current iteration of current loop */
struct Continue : public StmtNode<Continue> {
static Stmt make();

static const IRNodeType _type_info = IRNodeType::Continue;
};

/** Breaks out of the current loop */
struct Break : public StmtNode<Break> {
static Stmt make();

Expand Down
3 changes: 2 additions & 1 deletion include/taco/ir/ir_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,11 @@ class IRPrinter : public IRVisitorStrict {
virtual void visit(const Free*);
virtual void visit(const Comment*);
virtual void visit(const BlankLine*);
virtual void visit(const Break*);
virtual void visit(const Continue*);
virtual void visit(const Print*);
virtual void visit(const GetProperty*);
virtual void visit(const Sort*);
virtual void visit(const Break*);

std::ostream &stream;
int indent;
Expand Down
3 changes: 2 additions & 1 deletion include/taco/ir/ir_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,11 @@ class IRRewriter : public IRVisitorStrict {
virtual void visit(const Free* op);
virtual void visit(const Comment* op);
virtual void visit(const BlankLine* op);
virtual void visit(const Break* op);
virtual void visit(const Continue* op);
virtual void visit(const Print* op);
virtual void visit(const GetProperty* op);
virtual void visit(const Sort *op);
virtual void visit(const Break *op);
};

}}
Expand Down
9 changes: 6 additions & 3 deletions include/taco/ir/ir_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,11 @@ struct Allocate;
struct Free;
struct Comment;
struct BlankLine;
struct Break;
struct Continue;
struct Print;
struct GetProperty;
struct Sort;
struct Break;

/// Extend this class to visit every node in the IR.
class IRVisitorStrict {
Expand Down Expand Up @@ -96,10 +97,11 @@ class IRVisitorStrict {
virtual void visit(const Free*) = 0;
virtual void visit(const Comment*) = 0;
virtual void visit(const BlankLine*) = 0;
virtual void visit(const Break*) = 0;
virtual void visit(const Continue*) = 0;
virtual void visit(const Print*) = 0;
virtual void visit(const GetProperty*) = 0;
virtual void visit(const Sort*) = 0;
virtual void visit(const Break*) = 0;
};


Expand Down Expand Up @@ -150,10 +152,11 @@ class IRVisitor : public IRVisitorStrict {
virtual void visit(const Free* op);
virtual void visit(const Comment* op);
virtual void visit(const BlankLine* op);
virtual void visit(const Break* op);
virtual void visit(const Continue* op);
virtual void visit(const Print* op);
virtual void visit(const GetProperty* op);
virtual void visit(const Sort* op);
virtual void visit(const Break* op);
};

}}
Expand Down
2 changes: 1 addition & 1 deletion src/codegen/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1142,7 +1142,7 @@ void CodeGen_CUDA::visit(const Sqrt* op) {
stream << ")";
}

void CodeGen_CUDA::visit(const Break*) {
void CodeGen_CUDA::visit(const Continue*) {
doIndent();
if(!isHostFunction && deviceFunctionLoopDepth == 0) {
// can't break out of kernel
Expand Down
2 changes: 1 addition & 1 deletion src/codegen/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class CodeGen_CUDA : public CodeGen {
void visit(const Call*);
void visit(const Store*);
void visit(const Assign*);
void visit(const Break*);
void visit(const Continue*);
void visit(const Free* op);
std::string printDeviceFuncName(const std::vector<std::pair<std::string, Expr>> currentParameters, int index);
void printDeviceFuncCall(const std::vector<std::pair<std::string, Expr>> currentParameters, Expr blockSize, int index, Expr gridSize);
Expand Down
11 changes: 9 additions & 2 deletions src/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,11 @@ Stmt BlankLine::make() {
return new BlankLine;
}

// Continue
Stmt Continue::make() {
return new Continue;
}

// Break
Stmt Break::make() {
return new Break;
Expand Down Expand Up @@ -954,14 +959,16 @@ template<> void StmtNode<Comment>::accept(IRVisitorStrict *v)
const { v->visit((const Comment*)this); }
template<> void StmtNode<BlankLine>::accept(IRVisitorStrict *v)
const { v->visit((const BlankLine*)this); }
template<> void StmtNode<Break>::accept(IRVisitorStrict *v)
const { v->visit((const Break*)this); }
template<> void StmtNode<Continue>::accept(IRVisitorStrict *v)
const { v->visit((const Continue*)this); }
template<> void StmtNode<Print>::accept(IRVisitorStrict *v)
const { v->visit((const Print*)this); }
template<> void ExprNode<GetProperty>::accept(IRVisitorStrict *v)
const { v->visit((const GetProperty*)this); }
template<> void StmtNode<Sort>::accept(IRVisitorStrict *v)
const { v->visit((const Sort*)this); }
template<> void StmtNode<Break>::accept(IRVisitorStrict *v)
const { v->visit((const Break*)this); }

// printing methods
std::ostream& operator<<(std::ostream& os, const Stmt& stmt) {
Expand Down
7 changes: 6 additions & 1 deletion src/ir/ir_printer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -558,9 +558,14 @@ void IRPrinter::visit(const BlankLine*) {
stream << endl;
}

void IRPrinter::visit(const Continue*) {
doIndent();
stream << "continue;" << endl;
}

void IRPrinter::visit(const Break*) {
doIndent();
stream << "continue;" << endl; // TODO: add continue statement
stream << "break;" << endl;
}

void IRPrinter::visit(const Print* op) {
Expand Down
4 changes: 4 additions & 0 deletions src/ir/ir_rewriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,10 @@ void IRRewriter::visit(const BlankLine* op) {
stmt = op;
}

void IRRewriter::visit(const Continue* op) {
stmt = op;
}

void IRRewriter::visit(const Break* op) {
stmt = op;
}
Expand Down
3 changes: 3 additions & 0 deletions src/ir/ir_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,9 @@ void IRVisitor::visit(const Comment*) {
void IRVisitor::visit(const BlankLine*) {
}

void IRVisitor::visit(const Continue*) {
}

void IRVisitor::visit(const Break*) {
}

Expand Down
4 changes: 2 additions & 2 deletions src/lower/lowerer_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ Stmt LowererImpl::lowerForall(Forall forall)
if (isa<ir::Literal>(ir::simplify(iterBounds[0])) && ir::simplify(iterBounds[0]).as<ir::Literal>()->equalsScalar(0)) {
guardCondition = maxGuard;
}
ir::Stmt guard = ir::IfThenElse::make(guardCondition, ir::Break::make());
ir::Stmt guard = ir::IfThenElse::make(guardCondition, ir::Continue::make());
recoverySteps.push_back(guard);
}

Expand Down Expand Up @@ -438,7 +438,7 @@ Stmt LowererImpl::lowerForall(Forall forall)
}
if (!hasDirectDivBound) {
Stmt guard = IfThenElse::make(Gte::make(indexVarToExprMap[varToRecover], underivedBounds[varToRecover][1]),
Break::make());
Continue::make());
recoverySteps.push_back(guard);
}
}
Expand Down