Skip to content

Commit 39c340e

Browse files
ezyangpytorchmergebot
authored andcommitted
Add failing bitwise equivalence UT for aot_eager on rms_norm (#164280)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: #164280 Approved by: https://github.com/albanD
1 parent cfd46d1 commit 39c340e

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

test/functorch/test_aotdispatch.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import torch
2929
import torch._dynamo as torchdynamo
3030
import torch.nn as nn
31+
import torch.nn.functional as F
3132
import torch.utils._pytree as pytree
3233
from functorch import grad, jacrev, make_fx, vjp, vmap
3334
from functorch.compile import (
@@ -7199,6 +7200,27 @@ def fn_(x):
71997200
torch.compile(fn, backend="inductor", fullgraph=True)(x)
72007201
torch.compile(fn_, backend="inductor", fullgraph=True)(x)
72017202

7203+
def test_layer_norm(self):
7204+
def fn(x):
7205+
return F.layer_norm(x, normalized_shape=(8,))
7206+
7207+
x = torch.randn(2, 4, 8)
7208+
eager = fn(x)
7209+
aot_eager = torch.compile(backend="aot_eager")(fn)(x)
7210+
self.assertEqual(eager, aot_eager, atol=0, rtol=0)
7211+
7212+
@unittest.expectedFailure
7213+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
7214+
def test_rms_norm(self):
7215+
# Only CUDA rms norm fails to be decomposed
7216+
def fn(x):
7217+
return F.rms_norm(x, normalized_shape=(8,))
7218+
7219+
x = torch.randn(2, 4, 8, device="cuda")
7220+
eager = fn(x)
7221+
aot_eager = torch.compile(backend="aot_eager")(fn)(x)
7222+
self.assertEqual(eager, aot_eager, atol=0, rtol=0)
7223+
72027224
def test_subclass_parameters(self):
72037225
class _M(torch.nn.Module):
72047226
def __init__(self):

0 commit comments

Comments
 (0)