Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 82 additions & 2 deletions include/taco/index_notation/index_notation.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef TACO_INDEX_NOTATION_H
#define TACO_INDEX_NOTATION_H

#include <functional>
#include <ostream>
#include <string>
#include <memory>
Expand Down Expand Up @@ -30,13 +31,15 @@ class Format;
class Schedule;

class IndexVar;
class WindowedIndexVar;
class TensorVar;

class IndexExpr;
class Assignment;
class Access;

struct AccessNode;
struct AccessWindow;
struct LiteralNode;
struct NegNode;
struct SqrtNode;
Expand Down Expand Up @@ -220,14 +223,25 @@ class Access : public IndexExpr {
Access() = default;
Access(const Access&) = default;
Access(const AccessNode*);
Access(const TensorVar& tensorVar, const std::vector<IndexVar>& indices={});
Access(const TensorVar &tensorVar, const std::vector<IndexVar> &indices = {},
const std::map<int, AccessWindow> &windows = {});

/// Return the Access expression's TensorVar.
const TensorVar &getTensorVar() const;

/// Returns the index variables used to index into the Access's TensorVar.
const std::vector<IndexVar>& getIndexVars() const;

/// hasWindowedModes returns true if any accessed modes are windowed.
bool hasWindowedModes() const;

/// Returns whether or not the input mode (0-indexed) is windowed.
bool isModeWindowed(int mode) const;

/// Return the {lower,upper} bound of the window on the input mode (0-indexed).
int getWindowLowerBound(int mode) const;
int getWindowUpperBound(int mode) const;

/// Assign the result of an expression to a left-hand-side tensor access.
/// ```
/// a(i) = b(i) * c(i);
Expand Down Expand Up @@ -800,11 +814,67 @@ class Multi : public IndexStmt {
/// Create a multi index statement.
Multi multi(IndexStmt stmt1, IndexStmt stmt2);

/// IndexVarInterface is a marker superclass for IndexVar-like objects.
/// It is intended to be used in situations where many IndexVar-like objects
/// must be stored together, like when building an Access AST node where some
/// of the access variables are windowed. Use cases for IndexVarInterface
/// will inspect the underlying type of the IndexVarInterface. For sake of
/// completeness, the current implementers of IndexVarInterface are:
/// * IndexVar
/// * WindowedIndexVar
/// If this set changes, make sure to update the match function.
class IndexVarInterface {
public:
virtual ~IndexVarInterface() = default;

/// match performs a dynamic case analysis of the implementers of IndexVarInterface
/// as a utility for handling the different values within. It mimics the dynamic
/// type assertion of Go.
static void match(
std::shared_ptr<IndexVarInterface> ptr,
std::function<void(std::shared_ptr<IndexVar>)> ivarFunc,
std::function<void(std::shared_ptr<WindowedIndexVar>)> wvarFunc
) {
auto iptr = std::dynamic_pointer_cast<IndexVar>(ptr);
auto wptr = std::dynamic_pointer_cast<WindowedIndexVar>(ptr);
if (iptr != nullptr) {
ivarFunc(iptr);
} else if (wptr != nullptr) {
wvarFunc(wptr);
} else {
taco_iassert("IndexVarInterface was not IndexVar or WindowedIndexVar");
}
}
};

/// WindowedIndexVar represents an IndexVar that has been windowed. For example,
/// A(i) = B(i(2, 4))
/// In this case, i(2, 4) is a WindowedIndexVar. WindowedIndexVar is defined
/// before IndexVar so that IndexVar can return objects of type WindowedIndexVar.
class WindowedIndexVar : public util::Comparable<WindowedIndexVar>, public IndexVarInterface {
public:
WindowedIndexVar(IndexVar base, int lo = -1, int hi = -1);
~WindowedIndexVar() = default;

/// getIndexVar returns the underlying IndexVar.
IndexVar getIndexVar() const;

/// get{Lower,Upper}Bound returns the {lower,upper} bound of the window of
/// this index variable.
int getLowerBound() const;
int getUpperBound() const;

private:
struct Content;
std::shared_ptr<Content> content;
};

/// Index variables are used to index into tensors in index expressions, and
/// they represent iteration over the tensor modes they index into.
class IndexVar : public util::Comparable<IndexVar> {
class IndexVar : public util::Comparable<IndexVar>, public IndexVarInterface {
public:
IndexVar();
~IndexVar() = default;
IndexVar(const std::string& name);

/// Returns the name of the index variable.
Expand All @@ -813,6 +883,8 @@ class IndexVar : public util::Comparable<IndexVar> {
friend bool operator==(const IndexVar&, const IndexVar&);
friend bool operator<(const IndexVar&, const IndexVar&);

/// Indexing into an IndexVar returns a window into it.
WindowedIndexVar operator()(int lo, int hi);

private:
struct Content;
Expand All @@ -823,7 +895,15 @@ struct IndexVar::Content {
std::string name;
};

struct WindowedIndexVar::Content {
IndexVar base;
int lo;
int hi;
};

std::ostream& operator<<(std::ostream&, const std::shared_ptr<IndexVarInterface>&);
std::ostream& operator<<(std::ostream&, const IndexVar&);
std::ostream& operator<<(std::ostream&, const WindowedIndexVar&);

/// A suchthat statement provides a set of IndexVarRel that constrain
/// the iteration space for the child concrete index notation
Expand Down
21 changes: 19 additions & 2 deletions include/taco/index_notation/index_notation_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,21 @@

namespace taco {

// An AccessNode carries the windowing information for an IndexVar + TensorVar
// combination. An AccessWindow contains the lower and upper bounds of each
// windowed mode (0-indexed). AccessWindow is extracted from AccessNode so that
// it can be referenced externally.
struct AccessWindow {
int lo;
int hi;
friend bool operator==(const AccessWindow& a, const AccessWindow& b) {
return a.lo == b.lo && a.hi == b.hi;
}
};

struct AccessNode : public IndexExprNode {
AccessNode(TensorVar tensorVar, const std::vector<IndexVar>& indices)
: IndexExprNode(tensorVar.getType().getDataType()), tensorVar(tensorVar), indexVars(indices) {}
AccessNode(TensorVar tensorVar, const std::vector<IndexVar>& indices, const std::map<int, AccessWindow>& windows={})
: IndexExprNode(tensorVar.getType().getDataType()), tensorVar(tensorVar), indexVars(indices), windowedModes(windows) {}

void accept(IndexExprVisitorStrict* v) const {
v->visit(this);
Expand All @@ -26,6 +37,12 @@ struct AccessNode : public IndexExprNode {

TensorVar tensorVar;
std::vector<IndexVar> indexVars;
std::map<int, AccessWindow> windowedModes;

protected:
/// Initialize an AccessNode with just a TensorVar. If this constructor is used,
/// then indexVars must be set afterwards.
explicit AccessNode(TensorVar tensorVar) : IndexExprNode(tensorVar.getType().getDataType()), tensorVar(tensorVar) {}
};

struct LiteralNode : public IndexExprNode {
Expand Down
3 changes: 2 additions & 1 deletion include/taco/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -688,9 +688,10 @@ struct Allocate : public StmtNode<Allocate> {
Expr num_elements;
Expr old_elements; // used for realloc in CUDA
bool is_realloc;
bool clear; // Whether to use calloc to allocate this memory.

static Stmt make(Expr var, Expr num_elements, bool is_realloc=false,
Expr old_elements=Expr());
Expr old_elements=Expr(), bool clear=false);

static const IRNodeType _type_info = IRNodeType::Allocate;
};
Expand Down
15 changes: 15 additions & 0 deletions include/taco/lower/iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,17 @@ class Iterator : public util::Comparable<Iterator> {
/// Returns true if the iterator is defined, false otherwise.
bool defined() const;

/// Methods for querying and operating on windowed tensor modes.

/// isWindowed returns true if this iterator is operating over a window
/// of a tensor mode.
bool isWindowed() const;

/// getWindow{Lower,Upper}Bound return the {Lower,Upper} bound of the
/// window that this iterator operates over.
ir::Expr getWindowLowerBound() const;
ir::Expr getWindowUpperBound() const;

friend bool operator==(const Iterator&, const Iterator&);
friend bool operator<(const Iterator&, const Iterator&);
friend std::ostream& operator<<(std::ostream&, const Iterator&);
Expand All @@ -169,6 +180,10 @@ class Iterator : public util::Comparable<Iterator> {

Iterator(std::shared_ptr<Content> content);
void setChild(const Iterator& iterator) const;

friend class Iterators;
/// setWindowBounds sets the window bounds of this iterator.
void setWindowBounds(ir::Expr lo, ir::Expr hi);
};

/**
Expand Down
23 changes: 22 additions & 1 deletion include/taco/lower/lowerer_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -375,9 +375,30 @@ class LowererImpl : public util::Uncopyable {
/// Create an expression to index into a tensor value array.
ir::Expr generateValueLocExpr(Access access) const;

/// Expression that evaluates to true if none of the iteratators are exhausted
/// Expression that evaluates to true if none of the iterators are exhausted
ir::Expr checkThatNoneAreExhausted(std::vector<Iterator> iterators);

/// Expression that returns the beginning of a window to iterate over
/// in a compressed iterator. It is used when operating over windows of
/// tensors, instead of the full tensor.
ir::Expr searchForStartOfWindowPosition(Iterator iterator, ir::Expr start, ir::Expr end);

/// Statement that guards against going out of bounds of the window that
/// the input iterator was configured with.
ir::Stmt upperBoundGuardForWindowPosition(Iterator iterator, ir::Expr access);

/// Expression that recovers a canonical index variable from a position in
/// a windowed position iterator. A windowed position iterator iterates over
/// values in the range [lo, hi). This expression projects values in that
/// range back into the canonical range of [0, n).
ir::Expr projectWindowedPositionToCanonicalSpace(Iterator iterator, ir::Expr expr);

// projectCanonicalSpaceToWindowedPosition is the opposite of
// projectWindowedPositionToCanonicalSpace. It takes an expression ranging
// through the canonical space of [0, n) and projects it up to the windowed
// range of [lo, hi).
ir::Expr projectCanonicalSpaceToWindowedPosition(Iterator iterator, ir::Expr expr);

private:
bool assemble;
bool compute;
Expand Down
79 changes: 79 additions & 0 deletions include/taco/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,9 @@ class TensorBase {
/// Create an index expression that accesses (reads or writes) this tensor.
Access operator()(const std::vector<IndexVar>& indices);

/// Create a possibly windowed index expression that accesses (reads or writes) this tensor.
Access operator()(const std::vector<std::shared_ptr<IndexVarInterface>>& indices);

/// Create an index expression that accesses (reads) this (scalar) tensor.
Access operator()();

Expand Down Expand Up @@ -621,6 +624,20 @@ class Tensor : public TensorBase {
template <typename... IndexVars>
Access operator()(const IndexVars&... indices);

/// The below two Access methods are used to allow users to access tensors
/// with a mix of IndexVar's and WindowedIndexVar's. This allows natural
/// expressions like
/// A(i, j(1, 3)) = B(i(2, 4), j) * C(i(5, 7), j(7, 9))
/// to be constructed without adjusting the original API.

/// Create an index expression that accesses (reads, writes) this tensor.
template <typename... IndexVars>
Access operator()(const WindowedIndexVar& first, const IndexVars&... indices);

/// Create an index expression that accesses (reads, writes) this tensor.
template <typename... IndexVars>
Access operator()(const IndexVar& first, const IndexVars&... indices);

ScalarAccess<CType> operator()(const std::vector<int>& indices);

/// Create an index expression that accesses (reads) this tensor.
Expand All @@ -629,6 +646,15 @@ class Tensor : public TensorBase {

/// Assign an expression to a scalar tensor.
void operator=(const IndexExpr& expr);

private:
/// The _access method family is the template level implementation of
/// Access() expressions containing mixes of IndexVar and WindowedIndexVar objects.
template <typename First, typename... Rest>
std::vector<std::shared_ptr<IndexVarInterface>> _access(const First& first, const Rest&... rest);
std::vector<std::shared_ptr<IndexVarInterface>> _access();
template <typename... Args>
Access _access_wrapper(const Args&... args);
};

template <typename CType>
Expand Down Expand Up @@ -1084,6 +1110,59 @@ Access Tensor<CType>::operator()(const IndexVars&... indices) {
return TensorBase::operator()(std::vector<IndexVar>{indices...});
}

/// The _access() methods perform primitive recursion on the input variadic template.
/// This means that each instance of the _access method matches on the first element
/// of the variadic template parameter pack, performs an "action", then recurses
/// with the remaining elements in the parameter pack through a recursive call
/// to _access. Since this is recursion, we need a base case. The empty argument
/// instance of _access returns an empty value of the desired type, in this case
/// a vector of IndexVarInterface.
template <typename CType>
std::vector<std::shared_ptr<IndexVarInterface>> Tensor<CType>::_access() {
return std::vector<std::shared_ptr<IndexVarInterface>>{};
}

/// The recursive case of _access matches on the first element, and attempts to
/// create a shared_ptr out of it. It then makes a recursive call to get a
/// vector with the rest of the elements. Then, it pushes the first element onto
/// the back of the vector -- this check ensures that the type First is indeed
/// a member of IndexVarInterface.
template <typename CType>
template <typename First, typename... Rest>
std::vector<std::shared_ptr<IndexVarInterface>> Tensor<CType>::_access(const First& first, const Rest&... rest) {
auto var = std::make_shared<First>(first);
auto ret = _access(rest...);
ret.push_back(var);
return ret;
}

/// _access_wrapper just calls into _access and reverses the result to get the initial
/// order of the arguments.
template <typename CType>
template <typename... Args>
Access Tensor<CType>::_access_wrapper(const Args&... args) {
auto resultReversed = this->_access(args...);
std::vector<std::shared_ptr<IndexVarInterface>> result;
result.reserve(resultReversed.size());
for (auto& it : util::reverse(resultReversed)) {
result.push_back(it);
}
return TensorBase::operator()(result);
}

/// We have to case on whether the first argument is an IndexVar or a WindowedIndexVar
/// so that the template engine can differentiate between the two versions.
template <typename CType>
template <typename... IndexVars>
Access Tensor<CType>::operator()(const IndexVar& first, const IndexVars&... indices) {
return this->_access_wrapper(first, indices...);
}
template <typename CType>
template <typename... IndexVars>
Access Tensor<CType>::operator()(const WindowedIndexVar& first, const IndexVars&... indices) {
return this->_access_wrapper(first, indices...);
}

template <typename CType>
ScalarAccess<CType> Tensor<CType>::operator()(const std::vector<int>& indices) {
taco_uassert(indices.size() == (size_t)getOrder())
Expand Down
8 changes: 7 additions & 1 deletion src/codegen/codegen_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,13 @@ void CodeGen_C::visit(const Allocate* op) {
stream << ", ";
}
else {
stream << "malloc(";
// If the allocation was requested to clear the allocated memory,
// use calloc instead of malloc.
if (op->clear) {
stream << "calloc(1, ";
} else {
stream << "malloc(";
}
}
stream << "sizeof(" << elementType << ")";
stream << " * ";
Expand Down
9 changes: 7 additions & 2 deletions src/codegen/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1293,9 +1293,14 @@ void CodeGen_CUDA::visit(const Call* op) {
stream << op->func << "(";
parentPrecedence = Precedence::CALL;

// Need to print cast to type so that arguments match
// Need to print cast to type so that arguments match.
if (op->args.size() > 0) {
if (op->type != op->args[0].type() || isa<Literal>(op->args[0])) {
// However, the binary search arguments take int* as their first
// argument. This pointer information isn't carried anywhere in
// the argument expressions, so we need to special case and not
// emit an invalid cast for that argument.
auto opIsBinarySearch = op->func == "taco_binarySearchAfter" || op->func == "taco_binarySearchBefore";
if (!opIsBinarySearch && (op->type != op->args[0].type() || isa<Literal>(op->args[0]))) {
stream << "(" << printCUDAType(op->type, false) << ") ";
}
op->args[0].accept(this);
Expand Down
Loading