Skip to content

Conversation

bdhirsh
Copy link
Collaborator

@bdhirsh bdhirsh commented May 11, 2021

This PR updates pytorch/xla to use a boxed kernel to implement the CPU fallback, rather than relying on code-generation (see the pytorch wiki on boxing/unboxing)

Summary

In the corresponding Pytorch-side PR, I re-wrote the CPU fallbacks for XLA to use a single boxed kernel, instead of code-generating individual CPU fallback kernels for every operators.

This lets us kill a bunch of codegen logic in PyTorch, and simplifies a lot of the codegen, but it means that pytorch/XLA needs to do slightly more work in order to get access to CPU fallback kernels. I added some convenience helper functions in pytorch core to make the amount of extra work minimal.

Registering the CPU Fallback kernel
In torch_xla/csrc/aten_cpu_fallback.cpp , I added a boxed CPU fallback kernel that has XLA-specific logging, to preserve the same logging behavior that we had before. The kernel is just a function with signature void (const c10::OperatorHandle&, torch::jit::Stack*), that logs some information using XLA macros and then calls into the actual CPU fallback kernel that's implemented in PyTorch core (at::native::cpu_fallback). I then register that fallback to the dispatcher under the XLA key.

It technically would have been possible to have the codegen do all of that for you (the boxed kernel logging + dispatcher registration), but the logging is all XLA-specific and seems more reasonable to write directly in pytorch/XLA.

Calling the CPU Fallback
There are also a bunch of places where pytorch/XLA has to explicitly call the CPU fallback, depending on e.g. the op's input shapes. When each operator's CPU fallback was code-generated, we used to call the fallback like this:

return AtenXlaTypeDefault::add(a, b);

Now, we need to call into the boxed kernel. I added a convenience helper function to make calling into it easier, which looks like this:

return at::native::call_fallback_fn<&xla_cpu_fallback, ATEN_OP2(add, Tensor)>::call(a, b, 1.0);

Where xla_cpu_fallback is boxed fallback kernel with xla-specific logging, and ATEN_OP2 is a new macro that provides the helper function with all of the metadata that it needs to infer some extra information and call the boxed fallback.

Performance
It's also worth calling out perf, since the boxed fallback is a little slower than the unboxed, code-generated CPU fallback kernels. I put a more detailed analysis in the bottom of the description of pytorch/pytorch#58065, but the boxed fallback looks like it's on the order of 10-20% slower. I'm hoping that this isn't a huge concern, since we probably want to write XLA lowerings for operators that are perf critical.

@bdhirsh bdhirsh force-pushed the boxed_cpu_fallback branch from fabfda5 to 4d55c99 Compare May 11, 2021 22:45
@bdhirsh bdhirsh force-pushed the make_codegen_backend_agnostic_minus_fallbacks branch 2 times, most recently from 6e49287 to fc9f54d Compare May 12, 2021 14:13
@bdhirsh bdhirsh force-pushed the boxed_cpu_fallback branch from 4d55c99 to 4dd71a6 Compare May 12, 2021 14:13
@bdhirsh bdhirsh force-pushed the boxed_cpu_fallback branch 2 times, most recently from bd976a1 to 1085c8f Compare May 19, 2021 17:59
@bdhirsh bdhirsh force-pushed the make_codegen_backend_agnostic_minus_fallbacks branch from c8c7681 to d2ba37d Compare May 19, 2021 18:52
@bdhirsh bdhirsh force-pushed the boxed_cpu_fallback branch from 1085c8f to a8487b1 Compare May 19, 2021 18:53
@bdhirsh bdhirsh force-pushed the boxed_cpu_fallback branch from a8487b1 to 5e3536f Compare May 19, 2021 23:35
@bdhirsh bdhirsh force-pushed the make_codegen_backend_agnostic_minus_fallbacks branch from 49791eb to c223040 Compare May 24, 2021 21:54
Base automatically changed from make_codegen_backend_agnostic_minus_fallbacks to master May 26, 2021 20:23
@bdhirsh bdhirsh force-pushed the boxed_cpu_fallback branch from 5e3536f to 0002e0a Compare May 27, 2021 23:30
XLA_FN_TRACK(3);
const auto name = c10::toString(op.operator_name());

// Manually applying the XLA_COUNTER macro.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, after staring at some failing C++ tests for a while, I finally realized that some tests were failing because of my usage of the XLA_COUNTER macro here :(

It's defined in the XLA repo here, and defined a local static counter that's unique to the name that you pass in. Which means that it silently does bad things if you try to call the macro with different operator names from the same piece of source code. If I call xla_cpu_fallback() once with add, then with mul, the counter for add will get incremented twice.

Maybe it's worth a patch to xla to add a different version of the macro? For now, I just hardcoded what the macro does here, with a global mapping of counters for every op that the CPU fallback is called with.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch! XLA_COUNTER is designed for being used with different names, I think it is fine to leave it as it is.

@bdhirsh bdhirsh requested review from JackCaoG and ailzhang June 2, 2021 20:56
Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly LGTM, some minor comments.


namespace torch_xla {

std::unordered_map<std::string, ::xla::metrics::Counter*>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should make this map static since we don't expect code outside of this file to access it. wdyt?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, static sounds good

XLA_FN_TRACK(3);
const auto name = c10::toString(op.operator_name());

// Manually applying the XLA_COUNTER macro.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch! XLA_COUNTER is designed for being used with different names, I think it is fine to leave it as it is.

@JackCaoG
Copy link
Collaborator

JackCaoG commented Jun 3, 2021

BTW, reading the boxing doc you shared, I have a question. Are all pytorch/xla ops boexed kernel?

@bdhirsh
Copy link
Collaborator Author

bdhirsh commented Jun 3, 2021

BTW, reading the boxing doc you shared, I have a question. Are all pytorch/xla ops boexed kernel?

Nope, all of the pytorch/xla kernels (lowerings) are unboxed kernels, since each kernel is specialized for a specific operator. The main advantage of a boxed kernel is that you can write it once and it's supposed to work for all operators. It does that by having a very specific schema: void (const c10::OperatorHandle&, torch::jit::Stack*), where every operator can be represented as an OperatorHandle in the boxed world, and its arguments are represented as IValue objects that are stored on the jit stack.

This CPU fallback is actually one of the first main usages of a boxed fallback kernel, but we do have a couple of others: we have a boxed fallback for batching in-tree, and there are some on-going features being developed that use boxed fallbacks: conjugation of complex tensors, and dispatching to python

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @bdhirsh !

@JackCaoG
Copy link
Collaborator

JackCaoG commented Jun 4, 2021

FYI if https://github.com/pytorch/xla/pull/2936/files#diff-5e65c3c1d847191cb691d1874732e971f09fa1aad7a980a555c3b0504a5b6470R2454 merge first (seems like it might since pytorch pr is ready), you will need to fix the fallback call here.

Copy link
Contributor

@ailzhang ailzhang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Thanks @bdhirsh !

}
}

// Call the actual boxed CPU fallback.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will be nice to note here which device the tensor/ivalues are on here. (I assume it's still xla?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup they should all be XLA (although technically if we had a meta function that allowed mixed device inputs, they could be mixed).

I could tack it onto TF_VLOG(3) << ivalue.toTensor().toString(); if you think that would be useful. is TF_VLOG(3) the right macro to use for general purpose xla logging?

@bdhirsh bdhirsh force-pushed the boxed_cpu_fallback branch from 36b745f to df190b8 Compare June 10, 2021 23:26
@bdhirsh bdhirsh force-pushed the boxed_cpu_fallback branch from df190b8 to 9cfa18e Compare June 25, 2021 13:36
@bdhirsh bdhirsh merged commit cac652c into master Jun 26, 2021
@bdhirsh bdhirsh deleted the boxed_cpu_fallback branch June 26, 2021 02:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants