Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] torch_xla2 dynamo integration #7255

Open
qihqi opened this issue Jun 12, 2024 · 7 comments
Open

[RFC] torch_xla2 dynamo integration #7255

qihqi opened this issue Jun 12, 2024 · 7 comments
Assignees

Comments

@qihqi
Copy link
Collaborator

qihqi commented Jun 12, 2024

Dynamo backend for torchxla2

Goal

Have a dynamo backend backend by torch_xla2.

The users should be able to do the following:

m = model ...
m_compiled = torch.compile(m, backend='torch_xla2_compile')  # backend name TBD
result = m_compiled(*inputs)

The above should run on TPU will low overhead.

Challenge

Usually the challenge of a dynamo backend is the compiler that
transforms a fx graph with torch (or Aten) ops to the compiled executable.
However, in our case, that piece is solved.

For every call_function node; we lookup the corresponding implementation of
said ATen op in a dictionary for it's corresponding implementation in Jax,
and we just call it.

This is illustrated here: https://github.com/pytorch/xla/blob/master/experimental/torch_xla2/torch_xla2/export.py#L23

Now, the challenge is for dynamo to be able to 1. produce the graph; and 2. n
not incur any data copies in this process.

Consider this following pseudocode:

class XLATensor2:
  _data: jax.Array 
  def __torch_dispatch__(...):
      # do stuff with _data, get new data
      return XLATensor2(new_data)

def dynamo_backend(fx, sample):
  compiled = compile fx into graph that manipulate jax.Array.
  def returned_callable(inputs):
    datas = [i._data for i in inputs]
    res = compiled(*datas)
    return TensorSubclass(res)
  return returned_callable

model = torch.compile(model, backend = dynamo_backend)
inputs = a list of TensorSubclass or a list of torch.Tensor?
model(*inputs)

What would be the type of inputs?
If inputs are of type TensorSubclass, then dynamo
will attempt to trace through the __torch_dispatch__ method,
and throws error because it doesn't know what is _data and the
operations on it.

If inputs is of type torch.Tensor, then it works: dynamo
calls the backend, the backend can produce correct result.
But, inputs need to be converted to TensorSubclass first inside of
the backend; which usually means a data copy. This happens everytime
the compiled backend is executed, therefore not desirable.

The Desired behavior

When tracing dynamo treats TensorSubclass as if it is a regular tensor
without dispatch override; and when executing the compiled callable,
TensorSubclass is passed in as-is. We know that dynamo can do this with
some tensor subclass, namely FakeTensor.

Let's list out the possible ways we could accomplish this behavior.

Option 1. Have the jax.Array object hold in C++

Roughly we would have a Tensor subclass in C++, this is very
similar to the LazyTensor subclass that is the current XLATensor.
This tensor can hold it's own states in C++. In our case, that would
be a PyObject* that happens to point to either jnp.ndarray or
jax's Traced<ShapedArray> during jax.jit. We might further result the
XLA dispatch key to route the operators to the jax implementation,
emulating what __torch_dispatch__ does.

This way, eager mode will continue to work, and dynamo would work
because the Python class is still torch.Tensor (not a subclass), and
there are no Python logic in dispatching so dynamo cannot trace through.

Pros:

  • Very clear that this will work.

Cons:

Now need to deal with C++ builds. In particular, torch becomes a source
dependency instead of a pip dependency; meaning, again we need to start
building torch first then build torch_xla2. This might be mitigated if
that subclass can be upstreamed.

Option 2. Modify dynamo to do the desired behavior

We have one instance where a torch.Tensor dispatch subclass
just works with dynamo, without dynamo make a fuss when it traces
__torch_dispatch__. This is FakeTensor. (https://github.com/pytorch/pytorch/pull/100017/files)

The idea is to make dynamo trace as-if the inputs are FakeTensor and
not XLATensor. and only after the creation of fx graph and backend, dynamo
calls the compiled callable with XLATensor.

Pros:

  • Likely pure python changes.

Cons:

  • We also need to design a mechanism to represent tensor subclasses that
    is desirable for dynamo to trace through, and those is not.
  • Likely significant amount of work.

Option 3. Register All the ops as custom_ops

So currently dynamo traces __torch_dispatch__, and we don't like that
because it will find the operations on Jax arrays, and doesn't understand those.

What if we make dynamo able to understand what is inside?
The Black box python functions doc
points the possibility of registering things that we don't want dynamo
to go into as a custom op. So we could, theoretically do the following:

  1. Register the jax impl of an Aten op as a custom op.
    i.e. register jaten.add for aten.add.
  2. For meta kernels, just call the meta kernel of aten.add.
  3. In __torch_dispatch__, we forward the call from aten.add to jaten.add.

When dynamo attempts to go inside of __torch_dispatch__, it will find
jaten.add. Then it will record that in the fx.Graph.

Our backend will see the same ops but in a different namespace (jaten).
That is fine as long as we know how to look up its implementation.

Note: we probably also need to hook up gradients of custom ops via. autograph.Function.

Pros / Cons:
Haven't tried, don't know if it gonna work or not.

Current standing proposal

Current standing proposal is Option 2.

Meeting notes (so far):

2024-05-29

with @ezyang @williamwen42 @wconstab @JackCaoG @shauheen @Chillee @yanboliang

Went over the 3 options. Opinions split between option 1 and 2. People seems to agree that making 1 working is desired and the work done to make it work is considered "cost of integration".
People also discussed a bit on whether interoperability with Jax should be a valid use case:
i.e.

def f(jax_array_1, jax_array_2):
   wraps jax_array_1, jax_array_2 into XLATensor2
   call torch
   return unwraped

which is a valid Jax function; can it be used in with jax.grad or jax.jit.

@Chillee raised a point that if we use jax.grad to get the gradient and train a model; it might yield different behavior if the user have custom backward hooks in their code.

@williamwen42 Suggested to use this option for torchdynamo:

torch._dynamo.config.traceable_tensor_subclasses.add(
    torch_xla2.tensor.XLATensor2)

With this suggestion, dynamo tracing succeeded and called backend with the correct tensor subclass and graph (desired behavior).

However, it raised an error when the backend attempted to construct XLATensor2 for return value.

Details on the script ran: https://gist.github.com/qihqi/aa4fd50e5ef3cb96598433bd0f62817c?fbclid=IwAR3v5GQwYFUmlGxukEfriucav4f-ybMJ4yVA97I4cslQzCg8b7CF8VKIBac

@qihqi qihqi self-assigned this Jun 12, 2024
@ezyang
Copy link
Collaborator

ezyang commented Jun 17, 2024

This is something that I said in the May meeting (I think!) but I want to make sure it's durably recorded here: if for some reason you wanted to do CPU/CUDA computation with XLA, I think it would be substantially more idiomatic and natural for the alternate backend to be able to operate directly on traditional CPU/CUDA tensors.

This presents a UX tension with XLATensor which holds onto a jax.Array object is a plausible point in the design space, but when that jax.Array is CPU/CUDA, it's duplicative with traditional CPU/CUDA tensors... but also not really because if you have a CPU tensor that internally holds a jax.Array, you suddenly get more expressivity because of the interoperability with JAX thing (but not really, e.g., for the point @Chillee raised that is mentioned here.)

So, this is what I'm hoping to see:

  • XLATensor is a Python tensor subclass which wraps a jax.Array. In the limit, it is a full, eager-mode compatible translation layer that translates PyTorch API calls into equivalent JAX API calls, if you have some PyTorch code and you want to jax.jit it, as long as that code works when passed XLATensors instead it should work. Things like jax.grad would not work with backward hooks, but this is simply "as designed" (and will need to be emphasized in user documentation--it's worth noting that the interaction here is pretty similar to the interaction of PyTorch autograd and functorch grad, cc @zou3519).
  • The XLA Dynamo backend will promote plain CPU/CUDA tensors to XLATensors with relatively little runtime overhead (in particular, it shouldn't be necessary to copy the tensors into XLA's workspace)
  • Dynamo can deal with passed in XLATensor ala Option 2. They can be handled specially similarly to FakeTensor.

@zou3519
Copy link
Contributor

zou3519 commented Jun 17, 2024

I'm interested in hearing more about the requirements and plan for interop with JAX transforms.

  • Do you want jax transforms to work over the following function? (this seems reasonable to do even without torch.compile because when doing JAX tracing, XLATensor2 will desugar into JAX ops).
  • Should torch.compile work over the following function?
  • can we mix and match torch.compile and JAX transforms?
def f(jax_array_1, jax_array_2):
   wraps jax_array_1, jax_array_2 into XLATensor2
   call torch
   return unwraped

@williamwen42
Copy link
Member

Relevant issue when trying to trace through XLATensor2 without using traceable_tensor_subclasses: pytorch/pytorch#128160.

Discussed offline that it's not ideal to require to using traceable_tensor_subclasses for this case.

@qihqi
Copy link
Collaborator Author

qihqi commented Jun 17, 2024

Hi @ezyang, totally agree on the 3 points listed. For promoting CUDA tensors to XLATensor seems simple with dlpack. We'll pursue this direction.

@qihqi
Copy link
Collaborator Author

qihqi commented Jun 17, 2024

I'm interested in hearing more about the requirements and plan for interop with JAX transforms.

  • Do you want jax transforms to work over the following function? (this seems reasonable to do even without torch.compile because when doing JAX tracing, XLATensor2 will desugar into JAX ops).
  • Should torch.compile work over the following function?
  • can we mix and match torch.compile and JAX transforms?
def f(jax_array_1, jax_array_2):
   wraps jax_array_1, jax_array_2 into XLATensor2
   call torch
   return unwraped

Yes: jax transforms should work on this function.

can we mix and match torch.compile and JAX transforms?
I never dared to wish for this to be honest. Right now, if I call a jax function from a torch program using call_jax I don't expect dynamo to be able to handle it.

Although, if we somehow make this call_jax into a custom_op (a la https://colab.sandbox.google.com/drive/1xCh5BNHxGnutqGLMHaHwm47cbDL9CB1g) it should just work? With the issue that higher order ops (since call_jax takes callables as input) cannot use this custom op API but need to use the HigherOrderOps which IS traced through by dynamo. Presumably, dynamo needs to trace to figure out the returned dtype and shape. So another approach would be make HigherOrderOps not traced by have the implementer declare returned shape / dtype.

@LakeFeiLiu
Copy link

LakeFeiLiu commented Jul 31, 2024

@qihqi what the different between torch_xla and torch_xla2???

@qihqi
Copy link
Collaborator Author

qihqi commented Aug 27, 2024

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants