-
Notifications
You must be signed in to change notification settings - Fork 427
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Load PjRT Plugin API for Custom PjRT Backend #5046
Comments
@will-cromar Can you take a look? |
We don't support the xla_extensions pybind module in PyTorch/XLA. You have to load the module in C++ during the xla/third_party/xla_client/pjrt_computation_client.cc Lines 109 to 113 in bc28ca5
You can use the XPU PR as a reference to add a new plugin: #4891 I'm planning on updating this logic to allow dynamic plugin loading and remove all of those device type enums/lists, but that may have to wait until after we remove XRT. |
Thanks @will-cromar ! We will try doing this and update you |
Neuron PR for this #5428 |
Neuron PR for import hook to be added: #5429 |
❓ Questions and Help
We are currently trying to use the load_pjrt_plugin() API for loading custom C-API pjrt plugins but the function is not exposed in torch_xla. I can see it defined here in the xla.cc file of the tensorflow xla commit used for torch_xla 2.0: https://github.com/tensorflow/tensorflow/blob/f7759359f8420d3ca7b9fd19493f2a01bd47b4ef/tensorflow/compiler/xla/python/xla.cc#L325 as a part of the xla_extension.so file as a pybind11 module but that xla_extensions.so file isn't exposed in torch_xla.
Is there another way we can used this API in python to load our C-API pjrt plugin with the current torch_xla 2.0 or is there some other API exposed for us to load pjrt plugins?
The text was updated successfully, but these errors were encountered: