Skip to content

Commit

Permalink
move TypedStorage handling to assertEqual (#89557)
Browse files Browse the repository at this point in the history
#85303 added a patch to `torch.testing.assert_close` to handle `torch.storage.TypedStorage`'s. This change is not reflected in the docs and is not intended for the public API. This PR removes the patch ones again and moves the behavior to `TestCase.assertEqual` instead. Meaning, `TypedStorage`'s are again not supported by the public API, but the behavior is the same for all internal use cases.

Pull Request resolved: #89557
Approved by: https://github.com/kurtamohler, https://github.com/mruberry
  • Loading branch information
pmeier authored and pytorchmergebot committed Dec 12, 2022
1 parent 17941b1 commit 7bb97c4
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 28 deletions.
2 changes: 2 additions & 0 deletions test/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,8 @@ def test_load_error_msg(self):
with self.assertRaisesRegex(AttributeError, expected_err_msg):
torch.load(resource)

# FIXME: See https://github.com/pytorch/pytorch/issues/90497
@unittest.expectedFailure
def test_save_different_dtype_unallocated(self):
devices = ['cpu']
if torch.cuda.is_available():
Expand Down
28 changes: 1 addition & 27 deletions torch/testing/_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,35 +1076,9 @@ def originate_pairs(
Returns:
(List[Pair]): Originated pairs.
"""
if isinstance(actual, torch.TypedStorage) and isinstance(
expected, torch.TypedStorage
):
actual_len = actual._size()
expected_len = expected._size()
if actual_len != expected_len:
raise ErrorMeta(
AssertionError,
f"The length of the sequences mismatch: {actual_len} != {expected_len}",
id=id,
)

pairs = []
for idx in range(actual_len):
pairs.extend(
originate_pairs(
actual._getitem(idx),
expected._getitem(idx),
pair_types=pair_types,
sequence_types=sequence_types,
mapping_types=mapping_types,
id=(*id, idx),
**options,
)
)
return pairs
# We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop:
# "a" == "a"[0][0]...
elif (
if (
isinstance(actual, sequence_types)
and not isinstance(actual, str)
and isinstance(expected, sequence_types)
Expand Down
24 changes: 23 additions & 1 deletion torch/testing/_internal/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1879,6 +1879,28 @@ def _process_inputs(self, actual, expected, *, id, allow_subclasses):
return actual, expected


class TypedStoragePair(TensorLikePair):
"""Pair for :class:`torch.storage.TypedStorage` inputs."""
def __init__(self, actual, expected, *, rtol_override=0.0, atol_override=0.0, **other_parameters):
self._check_inputs_isinstance(actual, expected, cls=torch.storage.TypedStorage)
super().__init__(actual, expected, **other_parameters)
self.rtol = max(self.rtol, rtol_override)
self.atol = max(self.atol, atol_override)

def _to_tensor(self, typed_storage):
return torch.tensor(
typed_storage._untyped_storage,
dtype={
torch.quint8: torch.uint8,
torch.quint4x2: torch.uint8,
torch.quint2x4: torch.uint8,
torch.qint32: torch.int32,
torch.qint8: torch.int8
}.get(typed_storage.dtype, typed_storage.dtype),
device=typed_storage.device,
)


class UnittestPair(Pair):
"""Fallback ABC pair that handles non-numeric inputs.
Expand Down Expand Up @@ -2864,14 +2886,14 @@ def to_list(input):
RelaxedBooleanPair,
RelaxedNumberPair,
TensorOrArrayPair,
TypedStoragePair,
StringPair,
SetPair,
TypePair,
ObjectPair,
),
sequence_types=(
Sequence,
torch.storage.TypedStorage,
Sequential,
ModuleList,
ParameterList,
Expand Down

0 comments on commit 7bb97c4

Please sign in to comment.