Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 91 additions & 1 deletion include/taco/index_notation/index_notation.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef TACO_INDEX_NOTATION_H
#define TACO_INDEX_NOTATION_H

#include <functional>
#include <ostream>
#include <string>
#include <memory>
Expand Down Expand Up @@ -30,6 +31,7 @@ class Format;
class Schedule;

class IndexVar;
class WindowedIndexVar;
class TensorVar;

class IndexExpr;
Expand Down Expand Up @@ -228,6 +230,22 @@ class Access : public IndexExpr {
/// Returns the index variables used to index into the Access's TensorVar.
const std::vector<IndexVar>& getIndexVars() const;

/// hasWindowedModes returns true if any accessed modes are windowed.
bool hasWindowedModes() const;

/// Returns whether or not the input mode (0-indexed) is windowed.
bool isModeWindowed(int mode) const;

/// Return the {lower,upper} bound of the window on the input mode (0-indexed).
int getWindowLowerBound(int mode) const;
int getWindowUpperBound(int mode) const;

/// getWindowDimension returns the dimension size of a window.
int getWindowDimension(int mode) const;

/// getStride returns the stride of a window.
int getStride(int mode) const;

/// Assign the result of an expression to a left-hand-side tensor access.
/// ```
/// a(i) = b(i) * c(i);
Expand Down Expand Up @@ -800,11 +818,72 @@ class Multi : public IndexStmt {
/// Create a multi index statement.
Multi multi(IndexStmt stmt1, IndexStmt stmt2);

/// IndexVarInterface is a marker superclass for IndexVar-like objects.
/// It is intended to be used in situations where many IndexVar-like objects
/// must be stored together, like when building an Access AST node where some
/// of the access variables are windowed. Use cases for IndexVarInterface
/// will inspect the underlying type of the IndexVarInterface. For sake of
/// completeness, the current implementers of IndexVarInterface are:
/// * IndexVar
/// * WindowedIndexVar
/// If this set changes, make sure to update the match function.
class IndexVarInterface {
public:
virtual ~IndexVarInterface() = default;

/// match performs a dynamic case analysis of the implementers of IndexVarInterface
/// as a utility for handling the different values within. It mimics the dynamic
/// type assertion of Go.
static void match(
std::shared_ptr<IndexVarInterface> ptr,
std::function<void(std::shared_ptr<IndexVar>)> ivarFunc,
std::function<void(std::shared_ptr<WindowedIndexVar>)> wvarFunc
) {
auto iptr = std::dynamic_pointer_cast<IndexVar>(ptr);
auto wptr = std::dynamic_pointer_cast<WindowedIndexVar>(ptr);
if (iptr != nullptr) {
ivarFunc(iptr);
} else if (wptr != nullptr) {
wvarFunc(wptr);
} else {
taco_iassert("IndexVarInterface was not IndexVar or WindowedIndexVar");
}
}
};

/// WindowedIndexVar represents an IndexVar that has been windowed. For example,
/// A(i) = B(i(2, 4))
/// In this case, i(2, 4) is a WindowedIndexVar. WindowedIndexVar is defined
/// before IndexVar so that IndexVar can return objects of type WindowedIndexVar.
class WindowedIndexVar : public util::Comparable<WindowedIndexVar>, public IndexVarInterface {
public:
WindowedIndexVar(IndexVar base, int lo = -1, int hi = -1, int stride = 1);
~WindowedIndexVar() = default;

/// getIndexVar returns the underlying IndexVar.
IndexVar getIndexVar() const;

/// get{Lower,Upper}Bound returns the {lower,upper} bound of the window of
/// this index variable.
int getLowerBound() const;
int getUpperBound() const;
/// getStride returns the stride to access the window by.
int getStride() const;

/// getWindowSize returns the number of elements in the window.
int getWindowSize() const;

private:
struct Content;
std::shared_ptr<Content> content;
};

/// Index variables are used to index into tensors in index expressions, and
/// they represent iteration over the tensor modes they index into.
class IndexVar : public util::Comparable<IndexVar> {
class IndexVar : public util::Comparable<IndexVar>, public IndexVarInterface {
public:
IndexVar();
~IndexVar() = default;
IndexVar(const std::string& name);

/// Returns the name of the index variable.
Expand All @@ -813,6 +892,8 @@ class IndexVar : public util::Comparable<IndexVar> {
friend bool operator==(const IndexVar&, const IndexVar&);
friend bool operator<(const IndexVar&, const IndexVar&);

/// Indexing into an IndexVar returns a window into it.
WindowedIndexVar operator()(int lo, int hi, int stride = 1);

private:
struct Content;
Expand All @@ -823,7 +904,16 @@ struct IndexVar::Content {
std::string name;
};

struct WindowedIndexVar::Content {
IndexVar base;
int lo;
int hi;
int stride;
};

std::ostream& operator<<(std::ostream&, const std::shared_ptr<IndexVarInterface>&);
std::ostream& operator<<(std::ostream&, const IndexVar&);
std::ostream& operator<<(std::ostream&, const WindowedIndexVar&);

/// A suchthat statement provides a set of IndexVarRel that constrain
/// the iteration space for the child concrete index notation
Expand Down
18 changes: 18 additions & 0 deletions include/taco/index_notation/index_notation_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,24 @@ struct AccessNode : public IndexExprNode {

TensorVar tensorVar;
std::vector<IndexVar> indexVars;

// An AccessNode carries the windowing information for an IndexVar + TensorVar
// combination. windowedModes contains the lower and upper bounds of each
// windowed mode (0-indexed).
struct Window {
int lo;
int hi;
int stride;
friend bool operator==(const Window& a, const Window& b) {
return a.lo == b.lo && a.hi == b.hi && a.stride == b.stride;
}
};
std::map<int, Window> windowedModes;

protected:
/// Initialize an AccessNode with just a TensorVar. If this constructor is used,
/// then indexVars must be set afterwards.
explicit AccessNode(TensorVar tensorVar) : IndexExprNode(tensorVar.getType().getDataType()), tensorVar(tensorVar) {}
};

struct LiteralNode : public IndexExprNode {
Expand Down
27 changes: 27 additions & 0 deletions include/taco/lower/iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,29 @@ class Iterator : public util::Comparable<Iterator> {
/// Returns true if the iterator is defined, false otherwise.
bool defined() const;

/// Methods for querying and operating on windowed tensor modes.

/// isWindowed returns true if this iterator is operating over a window
/// of a tensor mode.
bool isWindowed() const;

/// isStrided returns true if this iterator has a stride != 1. Currently
/// only windowed iterators can have strides.
bool isStrided() const;

/// getWindow{Lower,Upper}Bound return the {Lower,Upper} bound of the
/// window that this iterator operates over.
ir::Expr getWindowLowerBound() const;
ir::Expr getWindowUpperBound() const;

/// getStride returns an Expr holding the stride that this iterator is
/// configured with.
ir::Expr getStride() const;

/// getWindowVar returns a Var specific to thw window that this iterator
/// is operating over. It can be used as temporary storage.
ir::Expr getWindowVar() const;

friend bool operator==(const Iterator&, const Iterator&);
friend bool operator<(const Iterator&, const Iterator&);
friend std::ostream& operator<<(std::ostream&, const Iterator&);
Expand All @@ -169,6 +192,10 @@ class Iterator : public util::Comparable<Iterator> {

Iterator(std::shared_ptr<Content> content);
void setChild(const Iterator& iterator) const;

friend class Iterators;
/// setWindowBounds sets the window bounds of this iterator.
void setWindowBounds(ir::Expr lo, ir::Expr hi, ir::Expr stride);
};

/**
Expand Down
29 changes: 28 additions & 1 deletion include/taco/lower/lowerer_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -375,9 +375,36 @@ class LowererImpl : public util::Uncopyable {
/// Create an expression to index into a tensor value array.
ir::Expr generateValueLocExpr(Access access) const;

/// Expression that evaluates to true if none of the iteratators are exhausted
/// Expression that evaluates to true if none of the iterators are exhausted
ir::Expr checkThatNoneAreExhausted(std::vector<Iterator> iterators);

/// Expression that returns the beginning of a window to iterate over
/// in a compressed iterator. It is used when operating over windows of
/// tensors, instead of the full tensor.
ir::Expr searchForStartOfWindowPosition(Iterator iterator, ir::Expr start, ir::Expr end);

/// Statement that guards against going out of bounds of the window that
/// the input iterator was configured with.
ir::Stmt upperBoundGuardForWindowPosition(Iterator iterator, ir::Expr access);

/// Expression that recovers a canonical index variable from a position in
/// a windowed position iterator. A windowed position iterator iterates over
/// values in the range [lo, hi). This expression projects values in that
/// range back into the canonical range of [0, n).
ir::Expr projectWindowedPositionToCanonicalSpace(Iterator iterator, ir::Expr expr);

// projectCanonicalSpaceToWindowedPosition is the opposite of
// projectWindowedPositionToCanonicalSpace. It takes an expression ranging
// through the canonical space of [0, n) and projects it up to the windowed
// range of [lo, hi).
ir::Expr projectCanonicalSpaceToWindowedPosition(Iterator iterator, ir::Expr expr);

/// strideBoundsGuard inserts a guard against accessing values from an
/// iterator that don't fit in the stride that the iterator is configured
/// with. It takes a boolean incrementPosVars to control whether the outer
/// loop iterator variable should be incremented when the guard is fired.
ir::Stmt strideBoundsGuard(Iterator iterator, ir::Expr access, bool incrementPosVar);

private:
bool assemble;
bool compute;
Expand Down
83 changes: 83 additions & 0 deletions include/taco/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,9 @@ class TensorBase {
/// Create an index expression that accesses (reads or writes) this tensor.
Access operator()(const std::vector<IndexVar>& indices);

/// Create a possibly windowed index expression that accesses (reads or writes) this tensor.
Access operator()(const std::vector<std::shared_ptr<IndexVarInterface>>& indices);

/// Create an index expression that accesses (reads) this (scalar) tensor.
Access operator()();

Expand Down Expand Up @@ -621,6 +624,20 @@ class Tensor : public TensorBase {
template <typename... IndexVars>
Access operator()(const IndexVars&... indices);

/// The below two Access methods are used to allow users to access tensors
/// with a mix of IndexVar's and WindowedIndexVar's. This allows natural
/// expressions like
/// A(i, j(1, 3)) = B(i(2, 4), j) * C(i(5, 7), j(7, 9))
/// to be constructed without adjusting the original API.

/// Create an index expression that accesses (reads, writes) this tensor.
template <typename... IndexVars>
Access operator()(const WindowedIndexVar& first, const IndexVars&... indices);

/// Create an index expression that accesses (reads, writes) this tensor.
template <typename... IndexVars>
Access operator()(const IndexVar& first, const IndexVars&... indices);

ScalarAccess<CType> operator()(const std::vector<int>& indices);

/// Create an index expression that accesses (reads) this tensor.
Expand All @@ -629,6 +646,15 @@ class Tensor : public TensorBase {

/// Assign an expression to a scalar tensor.
void operator=(const IndexExpr& expr);

private:
/// The _access method family is the template level implementation of
/// Access() expressions containing mixes of IndexVar and WindowedIndexVar objects.
template <typename First, typename... Rest>
std::vector<std::shared_ptr<IndexVarInterface>> _access(const First& first, const Rest&... rest);
std::vector<std::shared_ptr<IndexVarInterface>> _access();
template <typename... Args>
Access _access_wrapper(const Args&... args);
};

template <typename CType>
Expand Down Expand Up @@ -1084,6 +1110,63 @@ Access Tensor<CType>::operator()(const IndexVars&... indices) {
return TensorBase::operator()(std::vector<IndexVar>{indices...});
}

/// The _access() methods perform primitive recursion on the input variadic template.
/// This means that each instance of the _access method matches on the first element
/// of the variadic template parameter pack, performs an "action", then recurses
/// with the remaining elements in the parameter pack through a recursive call
/// to _access. Since this is recursion, we need a base case. The empty argument
/// instance of _access returns an empty value of the desired type, in this case
/// a vector of IndexVarInterface.
template <typename CType>
std::vector<std::shared_ptr<IndexVarInterface>> Tensor<CType>::_access() {
return std::vector<std::shared_ptr<IndexVarInterface>>{};
}

/// The recursive case of _access matches on the first element, and attempts to
/// create a shared_ptr out of it. It then makes a recursive call to get a
/// vector with the rest of the elements. Then, it pushes the first element onto
/// the back of the vector -- this check ensures that the type First is indeed
/// a member of IndexVarInterface.
template <typename CType>
template <typename First, typename... Rest>
std::vector<std::shared_ptr<IndexVarInterface>> Tensor<CType>::_access(const First& first, const Rest&... rest) {
auto var = std::make_shared<First>(first);
auto ret = _access(rest...);
ret.push_back(var);
return ret;
}

/// _access_wrapper just calls into _access and reverses the result to get the initial
/// order of the arguments.
template <typename CType>
template <typename... Args>
Access Tensor<CType>::_access_wrapper(const Args&... args) {
auto resultReversed = this->_access(args...);
std::vector<std::shared_ptr<IndexVarInterface>> result;
result.reserve(resultReversed.size());
for (auto it = resultReversed.rbegin(); it != resultReversed.rend(); it++) {
result.push_back(*it);
}
return TensorBase::operator()(result);
}

/// We have to case on whether the first argument is an IndexVar or a WindowedIndexVar
/// so that the template engine can differentiate between the two versions.
// TODO (rohany): I think that there is a chance here that I might not need these
// two methods if I have _access. I think that instead I would just have to remove
// the other operator() methods that also take in IndexVar... so that there isn't
// any confusion.
template <typename CType>
template <typename... IndexVars>
Access Tensor<CType>::operator()(const IndexVar& first, const IndexVars&... indices) {
return this->_access_wrapper(first, indices...);
}
template <typename CType>
template <typename... IndexVars>
Access Tensor<CType>::operator()(const WindowedIndexVar& first, const IndexVars&... indices) {
return this->_access_wrapper(first, indices...);
}

template <typename CType>
ScalarAccess<CType> Tensor<CType>::operator()(const std::vector<int>& indices) {
taco_uassert(indices.size() == (size_t)getOrder())
Expand Down
2 changes: 1 addition & 1 deletion src/codegen/codegen_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ void CodeGen_C::visit(const Allocate* op) {
stream << ", ";
}
else {
stream << "malloc(";
stream << "calloc(1, ";
}
stream << "sizeof(" << elementType << ")";
stream << " * ";
Expand Down
8 changes: 8 additions & 0 deletions src/error/error_checks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ std::pair<bool, string> dimensionsTypecheck(const std::vector<IndexVar>& resultV
for (size_t mode = 0; mode < readNode->indexVars.size(); mode++) {
IndexVar var = readNode->indexVars[mode];
Dimension dimension = readNode->tensorVar.getType().getShape().getDimension(mode);

// If this access has windowed modes, use the dimensions of those windows
// as the shape, rather than the shape of the underlying tensor.
auto a = Access(readNode);
if (a.isModeWindowed(mode)) {
dimension = Dimension(a.getWindowDimension(mode));
}

if (util::contains(indexVarDims,var) && indexVarDims.at(var) != dimension) {
errors.push_back(addDimensionError(var, indexVarDims.at(var), dimension));
} else {
Expand Down
Loading