-
Notifications
You must be signed in to change notification settings - Fork 480
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] torch_xla2 dynamo integration #7255
Comments
This is something that I said in the May meeting (I think!) but I want to make sure it's durably recorded here: if for some reason you wanted to do CPU/CUDA computation with XLA, I think it would be substantially more idiomatic and natural for the alternate backend to be able to operate directly on traditional CPU/CUDA tensors. This presents a UX tension with XLATensor which holds onto a jax.Array object is a plausible point in the design space, but when that jax.Array is CPU/CUDA, it's duplicative with traditional CPU/CUDA tensors... but also not really because if you have a CPU tensor that internally holds a jax.Array, you suddenly get more expressivity because of the interoperability with JAX thing (but not really, e.g., for the point @Chillee raised that is mentioned here.) So, this is what I'm hoping to see:
|
I'm interested in hearing more about the requirements and plan for interop with JAX transforms.
def f(jax_array_1, jax_array_2):
wraps jax_array_1, jax_array_2 into XLATensor2
call torch
return unwraped |
Relevant issue when trying to trace through Discussed offline that it's not ideal to require to using |
Hi @ezyang, totally agree on the 3 points listed. For promoting CUDA tensors to XLATensor seems simple with dlpack. We'll pursue this direction. |
Yes: jax transforms should work on this function.
Although, if we somehow make this |
@qihqi what the different between torch_xla and torch_xla2??? |
Dynamo backend for torchxla2
Goal
Have a dynamo backend backend by torch_xla2.
The users should be able to do the following:
The above should run on TPU will low overhead.
Challenge
Usually the challenge of a dynamo backend is the compiler that
transforms a fx graph with torch (or Aten) ops to the compiled executable.
However, in our case, that piece is solved.
For every
call_function
node; we lookup the corresponding implementation ofsaid ATen op in a dictionary for it's corresponding implementation in Jax,
and we just call it.
This is illustrated here: https://github.com/pytorch/xla/blob/master/experimental/torch_xla2/torch_xla2/export.py#L23
Now, the challenge is for dynamo to be able to 1. produce the graph; and 2. n
not incur any data copies in this process.
Consider this following pseudocode:
What would be the type of inputs?
If inputs are of type
TensorSubclass
, then dynamowill attempt to trace through the
__torch_dispatch__
method,and throws error because it doesn't know what is
_data
and theoperations on it.
If
inputs
is of typetorch.Tensor
, then it works: dynamocalls the backend, the backend can produce correct result.
But,
inputs
need to be converted toTensorSubclass
first inside ofthe backend; which usually means a data copy. This happens everytime
the compiled backend is executed, therefore not desirable.
The Desired behavior
When tracing dynamo treats TensorSubclass as if it is a regular tensor
without dispatch override; and when executing the compiled callable,
TensorSubclass is passed in as-is. We know that dynamo can do this with
some tensor subclass, namely
FakeTensor
.Let's list out the possible ways we could accomplish this behavior.
Option 1. Have the jax.Array object hold in C++
Roughly we would have a
Tensor
subclass in C++, this is verysimilar to the
LazyTensor
subclass that is the currentXLATensor
.This tensor can hold it's own states in C++. In our case, that would
be a
PyObject*
that happens to point to eitherjnp.ndarray
orjax's
Traced<ShapedArray>
during jax.jit. We might further result theXLA
dispatch key to route the operators to the jax implementation,emulating what
__torch_dispatch__
does.This way, eager mode will continue to work, and dynamo would work
because the Python class is still
torch.Tensor
(not a subclass), andthere are no Python logic in dispatching so dynamo cannot trace through.
Pros:
Cons:
Now need to deal with C++ builds. In particular,
torch
becomes a sourcedependency instead of a pip dependency; meaning, again we need to start
building torch first then build torch_xla2. This might be mitigated if
that subclass can be upstreamed.
Option 2. Modify dynamo to do the desired behavior
We have one instance where a
torch.Tensor
dispatch subclassjust works with dynamo, without dynamo make a fuss when it traces
__torch_dispatch__
. This isFakeTensor
. (https://github.com/pytorch/pytorch/pull/100017/files)The idea is to make dynamo trace as-if the inputs are
FakeTensor
andnot
XLATensor
. and only after the creation of fx graph and backend, dynamocalls the compiled callable with
XLATensor
.Pros:
Cons:
is desirable for dynamo to trace through, and those is not.
Option 3. Register All the ops as custom_ops
So currently dynamo traces
__torch_dispatch__
, and we don't like thatbecause it will find the operations on Jax arrays, and doesn't understand those.
What if we make dynamo able to understand what is inside?
The Black box python functions doc
points the possibility of registering things that we don't want dynamo
to go into as a custom op. So we could, theoretically do the following:
i.e. register
jaten.add
foraten.add
.aten.add
.__torch_dispatch__
, we forward the call fromaten.add
tojaten.add
.When dynamo attempts to go inside of
__torch_dispatch__
, it will findjaten.add
. Then it will record that in thefx.Graph
.Our backend will see the same ops but in a different namespace (
jaten
).That is fine as long as we know how to look up its implementation.
Note: we probably also need to hook up gradients of custom ops via.
autograph.Function
.Pros / Cons:
Haven't tried, don't know if it gonna work or not.
Current standing proposal
Current standing proposal is Option 2.
Meeting notes (so far):
2024-05-29
with @ezyang @williamwen42 @wconstab @JackCaoG @shauheen @Chillee @yanboliang
Went over the 3 options. Opinions split between option 1 and 2. People seems to agree that making 1 working is desired and the work done to make it work is considered "cost of integration".
People also discussed a bit on whether interoperability with Jax should be a valid use case:
i.e.
which is a valid Jax function; can it be used in with
jax.grad
orjax.jit
.@Chillee raised a point that if we use
jax.grad
to get the gradient and train a model; it might yield different behavior if the user have custom backward hooks in their code.@williamwen42 Suggested to use this option for torchdynamo:
With this suggestion, dynamo tracing succeeded and called
backend
with the correct tensor subclass and graph (desired behavior).However, it raised an error when the backend attempted to construct
XLATensor2
for return value.Details on the script ran: https://gist.github.com/qihqi/aa4fd50e5ef3cb96598433bd0f62817c?fbclid=IwAR3v5GQwYFUmlGxukEfriucav4f-ybMJ4yVA97I4cslQzCg8b7CF8VKIBac
The text was updated successfully, but these errors were encountered: