Skip to content

Commit 4fd7744

Browse files
committed
*: add support for windowing of tensors
This commit adds support for windowing of tensors in the existing index notation DSL. For example: ``` A(i, j) = B(i(1, 4), j) * C(i, j(5, 10)) ``` causes `B` to be windowed along its first mode, and `C` to be windowed along its second mode. In this commit any mix of windowed and non-windowed modes are supported, along with windowing the same tensor in different ways in the same expression. The windowing expressions correspond to the `:` operator to slice dimensions in `numpy`. Currently, only windowing by integers is supported. Windowing is achieved by tying windowing information to particular `Iterator` objects, as these are created for each `Tensor`-`IndexVar` pair. When iterating over an `Iterator` that may be windowed, extra steps are taken to either generate an index into the windowed space, or to recover an index from a point in the windowed space.
1 parent 468ad7f commit 4fd7744

File tree

14 files changed

+756
-23
lines changed

14 files changed

+756
-23
lines changed

include/taco/index_notation/index_notation.h

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#ifndef TACO_INDEX_NOTATION_H
22
#define TACO_INDEX_NOTATION_H
33

4+
#include <functional>
45
#include <ostream>
56
#include <string>
67
#include <memory>
@@ -30,6 +31,7 @@ class Format;
3031
class Schedule;
3132

3233
class IndexVar;
34+
class WindowedIndexVar;
3335
class TensorVar;
3436

3537
class IndexExpr;
@@ -228,6 +230,16 @@ class Access : public IndexExpr {
228230
/// Returns the index variables used to index into the Access's TensorVar.
229231
const std::vector<IndexVar>& getIndexVars() const;
230232

233+
/// hasWindowedModes returns true if any accessed modes are windowed.
234+
bool hasWindowedModes() const;
235+
236+
/// Returns whether or not the input mode (0-indexed) is windowed.
237+
bool isModeWindowed(int mode) const;
238+
239+
/// Return the {lower,upper} bound of the window on the input mode (0-indexed).
240+
int getWindowLowerBound(int mode) const;
241+
int getWindowUpperBound(int mode) const;
242+
231243
/// Assign the result of an expression to a left-hand-side tensor access.
232244
/// ```
233245
/// a(i) = b(i) * c(i);
@@ -800,11 +812,67 @@ class Multi : public IndexStmt {
800812
/// Create a multi index statement.
801813
Multi multi(IndexStmt stmt1, IndexStmt stmt2);
802814

815+
/// IndexVarInterface is a marker superclass for IndexVar-like objects.
816+
/// It is intended to be used in situations where many IndexVar-like objects
817+
/// must be stored together, like when building an Access AST node where some
818+
/// of the access variables are windowed. Use cases for IndexVarInterface
819+
/// will inspect the underlying type of the IndexVarInterface. For sake of
820+
/// completeness, the current implementers of IndexVarInterface are:
821+
/// * IndexVar
822+
/// * WindowedIndexVar
823+
/// If this set changes, make sure to update the match function.
824+
class IndexVarInterface {
825+
public:
826+
virtual ~IndexVarInterface() = default;
827+
828+
/// match performs a dynamic case analysis of the implementers of IndexVarInterface
829+
/// as a utility for handling the different values within. It mimics the dynamic
830+
/// type assertion of Go.
831+
static void match(
832+
std::shared_ptr<IndexVarInterface> ptr,
833+
std::function<void(std::shared_ptr<IndexVar>)> ivarFunc,
834+
std::function<void(std::shared_ptr<WindowedIndexVar>)> wvarFunc
835+
) {
836+
auto iptr = std::dynamic_pointer_cast<IndexVar>(ptr);
837+
auto wptr = std::dynamic_pointer_cast<WindowedIndexVar>(ptr);
838+
if (iptr != nullptr) {
839+
ivarFunc(iptr);
840+
} else if (wptr != nullptr) {
841+
wvarFunc(wptr);
842+
} else {
843+
taco_iassert("IndexVarInterface was not IndexVar or WindowedIndexVar");
844+
}
845+
}
846+
};
847+
848+
/// WindowedIndexVar represents an IndexVar that has been windowed. For example,
849+
/// A(i) = B(i(2, 4))
850+
/// In this case, i(2, 4) is a WindowedIndexVar. WindowedIndexVar is defined
851+
/// before IndexVar so that IndexVar can return objects of type WindowedIndexVar.
852+
class WindowedIndexVar : public util::Comparable<WindowedIndexVar>, public IndexVarInterface {
853+
public:
854+
WindowedIndexVar(IndexVar base, int lo = -1, int hi = -1);
855+
~WindowedIndexVar() = default;
856+
857+
/// getIndexVar returns the underlying IndexVar.
858+
IndexVar getIndexVar() const;
859+
860+
/// get{Lower,Upper}Bound returns the {lower,upper} bound of the window of
861+
/// this index variable.
862+
int getLowerBound() const;
863+
int getUpperBound() const;
864+
865+
private:
866+
struct Content;
867+
std::shared_ptr<Content> content;
868+
};
869+
803870
/// Index variables are used to index into tensors in index expressions, and
804871
/// they represent iteration over the tensor modes they index into.
805-
class IndexVar : public util::Comparable<IndexVar> {
872+
class IndexVar : public util::Comparable<IndexVar>, public IndexVarInterface {
806873
public:
807874
IndexVar();
875+
~IndexVar() = default;
808876
IndexVar(const std::string& name);
809877

810878
/// Returns the name of the index variable.
@@ -813,6 +881,8 @@ class IndexVar : public util::Comparable<IndexVar> {
813881
friend bool operator==(const IndexVar&, const IndexVar&);
814882
friend bool operator<(const IndexVar&, const IndexVar&);
815883

884+
/// Indexing into an IndexVar returns a window into it.
885+
WindowedIndexVar operator()(int lo, int hi);
816886

817887
private:
818888
struct Content;
@@ -823,7 +893,15 @@ struct IndexVar::Content {
823893
std::string name;
824894
};
825895

896+
struct WindowedIndexVar::Content {
897+
IndexVar base;
898+
int lo;
899+
int hi;
900+
};
901+
902+
std::ostream& operator<<(std::ostream&, const std::shared_ptr<IndexVarInterface>&);
826903
std::ostream& operator<<(std::ostream&, const IndexVar&);
904+
std::ostream& operator<<(std::ostream&, const WindowedIndexVar&);
827905

828906
/// A suchthat statement provides a set of IndexVarRel that constrain
829907
/// the iteration space for the child concrete index notation

include/taco/index_notation/index_notation_nodes.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,23 @@ struct AccessNode : public IndexExprNode {
2626

2727
TensorVar tensorVar;
2828
std::vector<IndexVar> indexVars;
29+
30+
// An AccessNode carries the windowing information for an IndexVar + TensorVar
31+
// combination. windowedModes contains the lower and upper bounds of each
32+
// windowed mode (0-indexed).
33+
struct Window {
34+
int lo;
35+
int hi;
36+
friend bool operator==(const Window& a, const Window& b) {
37+
return a.lo == b.lo && a.hi == b.hi;
38+
}
39+
};
40+
std::map<int, Window> windowedModes;
41+
42+
protected:
43+
/// Initialize an AccessNode with just a TensorVar. If this constructor is used,
44+
/// then indexVars must be set afterwards.
45+
explicit AccessNode(TensorVar tensorVar) : IndexExprNode(tensorVar.getType().getDataType()), tensorVar(tensorVar) {}
2946
};
3047

3148
struct LiteralNode : public IndexExprNode {

include/taco/lower/iterator.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,17 @@ class Iterator : public util::Comparable<Iterator> {
159159
/// Returns true if the iterator is defined, false otherwise.
160160
bool defined() const;
161161

162+
/// Methods for querying and operating on windowed tensor modes.
163+
164+
/// isWindowed returns true if this iterator is operating over a window
165+
/// of a tensor mode.
166+
bool isWindowed() const;
167+
168+
/// getWindow{Lower,Upper}Bound return the {Lower,Upper} bound of the
169+
/// window that this iterator operates over.
170+
ir::Expr getWindowLowerBound() const;
171+
ir::Expr getWindowUpperBound() const;
172+
162173
friend bool operator==(const Iterator&, const Iterator&);
163174
friend bool operator<(const Iterator&, const Iterator&);
164175
friend std::ostream& operator<<(std::ostream&, const Iterator&);
@@ -169,6 +180,10 @@ class Iterator : public util::Comparable<Iterator> {
169180

170181
Iterator(std::shared_ptr<Content> content);
171182
void setChild(const Iterator& iterator) const;
183+
184+
friend class Iterators;
185+
/// setWindowBounds sets the window bounds of this iterator.
186+
void setWindowBounds(ir::Expr lo, ir::Expr hi);
172187
};
173188

174189
/**

include/taco/lower/lowerer_impl.h

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,9 +375,30 @@ class LowererImpl : public util::Uncopyable {
375375
/// Create an expression to index into a tensor value array.
376376
ir::Expr generateValueLocExpr(Access access) const;
377377

378-
/// Expression that evaluates to true if none of the iteratators are exhausted
378+
/// Expression that evaluates to true if none of the iterators are exhausted
379379
ir::Expr checkThatNoneAreExhausted(std::vector<Iterator> iterators);
380380

381+
/// Expression that returns the beginning of a window to iterate over
382+
/// in a compressed iterator. It is used when operating over windows of
383+
/// tensors, instead of the full tensor.
384+
ir::Expr searchForStartOfWindowPosition(Iterator iterator, ir::Expr start, ir::Expr end);
385+
386+
/// Statement that guards against going out of bounds of the window that
387+
/// the input iterator was configured with.
388+
ir::Stmt upperBoundGuardForWindowPosition(Iterator iterator, ir::Expr access);
389+
390+
/// Expression that recovers a canonical index variable from a position in
391+
/// a windowed position iterator. A windowed position iterator iterates over
392+
/// values in the range [lo, hi). This expression projects values in that
393+
/// range back into the canonical range of [0, n).
394+
ir::Expr projectWindowedPositionToCanonicalSpace(Iterator iterator, ir::Expr expr);
395+
396+
// projectCanonicalSpaceToWindowedPosition is the opposite of
397+
// projectWindowedPositionToCanonicalSpace. It takes an expression ranging
398+
// through the canonical space of [0, n) and projects it up to the windowed
399+
// range of [lo, hi).
400+
ir::Expr projectCanonicalSpaceToWindowedPosition(Iterator iterator, ir::Expr expr);
401+
381402
private:
382403
bool assemble;
383404
bool compute;

include/taco/tensor.h

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,9 @@ class TensorBase {
386386
/// Create an index expression that accesses (reads or writes) this tensor.
387387
Access operator()(const std::vector<IndexVar>& indices);
388388

389+
/// Create a possibly windowed index expression that accesses (reads or writes) this tensor.
390+
Access operator()(const std::vector<std::shared_ptr<IndexVarInterface>>& indices);
391+
389392
/// Create an index expression that accesses (reads) this (scalar) tensor.
390393
Access operator()();
391394

@@ -621,6 +624,20 @@ class Tensor : public TensorBase {
621624
template <typename... IndexVars>
622625
Access operator()(const IndexVars&... indices);
623626

627+
/// The below two Access methods are used to allow users to access tensors
628+
/// with a mix of IndexVar's and WindowedIndexVar's. This allows natural
629+
/// expressions like
630+
/// A(i, j(1, 3)) = B(i(2, 4), j) * C(i(5, 7), j(7, 9))
631+
/// to be constructed without adjusting the original API.
632+
633+
/// Create an index expression that accesses (reads, writes) this tensor.
634+
template <typename... IndexVars>
635+
Access operator()(const WindowedIndexVar& first, const IndexVars&... indices);
636+
637+
/// Create an index expression that accesses (reads, writes) this tensor.
638+
template <typename... IndexVars>
639+
Access operator()(const IndexVar& first, const IndexVars&... indices);
640+
624641
ScalarAccess<CType> operator()(const std::vector<int>& indices);
625642

626643
/// Create an index expression that accesses (reads) this tensor.
@@ -629,6 +646,15 @@ class Tensor : public TensorBase {
629646

630647
/// Assign an expression to a scalar tensor.
631648
void operator=(const IndexExpr& expr);
649+
650+
private:
651+
/// The _access method family is the template level implementation of
652+
/// Access() expressions containing mixes of IndexVar and WindowedIndexVar objects.
653+
template <typename First, typename... Rest>
654+
std::vector<std::shared_ptr<IndexVarInterface>> _access(const First& first, const Rest&... rest);
655+
std::vector<std::shared_ptr<IndexVarInterface>> _access();
656+
template <typename... Args>
657+
Access _access_wrapper(const Args&... args);
632658
};
633659

634660
template <typename CType>
@@ -1084,6 +1110,63 @@ Access Tensor<CType>::operator()(const IndexVars&... indices) {
10841110
return TensorBase::operator()(std::vector<IndexVar>{indices...});
10851111
}
10861112

1113+
/// The _access() methods perform primitive recursion on the input variadic template.
1114+
/// This means that each instance of the _access method matches on the first element
1115+
/// of the variadic template parameter pack, performs an "action", then recurses
1116+
/// with the remaining elements in the parameter pack through a recursive call
1117+
/// to _access. Since this is recursion, we need a base case. The empty argument
1118+
/// instance of _access returns an empty value of the desired type, in this case
1119+
/// a vector of IndexVarInterface.
1120+
template <typename CType>
1121+
std::vector<std::shared_ptr<IndexVarInterface>> Tensor<CType>::_access() {
1122+
return std::vector<std::shared_ptr<IndexVarInterface>>{};
1123+
}
1124+
1125+
/// The recursive case of _access matches on the first element, and attempts to
1126+
/// create a shared_ptr out of it. It then makes a recursive call to get a
1127+
/// vector with the rest of the elements. Then, it pushes the first element onto
1128+
/// the back of the vector -- this check ensures that the type First is indeed
1129+
/// a member of IndexVarInterface.
1130+
template <typename CType>
1131+
template <typename First, typename... Rest>
1132+
std::vector<std::shared_ptr<IndexVarInterface>> Tensor<CType>::_access(const First& first, const Rest&... rest) {
1133+
auto var = std::make_shared<First>(first);
1134+
auto ret = _access(rest...);
1135+
ret.push_back(var);
1136+
return ret;
1137+
}
1138+
1139+
/// _access_wrapper just calls into _access and reverses the result to get the initial
1140+
/// order of the arguments.
1141+
template <typename CType>
1142+
template <typename... Args>
1143+
Access Tensor<CType>::_access_wrapper(const Args&... args) {
1144+
auto resultReversed = this->_access(args...);
1145+
std::vector<std::shared_ptr<IndexVarInterface>> result;
1146+
result.reserve(resultReversed.size());
1147+
for (auto it = resultReversed.rbegin(); it != resultReversed.rend(); it++) {
1148+
result.push_back(*it);
1149+
}
1150+
return TensorBase::operator()(result);
1151+
}
1152+
1153+
/// We have to case on whether the first argument is an IndexVar or a WindowedIndexVar
1154+
/// so that the template engine can differentiate between the two versions.
1155+
// TODO (rohany): I think that there is a chance here that I might not need these
1156+
// two methods if I have _access. I think that instead I would just have to remove
1157+
// the other operator() methods that also take in IndexVar... so that there isn't
1158+
// any confusion.
1159+
template <typename CType>
1160+
template <typename... IndexVars>
1161+
Access Tensor<CType>::operator()(const IndexVar& first, const IndexVars&... indices) {
1162+
return this->_access_wrapper(first, indices...);
1163+
}
1164+
template <typename CType>
1165+
template <typename... IndexVars>
1166+
Access Tensor<CType>::operator()(const WindowedIndexVar& first, const IndexVars&... indices) {
1167+
return this->_access_wrapper(first, indices...);
1168+
}
1169+
10871170
template <typename CType>
10881171
ScalarAccess<CType> Tensor<CType>::operator()(const std::vector<int>& indices) {
10891172
taco_uassert(indices.size() == (size_t)getOrder())

src/codegen/codegen_c.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,7 @@ void CodeGen_C::visit(const Allocate* op) {
516516
stream << ", ";
517517
}
518518
else {
519-
stream << "malloc(";
519+
stream << "calloc(1, ";
520520
}
521521
stream << "sizeof(" << elementType << ")";
522522
stream << " * ";

src/error/error_checks.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,14 @@ std::pair<bool, string> dimensionsTypecheck(const std::vector<IndexVar>& resultV
5353
for (size_t mode = 0; mode < readNode->indexVars.size(); mode++) {
5454
IndexVar var = readNode->indexVars[mode];
5555
Dimension dimension = readNode->tensorVar.getType().getShape().getDimension(mode);
56+
57+
// If this access has windowed modes, use the dimensions of those windows
58+
// as the shape, rather than the shape of the underlying tensor.
59+
auto a = Access(readNode);
60+
if (a.isModeWindowed(mode)) {
61+
dimension = Dimension(a.getWindowUpperBound(mode) - a.getWindowLowerBound(mode));
62+
}
63+
5664
if (util::contains(indexVarDims,var) && indexVarDims.at(var) != dimension) {
5765
errors.push_back(addDimensionError(var, indexVarDims.at(var), dimension));
5866
} else {

0 commit comments

Comments
 (0)