Skip to content

Commit

Permalink
added List as an option to the unflattened_size (#49838)
Browse files Browse the repository at this point in the history
Summary:
Fixes #49743

Pull Request resolved: #49838

Reviewed By: mruberry

Differential Revision: D25727971

Pulled By: ngimel

fbshipit-source-id: 60142dae84ef107f0083676a2a78ce6b0472b7e1
  • Loading branch information
jonykarki authored and facebook-github-bot committed Dec 30, 2020
1 parent 01b57e1 commit e482c70
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 29 deletions.
31 changes: 16 additions & 15 deletions test/test_nn.py
Expand Up @@ -9283,18 +9283,19 @@ def test_flatten(self):
def test_unflatten(self):
tensor_input = torch.randn(2, 50)

# Unflatten Tensor
# Unflatten Tensor (unflattened_size as a tuple of ints and list of ints)

unflatten = nn.Unflatten(dim=1, unflattened_size=(2, 5, 5))
tensor_output = unflatten(tensor_input)
self.assertEqual(tensor_output.size(), torch.Size([2, 2, 5, 5]))
for us in ((2, 5, 5), [2, 5, 5]):
unflatten = nn.Unflatten(dim=1, unflattened_size=us)
tensor_output = unflatten(tensor_input)
self.assertEqual(tensor_output.size(), torch.Size([2, 2, 5, 5]))

# Unflatten NamedTensor

unflatten = nn.Unflatten(dim='features', unflattened_size=(('C', 2), ('H', 5), ('W', 5)))
named_tensor_input = tensor_input.refine_names('N', 'features')
named_tensor_output = unflatten(named_tensor_input)
self.assertEqual(tensor_output.size(), torch.Size([2, 2, 5, 5]))
self.assertEqual(named_tensor_output.size(), torch.Size([2, 2, 5, 5]))

def test_unflatten_invalid_arg(self):
# Wrong type for unflattened_size (tuple of floats)
Expand All @@ -9304,26 +9305,26 @@ def test_unflatten_invalid_arg(self):
r"unflattened_size must be tuple of ints, but found element of type float at pos 2"):
nn.Unflatten(dim=1, unflattened_size=(2, 5, 5.0))

# Wrong type for unflattened_size (list of lists and list of tuples)
for us in ([['C', 2], ['W', 5], ['H', 5]], [('C', 2), ('W', 5), ('H', 5)]):
with self.assertRaisesRegex(
TypeError,
r"unflattened_size must be a tuple of tuples, but found type list"):
nn.Unflatten(dim='features', unflattened_size=us)

# Wrong type for unflattened_size (tuple of lists)

with self.assertRaisesRegex(
TypeError,
r"unflattened_size must be tuple of tuples, but found element of type list at pos 0"):
nn.Unflatten(dim='features', unflattened_size=(['C', 2], ['W', 5], ['H', 5]))

# Wrong type for unflattened_size (list of ints)

with self.assertRaisesRegex(
TypeError,
r"unflattened_size must be a tuple of ints, but found type list"):
nn.Unflatten(dim=1, unflattened_size=[2, 5, 5])

# Wrong type for unflattened_size (list of lists)
# Wrong type for unflattened_size (tuple of dicts)

with self.assertRaisesRegex(
TypeError,
r"unflattened_size must be a tuple of tuples, but found type list"):
nn.Unflatten(dim='features', unflattened_size=[['C', 2], ['W', 5], ['H', 5]])
r"unflattened_size must be tuple of tuples, but found element of type dict at pos 0"):
nn.Unflatten(dim='features', unflattened_size=({'C': 2}, {'W': 5}, {'H': 5}))

def test_layer_norm_grads_with_create_graph_flag(self):
atol = 1e-5
Expand Down
26 changes: 12 additions & 14 deletions torch/nn/modules/flatten.py
Expand Up @@ -2,7 +2,7 @@

from typing import Tuple, Union
from torch import Tensor
from torch import Size
from torch.types import _size


class Flatten(Module):
Expand Down Expand Up @@ -53,16 +53,16 @@ class Unflatten(Module):
be either `int` or `str` when `Tensor` or `NamedTensor` is used, respectively.
* :attr:`unflattened_size` is the new shape of the unflattened dimension of the tensor and it can be
a `tuple` of ints or `torch.Size` for `Tensor` input or a `NamedShape` (tuple of `(name, size)` tuples)
for `NamedTensor` input.
a `tuple` of ints or a `list` of ints or `torch.Size` for `Tensor` input; a `NamedShape`
(tuple of `(name, size)` tuples) for `NamedTensor` input.
Shape:
- Input: :math:`(N, *dims)`
- Output: :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})`
Args:
dim (Union[int, str]): Dimension to be unflattened
unflattened_size (Union[torch.Size, NamedShape]): New shape of the unflattened dimension
unflattened_size (Union[torch.Size, Tuple, List, NamedShape]): New shape of the unflattened dimension
Examples:
>>> input = torch.randn(2, 50)
Expand All @@ -71,33 +71,31 @@ class Unflatten(Module):
>>> nn.Linear(50, 50),
>>> nn.Unflatten(1, (2, 5, 5))
>>> )
>>> output = m(output)
>>> output = m(input)
>>> output.size()
torch.Size([2, 2, 5, 5])
>>> # With torch.Size
>>> m = nn.Sequential(
>>> nn.Linear(50, 50),
>>> nn.Unflatten(1, torch.Size([2, 5, 5]))
>>> )
>>> output = m(output)
>>> output = m(input)
>>> output.size()
torch.Size([2, 2, 5, 5])
>>> # With namedshape (tuple of tuples)
>>> m = nn.Sequential(
>>> nn.Linear(50, 50),
>>> nn.Unflatten('features', (('C', 2), ('H', 50), ('W',50)))
>>> )
>>> output = m(output)
>>> input = torch.randn(2, 50, names=('N', 'features'))
>>> unflatten = nn.Unflatten('features', (('C', 2), ('H', 5), ('W', 5)))
>>> output = unflatten(input)
>>> output.size()
torch.Size([2, 2, 5, 5])
"""
NamedShape = Tuple[Tuple[str, int]]

__constants__ = ['dim', 'unflattened_size']
dim: Union[int, str]
unflattened_size: Union[Size, NamedShape]
unflattened_size: Union[_size, NamedShape]

def __init__(self, dim: Union[int, str], unflattened_size: Union[Size, NamedShape]) -> None:
def __init__(self, dim: Union[int, str], unflattened_size: Union[_size, NamedShape]) -> None:
super(Unflatten, self).__init__()

if isinstance(dim, int):
Expand All @@ -121,7 +119,7 @@ def _require_tuple_tuple(self, input):
"but found type {}".format(type(input).__name__))

def _require_tuple_int(self, input):
if (isinstance(input, tuple)):
if (isinstance(input, (tuple, list))):
for idx, elem in enumerate(input):
if not isinstance(elem, int):
raise TypeError("unflattened_size must be tuple of ints, " +
Expand Down

0 comments on commit e482c70

Please sign in to comment.