Skip to content
324 changes: 312 additions & 12 deletions tools/taco.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
#include "taco/util/env.h"
#include "taco/util/collections.h"
#include "taco/cuda.h"
#include <taco/index_notation/transformations.h>
#include "taco/index_notation/transformations.h"
#include "taco/index_notation/index_notation_visitor.h"
#include "taco/index_notation/index_notation_nodes.h"

using namespace std;
using namespace taco;
Expand Down Expand Up @@ -112,6 +114,11 @@ static void printUsageInfo() {
"long, longlong, float, double, complexfloat, complexdouble"
"Examples: A:uint16, b:long and D:complexfloat.");
cout << endl;
printFlag("s=\"<command>(<params>)\"",
"Specify a scheduling command to apply to the generated code. "
"Parameters take the form of a comma-delimited list. "
"Examples: split(i,i0,i1,16), precompute(A(i,j)*x(j),i,i).");
cout << endl;
printFlag("c",
"Generate compute kernel that simultaneously does assembly.");
cout << endl;
Expand Down Expand Up @@ -201,6 +208,261 @@ static void printCommandLine(ostream& os, int argc, char* argv[]) {
}
}

static bool setSchedulingCommands(istream& in, parser::Parser& parser, IndexStmt& stmt) {
auto findVar = [&stmt](string name) {
ProvenanceGraph graph(stmt);
for (auto v : graph.getAllIndexVars()) {
if (v.getName() == name) {
return v;
}
}

throw "Index variable not defined in statement.";
};

bool isGPU = false;

while (true) {
string command;
in >> command;

if (command == "pos") {
string i, ipos;
in >> i;
in >> ipos;

string tensor;
in >> tensor;

for (auto a : getArgumentAccesses(stmt)) {
if (a.getTensorVar().getName() == tensor) {
IndexVar derived(ipos);
stmt = stmt.pos(findVar(i), derived, a);
goto end;
}
}

} else if (command == "fuse") {
string i, j, f;
in >> i;
in >> j;
in >> f;

IndexVar fused(f);
stmt = stmt.fuse(findVar(i), findVar(j), fused);

} else if (command == "split") {
string i, i1, i2;
in >> i;
in >> i1;
in >> i2;

size_t splitFactor;
in >> splitFactor;

IndexVar split1(i1);
IndexVar split2(i2);
stmt = stmt.split(findVar(i), split1, split2, splitFactor);

// } else if (command == "divide") {
// string i, i1, i2;
// in >> i;
// in >> i1;
// in >> i2;

// size_t divideFactor;
// in >> divideFactor;

// IndexVar divide1(i1);
// IndexVar divide2(i2);
// stmt = stmt.divide(findVar(i), divide1, divide2, divideFactor);

} else if (command == "precompute") {
string exprStr, i, iw;
in >> exprStr;
in >> i;
in >> iw;

IndexVar orig = findVar(i);
IndexVar pre;
try {
pre = findVar(iw);
} catch (const char* e) {
pre = IndexVar(iw);
}

struct GetExpr : public IndexNotationVisitor {
using IndexNotationVisitor::visit;

string exprStr;
IndexExpr expr;

void setExprStr(string input) {
exprStr = input;
exprStr.erase(remove(exprStr.begin(), exprStr.end(), ' '), exprStr.end());
}

string toString(IndexExpr e) {
stringstream tempStream;
tempStream << e;
string tempStr = tempStream.str();
tempStr.erase(remove(tempStr.begin(), tempStr.end(), ' '), tempStr.end());
return tempStr;
}

void visit(const AccessNode* node) {
IndexExpr currentExpr(node);
if (toString(currentExpr) == exprStr) {
expr = currentExpr;
}
else {
IndexNotationVisitor::visit(node);
}
}

void visit(const UnaryExprNode* node) {
IndexExpr currentExpr(node);
if (toString(currentExpr) == exprStr) {
expr = currentExpr;
}
else {
IndexNotationVisitor::visit(node);
}
}

void visit(const BinaryExprNode* node) {
IndexExpr currentExpr(node);
if (toString(currentExpr) == exprStr) {
expr = currentExpr;
}
else {
IndexNotationVisitor::visit(node);
}
}
};

GetExpr visitor;
visitor.setExprStr(exprStr);
stmt.accept(&visitor);

Dimension dim;
auto domains = stmt.getIndexVarDomains();
auto it = domains.find(orig);
if (it != domains.end()) {
dim = it->second;
} else {
dim = Dimension(orig);
}

TensorVar workspace("workspace", Type(Float64, {dim}), Dense);
stmt = stmt.precompute(visitor.expr, orig, pre, workspace);

} else if (command == "reorder") {
string line;
getline(in, line);
stringstream temp;
temp << line;

vector<IndexVar> reorderedVars;
string var;
while (temp >> var) {
reorderedVars.push_back(findVar(var));
}

stmt = stmt.reorder(reorderedVars);

} else if (command == "bound") {
string i, i1;
in >> i;
in >> i1;

size_t bound;
in >> bound;

string type;
in >> type;

BoundType bound_type;
if (type == "MinExact") {
bound_type = BoundType::MinExact;
} else if (type == "MinConstraint") {
bound_type = BoundType::MinConstraint;
} else if (type == "MaxExact") {
bound_type = BoundType::MaxExact;
} else if (type == "MaxConstraint") {
bound_type = BoundType::MaxConstraint;
} else {
taco_uerror << "Bound type not defined.";
goto end;
}

IndexVar bound1(i1);
stmt = stmt.bound(findVar(i), bound1, bound, bound_type);

} else if (command == "unroll") {
string i;
in >> i;

size_t unrollFactor;
in >> unrollFactor;

stmt = stmt.unroll(findVar(i), unrollFactor);

} else if (command == "parallelize") {
string i, unit, strategy;
in >> i;
in >> unit;
in >> strategy;

ParallelUnit parallel_unit;
if (unit == "NotParallel") {
parallel_unit = ParallelUnit::NotParallel;
} else if (unit == "GPUBlock") {
parallel_unit = ParallelUnit::GPUBlock;
isGPU = true;
} else if (unit == "GPUWarp") {
parallel_unit = ParallelUnit::GPUWarp;
isGPU = true;
} else if (unit == "GPUThread") {
parallel_unit = ParallelUnit::GPUThread;
isGPU = true;
} else if (unit == "CPUThread") {
parallel_unit = ParallelUnit::CPUThread;
} else if (unit == "CPUVector") {
parallel_unit = ParallelUnit::CPUVector;
} else {
taco_uerror << "Parallel hardware not defined.";
goto end;
}

OutputRaceStrategy output_race_strategy;
if (strategy == "IgnoreRaces") {
output_race_strategy = OutputRaceStrategy::IgnoreRaces;
} else if (strategy == "NoRaces") {
output_race_strategy = OutputRaceStrategy::NoRaces;
} else if (strategy == "Atomics") {
output_race_strategy = OutputRaceStrategy::Atomics;
} else if (strategy == "Temporary") {
output_race_strategy = OutputRaceStrategy::Temporary;
} else if (strategy == "ParallelReduction") {
output_race_strategy = OutputRaceStrategy::ParallelReduction;
} else {
taco_uerror << "Race strategy not defined.";
goto end;
}

stmt = stmt.parallelize(findVar(i), parallel_unit, output_race_strategy);

} else {
break;
}

end:;
}

return isGPU;
}

int main(int argc, char* argv[]) {
if (argc < 2) {
printUsageInfo();
Expand Down Expand Up @@ -228,6 +490,8 @@ int main(int argc, char* argv[]) {
bool readKernels = false;
bool cuda = false;

bool setSchedule = false;

ParallelSchedule sched = ParallelSchedule::Static;
int chunkSize = 0;
int nthreads = 0;
Expand Down Expand Up @@ -256,6 +520,8 @@ int main(int argc, char* argv[]) {

vector<string> kernelFilenames;

vector<string> scheduleCommands;

for (int i = 1; i < argc; i++) {
string arg = argv[i];
vector<string> argparts = util::split(arg, "=");
Expand Down Expand Up @@ -543,6 +809,30 @@ int main(int argc, char* argv[]) {
else if ("-print-kernels" == argName) {
printKernels = true;
}
else if ("-s" == argName) {
setSchedule = true;
bool insideCall = false;
bool parsingExpr = false;

std::replace_if(argValue.begin(), argValue.end(), [&insideCall, &parsingExpr](char c) {
if (c == '(') {
if (insideCall) {
parsingExpr = true; // need to handle precompute case specially
} else {
insideCall = true;
return true;
}
} else if (c == ',') {
return !parsingExpr;
} else if (c == ')') {
bool previous = parsingExpr;
parsingExpr = false;
return !previous;
}
return false;
}, ' ');
scheduleCommands.push_back(argValue);
}
else {
if (exprStr.size() != 0) {
printUsageInfo();
Expand Down Expand Up @@ -623,16 +913,6 @@ int main(int argc, char* argv[]) {
}
}

if (cuda) {
if (!CUDA_BUILT && benchmark) {
return reportError("TACO must be built for CUDA (cmake -DCUDA=ON ..) to benchmark", 2);
}
set_CUDA_codegen_enabled(true);
}
else {
set_CUDA_codegen_enabled(false);
}

ir::Stmt assemble;
ir::Stmt compute;
ir::Stmt evaluate;
Expand All @@ -645,6 +925,26 @@ int main(int argc, char* argv[]) {
stmt = reorderLoopsTopologically(stmt);
stmt = insertTemporaries(stmt);
stmt = parallelizeOuterLoop(stmt);

if (setSchedule) {
stringstream scheduleStream;
for (string command : scheduleCommands) {
scheduleStream << command << endl;
}

cuda |= setSchedulingCommands(scheduleStream, parser, stmt);
}

if (cuda) {
if (!CUDA_BUILT && benchmark) {
return reportError("TACO must be built for CUDA (cmake -DCUDA=ON ..) to benchmark", 2);
}
set_CUDA_codegen_enabled(true);
}
else {
set_CUDA_codegen_enabled(false);
}

stmt = scalarPromote(stmt);
if (printConcrete) {
cout << stmt << endl;
Expand Down Expand Up @@ -749,7 +1049,7 @@ int main(int argc, char* argv[]) {
" * For both, the `_COO_pos` arrays contain two elements, where the first is 0\n"
" * and the second is the number of nonzeros in the tensor.\n"
" */";

vector<ir::Stmt> packs;
for (auto a : getArgumentAccesses(stmt)) {
TensorVar tensor = a.getTensorVar();
Expand Down