diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index 5542a9de58..4081bcf1d1 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -1129,6 +1129,26 @@ def _(*args, **kwargs): return out +@implements_torch_function(torch.Tensor.view_as) +def function_view_as(*args, **kwargs): + """Handle view_as for NF4Tensor. + + When view_as is called (typically by autograd internals), we need to return + a fresh NF4Tensor without autograd metadata to avoid conflicts. + """ + tensor = args[0] + + # Create a new NF4Tensor with detached inner tensors to avoid autograd conflicts + updated_attrs = {} + tensor_attrs, _ = tensor.__tensor_flatten__() + for attr in tensor_attrs: + inner_tensor = getattr(tensor, attr) + # Detach to create a fresh tensor without autograd metadata + updated_attrs[attr] = inner_tensor.detach() + + return NF4Tensor(*construct_nf4_args(tensor, updated_attrs)) + + @torch._dynamo.allow_in_graph def nf4_constructor( tensor_meta: SubclassTensorArgs,