Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 76 additions & 1 deletion test/test_gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
from torch_xla.experimental.custom_kernel import gmm, _make_group_metadata, _histogram, tgmm, gmm_backward
from torch_xla.experimental.custom_kernel import gmm, _make_group_metadata, _histogram, tgmm, gmm_backward, GMM
from torch_xla import runtime as xr
from torch_xla._internal import tpu

Expand Down Expand Up @@ -374,6 +374,81 @@ def test_gmm_backward(self):
# Make sure gmm doesn't fallback.
self.assertNotIn("aten::", met.short_metrics_report())

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you need TPU version check here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol, good question.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

v2 is pretty happy on the tree.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting.. I thought pallas is not supported on v2.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gmm is just mm... The kernel is simple... Other fuses softmax, etc...

def test_gmm_backward_2(self):
self._init_test_cases()
for test_case in self.tests_cases:
num_groups = test_case['num_groups']
k = test_case['k']
m = test_case['m']
n = test_case['n']
lhs_dtype = rhs_dtype = torch.bfloat16

torch.manual_seed(42)
lhs = torch.rand(m, k, dtype=lhs_dtype, requires_grad=True)
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype, requires_grad=True)
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
lhs.retain_grad()
rhs.retain_grad()

ref_out = self._reference_gmm(lhs, rhs, group_sizes)
ref_out.sum().backward()

torch.manual_seed(42)
lhs_xla = torch.rand(m, k, dtype=lhs_dtype, requires_grad=True).to("xla")
rhs_xla = torch.rand(
num_groups, k, n, dtype=rhs_dtype, requires_grad=True).to("xla")
lhs_xla.retain_grad()
rhs_xla.retain_grad()

out = GMM.apply(lhs_xla, rhs_xla, group_sizes.to("xla"))
out.sum().backward()

self.assertTrue(torch.allclose(ref_out, out.cpu()))
self.assertTrue(torch.allclose(lhs.grad, lhs_xla.grad.cpu()))
self.assertTrue(torch.allclose(rhs.grad, rhs_xla.grad.cpu()))

# Make sure gmm doesn't fallback.
self.assertNotIn("aten::", met.short_metrics_report())

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_gmm_backward_3(self):
self._init_test_cases()
for test_case in self.tests_cases:
num_groups = test_case['num_groups']
k = test_case['k']
m = test_case['m']
n = test_case['n']
lhs_dtype = rhs_dtype = torch.bfloat16

torch.manual_seed(42)
lhs = torch.rand(m, k, dtype=lhs_dtype, requires_grad=True)
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype, requires_grad=True)
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
lhs.retain_grad()
rhs.retain_grad()

ref_out = self._reference_gmm(lhs, rhs, group_sizes)
ref_out.sum().backward()

torch.manual_seed(42)
lhs_xla = torch.rand(m, k, dtype=lhs_dtype, requires_grad=True).to("xla")
rhs_xla = torch.rand(
num_groups, k, n, dtype=rhs_dtype, requires_grad=True).to("xla")
lhs_xla.retain_grad()
rhs_xla.retain_grad()

out = GMM.apply(lhs_xla, rhs_xla, group_sizes.to("xla"))
grad_out = torch.ones_like(out)
torch.autograd.backward([out], [grad_out, lhs_xla, rhs_xla])

self.assertTrue(torch.allclose(ref_out, out.cpu()))
self.assertTrue(torch.allclose(lhs.grad, lhs_xla.grad.cpu()))
self.assertTrue(torch.allclose(rhs.grad, rhs_xla.grad.cpu()))

# Make sure gmm doesn't fallback.
self.assertNotIn("aten::", met.short_metrics_report())


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down
16 changes: 16 additions & 0 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,22 @@ def gmm_backward(grad, lhs, rhs, group_sizes, tiling=(512, 512, 512)):
return grad_lhs, grad_rhs


class GMM(torch.autograd.Function):

@staticmethod
def forward(ctx, lhs, rhs, group_sizes, tiling=(512, 512, 512)):
ctx.save_for_backward(lhs, rhs, group_sizes)
ctx.tiling = tiling
return gmm(lhs, rhs, group_sizes, tiling)

@staticmethod
def backward(ctx, grad_output):
lhs, rhs, group_sizes = ctx.saved_tensors
grad_lhs, grad_rhs = gmm_backward(grad_output, lhs, rhs, group_sizes,
ctx.tiling)
return grad_lhs, grad_rhs, None, None


def non_xla_attetion(q, k, v, attention_type):
# This will be called when dynamo use fake tensor to construct the fake output.
# We need to make sure output tensor's shape is correct.
Expand Down