Skip to content

Commit

Permalink
Removes Var Kind constructors
Browse files Browse the repository at this point in the history
  • Loading branch information
fredrikbk committed May 20, 2017
1 parent 372661d commit 2c543dd
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 49 deletions.
2 changes: 1 addition & 1 deletion apps/tensor_times_vector/tensor_times_vector.cpp
Expand Up @@ -26,7 +26,7 @@ int main(int argc, char* argv[]) {
c.pack();

// Form a tensor-vector multiplication expression
Var i, j, k(Var::Sum);
Var i, j, k;
A(i,j) = B(i,j,k) * c(k);

// Compile the expression
Expand Down
32 changes: 6 additions & 26 deletions include/taco/expr.h
Expand Up @@ -23,35 +23,15 @@ class ExprVisitorStrict;
/// represent iteration over a tensor dimension.
class Var : public util::Comparable<Var> {
public:
enum Kind { Free, Sum };
Var();
Var(const std::string& name);

private:
struct Content {
Var::Kind kind;
std::string name;
};

public:
Var(Kind kind = Kind::Free);
Var(const std::string& name, Kind kind = Kind::Free);

std::string getName() const {return content->name;}

Kind getKind() const {return content->kind;}

bool isFree() const {return content->kind == Free;}

bool isReduction() const {return content->kind != Free;}

friend bool operator==(const Var& l, const Var& r) {
return l.content == r.content;
}

friend bool operator<(const Var& l, const Var& r) {
return l.content < r.content;
}
std::string getName() const;
friend bool operator==(const Var&, const Var&);
friend bool operator<(const Var&, const Var&);

private:
struct Content;
std::shared_ptr<Content> content;
};

Expand Down
20 changes: 17 additions & 3 deletions src/expr.cpp
Expand Up @@ -8,12 +8,26 @@ using namespace std;
namespace taco {

// class Var
Var::Var(const std::string& name, Kind kind) : content(new Content) {
struct Var::Content {
std::string name;
};

Var::Var() : Var(util::uniqueName('i')) {}

Var::Var(const std::string& name) : content(new Content) {
content->name = name;
content->kind = kind;
}

Var::Var(Kind kind) : Var(util::uniqueName('i'), kind) {
std::string Var::getName() const {
return content->name;
}

bool operator==(const Var& a, const Var& b) {
return a.content == b.content;
}

bool operator<(const Var& a, const Var& b) {
return a.content < b.content;
}

std::ostream& operator<<(std::ostream& os, const Var& var) {
Expand Down
2 changes: 1 addition & 1 deletion src/parser/parser.cpp
Expand Up @@ -313,7 +313,7 @@ bool Parser::hasIndexVar(std::string name) const {
Var Parser::getIndexVar(string name) const {
taco_iassert(name != "");
if (!hasIndexVar(name)) {
Var var(name, (content->parsingLhs ? Var::Free : Var::Sum));
Var var(name);
content->indexVars.insert({name, var});

// dimensionSizes can also store index var sizes
Expand Down
6 changes: 0 additions & 6 deletions src/tensor.cpp
Expand Up @@ -536,12 +536,6 @@ void TensorBase::setExpr(const vector<taco::Var>& indexVars, taco::Expr expr) {
})
);

// Check that the index variables on the left-hand-side are free
for (auto& indexVar : indexVars) {
taco_uassert(indexVar.getKind() == Var::Free) <<
"Can only use free index variables to index the left-hand-side";
}

// The following are index expressions we don't currently support, but that
// are planned for the future.
// We don't yet support distributing tensors. That is, every free variable
Expand Down
18 changes: 9 additions & 9 deletions test/expr_factory.cpp
Expand Up @@ -40,7 +40,7 @@ MatrixMultiplyFactory::operator()(Tensors& operands, Format outFormat) {
Tensor<double> A({operands[0].getDimensions()[0],
operands[1].getDimensions()[1]}, outFormat);

Var i("i"), j("j"), k("k", Var::Sum);
Var i("i"), j("j"), k("k");
A(i,j) = operands[0](i,k) * operands[1](k,j);

return A;
Expand All @@ -53,7 +53,7 @@ MatrixTransposeMultiplyFactory::operator()(Tensors& operands,

Tensor<double> A(operands[0].getDimensions(), outFormat);

Var i("i"), j("j"), k("k", Var::Sum);
Var i("i"), j("j"), k("k");
A(i,j) = operands[0](k,i) * operands[0](k,j);

return A;
Expand All @@ -66,7 +66,7 @@ MatrixColumnSquaredNormFactory::operator()(Tensors& operands,

Tensor<double> A({operands[0].getDimensions()[1]}, outFormat);

Var i("i"), j("j", Var::Sum);
Var i("i"), j("j");
A(i) = operands[0](j,i) * operands[0](j,i);

return A;
Expand All @@ -91,7 +91,7 @@ MTTKRP1Factory::operator()(Tensors& operands, Format outFormat) {
Tensor<double> A({operands[0].getDimensions()[0],
operands[1].getDimensions()[1]}, outFormat);

Var i("i"), j("j"), k("k", Var::Sum), l("l", Var::Sum);
Var i("i"), j("j"), k("k"), l("l");
A(i,j) = operands[0](i,k,l) * operands[2](l,j) * operands[1](k,j);

return A;
Expand All @@ -104,7 +104,7 @@ MTTKRP2Factory::operator()(Tensors& operands, Format outFormat) {
Tensor<double> A({operands[0].getDimensions()[1],
operands[1].getDimensions()[1]}, outFormat);

Var i("i"), j("j"), k("k", Var::Sum), l("l", Var::Sum);
Var i("i"), j("j"), k("k"), l("l");
A(i,j) = operands[0](k,i,l) * operands[2](l,j) * operands[1](k,j);

return A;
Expand All @@ -117,7 +117,7 @@ MTTKRP3Factory::operator()(Tensors& operands, Format outFormat) {
Tensor<double> A({operands[0].getDimensions()[2],
operands[1].getDimensions()[1]}, outFormat);

Var i("i"), j("j"), k("k", Var::Sum), l("l", Var::Sum);
Var i("i"), j("j"), k("k"), l("l");
A(i,j) = operands[0](k,l,i) * operands[2](l,j) * operands[1](k,j);

return A;
Expand All @@ -129,7 +129,7 @@ TensorSquaredNormFactory::operator()(Tensors& operands, Format outFormat) {

Tensor<double> A({}, outFormat);

Var i("i", Var::Sum), j("j", Var::Sum), k("k", Var::Sum);
Var i("i"), j("j"), k("k");
A() = operands[0](i,j,k) * operands[0](i,j,k);

return A;
Expand All @@ -142,7 +142,7 @@ FactorizedTensorSquaredNormFactory::operator()(Tensors& operands,

Tensor<double> A({}, outFormat);

Var i("i", Var::Sum), j("j", Var::Sum);
Var i("i"), j("j");
A() = operands[0](i) * operands[0](j) * operands[1](i,j) *
operands[2](i,j) * operands[3](i,j);

Expand All @@ -156,7 +156,7 @@ FactorizedTensorInnerProductFactory::operator()(Tensors& operands,

Tensor<double> A({}, outFormat);

Var i("i", Var::Sum), j("j", Var::Sum), k("k", Var::Sum), r("r", Var::Sum);
Var i("i"), j("j"), k("k"), r("r");
A() = operands[0](i,j,k) * operands[1](r) * operands[2](i,r) *
operands[3](j,r) * operands[4](k,r);

Expand Down
2 changes: 1 addition & 1 deletion test/expr_storage-tests.cpp
Expand Up @@ -16,7 +16,7 @@ typedef std::vector<IndexArray> Index; // [0,2] index arrays per Index
typedef std::vector<Index> Indices; // One Index per level

Var i("i"), j("j"), m("m"), n("n");
Var k("k", Var::Sum), l("l", Var::Sum);
Var k("k"), l("l");


struct TestData {
Expand Down
3 changes: 1 addition & 2 deletions test/storage_alloc-tests.cpp
Expand Up @@ -52,8 +52,7 @@ TEST_P(alloc, storage) {
ASSERT_STORAGE_EQUALS(expectedIndices, expectedValues, tensor);
}

Var i("i"), j("j"), m("m"), n("n");
Var k("k", Var::Sum), l("l", Var::Sum);
Var i("i"), j("j"), m("m"), n("n"), k("k"), l("l");

IndexArray dlab_indices() {
IndexArray indices;
Expand Down

0 comments on commit 2c543dd

Please sign in to comment.