New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get more fusion after autodiff uses SumToSize #14957

Open
wants to merge 34 commits into
base: master
from

Conversation

Projects
None yet
6 participants
@t-vi
Copy link
Contributor

t-vi commented Dec 9, 2018

Here is a fresh attempt at getting some fusion back in autodiff-generated graphs in the presence of SumToSize.

  • The sum to size operator is now aten::_grad_sum_to_size to allow symbolic script differentiation (and that in turn would need to use this in place of sum_to_size to signal that it strictly operates on gradients). This is also used in the autodiff code, replacing prim::SumToSize.
  • _grad_sum_to_size is now fusable, cats - which are fused afterwards thanks to Adam's simplification of the code - are only fused if there is no _grad_sum_to_size in the fusion group.
  • I push the _grad_sum_to_size out of the the fusion group when compiling and record the desired summations in the KernelSpec. The reasoning is the following:
    • As the autodiff is a repeated applicaiton of the chain rule, we always have the pattern grad_in = mm(A, grad_out), with A often diagonal for cases interesting to the fuser, whence it is grad_in = a * grad_out (a pointwise multiplication). We know that only grad_out may have AutodiffGradSumToSize applied, so we can commute AutodiffGradSumToSize with the mul (and div and neg are of similar origin).
    • For type_as the gradient might be giving the type, so just skip SumToSize,
    • add (which was inserted as prim::AutogradAdd) adding gradients when the forward used the same value in several places. This is non-broadcasting, so we know that the two arguments would have the same sizes as inputs - which is good so we don't have to do bookkeeping of the two parts.

Details:

  • During fusion, the Tensor arguments are always kept as the first parameters of the fusion group to accomodate indexing assumptions in the fuser.
  • The rewriting of the fusion group to record the necessary output transformation and eliminate _grad_sum_to_size from the fusion group is now in the fuser compile step.
  • In the execution step, the arguments are split into Tensor / Non-Tensor and the non-tensor args are mostly forgotten about except for doing sum_to_size at the end. This would want to be improved if/when we fuse nonconstant scalar arguments.
  • In a number of places in the fuser, the non-Tensor arguments to the fusion group needed to be ignored.

Thank you, @apaszke for the insightful discussion. All bad ideas and errors are my own.

@apaszke
Copy link
Member

apaszke left a comment

I think the removal of SumToSize nodes is happening way too early. Basically, you shouldn’t think of FusionGroups as graphs that have already been fused and will conform to those semantics, but graphs eligible to be fused. That means that we still want to preserve the original semantics of the code, because it might turn out that our fusion guesses were wrong, and will end up running a deoptimized version of the original code. Instead, we should allow putting them in FusionGroups and simply remove them right before a kernel is compiled (once we know that the fusion is valid, etc.).

Finally, marking those nodes as fusible is a bad idea, because the only reason why you might put them in a fusion group is because you are certain that it will help you perform more fusions. That should be checked and processed similarly to how we deal with rearranging chunk nodes.

Show resolved Hide resolved torch/csrc/jit/autodiff.cpp Outdated
Show resolved Hide resolved torch/csrc/jit/passes/graph_fuser.cpp Outdated
@t-vi

This comment has been minimized.

Copy link
Contributor

t-vi commented Dec 10, 2018

Thanks for your comments Adam!

  • I'll rename the prim::GradSumToSize.
  • So I'll move the graph rewriting into the fuser codegen.
  • For "when to fuse SumToSize", would it be OK to put them into the fusion group if there aren't any FusedConcat nodes in there?
    This would mean that we might end up with SumToSize at the top of the fusion group which we would undo before the fusion, but I'm a bit weary that GraphFuser.run will get considerably more complicated if we split out the scan phase as done for chunk.
@t-vi

This comment has been minimized.

Copy link
Contributor

t-vi commented Dec 13, 2018

Hmh. I need to rebase.
So I think I'm not fusing sumtosize any more when concat is close.
I'm not as sure about the "when to relocate sumtosize": If I move that into kernel generation, is it still safe to move the sumtosize to outside the fusion group? I'll try that next, but I'm still a bit sceptical about it.

apaszke and others added some commits Dec 11, 2018

Add support for batch_norm fusion to the JIT
We don't support reductions yet, but simply decomposing batch_norm
into a kernel that computes the stats, and the fusing everything else
with ReLU and following pointwise ops provides nice speedups.

Note that this is only limited to inference mode for now, because we
don't support convolutions and batch norm in AD, so the fuser isn't
applied to those parts.
Separate aten::cat fusion from pointwise fusion
That makes that definition of a "fusable node" much simpler,
as we don't need to keep considering whether something has to be an
"exit node" at every step. The fuser now tries to maximize the
pointwise fusions first, and proceeds to prepending chunks and appending
concats only once a fix point is reached.

This patch not only makes the fuser much simpler to reason about,
making it siginifcantly easier to implement features like SumToSize
fusion, to improve performance of derivative graphs.
Add support for batch_norm fusion to the JIT
We don't support reductions yet, but simply decomposing batch_norm
into a kernel that computes the stats, and the fusing everything else
with ReLU and following pointwise ops provides nice speedups.

Note that this is only limited to inference mode for now, because we
don't support convolutions and batch norm in AD, so the fuser isn't
applied to those parts.
Separate aten::cat fusion from pointwise fusion
That makes that definition of a "fusable node" much simpler,
as we don't need to keep considering whether something has to be an
"exit node" at every step. The fuser now tries to maximize the
pointwise fusions first, and proceeds to prepending chunks and appending
concats only once a fix point is reached.

This patch not only makes the fuser much simpler to reason about,
making it siginifcantly easier to implement features like SumToSize
fusion, to improve performance of derivative graphs.
@t-vi

This comment has been minimized.

Copy link
Contributor

t-vi commented Jan 10, 2019

So now I get a fully fused backward for the IoU-loss again.
The PR is based on #15633 (simplify cat fusion), which in turn depends on #15897 (batch norm fusion). The latter was merged and reverted yesterday.

I'm not quite sure how to rearrange the PR commits to facilitate review, if you have a preference, let me know. I'm leaning towards squashing everything (beyond #15633) into a large commit.

I'll edit the start comment to describe the PR a bit more.

@t-vi t-vi changed the title [WIP] Get more fusion after autodiff fix (SumToSize) Get more fusion after autodiff uses SumToSize Jan 10, 2019

Show resolved Hide resolved test/test_jit.py Outdated
Show resolved Hide resolved torch/csrc/jit/autodiff.cpp Outdated
@@ -339,6 +344,12 @@ bool runFusion(const int64_t key, Stack& stack) {
std::vector<at::Tensor> outputs;
launchFusion(*(*maybe_kernel), device, inputs, outputs);

for (size_t i = 0; i < outputs.size(); i++) {

This comment has been minimized.

@ngimel

ngimel Jan 15, 2019

Contributor

I don't know if it affects perf at all, but at this point with runtime shapes information you already know when in fact there are no reductions required, so this at::sum_to can be skipped altogether.

This comment has been minimized.

@t-vi

t-vi Jan 15, 2019

Contributor

Thanks!

I would hope that sum_to is fast if there isn't anything to be done, even if it compares number of tensors * average dimension sizes.

That said, what would be the condition to skip the entire thing?

This comment has been minimized.

@ngimel

ngimel Jan 15, 2019

Contributor

It should be fast, it's just I've been seeing cases lately where CPU cannot get ahead of GPU and all 20-50 us latencies arising even in cheap ops are exposed. I think sufficient condition is if no argument was expanded https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/fuser/executor.cpp#L110

@ngimel

This comment has been minimized.

Copy link
Contributor

ngimel commented Jan 17, 2019

@t-vi, with this PR applied, the generated backward kernel for the following function

def fn1(x,y,z):
    a = x+y+z     
    return torch.sigmoid(a)

separately materializes gradients to x,y and z

      float n0 = t0.data[t0_offset];
      float n1 = t1.data[t1_offset];
      int n5 = 1;
      float n6 = -n1;
      float n7 = n6 + ((float) n5)*((float) n5);
      float n8 = n7 * n1;
      float n9 = n8 * n0;
      t2.data[t2_offset] = n9;
      t3.data[t3_offset] = n9;
      t4.data[t4_offset] = n9;

In current master without this PR fusion group has a single output, that later serves as an input to 3 SumToSize ops.

@t-vi

This comment has been minimized.

Copy link
Contributor

t-vi commented Jan 17, 2019

Thanks @ngimel, for raising this. I'll see how to fix that. My understanding is that we would need to deduplicate it for the kernel, but not for the _grad_sum_to_size application after running the fused kernel. That in turn means we have different outputs for the fused kernel vs. the fusion group.

t-vi added some commits Jan 19, 2019

Deduplicate fusion kernel outputs
After squeezing out the _grad_sum_to_size during kernel compilation,
we may end up with duplicate outputs.
For example example, the backward of
    def fn1(x,y,z):
        a = x+y+z
        return torch.sigmoid(a)
has that.
Thank you @ngimel for noting and providing the example!
@t-vi

This comment has been minimized.

Copy link
Contributor

t-vi commented Jan 19, 2019

So I added the output deduplication in the fuser and a test using @ngimel 's example (thanks again!).

Show resolved Hide resolved test/test_jit.py Outdated
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment