Skip to content

Commit

Permalink
[TensorExpr] Fix lowerings for aten::view and aten::reshape. (#65852)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #65852

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D31286024

Pulled By: ZolotukhinM

fbshipit-source-id: eb5b5f2ed86b6f325f09904e841815b8183b4e1d
  • Loading branch information
Mikhail Zolotukhin authored and facebook-github-bot committed Oct 12, 2021
1 parent 60a2a29 commit 6864146
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 44 deletions.
6 changes: 2 additions & 4 deletions test/test_jit_fuser_te.py
Original file line number Diff line number Diff line change
Expand Up @@ -2027,6 +2027,7 @@ def bn_neither(i, x):
'remainder',
'remainder.autodiffed',
'reshape',
'reshape_as',
'round',
'rsub',
'rsub.rsub_tensor',
Expand All @@ -2045,6 +2046,7 @@ def bn_neither(i, x):
'trunc',
'unsqueeze',
'view',
'view_as',
'where',
]

Expand All @@ -2060,8 +2062,6 @@ def bn_neither(i, x):
# Reference: https://github.com/pytorch/pytorch/pull/59442/checks?check_run_id=2746156896
't',
'conj'
'view',
'reshape',
]

def get_name(op):
Expand All @@ -2072,8 +2072,6 @@ def get_name(op):

class TestNNCOpInfo(TestCase):
def te_compile(self, device, dtype, op):
# If adding new OpInfo tests cause this test to fail, add it into here
skip_ops = ['view', 'reshape']
if op.name in skip_ops:
return
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
Expand Down
10 changes: 6 additions & 4 deletions torch/csrc/jit/tensorexpr/lowerings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1399,10 +1399,12 @@ RegisterNNCLoweringsFunction aten_expand(

// TODO: convert to schema, add a test
// RegisterNNCLoweringsFunction aten_flatten({"aten::flatten"}, computeFlatten);
// TODO: convert to schema, add a test
// RegisterNNCLoweringsFunction aten_view(
// {"aten::view", "aten::reshape"},
// computeReshape);
RegisterNNCLoweringsFunction aten_view(
{"aten::reshape(Tensor(a) self, int[] shape) -> (Tensor(a))",
"aten::reshape_as(Tensor(a) self, Tensor other) -> (Tensor(a))",
"aten::view(Tensor(a) self, int[] size) -> (Tensor(a))",
"aten::view_as(Tensor(a) self, Tensor other) -> (Tensor(a))"},
computeReshape);

// aten::mm is a subset of aten::matmul where both inputs are rank 2
RegisterNNCLoweringsFunction aten_matmul(
Expand Down
36 changes: 5 additions & 31 deletions torch/csrc/jit/tensorexpr/operators/misc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,12 +382,11 @@ Tensor computeExpand(
});
}

static Tensor computeReshapeHelper(
Tensor computeReshape(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const c10::optional<ScalarType>& outputType,
at::Device device,
const IntList& view_dims) {
at::Device device) {
auto A = c10::get<BufHandle>(inputs[0]);
if (A.ndim() == 0) {
return Compute(
Expand All @@ -403,7 +402,7 @@ static Tensor computeReshapeHelper(
c10::fmap<DimArg>(outputShape),
[&](const std::vector<VarHandle>& axes) {
std::vector<VarHandle> new_axes;
assert(view_dims.size() == axes.size());
assert(outputShape.size() == axes.size());
/*
Example for the index transformation. Assume we have a tensor A and
its view B:
Expand All @@ -419,11 +418,9 @@ static Tensor computeReshapeHelper(
idx = i5 + i4*2 + i3*2 + i2*18 + i1*18
B[i1,i2,i3,i4,i5] = A[idx/(3*2), (idx/3)%2, idx%3]
*/
// NOLINTNEXTLINE(clang-diagnostic-unused-variable)
ExprHandle cur_stride = 1;
std::vector<ExprPtr> dims, indices;
for (size_t idx = 0; idx < view_dims.size(); idx++) {
dims.push_back(alloc<LongImm>(view_dims[idx]));
for (size_t idx = 0; idx < outputShape.size(); idx++) {
dims.push_back(outputShape[idx].node());
indices.push_back(axes[idx].node());
}
ExprHandle flat_idx = ExprHandle(flatten_index(dims, indices));
Expand All @@ -448,29 +445,6 @@ static Tensor computeReshapeHelper(
});
}

Tensor computeFlatten(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const c10::optional<ScalarType>& outputType,
at::Device device) {
std::vector<int64_t> view_dims;
for (const auto dim : c10::irange(outputShape.size())) {
view_dims.push_back(outputShape[dim].AsNode<LongImm>()->value());
}
return computeReshapeHelper(
inputs, outputShape, outputType, device, view_dims);
}

Tensor computeReshape(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const c10::optional<ScalarType>& outputType,
at::Device device) {
const auto& view_dims = c10::get<IntList>(inputs[1]);
return computeReshapeHelper(
inputs, outputShape, outputType, device, view_dims);
}

static std::pair<ScalarType, std::vector<BufHandle>> processCatList(
const std::vector<BufHandle>& bufList) {
if (bufList.size() == 0) {
Expand Down
5 changes: 0 additions & 5 deletions torch/csrc/jit/tensorexpr/operators/misc.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,6 @@ Tensor computeExpand(
const std::vector<ExprHandle>& outputShape,
const c10::optional<ScalarType>& outputType,
at::Device device);
Tensor computeFlatten(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const c10::optional<ScalarType>& outputType,
at::Device device);
Tensor computeReshape(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
Expand Down

0 comments on commit 6864146

Please sign in to comment.