-
Notifications
You must be signed in to change notification settings - Fork 26k
Open
Labels
actionablemodule: fakeTensormodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
import torch
from torch._subclasses.fake_tensor import FakeTensorMode
with FakeTensorMode(), torch.device("meta"):
a = torch.tensor(3.0) # not a FakeTensor
b = torch.full(size=tuple(), fill_value=3.0) # workaround
print(a) # tensor(..., device='meta', size=())
print(b) # FakeTensor(..., device='meta', size=())You would expect both to become FakeTensors
This issue results in a downstream assertion error
pytorch/torch/_subclasses/fake_tensor.py
Lines 2230 to 2233 in 01b055a
| raise AssertionError( | |
| f"Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode " | |
| f"with 'allow_non_fake_inputs'. Found in {render_call(func, args, kwargs)}" | |
| ) |
Versions
torch==2.6.0.dev20241018+cu118
cc @ezyang @chauhang @penguinwu @eellison @zou3519 @bdhirsh @yf225
Metadata
Metadata
Assignees
Labels
actionablemodule: fakeTensormodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module