-
Notifications
You must be signed in to change notification settings - Fork 565
[SPMD] Patch nn.Linear #5491
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPMD] Patch nn.Linear #5491
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
import math | ||
import numpy as np | ||
import os | ||
import sys | ||
|
||
import torch | ||
from torch import nn | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -122,7 +122,9 @@ def _xla_patched_nn_linear_forward(m, input): | |
return XLAPatchedLinear.apply(input, m.weight, m.bias) | ||
|
||
|
||
def apply_xla_patch_to_nn_linear(module): | ||
def apply_xla_patch_to_nn_linear(module, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will this function always be called? or only when fsdp being used. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's called in the FSDP wrapper. For GSPMD, we don't have a wrapper so we have to call it explicitly. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok it is not enabled by default.. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @alanwaketan did it, using einsum instead of matmul, improve the benchmark training result? I recall that it didn't actually have an impact (or lowered MFU in @jonb377 's result?)? If so, we should make a comment explaining that this may reduce the performance, too. And +1 to not enabling by default... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The lower MFU was from user-mode FSDP, I don't have a head-to-head benchmark on SPMD. @alanwaketan had some data to suggest it improves performance for SPMD. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the solution in our end to stop the reshape on nn.Linear's matmul. |
||
patched_function=_xla_patched_nn_linear_forward | ||
): | ||
""" | ||
Recursively apply a patch to the forward pass of `nn.Linear` layers | ||
to enable using `XLAPatchedLinear.apply` as `torch.nn.functional.linear`, | ||
|
@@ -144,7 +146,7 @@ def _try_patching_forward_method(m, forward_method_name="forward"): | |
if getattr(forward_method, "__func__", None) != torch.nn.Linear.forward: | ||
return | ||
|
||
patched_forward_method = MethodType(_xla_patched_nn_linear_forward, m) | ||
patched_forward_method = MethodType(patched_function, m) | ||
m._nn_linear_forward_original = forward_method | ||
setattr(m, forward_method_name, patched_forward_method) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -521,3 +521,42 @@ def apply(self, t: torch.Tensor): | |
# TODO(yeounoh) use virtual device interface when available. | ||
assert (t.device == xm.xla_device()) | ||
mark_sharding(t, self.mesh, self.partition_spec) | ||
|
||
|
||
class XLAPatchedLinear(torch.autograd.Function): | ||
""" | ||
A patched version of `torch.nn.functional.linear` that uses einsum instead | ||
of torch.matmul which will flatten the tensors to 2D and collide the sharded | ||
dimensions. The torch.matmul default behavior makes it very hard for XLA compiler | ||
to propagate the sharding annotation. | ||
|
||
TODO (alanwaketan): Let's patch it on the dispatcher level. | ||
""" | ||
|
||
@staticmethod | ||
def forward(ctx, input, weight, bias=None): | ||
# bias is an optional argument | ||
ctx.save_for_backward(input, weight, bias) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm.. are we doing explicit gradient checkpointing here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not family with how nn moudle is being written but I mainly ask because the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, it's just a convention in autograd function to save the activations. So not the gradient checkpointing or re-materialization. |
||
with torch.no_grad(): | ||
product = torch.einsum('...n,mn->...m', input, weight) | ||
if bias is None: | ||
return product | ||
return product + bias | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
input, weight, bias = ctx.saved_tensors | ||
grad_input = grad_weight = grad_bias = None | ||
|
||
if ctx.needs_input_grad[0]: | ||
grad_input = torch.einsum('...m,mn->...n', grad_output, weight) | ||
if ctx.needs_input_grad[1]: | ||
grad_weight = torch.einsum('...m,...n->mn', grad_output, input) | ||
if bias is not None and ctx.needs_input_grad[2]: | ||
grad_bias = torch.einsum('...m->m', grad_output) | ||
|
||
return grad_input, grad_weight, grad_bias | ||
|
||
|
||
def xla_patched_nn_linear_forward(m, input): | ||
return XLAPatchedLinear.apply(input, m.weight, m.bias) |
Uh oh!
There was an error while loading. Please reload this page.