Skip to content

Dynamo treats dataclasses as UserDefinedVariable, prevents proxying into graph #133858

@bdhirsh

Description

@bdhirsh

One restriction around authoring subclasses today is that if you want to construct a subclass in-graph:

(1) if you have some constant metadata on your subclass in the form of a NamedTuple, this is allowed. Dynamo can proxy the constructor into the graph, with a NamedTuple as one of the input nodes to the constructor

(2) Using a dataclass instead of a NamedTuple will break. This is because dynamo ends up treating the dataclass as a UserDefinedObject, and dynamo graph breaks when we try to proxy it into the graph.

This originally came from FSDP2 + NF4 repro here. I tried making a smaller repro below:

import torch
import torch.utils._pytree as pytree
from torch.utils._python_dispatch import return_and_correct_aliasing
from dataclasses import dataclass
from collections import namedtuple
from typing import Tuple

@dataclass
class SubclassTensorArgs:
    original_shape: torch.Size
    original_strides: Tuple
    storage_offset: int
    dtype: torch.dtype
    device: torch.device

SubclassTensorArgs2 = namedtuple("SubclassTensorArgs2", [
    "original_shape",
    "original_strides",
    "storage_offset",
    "dtype",
    "device",
])


# A simple tensor subclass that holds two tensors internally, and runs every op on both tensors.
class TwoTensor(torch.Tensor):
    @staticmethod
    def __new__(cls, a, b, meta):
        assert (
            a.device == b.device
            and a.layout == b.layout
            and a.requires_grad == b.requires_grad
            and a.dtype == b.dtype
        )
        # I guess it would be more accurate to represent the shape as torch.cat(a, b).shape
        shape = a.shape
        kwargs = {}
        kwargs["strides"] = a.stride()
        kwargs["storage_offset"] = a.storage_offset()
        kwargs["device"] = a.device
        kwargs["layout"] = a.layout
        kwargs["requires_grad"] = a.requires_grad
        kwargs["dtype"] = a.dtype
        out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)

        assert a.shape == b.shape
        assert a.stride() == b.stride()
        assert a.storage_offset() == b.storage_offset()
        return out

    def __init__(self, a, b, meta):
        self.a = a
        self.b = b
        self.meta = meta

    def __repr__(self):
        a_repr = repr(self.a)
        b_repr = repr(self.b)
        return f"TwoTensor({a_repr}, {b_repr}, {repr(self.meta)})"

    def __tensor_flatten__(self):
        return ["a", "b"], self.meta

    @staticmethod
    def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
        a, b = inner_tensors["a"], inner_tensors["b"]
        return TwoTensor(a, b, meta)

    @classmethod
    def __torch_dispatch__(cls, func, types, args, kwargs):
        if kwargs is None:
            kwargs = {}
        args_a = pytree.tree_map_only(TwoTensor, lambda x: x.a, args)
        args_b = pytree.tree_map_only(TwoTensor, lambda x: x.b, args)

        kwargs_a = pytree.tree_map_only(TwoTensor, lambda x: x.a, kwargs)
        kwargs_b = pytree.tree_map_only(TwoTensor, lambda x: x.b, kwargs)

        out_a = func(*args_a, **kwargs_a)
        out_b = func(*args_b, **kwargs_b)
        out_a_flat, spec = pytree.tree_flatten(out_a)
        out_b_flat = pytree.tree_leaves(out_b)
        # for aten ops that return non-tensors, just assume that
        # our two inner tensors return the same value
        out_flat = [
            TwoTensor(o_a, o_b) if isinstance(o_a, torch.Tensor) else o_a
            for o_a, o_b in zip(out_a_flat, out_b_flat)
        ]
        out = pytree.tree_unflatten(out_flat, spec)
        return out


@torch._dynamo.allow_in_graph
def create_two_tensor(a, b, meta):
    return TwoTensor(a, b, meta)

@torch.compile(backend="aot_eager")
def f_bad(x_a):
    #meta = SubclassTensorArgs(x_a.shape, x_a.stride(), x_a.storage_offset(), x_a.dtype, x_a.device)
    meta = SubclassTensorArgs(x_a.shape, x_a.stride(), x_a.dtype, x_a.device)
    x = create_two_tensor(x_a, x_a.clone(), meta)
    x_a * x_a

@torch.compile(backend="aot_eager")
def f_good(x_a):
    meta = SubclassTensorArgs2(x_a.shape, x_a.stride(), x_a.storage_offset(), x_a.dtype, x_a.device)
    x = create_two_tensor(x_a, x_a.clone(), meta)
    x_a * x_a

x_a = torch.randn(4)
# namedtuple is ok
#out = f_good(x_a)
# dataclass not ok
out = f_bad(x_a)

cc @Chillee @ezyang @zou3519 @albanD @samdow @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions