Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import math
import numpy as np
import os
import sys

import torch
from torch import nn
Expand Down
120 changes: 120 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@
import torch_xla.core.xla_builder as xb
import torch_xla.core.xla_op_registry as xor
import torch_xla.distributed.data_parallel as dp
from torch_xla.distributed.fsdp.utils import apply_xla_patch_to_nn_linear
import torch_xla.debug.metrics as met
import torch_xla.debug.model_comparator as mc
import torch_xla.distributed.parallel_loader as pl
import torch_xla.experimental.xla_sharding as xs
from torch_xla import runtime as xr
import torch_xla.test.test_utils as xtu
import torch_xla.utils.utils as xu
Expand Down Expand Up @@ -1656,6 +1658,124 @@ def test_conv2d_backward(self):
self.assertTrue(
torch.allclose(conv.weight.grad.cpu(), torch.tensor([[[[2077.0]]]])))

def test_patched_linear_3D(self):
linear_cpu = nn.Linear(2, 4, bias=False)
input_cpu = torch.randn(4, 3, 2, requires_grad=True)
input_cpu.retain_grad()
output_cpu = linear_cpu(input_cpu)

# It looks like nn.Module.to is in-place.
linear = copy.deepcopy(linear_cpu).to('xla')
apply_xla_patch_to_nn_linear(linear, xs.xla_patched_nn_linear_forward)
input = copy.deepcopy(input_cpu).to('xla')
input.retain_grad()
output = linear(input)

# Make sure that we don't have any reshapes in the patched linear.
hlo = torch_xla._XLAC._get_xla_tensors_hlo([output])
self.assertNotIn("reshape", hlo)

# Make sure the forward result is correct.
self.assertTrue(torch.allclose(output.cpu(), output_cpu))

# Now work on the backward.
linear_cpu.weight.retain_grad()
loss_cpu = output_cpu.sum()
loss_cpu.backward()

loss = output.sum()
loss.backward()

self.assertTrue(
torch.allclose(linear.weight.grad.cpu(), linear_cpu.weight.grad))
self.assertTrue(torch.allclose(input.grad.cpu(), input_cpu.grad))

def test_patched_linear_3D_bias(self):
linear_cpu = nn.Linear(2, 4)
input_cpu = torch.randn(4, 3, 2)
output_cpu = linear_cpu(input_cpu)

# It looks like nn.Module.to is in-place.
linear = copy.deepcopy(linear_cpu).to('xla')
apply_xla_patch_to_nn_linear(linear, xs.xla_patched_nn_linear_forward)
input = copy.deepcopy(input_cpu).to('xla')
output = linear(input)

# We will have some reshapes on the bias. So skip the check here.
# Make sure the forward result is correct.
self.assertTrue(torch.allclose(output.cpu(), output_cpu))

# Now work on the backward.
linear_cpu.weight.retain_grad()
loss_cpu = output_cpu.sum()
loss_cpu.backward()

loss = output.sum()
loss.backward()

self.assertTrue(
torch.allclose(linear.bias.grad.cpu(), linear_cpu.bias.grad))

def test_patched_linear_2D_bias(self):
linear_cpu = nn.Linear(2, 4)
input_cpu = torch.randn(4, 2, requires_grad=True)
input_cpu.retain_grad()
output_cpu = linear_cpu(input_cpu)

# It looks like nn.Module.to is in-place.
linear = copy.deepcopy(linear_cpu).to('xla')
apply_xla_patch_to_nn_linear(linear, xs.xla_patched_nn_linear_forward)
input = copy.deepcopy(input_cpu).to('xla')
input.retain_grad()
output = linear(input)

# Make sure the forward result is correct.
self.assertTrue(torch.allclose(output.cpu(), output_cpu))

# Now work on the backward.
linear_cpu.weight.retain_grad()
loss_cpu = output_cpu.sum()
loss_cpu.backward()

loss = output.sum()
loss.backward()

self.assertTrue(
torch.allclose(linear.weight.grad.cpu(), linear_cpu.weight.grad))
self.assertTrue(torch.allclose(input.grad.cpu(), input_cpu.grad))
self.assertTrue(
torch.allclose(linear.bias.grad.cpu(), linear_cpu.bias.grad))

def test_patched_linear_1D_bias(self):
linear_cpu = nn.Linear(2, 4)
input_cpu = torch.randn(2, requires_grad=True)
input_cpu.retain_grad()
output_cpu = linear_cpu(input_cpu)

# It looks like nn.Module.to is in-place.
linear = copy.deepcopy(linear_cpu).to('xla')
apply_xla_patch_to_nn_linear(linear, xs.xla_patched_nn_linear_forward)
input = copy.deepcopy(input_cpu).to('xla')
input.retain_grad()
output = linear(input)

# Make sure the forward result is correct.
self.assertTrue(torch.allclose(output.cpu(), output_cpu))

# Now work on the backward.
linear_cpu.weight.retain_grad()
loss_cpu = output_cpu.sum()
loss_cpu.backward()

loss = output.sum()
loss.backward()

self.assertTrue(
torch.allclose(linear.weight.grad.cpu(), linear_cpu.weight.grad))
self.assertTrue(torch.allclose(input.grad.cpu(), input_cpu.grad))
self.assertTrue(
torch.allclose(linear.bias.grad.cpu(), linear_cpu.bias.grad))


class MNISTComparator(nn.Module):

Expand Down
6 changes: 4 additions & 2 deletions torch_xla/distributed/fsdp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,9 @@ def _xla_patched_nn_linear_forward(m, input):
return XLAPatchedLinear.apply(input, m.weight, m.bias)


def apply_xla_patch_to_nn_linear(module):
def apply_xla_patch_to_nn_linear(module,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this function always be called? or only when fsdp being used.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's called in the FSDP wrapper. For GSPMD, we don't have a wrapper so we have to call it explicitly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok it is not enabled by default..

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alanwaketan did it, using einsum instead of matmul, improve the benchmark training result? I recall that it didn't actually have an impact (or lowered MFU in @jonb377 's result?)? If so, we should make a comment explaining that this may reduce the performance, too. And +1 to not enabling by default...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The lower MFU was from user-mode FSDP, I don't have a head-to-head benchmark on SPMD. @alanwaketan had some data to suggest it improves performance for SPMD.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the solution in our end to stop the reshape on nn.Linear's matmul.

patched_function=_xla_patched_nn_linear_forward
):
"""
Recursively apply a patch to the forward pass of `nn.Linear` layers
to enable using `XLAPatchedLinear.apply` as `torch.nn.functional.linear`,
Expand All @@ -144,7 +146,7 @@ def _try_patching_forward_method(m, forward_method_name="forward"):
if getattr(forward_method, "__func__", None) != torch.nn.Linear.forward:
return

patched_forward_method = MethodType(_xla_patched_nn_linear_forward, m)
patched_forward_method = MethodType(patched_function, m)
m._nn_linear_forward_original = forward_method
setattr(m, forward_method_name, patched_forward_method)

Expand Down
39 changes: 39 additions & 0 deletions torch_xla/experimental/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,3 +521,42 @@ def apply(self, t: torch.Tensor):
# TODO(yeounoh) use virtual device interface when available.
assert (t.device == xm.xla_device())
mark_sharding(t, self.mesh, self.partition_spec)


class XLAPatchedLinear(torch.autograd.Function):
"""
A patched version of `torch.nn.functional.linear` that uses einsum instead
of torch.matmul which will flatten the tensors to 2D and collide the sharded
dimensions. The torch.matmul default behavior makes it very hard for XLA compiler
to propagate the sharding annotation.

TODO (alanwaketan): Let's patch it on the dispatcher level.
"""

@staticmethod
def forward(ctx, input, weight, bias=None):
# bias is an optional argument
ctx.save_for_backward(input, weight, bias)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm.. are we doing explicit gradient checkpointing here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not family with how nn moudle is being written but I mainly ask because the torch.no_grad below.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's just a convention in autograd function to save the activations. So not the gradient checkpointing or re-materialization.

with torch.no_grad():
product = torch.einsum('...n,mn->...m', input, weight)
if bias is None:
return product
return product + bias

@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None

if ctx.needs_input_grad[0]:
grad_input = torch.einsum('...m,mn->...n', grad_output, weight)
if ctx.needs_input_grad[1]:
grad_weight = torch.einsum('...m,...n->mn', grad_output, input)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = torch.einsum('...m->m', grad_output)

return grad_input, grad_weight, grad_bias


def xla_patched_nn_linear_forward(m, input):
return XLAPatchedLinear.apply(input, m.weight, m.bias)