From e482c70a3dbeca70cc5164ac28b87c2c6906edf3 Mon Sep 17 00:00:00 2001 From: Jony Karki <25265687+jonykarki@users.noreply.github.com> Date: Tue, 29 Dec 2020 16:42:53 -0800 Subject: [PATCH] added List as an option to the unflattened_size (#49838) Summary: Fixes https://github.com/pytorch/pytorch/issues/49743 Pull Request resolved: https://github.com/pytorch/pytorch/pull/49838 Reviewed By: mruberry Differential Revision: D25727971 Pulled By: ngimel fbshipit-source-id: 60142dae84ef107f0083676a2a78ce6b0472b7e1 --- test/test_nn.py | 31 ++++++++++++++++--------------- torch/nn/modules/flatten.py | 26 ++++++++++++-------------- 2 files changed, 28 insertions(+), 29 deletions(-) diff --git a/test/test_nn.py b/test/test_nn.py index 1d63be6e3075..386ba369dca6 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -9283,18 +9283,19 @@ def test_flatten(self): def test_unflatten(self): tensor_input = torch.randn(2, 50) - # Unflatten Tensor + # Unflatten Tensor (unflattened_size as a tuple of ints and list of ints) - unflatten = nn.Unflatten(dim=1, unflattened_size=(2, 5, 5)) - tensor_output = unflatten(tensor_input) - self.assertEqual(tensor_output.size(), torch.Size([2, 2, 5, 5])) + for us in ((2, 5, 5), [2, 5, 5]): + unflatten = nn.Unflatten(dim=1, unflattened_size=us) + tensor_output = unflatten(tensor_input) + self.assertEqual(tensor_output.size(), torch.Size([2, 2, 5, 5])) # Unflatten NamedTensor unflatten = nn.Unflatten(dim='features', unflattened_size=(('C', 2), ('H', 5), ('W', 5))) named_tensor_input = tensor_input.refine_names('N', 'features') named_tensor_output = unflatten(named_tensor_input) - self.assertEqual(tensor_output.size(), torch.Size([2, 2, 5, 5])) + self.assertEqual(named_tensor_output.size(), torch.Size([2, 2, 5, 5])) def test_unflatten_invalid_arg(self): # Wrong type for unflattened_size (tuple of floats) @@ -9304,6 +9305,13 @@ def test_unflatten_invalid_arg(self): r"unflattened_size must be tuple of ints, but found element of type float at pos 2"): nn.Unflatten(dim=1, unflattened_size=(2, 5, 5.0)) + # Wrong type for unflattened_size (list of lists and list of tuples) + for us in ([['C', 2], ['W', 5], ['H', 5]], [('C', 2), ('W', 5), ('H', 5)]): + with self.assertRaisesRegex( + TypeError, + r"unflattened_size must be a tuple of tuples, but found type list"): + nn.Unflatten(dim='features', unflattened_size=us) + # Wrong type for unflattened_size (tuple of lists) with self.assertRaisesRegex( @@ -9311,19 +9319,12 @@ def test_unflatten_invalid_arg(self): r"unflattened_size must be tuple of tuples, but found element of type list at pos 0"): nn.Unflatten(dim='features', unflattened_size=(['C', 2], ['W', 5], ['H', 5])) - # Wrong type for unflattened_size (list of ints) - - with self.assertRaisesRegex( - TypeError, - r"unflattened_size must be a tuple of ints, but found type list"): - nn.Unflatten(dim=1, unflattened_size=[2, 5, 5]) - - # Wrong type for unflattened_size (list of lists) + # Wrong type for unflattened_size (tuple of dicts) with self.assertRaisesRegex( TypeError, - r"unflattened_size must be a tuple of tuples, but found type list"): - nn.Unflatten(dim='features', unflattened_size=[['C', 2], ['W', 5], ['H', 5]]) + r"unflattened_size must be tuple of tuples, but found element of type dict at pos 0"): + nn.Unflatten(dim='features', unflattened_size=({'C': 2}, {'W': 5}, {'H': 5})) def test_layer_norm_grads_with_create_graph_flag(self): atol = 1e-5 diff --git a/torch/nn/modules/flatten.py b/torch/nn/modules/flatten.py index c06b7a5534f6..dd491ba99620 100644 --- a/torch/nn/modules/flatten.py +++ b/torch/nn/modules/flatten.py @@ -2,7 +2,7 @@ from typing import Tuple, Union from torch import Tensor -from torch import Size +from torch.types import _size class Flatten(Module): @@ -53,8 +53,8 @@ class Unflatten(Module): be either `int` or `str` when `Tensor` or `NamedTensor` is used, respectively. * :attr:`unflattened_size` is the new shape of the unflattened dimension of the tensor and it can be - a `tuple` of ints or `torch.Size` for `Tensor` input or a `NamedShape` (tuple of `(name, size)` tuples) - for `NamedTensor` input. + a `tuple` of ints or a `list` of ints or `torch.Size` for `Tensor` input; a `NamedShape` + (tuple of `(name, size)` tuples) for `NamedTensor` input. Shape: - Input: :math:`(N, *dims)` @@ -62,7 +62,7 @@ class Unflatten(Module): Args: dim (Union[int, str]): Dimension to be unflattened - unflattened_size (Union[torch.Size, NamedShape]): New shape of the unflattened dimension + unflattened_size (Union[torch.Size, Tuple, List, NamedShape]): New shape of the unflattened dimension Examples: >>> input = torch.randn(2, 50) @@ -71,7 +71,7 @@ class Unflatten(Module): >>> nn.Linear(50, 50), >>> nn.Unflatten(1, (2, 5, 5)) >>> ) - >>> output = m(output) + >>> output = m(input) >>> output.size() torch.Size([2, 2, 5, 5]) >>> # With torch.Size @@ -79,15 +79,13 @@ class Unflatten(Module): >>> nn.Linear(50, 50), >>> nn.Unflatten(1, torch.Size([2, 5, 5])) >>> ) - >>> output = m(output) + >>> output = m(input) >>> output.size() torch.Size([2, 2, 5, 5]) >>> # With namedshape (tuple of tuples) - >>> m = nn.Sequential( - >>> nn.Linear(50, 50), - >>> nn.Unflatten('features', (('C', 2), ('H', 50), ('W',50))) - >>> ) - >>> output = m(output) + >>> input = torch.randn(2, 50, names=('N', 'features')) + >>> unflatten = nn.Unflatten('features', (('C', 2), ('H', 5), ('W', 5))) + >>> output = unflatten(input) >>> output.size() torch.Size([2, 2, 5, 5]) """ @@ -95,9 +93,9 @@ class Unflatten(Module): __constants__ = ['dim', 'unflattened_size'] dim: Union[int, str] - unflattened_size: Union[Size, NamedShape] + unflattened_size: Union[_size, NamedShape] - def __init__(self, dim: Union[int, str], unflattened_size: Union[Size, NamedShape]) -> None: + def __init__(self, dim: Union[int, str], unflattened_size: Union[_size, NamedShape]) -> None: super(Unflatten, self).__init__() if isinstance(dim, int): @@ -121,7 +119,7 @@ def _require_tuple_tuple(self, input): "but found type {}".format(type(input).__name__)) def _require_tuple_int(self, input): - if (isinstance(input, tuple)): + if (isinstance(input, (tuple, list))): for idx, elem in enumerate(input): if not isinstance(elem, int): raise TypeError("unflattened_size must be tuple of ints, " +