Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added List as an option to the unflattened_size #49838

Closed
wants to merge 11 commits into from
31 changes: 16 additions & 15 deletions test/test_nn.py
Expand Up @@ -9283,18 +9283,19 @@ def test_flatten(self):
def test_unflatten(self):
tensor_input = torch.randn(2, 50)

# Unflatten Tensor
# Unflatten Tensor (unflattened_size as a tuple of ints and list of ints)

unflatten = nn.Unflatten(dim=1, unflattened_size=(2, 5, 5))
tensor_output = unflatten(tensor_input)
self.assertEqual(tensor_output.size(), torch.Size([2, 2, 5, 5]))
for us in ((2, 5, 5), [2, 5, 5]):
unflatten = nn.Unflatten(dim=1, unflattened_size=us)
tensor_output = unflatten(tensor_input)
self.assertEqual(tensor_output.size(), torch.Size([2, 2, 5, 5]))

# Unflatten NamedTensor

unflatten = nn.Unflatten(dim='features', unflattened_size=(('C', 2), ('H', 5), ('W', 5)))
named_tensor_input = tensor_input.refine_names('N', 'features')
named_tensor_output = unflatten(named_tensor_input)
self.assertEqual(tensor_output.size(), torch.Size([2, 2, 5, 5]))
self.assertEqual(named_tensor_output.size(), torch.Size([2, 2, 5, 5]))
jonykarki marked this conversation as resolved.
Show resolved Hide resolved

def test_unflatten_invalid_arg(self):
# Wrong type for unflattened_size (tuple of floats)
Expand All @@ -9304,26 +9305,26 @@ def test_unflatten_invalid_arg(self):
r"unflattened_size must be tuple of ints, but found element of type float at pos 2"):
nn.Unflatten(dim=1, unflattened_size=(2, 5, 5.0))

# Wrong type for unflattened_size (list of lists and list of tuples)
for us in ([['C', 2], ['W', 5], ['H', 5]], [('C', 2), ('W', 5), ('H', 5)]):
with self.assertRaisesRegex(
TypeError,
r"unflattened_size must be a tuple of tuples, but found type list"):
nn.Unflatten(dim='features', unflattened_size=us)

# Wrong type for unflattened_size (tuple of lists)

with self.assertRaisesRegex(
TypeError,
r"unflattened_size must be tuple of tuples, but found element of type list at pos 0"):
nn.Unflatten(dim='features', unflattened_size=(['C', 2], ['W', 5], ['H', 5]))

# Wrong type for unflattened_size (list of ints)

with self.assertRaisesRegex(
TypeError,
r"unflattened_size must be a tuple of ints, but found type list"):
nn.Unflatten(dim=1, unflattened_size=[2, 5, 5])

# Wrong type for unflattened_size (list of lists)
# Wrong type for unflattened_size (tuple of dicts)

with self.assertRaisesRegex(
TypeError,
r"unflattened_size must be a tuple of tuples, but found type list"):
nn.Unflatten(dim='features', unflattened_size=[['C', 2], ['W', 5], ['H', 5]])
r"unflattened_size must be tuple of tuples, but found element of type dict at pos 0"):
nn.Unflatten(dim='features', unflattened_size=({'C': 2}, {'W': 5}, {'H': 5}))

def test_layer_norm_grads_with_create_graph_flag(self):
atol = 1e-5
Expand Down
26 changes: 12 additions & 14 deletions torch/nn/modules/flatten.py
Expand Up @@ -2,7 +2,7 @@

from typing import Tuple, Union
from torch import Tensor
from torch import Size
from torch.types import _size


class Flatten(Module):
Expand Down Expand Up @@ -53,16 +53,16 @@ class Unflatten(Module):
be either `int` or `str` when `Tensor` or `NamedTensor` is used, respectively.

* :attr:`unflattened_size` is the new shape of the unflattened dimension of the tensor and it can be
a `tuple` of ints or `torch.Size` for `Tensor` input or a `NamedShape` (tuple of `(name, size)` tuples)
for `NamedTensor` input.
a `tuple` of ints or a `list` of ints or `torch.Size` for `Tensor` input; a `NamedShape`
(tuple of `(name, size)` tuples) for `NamedTensor` input.

Shape:
- Input: :math:`(N, *dims)`
- Output: :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})`

Args:
dim (Union[int, str]): Dimension to be unflattened
unflattened_size (Union[torch.Size, NamedShape]): New shape of the unflattened dimension
unflattened_size (Union[torch.Size, Tuple, List, NamedShape]): New shape of the unflattened dimension

Examples:
>>> input = torch.randn(2, 50)
Expand All @@ -71,33 +71,31 @@ class Unflatten(Module):
>>> nn.Linear(50, 50),
>>> nn.Unflatten(1, (2, 5, 5))
>>> )
>>> output = m(output)
>>> output = m(input)
>>> output.size()
torch.Size([2, 2, 5, 5])
>>> # With torch.Size
>>> m = nn.Sequential(
>>> nn.Linear(50, 50),
>>> nn.Unflatten(1, torch.Size([2, 5, 5]))
>>> )
>>> output = m(output)
>>> output = m(input)
>>> output.size()
torch.Size([2, 2, 5, 5])
>>> # With namedshape (tuple of tuples)
>>> m = nn.Sequential(
>>> nn.Linear(50, 50),
>>> nn.Unflatten('features', (('C', 2), ('H', 50), ('W',50)))
>>> )
>>> output = m(output)
>>> input = torch.randn(2, 50, names=('N', 'features'))
>>> unflatten = nn.Unflatten('features', (('C', 2), ('H', 5), ('W', 5)))
>>> output = unflatten(input)
>>> output.size()
torch.Size([2, 2, 5, 5])
"""
NamedShape = Tuple[Tuple[str, int]]

__constants__ = ['dim', 'unflattened_size']
dim: Union[int, str]
unflattened_size: Union[Size, NamedShape]
unflattened_size: Union[_size, NamedShape]

def __init__(self, dim: Union[int, str], unflattened_size: Union[Size, NamedShape]) -> None:
def __init__(self, dim: Union[int, str], unflattened_size: Union[_size, NamedShape]) -> None:
super(Unflatten, self).__init__()

if isinstance(dim, int):
Expand All @@ -121,7 +119,7 @@ def _require_tuple_tuple(self, input):
"but found type {}".format(type(input).__name__))

def _require_tuple_int(self, input):
if (isinstance(input, tuple)):
if (isinstance(input, (tuple, list))):
for idx, elem in enumerate(input):
if not isinstance(elem, int):
raise TypeError("unflattened_size must be tuple of ints, " +
Expand Down