Skip to content

Commit

Permalink
Minor bugfix in transform_rfactor.cpp (#1715)
Browse files Browse the repository at this point in the history
`found_non_rfactor_reduction` is used to detect errors when all reduction dims are marked as rfactors. However, this code is not finding non-rfactor reduction, but instead arbitrary reduction. Fortunately, other parts of our code could detect the same error, so this bug does not have any real effect. But still, I think we need to fix this.
  • Loading branch information
zasdfgbnm committed May 19, 2022
1 parent 3675c70 commit 4ceeee5
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions torch/csrc/jit/codegen/cuda/transform_rfactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,11 @@ TensorDomain* TransformRFactor::runReplay(
{
size_t i = 0;
for (auto id : orig_td->domain()) {
if (axes_set.find(i++) != axes_set.end())
if (axes_set.find(i++) != axes_set.end()) {
rfactor_axes.emplace(id);
if (id->isReduction())
} else if (id->isReduction()) {
found_non_rfactor_reduction = true;
}
}
}

Expand Down

0 comments on commit 4ceeee5

Please sign in to comment.