Skip to content

Conversation

nickgg
Copy link
Contributor

@nickgg nickgg commented Aug 11, 2020

Insert the registerizer into the Cuda Codegen pass list, to enable scalar replacement and close the gap in simple reduction performance.

First up the good stuff, benchmark before:

          Column sum          Caffe2             NNC          Simple          Better
           (10, 100)          5.7917          9.7037          6.9386          6.0448
          (100, 100)          5.9338          14.972          7.1139          6.3254
        (100, 10000)          21.453          741.54          145.74          12.555
        (1000, 1000)          8.0678          122.75          22.833          9.0778

             Row sum          Caffe2             NNC          Simple          Better
           (10, 100)          5.4502          7.9661          6.1469          5.5587
          (100, 100)          5.7613          13.897           21.49          5.5808
        (100, 10000)          21.702          82.398          75.462          22.793
        (1000, 1000)          22.527             129          176.51          22.517

After:

          Column sum          Caffe2             NNC          Simple          Better
           (10, 100)          6.0458          9.4966          7.1094           6.056
          (100, 100)          5.9299          9.1482          7.1693           6.593
        (100, 10000)          21.739          121.97          162.63          14.376
        (1000, 1000)          9.2374           29.01          26.883          10.127

             Row sum          Caffe2             NNC          Simple          Better
           (10, 100)          5.9773          8.1792          7.2307          5.8941
          (100, 100)          6.1456          9.3155          24.563          5.8163
        (100, 10000)          25.384          30.212          88.531          27.185
        (1000, 1000)          26.517          32.702          209.31          26.537

Speedup about 3-8x depending on the size of the data (increasing with bigger inputs).

The gap between NNC and simple is closed or eliminated - remaining issue appears to be kernel launch overhead. Next up is getting us closer to the Better kernel.

It required a lot of refactoring and bug fixes on the way:

  • Refactored flattening of parallelized loops out of the CudaPrinter and into its own stage, so we can transform the graph in the stage between flattening and printing (where registerization occurs).
  • Made AtomicAddFuser less pessimistic, it will now recognize that if an Add to a buffer is dependent on all used Block and Thread vars then it has no overlap and does not need to be atomic. This allows registerization to apply to these stores.
  • Fixed PrioritizeLoad mutator so that it does not attempt to separate the Store and Load to the same buffer (i.e. reduction case).
  • Moved CudaAnalysis earlier in the process, allowing later stages to use the analyzed bufs.
  • Fixed a bug in the Registerizer where when adding a default initializer statement it would use the dtype of the underlying var (which is always kHandle) instead of the dtype of the Buf.
  • Fixed a bug in the IRMutator where Allocate statements logic was inverted to be replaced only if they did not change.
  • Added simplification of simple Division patterns to the IRSimplifier.

@nickgg nickgg requested a review from apaszke as a code owner August 11, 2020 20:19
@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Aug 11, 2020
@nickgg nickgg requested review from bertmaher and zheng-xq August 11, 2020 20:20
@dr-ci
Copy link

dr-ci bot commented Aug 11, 2020

💊 CI failures summary and remediations

As of commit e7690af (more details on the Dr. CI page):


  • 2/2 failures possibly* introduced in this PR
    • 2/2 non-CircleCI failure(s)

Extra GitHub checks: 1 failed


codecov.io: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 51 times.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: how about the case where thread_idx is outside block_idx? It is okay to ignore this for now. But maybe a TODO to remind ourselves of this case down the road.

for i in (0..10):   # theadIdx
   for j in (0..100):  # blockIdx
      t1 = alloc(1)
      ....

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, this is an important case. I'll come back to it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: Maybe a TODO to handle the case where the size of the shared memory is dynamic. Then the dynamic size needs to be provided at the kernel launch time.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason why "metavars" is not a const ref, similar to "thread_local_bufs"? The most common reason for a const pointer is that it could be nullptr, but it doesn't seem to be the case here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only reason was to match the local style of the file: storing as a unique_ptr and passing as ptr (see CudaAnalysis and it's use by CudaPrinter).

NBD to change this one, but wouldn't want to add a general cuda codgen cleanup into this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: Style wise, I mildly prefer "std::vector<const Expr*> block_extents = ...", over "const ... &" here. With return-value-optimization, both does the same thing. But the first is easier to tell the scope of the temporary object, while the second relies on a special rule in C++ to extend its lifespan. Both are correct, just some people might trip over why the 2nd form is correct.

Either way is fine with me, if you are more comfortable with that style.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not particularly bothered either way, I can switch to value.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: this seems to be idiomatic in our code base now. Maybe we should have an overloaded "immediateEquals" that does this now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, makes sense - in a follow up though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great feature! This seems so fundamental that deserves its own directed tests. I would like to see a list of more explict expressions, and which will become AtomicAdd, vs others won't.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, this needs a lot more testing. The whole Cuda Codegen does. We'll have to come back to it.

@nickgg nickgg force-pushed the registerizerCuda branch 8 times, most recently from 092de19 to d652db0 Compare August 27, 2020 21:02
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nickgg has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@codecov
Copy link

codecov bot commented Aug 28, 2020

Codecov Report

Merging #42878 into master will decrease coverage by 0.00%.
The diff coverage is n/a.

Impacted file tree graph

@@            Coverage Diff             @@
##           master   #42878      +/-   ##
==========================================
- Coverage   69.41%   69.40%   -0.01%     
==========================================
  Files         378      378              
  Lines       46602    46602              
==========================================
- Hits        32347    32345       -2     
- Misses      14255    14257       +2     
Impacted Files Coverage Δ
torch/testing/_internal/common_utils.py 76.61% <0.00%> (-0.19%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 73dcfc5...e7690af. Read the comment docs.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nickgg has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@nickgg merged this pull request in 1390cad.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants