Skip to content

Conversation

@bertmaher
Copy link
Contributor

@bertmaher bertmaher commented Aug 17, 2021

@facebook-github-bot facebook-github-bot added cla signed oncall: jit Add this issue/PR to JIT oncall triage queue labels Aug 17, 2021
bertmaher added a commit that referenced this pull request Aug 17, 2021
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Aug 17, 2021

🔗 Helpful links

💊 CI failures summary and remediations

As of commit 82d2274 (more details on the Dr. CI page):


  • 5/5 failures possibly* introduced in this PR
    • 1/5 non-scanned failure(s)

🕵️ 4 new failures recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See GitHub Actions build linux-bionic-py3.8-gcc9-coverage / build (1/4)

Step: "Build PyTorch" (full log | diagnosis details | 🔁 rerun)

2021-08-20T03:19:03.3988576Z Build left local git repository checkout dirty
2021-08-20T03:18:55.6508299Z multiple input files                 13
2021-08-20T03:18:55.6508526Z 
2021-08-20T03:18:55.6509886Z Cache location                  S3, bucket: Bucket(name=ossci-compiler-cache-circleci-v2, base_url=http://ossci-compiler-cache-circleci-v2.s3.amazonaws.com/)
2021-08-20T03:18:55.6511149Z + assert_git_not_dirty
2021-08-20T03:18:55.6511754Z + [[ linux-bionic-py3.8-gcc9-coverage != *rocm* ]]
2021-08-20T03:18:55.6512543Z + [[ linux-bionic-py3.8-gcc9-coverage != *xla* ]]
2021-08-20T03:18:55.6513145Z ++ git status --porcelain
2021-08-20T03:19:03.3986367Z + git_status='?? third_party/breakpad/'
2021-08-20T03:19:03.3986935Z + [[ -n ?? third_party/breakpad/ ]]
2021-08-20T03:19:03.3987837Z + echo 'Build left local git repository checkout dirty'
2021-08-20T03:19:03.3988576Z Build left local git repository checkout dirty
2021-08-20T03:19:03.3989325Z + echo 'git status --porcelain:'
2021-08-20T03:19:03.3989786Z git status --porcelain:
2021-08-20T03:19:03.3990235Z + echo '?? third_party/breakpad/'
2021-08-20T03:19:03.3994996Z ?? third_party/breakpad/
2021-08-20T03:19:03.3995611Z + exit 1
2021-08-20T03:19:03.3996079Z + cleanup
2021-08-20T03:19:03.3996355Z + retcode=1
2021-08-20T03:19:03.3996646Z + set +x
2021-08-20T03:19:03.3997007Z =================== sccache compilation log ===================
2021-08-20T03:19:03.4154050Z =========== If your build fails, please take a look at the log above for possible reasons ===========

See CircleCI build pytorch_linux_xenial_py3_6_gcc5_4_build (2/4)

Step: "(Optional) Merge target branch" (full log | diagnosis details | 🔁 rerun)

Automatic merge failed; fix conflicts and then commit the result.
+ git merge --allow-unrelated-histories --no-edit --no-ff 0a66d5b3253fd2d2304f3897526db3c8fb139376
Auto-merging torch/testing/_internal/common_methods_invocations.py
Auto-merging torch/quantization/qconfig.py
CONFLICT (content): Merge conflict in torch/quantization/qconfig.py
Auto-merging torch/quantization/fx/prepare.py
CONFLICT (content): Merge conflict in torch/quantization/fx/prepare.py
Auto-merging torch/csrc/jit/tensorexpr/loopnest.cpp
Auto-merging test/test_jit.py
Auto-merging test/test_fx.py
Auto-merging test/cpp/tensorexpr/test_loopnest.cpp
Automatic merge failed; fix conflicts and then commit the result.


Exited with code exit status 1

See CircleCI build pytorch_linux_xenial_py3_clang7_asan_test1 (3/4)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Aug 20 03:47:46 SUMMARY: UndefinedBehaviorSanit.../jenkins/workspace/aten/src/ATen/Utils.cpp:20:3 in
Aug 20 03:47:46     #4 0x5558acdad15f  (/opt/conda/bin/python3.6+0x13015f)
Aug 20 03:47:46     #5 0x5558acdef8f2  (/opt/conda/bin/python3.6+0x1728f2)
Aug 20 03:47:46     #6 0x5558ace57cd5  (/opt/conda/bin/python3.6+0x1dacd5)
Aug 20 03:47:46     #7 0x5558ace59d5d  (/opt/conda/bin/python3.6+0x1dcd5d)
Aug 20 03:47:46     #8 0x5558ace59dbb  (/opt/conda/bin/python3.6+0x1dcdbb)
Aug 20 03:47:46     #9 0x5558ace5a926  (/opt/conda/bin/python3.6+0x1dd926)
Aug 20 03:47:46     #10 0x5558acd94196  (/opt/conda/bin/python3.6+0x117196)
Aug 20 03:47:46     #11 0x7feb0668f83f  (/lib/x86_64-linux-gnu/libc.so.6+0x2083f)
Aug 20 03:47:46     #12 0x5558ace2433d  (/opt/conda/bin/python3.6+0x1a733d)
Aug 20 03:47:46 
Aug 20 03:47:46 SUMMARY: UndefinedBehaviorSanitizer: undefined-behavior /var/lib/jenkins/workspace/aten/src/ATen/Utils.cpp:20:3 in 
Aug 20 03:47:46 + retcode=1
Aug 20 03:47:46 + set -e
Aug 20 03:47:46 + return 1
Aug 20 03:47:46 + [[ pytorch-linux-xenial-py3-clang7-asan-test1 == *-NO_AVX-* ]]
Aug 20 03:47:46 + [[ '' == \n\o\g\p\u\_\N\O\_\A\V\X ]]
Aug 20 03:47:46 + [[ pytorch-linux-xenial-py3-clang7-asan-test1 == *-NO_AVX2-* ]]
Aug 20 03:47:46 + [[ '' == \n\o\g\p\u\_\N\O\_\A\V\X\2 ]]
Aug 20 03:47:46 + [[ pytorch-linux-xenial-py3-clang7-asan-test1 == *-NO_AVX512-* ]]
Aug 20 03:47:46 + [[ '' == \n\o\g\p\u\_\N\O\_\A\V\X\5\1\2 ]]
Aug 20 03:47:46 + '[' -n https://github.com/pytorch/pytorch/pull/63386 ']'

See CircleCI build pytorch_xla_linux_bionic_py3_6_clang9_build (4/4)

Step: "(Optional) Merge target branch" (full log | diagnosis details | 🔁 rerun)

Automatic merge failed; fix conflicts and then commit the result.
+ git merge --allow-unrelated-histories --no-edit --no-ff 0a66d5b3253fd2d2304f3897526db3c8fb139376
Auto-merging torch/testing/_internal/common_methods_invocations.py
Auto-merging torch/quantization/qconfig.py
CONFLICT (content): Merge conflict in torch/quantization/qconfig.py
Auto-merging torch/quantization/fx/prepare.py
CONFLICT (content): Merge conflict in torch/quantization/fx/prepare.py
Auto-merging torch/csrc/jit/tensorexpr/loopnest.cpp
Auto-merging test/test_jit.py
Auto-merging test/test_fx.py
Auto-merging test/cpp/tensorexpr/test_loopnest.cpp
Automatic merge failed; fix conflicts and then commit the result.


Exited with code exit status 1


1 job timed out:

  • pytorch_linux_xenial_py3_clang7_asan_test1

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@bertmaher
Copy link
Contributor Author

@bertmaher has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

bertmaher added a commit that referenced this pull request Aug 17, 2021
Pull Request resolved: #63386

Differential Revision: [D30360382](https://our.internmc.facebook.com/intern/diff/D30360382/)
ghstack-source-id: 136005946
@bertmaher bertmaher requested review from ZolotukhinM, huiguoo and navahgar and removed request for ZolotukhinM August 18, 2021 20:24
bertmaher added a commit that referenced this pull request Aug 18, 2021
@bertmaher
Copy link
Contributor Author

@bertmaher has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@bertmaher
Copy link
Contributor Author

@bertmaher has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

for (int64_t i = loops.size(); i > 0; i--) {
auto const& loop = loops[i - 1];
if (auto stop = to<IntImm>(loop->stop())) {
grainSize *= stop->value();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is assuming the loops are normalized at this point. While that might be true mostly, I'm not sure that is guaranteed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think anything's likely to un-normalize the loops before this point, but maybe I should simplify (and do stop-start) just to be safe. I guess the worst that happens is we miss a parallelization opportunity.

template <typename Bufs>
static void parallelizeOuterLoops(LoopNest& l, Bufs&& bufs) {
for (auto const& buf : bufs) {
auto loops = l.getLoopStmtsFor(buf);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this function is called after fuseAllLoops it is possible that multiple buffers belong to the same loopnest. So, we could be repeating this loop multiple times for the same loopnest. I understand that that may not be incorrect at this point, but it could lead to bugs in future.

IMO, we shouldn't be looking at output buffers and their loopnests. Instead, we should just take root_stmt in the given LoopNest and apply parallelization for all loopnests in that stmt. Wdyt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems OK to me, sure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess where things could get a little weird is if multiple buffers are updated at different levels of the loopnest, e.g.:

for i:
  y1[] = ...
  for j:
    y2[] = ...

The current approach sort of gives each buffer an "independent" chance to affect the loop parallelization. That doesn't seem terrible, tbh.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually the more I think about it, is there really any advantage to starting with the root stmt and working down? From my POV it just makes the code a lot more complicated; the way things happen now I just get a nice vector of loops leading to a buffer and try to flatten them. If it's not flattenable, it simply fails and I give up.

Although, maybe it's not too much work to build up my own vector starting from the root. Idk.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you are okay with having the same set of loops being handled here for different bufs, then I have no objections to it.

Personally, I felt starting from the root_stmt might be better. We might need another API to extract all loops in the root_stmt. So, may be we can do this in future.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, going through all buffers would not miss parallelism opportunities like the following

for i 
  for j1
      y1 = ... [data dependence exists between iterations]
  for j2
     y2 = ... [no data dependence between iterations]

i+j1 cannot be parallelized because there's data dependence between iterations for y1; but i+j2 can be parallelized and we should not miss it. If this is the thing we try to do here, I guess we need a distribute transformation before flatten. flatten currently only handles perfectly nested loops.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah the approach here will definitely miss opportunities where some nested loops are parallelizable. It's really kind of a best-effort thing to get simple elementwise fusions right, not a general solution to parallelism.

continue;
}
// Try to flatten the outer loops and parallelize them if successful.
For* flattened = nullptr;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: ForPtr please ;)

Copy link
Contributor

@navahgar navahgar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

callee(index, packed_data);
}
});
} catch (...) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kinda curious about this place: why not terminate if there's an exception? I guess the execution of the left stmts would ultimately lead to wrong results?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting point... if an exception happens here things are really screwed up, because we don't know how to unwind past llvm-generated frames. But no exceptions should be possible here, since we're just parallel-dispatching to our own kernel, which doesn't throw exceptions. So I was mainly putting the try-catch here to ensure that the compiler knew that it wouldn't need to unwind this frame.

@bertmaher
Copy link
Contributor Author

@bertmaher has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@bertmaher merged this pull request in d6d86ef.

@facebook-github-bot
Copy link
Contributor

This pull request has been reverted by 37d60c0.

@facebook-github-bot facebook-github-bot deleted the gh/bertmaher/149/head branch August 24, 2021 14:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged oncall: jit Add this issue/PR to JIT oncall triage queue Reverted

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants