Skip to content

Commit

Permalink
Fix torch.load(..., weights_only=True) for NT (#112516)
Browse files Browse the repository at this point in the history
Found when looking into #112509
Pull Request resolved: #112516
Approved by: https://github.com/soulitzer
  • Loading branch information
jbschlosser authored and pytorchmergebot committed Nov 2, 2023
1 parent 85e9363 commit 51a3838
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 3 deletions.
5 changes: 3 additions & 2 deletions test/test_nestedtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2995,7 +2995,8 @@ def test_split_with_sizes(self, device):

@dtypes(torch.float, torch.double, torch.half)
@parametrize("requires_grad", [False, True])
def test_serialization(self, device, dtype, requires_grad):
@parametrize("weights_only", [False, True])
def test_serialization(self, device, dtype, requires_grad, weights_only):

def compare_metadata(nt1, nt2):
self.assertEqual(nt1._nested_tensor_size(), nt2._nested_tensor_size())
Expand All @@ -3008,7 +3009,7 @@ def compare_metadata(nt1, nt2):
buffer = io.BytesIO()
serialized = torch.save(a, buffer)
buffer.seek(0)
b = torch.load(buffer)
b = torch.load(buffer, weights_only=weights_only)
# should be both conceptually equal and metadata equivalent
self.assertEqual(a, b)
compare_metadata(a, b)
Expand Down
2 changes: 1 addition & 1 deletion torch/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def _reduce_ex_internal(self, proto):
self._nested_tensor_strides(),
self._nested_tensor_storage_offsets(),
)
return (torch._nested_view_from_buffer, args_nested)
return (torch._utils._rebuild_nested_tensor, args_nested)
elif (
self.data_ptr() == 0
and type(self) is not torch.Tensor
Expand Down
4 changes: 4 additions & 0 deletions torch/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,10 @@ def _rebuild_sparse_tensor(layout, data):
raise NotImplementedError(f"rebuilding sparse tensor for layout {layout}")


def _rebuild_nested_tensor(buffer, sizes, strides, storage_offsets):
return torch._nested_view_from_buffer(buffer, sizes, strides, storage_offsets)


def _rebuild_device_tensor_from_numpy(data, dtype, device, requires_grad):
tensor = torch.from_numpy(data).to(dtype=dtype, device=device)
tensor.requires_grad = requires_grad
Expand Down
1 change: 1 addition & 0 deletions torch/_weights_only_unpickler.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def _get_allowed_globals():
torch._utils._rebuild_tensor_v2,
torch._utils._rebuild_sparse_tensor,
torch._utils._rebuild_meta_tensor_no_storage,
torch._utils._rebuild_nested_tensor,
]:
rc[f"torch._utils.{f.__name__}"] = f

Expand Down

0 comments on commit 51a3838

Please sign in to comment.