@@ -386,6 +386,9 @@ class TensorBase {
386386 // / Create an index expression that accesses (reads or writes) this tensor.
387387 Access operator ()(const std::vector<IndexVar>& indices);
388388
389+ // / Create a possibly windowed index expression that accesses (reads or writes) this tensor.
390+ Access operator ()(const std::vector<std::shared_ptr<IndexVarInterface>>& indices);
391+
389392 // / Create an index expression that accesses (reads) this (scalar) tensor.
390393 Access operator ()();
391394
@@ -621,6 +624,20 @@ class Tensor : public TensorBase {
621624 template <typename ... IndexVars>
622625 Access operator ()(const IndexVars&... indices);
623626
627+ // / The below two Access methods are used to allow users to access tensors
628+ // / with a mix of IndexVar's and WindowedIndexVar's. This allows natural
629+ // / expressions like
630+ // / A(i, j(1, 3)) = B(i(2, 4), j) * C(i(5, 7), j(7, 9))
631+ // / to be constructed without adjusting the original API.
632+
633+ // / Create an index expression that accesses (reads, writes) this tensor.
634+ template <typename ... IndexVars>
635+ Access operator ()(const WindowedIndexVar& first, const IndexVars&... indices);
636+
637+ // / Create an index expression that accesses (reads, writes) this tensor.
638+ template <typename ... IndexVars>
639+ Access operator ()(const IndexVar& first, const IndexVars&... indices);
640+
624641 ScalarAccess<CType> operator ()(const std::vector<int >& indices);
625642
626643 // / Create an index expression that accesses (reads) this tensor.
@@ -629,6 +646,15 @@ class Tensor : public TensorBase {
629646
630647 // / Assign an expression to a scalar tensor.
631648 void operator =(const IndexExpr& expr);
649+
650+ private:
651+ // / The _access method family is the template level implementation of
652+ // / Access() expressions containing mixes of IndexVar and WindowedIndexVar objects.
653+ template <typename First, typename ... Rest>
654+ std::vector<std::shared_ptr<IndexVarInterface>> _access (const First& first, const Rest&... rest);
655+ std::vector<std::shared_ptr<IndexVarInterface>> _access ();
656+ template <typename ... Args>
657+ Access _access_wrapper (const Args&... args);
632658};
633659
634660template <typename CType>
@@ -1084,6 +1110,63 @@ Access Tensor<CType>::operator()(const IndexVars&... indices) {
10841110 return TensorBase::operator ()(std::vector<IndexVar>{indices...});
10851111}
10861112
1113+ // / The _access() methods perform primitive recursion on the input variadic template.
1114+ // / This means that each instance of the _access method matches on the first element
1115+ // / of the variadic template parameter pack, performs an "action", then recurses
1116+ // / with the remaining elements in the parameter pack through a recursive call
1117+ // / to _access. Since this is recursion, we need a base case. The empty argument
1118+ // / instance of _access returns an empty value of the desired type, in this case
1119+ // / a vector of IndexVarInterface.
1120+ template <typename CType>
1121+ std::vector<std::shared_ptr<IndexVarInterface>> Tensor<CType>::_access() {
1122+ return std::vector<std::shared_ptr<IndexVarInterface>>{};
1123+ }
1124+
1125+ // / The recursive case of _access matches on the first element, and attempts to
1126+ // / create a shared_ptr out of it. It then makes a recursive call to get a
1127+ // / vector with the rest of the elements. Then, it pushes the first element onto
1128+ // / the back of the vector -- this check ensures that the type First is indeed
1129+ // / a member of IndexVarInterface.
1130+ template <typename CType>
1131+ template <typename First, typename ... Rest>
1132+ std::vector<std::shared_ptr<IndexVarInterface>> Tensor<CType>::_access(const First& first, const Rest&... rest) {
1133+ auto var = std::make_shared<First>(first);
1134+ auto ret = _access (rest...);
1135+ ret.push_back (var);
1136+ return ret;
1137+ }
1138+
1139+ // / _access_wrapper just calls into _access and reverses the result to get the initial
1140+ // / order of the arguments.
1141+ template <typename CType>
1142+ template <typename ... Args>
1143+ Access Tensor<CType>::_access_wrapper(const Args&... args) {
1144+ auto resultReversed = this ->_access (args...);
1145+ std::vector<std::shared_ptr<IndexVarInterface>> result;
1146+ result.reserve (resultReversed.size ());
1147+ for (auto it = resultReversed.rbegin (); it != resultReversed.rend (); it++) {
1148+ result.push_back (*it);
1149+ }
1150+ return TensorBase::operator ()(result);
1151+ }
1152+
1153+ // / We have to case on whether the first argument is an IndexVar or a WindowedIndexVar
1154+ // / so that the template engine can differentiate between the two versions.
1155+ // TODO (rohany): I think that there is a chance here that I might not need these
1156+ // two methods if I have _access. I think that instead I would just have to remove
1157+ // the other operator() methods that also take in IndexVar... so that there isn't
1158+ // any confusion.
1159+ template <typename CType>
1160+ template <typename ... IndexVars>
1161+ Access Tensor<CType>::operator ()(const IndexVar& first, const IndexVars&... indices) {
1162+ return this ->_access_wrapper (first, indices...);
1163+ }
1164+ template <typename CType>
1165+ template <typename ... IndexVars>
1166+ Access Tensor<CType>::operator ()(const WindowedIndexVar& first, const IndexVars&... indices) {
1167+ return this ->_access_wrapper (first, indices...);
1168+ }
1169+
10871170template <typename CType>
10881171ScalarAccess<CType> Tensor<CType>::operator ()(const std::vector<int >& indices) {
10891172 taco_uassert (indices.size () == (size_t )getOrder ())
0 commit comments