from math import inf
import torch
from torch import tensor, device
import torch.fx as fx
import torch._dynamo
from torch._dynamo.testing import rand_strided
from torch._dynamo.debug_utils import run_fwd_maybe_bwd
from torch._dynamo.debug_utils import same_two_models
# REPLACEABLE COMMENT FOR TESTING PURPOSES
args = [((2, 50, 256), (12800, 256, 1), torch.float16, 'cuda', False)]
args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args]
from torch.nn import *
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, _stack0 : torch.Tensor):
getitem = _stack0[(slice(None, None, None), -1)]; _stack0 = None
normalize = torch.nn.functional.normalize(getitem, p = 2, dim = 1); getitem = None
return (normalize,)
mod = Repro()
opt_mod = torch._dynamo.optimize("aot_inductor_debug")(mod)
# opt_mod = torch._dynamo.optimize("aot_eager")(mod)
mod.eval()
opt_mod.eval()
class AccuracyError(Exception):
pass
with torch.cuda.amp.autocast(enabled=True):
ref = mod(*args)[0]
res = opt_mod(*args)[0]
print(f"{ref.dtype}, {res.dtype}")
assert ref.dtype == res.dtype, f"{ref.dtype}, {res.dtype}"