Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.
This repository was archived by the owner on Aug 1, 2025. It is now read-only.

[aot_eager] [accuracy] [tts_angular] Dtype mismatch #1889

@anijain2305

Description

@anijain2305

🐛 Describe the bug

Accuracy failure - Dtype mismatch

Minified repro

from math import inf
import torch
from torch import tensor, device
import torch.fx as fx
import torch._dynamo
from torch._dynamo.testing import rand_strided
from torch._dynamo.debug_utils import run_fwd_maybe_bwd
from torch._dynamo.debug_utils import same_two_models

# REPLACEABLE COMMENT FOR TESTING PURPOSES

args = [((2, 50, 256), (12800, 256, 1), torch.float16, 'cuda', False)]
args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args]


from torch.nn import *
class Repro(torch.nn.Module):
    def __init__(self):
        super().__init__()



    def forward(self, _stack0 : torch.Tensor):
        getitem = _stack0[(slice(None, None, None), -1)];  _stack0 = None
        normalize = torch.nn.functional.normalize(getitem, p = 2, dim = 1);  getitem = None
        return (normalize,)



mod = Repro()
opt_mod = torch._dynamo.optimize("aot_inductor_debug")(mod)
# opt_mod = torch._dynamo.optimize("aot_eager")(mod)


mod.eval()
opt_mod.eval()

class AccuracyError(Exception):
    pass

with torch.cuda.amp.autocast(enabled=True):
    ref = mod(*args)[0]
    res = opt_mod(*args)[0]
    print(f"{ref.dtype}, {res.dtype}")
    assert ref.dtype ==  res.dtype, f"{ref.dtype}, {res.dtype}"





Metadata

Metadata

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions