Skip to content

tensor not a FakeTensor under FakeTensorMode and device('meta') #139092

@carmocca

Description

@carmocca

🐛 Describe the bug

import torch
from torch._subclasses.fake_tensor import FakeTensorMode

with FakeTensorMode(), torch.device("meta"):
    a = torch.tensor(3.0)  # not a FakeTensor
    b = torch.full(size=tuple(), fill_value=3.0)  # workaround

print(a)  # tensor(..., device='meta', size=())
print(b)  # FakeTensor(..., device='meta', size=())

You would expect both to become FakeTensors

This issue results in a downstream assertion error

raise AssertionError(
f"Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode "
f"with 'allow_non_fake_inputs'. Found in {render_call(func, args, kwargs)}"
)
in scripts that try to use this tech

Versions

torch==2.6.0.dev20241018+cu118

cc @ezyang @chauhang @penguinwu @eellison @zou3519 @bdhirsh @yf225

Metadata

Metadata

Assignees

No one assigned

    Labels

    actionablemodule: fakeTensormodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions