-
Notifications
You must be signed in to change notification settings - Fork 432
Description
Pardon the naive question, trying to understand how to implement a basic tensor subclass.
The problem I'm encountering is that the tensor subclass loses its attributes after calling torch.save on a state dict containing the subclass likely due to the use of swap_tensors.
Minimal repro:
from io import BytesIO
import torch
from torch._ops import OpOverload
from torchao.dtypes.nf4tensor import _INNER_TENSOR_NAMES_FOR_SHARDING, NF4Tensor, to_nf4
aten = torch.ops.aten
class SimpleTensor(torch.Tensor):
@staticmethod
def __new__(cls, inner_tensor, *args, **kwargs):
kwargs["device"] = inner_tensor.device
kwargs["layout"] = inner_tensor.layout
kwargs["dtype"] = inner_tensor.dtype
kwargs["requires_grad"] = inner_tensor.requires_grad
print(f"New SimpleTensor: {kwargs}")
return torch.Tensor._make_wrapper_subclass(cls, inner_tensor.shape, **kwargs) # type: ignore[attr-defined]
def __init__(self, inner_tensor, *args, **kwargs):
self.inner_tensor = inner_tensor
def __repr__(self):
return f"SimpleTensor({self.inner_tensor.shape})"
def __tensor_flatten__(self):
return ["inner_tensor"], None
def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride):
return SimpleTensor(inner_tensors["inner_tensor"])
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs
try:
print(f"calling {func.__name__} with args: {[type(arg) for arg in args]} and kwargs: {kwargs}")
with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)
except Exception as e:
print(f"ERR: subclass doesn't implement {func}")
raise e
def __torch_dispatch__(self, func: OpOverload, types, args=(), kwargs=None):
FUNCS = [aten.detach.default, aten.copy_.default]
print(f"dispatching {func._schema.name} {func._opname} {func._overloadname} with {len(args)} args: {[type(arg) for arg in args]} and kwargs: {kwargs}")
print(f"Func in impelmented funcs: {func in FUNCS}")
if func is torch.ops.aten.detach.default:
print(f"returning {args[0]}")
return args[0]
if func is aten.copy_.default:
print(f"copying {args[0]} to {args[1]}")
original = args[0]
copy_in = args[1]
original.inner_tensor.copy_(copy_in.inner_tensor)
return
return func(*args, **kwargs)
torch.serialization.add_safe_globals([SimpleTensor])
###
dtype = torch.bfloat16
device = "cuda"
batch_size = 2
in_features = 256
out_features = 128
original_tensor = torch.randn(out_features, in_features, dtype=dtype, device=device)
print("\n=================== SimpleTensor =================================\n")
simple_tensor = SimpleTensor(original_tensor)
try:
print(f"Simple tensor: {simple_tensor.inner_tensor.shape}")
except Exception as e:
print(f"Simple tensor error: {e}")
torch.utils.swap_tensors(original_tensor, simple_tensor)
try:
print(f"Swapped tensor: {original_tensor.inner_tensor.shape}")
except Exception as e:
print(f"Swapped tensor error: {e}")
buffer = BytesIO()
torch.save({"weight": original_tensor}, buffer)
buffer.seek(0)
try:
state_dict = torch.load(buffer)
except Exception as e:
print(f"State load error: {e}")
try:
restored_tensor = state_dict['weight']
print(f"Restored tensor: {restored_tensor.inner_tensor.shape}")
except Exception as e:
print(f"Restored tensor error: {e}")
print("\n=================== NF4Tensor =================================\n")
original_tensor = torch.randn(out_features, in_features, dtype=dtype, device=device)
nf4_tensor = to_nf4(original_tensor)
try:
for name in _INNER_TENSOR_NAMES_FOR_SHARDING:
print(f"NF4 tensor {name}: {getattr(nf4_tensor, name).shape}")
except Exception as e:
print(f"NF4 tensor error: {e}")
torch.utils.swap_tensors(original_tensor, nf4_tensor)
try:
for name in _INNER_TENSOR_NAMES_FOR_SHARDING:
print(f"Swapped tensor {name}: {getattr(original_tensor, name).shape}")
except Exception as e:
print(f"Swapped tensor Error: {e}")
buffer = BytesIO()
torch.save({"weight": original_tensor}, buffer)
buffer.seek(0)
state_dict = torch.load(buffer)
try:
restored_tensor = state_dict['weight']
for name in _INNER_TENSOR_NAMES_FOR_SHARDING:
print(f"State dict {name}: {getattr(restored_tensor, name).shape}")
except Exception as e:
print(f"State dict error: {e}")Running the above gives the following prints an error while loading the state dict for SimpleTensor with weights_only=True even after registering SimpleTensor as safe (torch.serialization.add_safe_globals([SimpleTensor])):
State load error: Weights only load failed. In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
Please file an issue with the following so that we can make `weights_only=True` compatible with your use case: WeightsUnpickler error: Unsupported operand 48
If I set weights_only=False, the loaded state dict loads the tensor as a SimpleTensor but gives the following error:
Restored tensor error: 'SimpleTensor' object has no attribute 'inner_tensor'
NF4Tensor, on the other hand, saves and loads just fine.
Are there particular ops that need to be implemented in order to serialize a subclass?
The issue I think is rising from the use of swap_tensors, which I've seen used in torchtune here and mentioned as needed when loading subclasses with multiple wrapped tensors.
This is with torch 2.6.
Thanks!