Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9038,6 +9038,65 @@ def fn(img):
self.assertEqual(msg, "shape torch.Size([8, 8]) batch size 1.00")
self.assertEqual(res, img1 + torch.sin(img1))

def test_sourceless_namedtuple(self):
from collections import namedtuple

CustomDtype = namedtuple("CustomDtype", ["dtype", "higher_dtype"])

class CustomTensor(torch.Tensor):
_data: torch.Tensor
custom_dtype: CustomDtype
__torch_function__ = torch._C._disabled_torch_function_impl
__slots__ = [
"_data",
"custom_dtype",
]

def __new__(
cls,
data: torch.Tensor,
custom_dtype: CustomDtype,
):
self = torch.Tensor._make_wrapper_subclass(
cls,
data.size(),
strides=data.stride(),
storage_offset=data.storage_offset(),
dtype=custom_dtype.dtype,
layout=data.layout,
requires_grad=data.requires_grad,
device=data.device,
)
self._data = data
self.custom_dtype = custom_dtype
return self

def __tensor_flatten__(self):
meta = {
"custom_dtype": self.custom_dtype,
}
return ["_data"], meta

@staticmethod
def __tensor_unflatten__(
inner_tensors: dict, metadata, outer_size, outer_stride
):
return CustomTensor(
inner_tensors["_data"],
metadata["custom_dtype"],
)

@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs={}):
return func(*args, **kwargs)

@torch.compile(backend="eager", fullgraph=True)
def fn(x):
y = CustomTensor(x, CustomDtype(torch.float32, torch.bfloat16))
return y, y.custom_dtype

fn(torch.ones(2, 2, device="cpu"))

# Compiling autograd.Function traces fwd function twice, but the same unbacked symints were not identified
# as the same across the two tracings. This is an unlikely situation in real use cases, so we add another
# `test_validate_outputs_unbacked_by_custom_op` to mitigate it and keep this one as expected failure
Expand Down
6 changes: 6 additions & 0 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3328,6 +3328,12 @@ def create(tx: "InstructionTranslator", value) -> VariableTracker:
)
elif isinstance(value, types.GenericAlias):
return TypingVariable(value)
elif is_namedtuple(value):
output = [
SourcelessBuilder.create(tx, getattr(value, name))
for name in namedtuple_fields(type(value))
]
return NamedTupleVariable(output, tuple_cls=type(value))
unimplemented_v2(
gb_type="Unexpected type in sourceless builder",
context=f"{value_type.__module__}.{value_type.__qualname__}",
Expand Down
Loading