-
Notifications
You must be signed in to change notification settings - Fork 480
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use torch_xla.experimental.compile for all examples #7642
Conversation
@will-cromar I think this is ready for review |
@@ -33,23 +33,27 @@ def __init__(self): | |||
self.model = torchvision.models.resnet50().to(self.device) | |||
self.optimizer = optim.SGD(self.model.parameters(), weight_decay=1e-4) | |||
self.loss_fn = nn.CrossEntropyLoss() | |||
self.compiled_step_fn = torch_xla.experimental.compile(self.step_fn) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still not sure I understand why compile is separate from torch_xla.step
. The main difference that I can see is that you flip eager off then on, which I think you can just add to torch_xla.step
. contextlib.contextmanager
already handles the rest of the plumbing you have there such as wrapping and exception handling.
xla/torch_xla/experimental/eager.py
Lines 41 to 59 in 3078d16
@functools.wraps(func) # Keep function's name, docstring, etc. | |
def wrapper(*args, **kwargs): | |
# compile should only be called with | |
assert torch_xla._XLAC._get_use_eager_mode() == True | |
torch_xla._XLAC._set_use_eager_mode(False) | |
# clear the pending graph if any | |
torch_xla.sync() | |
try: | |
# Target Function Execution | |
result = func(*args, **kwargs) | |
# Sync the graph generated by the target function. | |
torch_xla.sync() | |
except Exception as e: | |
# Handle exceptions (if needed) | |
print(f"Error in target function: {e}") | |
raise # Re-raise the exception | |
torch_xla._XLAC._set_use_eager_mode(True) | |
return result |
vs
Lines 57 to 71 in 3078d16
@contextlib.contextmanager | |
def step(): | |
"""Wraps code that should be dispatched to the runtime. | |
Experimental: `xla.step` is still a work in progress. Some code that currently | |
works with `xla.step` but does not follow best practices will become errors in | |
future releases. See https://github.com/pytorch/xla/issues/6751 for context. | |
""" | |
# Clear pending operations | |
xm.mark_step() | |
try: | |
yield | |
finally: | |
xm.mark_step() |
https://docs.python.org/3/library/contextlib.html#contextlib.contextmanager
https://docs.python.org/3/library/contextlib.html#contextlib.ContextDecorator
step
already has a cautionary note in the docstring that we will be changing the public API over time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess I don't know how do you turn a context manger to a wrapper around the function? All I want was that api looks like
torch_xla.experimental.compile()
I think it can share the same implementation as the step.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh ok I guess you meant to use @functools.wraps(func)
on top of the step.. let me try that..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok they are actually different, step
is a context manager that doesn't take fn and arguments as input. It also doesn't return the fn's output back.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok they are actually different, step is a context manager that doesn't take fn and arguments as input. It also doesn't return the fn's output back.
Context managers and decorators are ~interchangable in python thanks to ContextDecorator
:
@torch_xla.experimental.compile
def f(x):
return 2 * x
@torch_xla.step()
def g(x):
return 2 * x
The only semantic difference is you have to instantiate the context manager first (step()
). It appropriately wraps the functions inputs, outputs, docstring, signature, types, etc too. If we want a "no arguments" form of step
, then I can make that work. Even something silly like this works:
stepwithnoargs = torch_xla.step() # not recommended for readability lol
@stepwithnoargs
def g(x):
"""docstring"""
return 2 * x
Python has almost exactly what we want in the standard library already, so let's use it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so many new python hacks haha, let me read into this a bit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand how decorators
and context manager
are interchangeable now, I am still not sure how to implement the api I want new_fn = torch.experimental.compile(func)
with the decorator
, this is a bit mind twisting. Are you going to make step
takes an optional argument and branch from there? I will update step
to handle eager mode
and you can probably refactor from there.
953f95d
to
a007c69
Compare
Here is my plan, @will-cromar let me know what you think
torch_xla.experimental.compile
(and torch_xla.step()) to wrap their step fn which will mark_step the outside region for themPretty much we can stage the effort for this ux migration. For this to work I need to make
torch_xla.step
handles eager mode, I will do that in a follow up pr.