-
Notifications
You must be signed in to change notification settings - Fork 560
Boxed cpu fallback #2945
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Boxed cpu fallback #2945
Conversation
fabfda5
to
4d55c99
Compare
6e49287
to
fc9f54d
Compare
4d55c99
to
4dd71a6
Compare
bd976a1
to
1085c8f
Compare
c8c7681
to
d2ba37d
Compare
1085c8f
to
a8487b1
Compare
a8487b1
to
5e3536f
Compare
49791eb
to
c223040
Compare
5e3536f
to
0002e0a
Compare
XLA_FN_TRACK(3); | ||
const auto name = c10::toString(op.operator_name()); | ||
|
||
// Manually applying the XLA_COUNTER macro. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, after staring at some failing C++ tests for a while, I finally realized that some tests were failing because of my usage of the XLA_COUNTER
macro here :(
It's defined in the XLA repo here, and defined a local static counter that's unique to the name that you pass in. Which means that it silently does bad things if you try to call the macro with different operator names from the same piece of source code. If I call xla_cpu_fallback()
once with add
, then with mul
, the counter for add
will get incremented twice.
Maybe it's worth a patch to xla to add a different version of the macro? For now, I just hardcoded what the macro does here, with a global mapping of counters for every op that the CPU fallback is called with.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice catch! XLA_COUNTER
is designed for being used with different names, I think it is fine to leave it as it is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly LGTM, some minor comments.
torch_xla/csrc/aten_cpu_fallback.cpp
Outdated
|
||
namespace torch_xla { | ||
|
||
std::unordered_map<std::string, ::xla::metrics::Counter*> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should make this map static since we don't expect code outside of this file to access it. wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch, static sounds good
XLA_FN_TRACK(3); | ||
const auto name = c10::toString(op.operator_name()); | ||
|
||
// Manually applying the XLA_COUNTER macro. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice catch! XLA_COUNTER
is designed for being used with different names, I think it is fine to leave it as it is.
BTW, reading the boxing doc you shared, I have a question. Are all pytorch/xla ops boexed kernel? |
Nope, all of the pytorch/xla kernels (lowerings) are unboxed kernels, since each kernel is specialized for a specific operator. The main advantage of a boxed kernel is that you can write it once and it's supposed to work for all operators. It does that by having a very specific schema: This CPU fallback is actually one of the first main usages of a boxed fallback kernel, but we do have a couple of others: we have a boxed fallback for batching in-tree, and there are some on-going features being developed that use boxed fallbacks: conjugation of complex tensors, and dispatching to python |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @bdhirsh !
FYI if https://github.com/pytorch/xla/pull/2936/files#diff-5e65c3c1d847191cb691d1874732e971f09fa1aad7a980a555c3b0504a5b6470R2454 merge first (seems like it might since pytorch pr is ready), you will need to fix the fallback call here. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Thanks @bdhirsh !
} | ||
} | ||
|
||
// Call the actual boxed CPU fallback. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will be nice to note here which device the tensor/ivalues are on here. (I assume it's still xla?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup they should all be XLA (although technically if we had a meta function that allowed mixed device inputs, they could be mixed).
I could tack it onto TF_VLOG(3) << ivalue.toTensor().toString();
if you think that would be useful. is TF_VLOG(3)
the right macro to use for general purpose xla logging?
36b745f
to
df190b8
Compare
…e kernels when possible
…la_type_default.h
df190b8
to
9cfa18e
Compare
This PR updates pytorch/xla to use a boxed kernel to implement the CPU fallback, rather than relying on code-generation (see the pytorch wiki on boxing/unboxing)
Summary
In the corresponding Pytorch-side PR, I re-wrote the CPU fallbacks for XLA to use a single boxed kernel, instead of code-generating individual CPU fallback kernels for every operators.
This lets us kill a bunch of codegen logic in PyTorch, and simplifies a lot of the codegen, but it means that pytorch/XLA needs to do slightly more work in order to get access to CPU fallback kernels. I added some convenience helper functions in pytorch core to make the amount of extra work minimal.
Registering the CPU Fallback kernel
In
torch_xla/csrc/aten_cpu_fallback.cpp
, I added a boxed CPU fallback kernel that has XLA-specific logging, to preserve the same logging behavior that we had before. The kernel is just a function with signaturevoid (const c10::OperatorHandle&, torch::jit::Stack*)
, that logs some information using XLA macros and then calls into the actual CPU fallback kernel that's implemented in PyTorch core (at::native::cpu_fallback
). I then register that fallback to the dispatcher under the XLA key.It technically would have been possible to have the codegen do all of that for you (the boxed kernel logging + dispatcher registration), but the logging is all XLA-specific and seems more reasonable to write directly in pytorch/XLA.
Calling the CPU Fallback
There are also a bunch of places where pytorch/XLA has to explicitly call the CPU fallback, depending on e.g. the op's input shapes. When each operator's CPU fallback was code-generated, we used to call the fallback like this:
Now, we need to call into the boxed kernel. I added a convenience helper function to make calling into it easier, which looks like this:
Where
xla_cpu_fallback
is boxed fallback kernel with xla-specific logging, andATEN_OP2
is a new macro that provides the helper function with all of the metadata that it needs to infer some extra information and call the boxed fallback.Performance
It's also worth calling out perf, since the boxed fallback is a little slower than the unboxed, code-generated CPU fallback kernels. I put a more detailed analysis in the bottom of the description of pytorch/pytorch#58065, but the boxed fallback looks like it's on the order of 10-20% slower. I'm hoping that this isn't a huge concern, since we probably want to write XLA lowerings for operators that are perf critical.