|
28 | 28 | import torch
|
29 | 29 | import torch._dynamo as torchdynamo
|
30 | 30 | import torch.nn as nn
|
| 31 | +import torch.nn.functional as F |
31 | 32 | import torch.utils._pytree as pytree
|
32 | 33 | from functorch import grad, jacrev, make_fx, vjp, vmap
|
33 | 34 | from functorch.compile import (
|
@@ -7199,6 +7200,27 @@ def fn_(x):
|
7199 | 7200 | torch.compile(fn, backend="inductor", fullgraph=True)(x)
|
7200 | 7201 | torch.compile(fn_, backend="inductor", fullgraph=True)(x)
|
7201 | 7202 |
|
| 7203 | + def test_layer_norm(self): |
| 7204 | + def fn(x): |
| 7205 | + return F.layer_norm(x, normalized_shape=(8,)) |
| 7206 | + |
| 7207 | + x = torch.randn(2, 4, 8) |
| 7208 | + eager = fn(x) |
| 7209 | + aot_eager = torch.compile(backend="aot_eager")(fn)(x) |
| 7210 | + self.assertEqual(eager, aot_eager, atol=0, rtol=0) |
| 7211 | + |
| 7212 | + @unittest.expectedFailure |
| 7213 | + @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") |
| 7214 | + def test_rms_norm(self): |
| 7215 | + # Only CUDA rms norm fails to be decomposed |
| 7216 | + def fn(x): |
| 7217 | + return F.rms_norm(x, normalized_shape=(8,)) |
| 7218 | + |
| 7219 | + x = torch.randn(2, 4, 8, device="cuda") |
| 7220 | + eager = fn(x) |
| 7221 | + aot_eager = torch.compile(backend="aot_eager")(fn)(x) |
| 7222 | + self.assertEqual(eager, aot_eager, atol=0, rtol=0) |
| 7223 | + |
7202 | 7224 | def test_subclass_parameters(self):
|
7203 | 7225 | class _M(torch.nn.Module):
|
7204 | 7226 | def __init__(self):
|
|
0 commit comments