From 4bfe2c5be72f7a535a3f3b4c6d9ec87984306ca5 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Mon, 29 Sep 2025 15:19:27 -0700 Subject: [PATCH 1/2] Patch same metadata check to be able to accept tensor subclass without tensor metadata --- torchao/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchao/utils.py b/torchao/utils.py index 2a5857460f..76fb203c2e 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -525,6 +525,8 @@ def _(func, types, args, kwargs): ) def _same_metadata(self: TorchAOBaseTensor, src: TorchAOBaseTensor) -> bool: + if not (hasattr(self, "tensor_data_names") and hasattr(src, "tensor_data_names")): + return False _tensor_shape_match = all( getattr(self, t_name).shape == getattr(src, t_name).shape for t_name in self.tensor_data_names From 54de795a49b8e5af30c5d84a4f2ba8c1215d99b9 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Mon, 29 Sep 2025 15:22:24 -0700 Subject: [PATCH 2/2] more patch --- torchao/utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torchao/utils.py b/torchao/utils.py index 76fb203c2e..d75dfd22fc 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -566,11 +566,16 @@ def _same_metadata(self: TorchAOBaseTensor, src: TorchAOBaseTensor) -> bool: def _(func, types, args, kwargs): self = args[0] src = args[1] - if _same_metadata(self, src): + has_self_meta = hasattr(self, "tensor_data_names") + has_src_meta = hasattr(src, "tensor_data_names") + if has_self_meta and has_src_meta and _same_metadata(self, src): self_tensors = self.__tensor_flatten__()[0] for tensor_name in self_tensors: getattr(self, tensor_name).copy_(getattr(src, tensor_name)) return + if not (has_self_meta and has_src_meta): + with torch._C._DisableTorchDispatch(): + return func(*args, **kwargs) raise ValueError( f"Not supported args for copy_ due to metadata mismatch: {args[0], args[1]}" )