-
Notifications
You must be signed in to change notification settings - Fork 21.3k
/
flatten.py
134 lines (110 loc) · 4.79 KB
/
flatten.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from .module import Module
from typing import Tuple, Union
from torch import Tensor
from torch.types import _size
class Flatten(Module):
r"""
Flattens a contiguous range of dims into a tensor. For use with :class:`~nn.Sequential`.
Shape:
- Input: :math:`(N, *dims)`
- Output: :math:`(N, \prod *dims)` (for the default case).
Args:
start_dim: first dim to flatten (default = 1).
end_dim: last dim to flatten (default = -1).
Examples::
>>> input = torch.randn(32, 1, 5, 5)
>>> m = nn.Sequential(
>>> nn.Conv2d(1, 32, 5, 1, 1),
>>> nn.Flatten()
>>> )
>>> output = m(input)
>>> output.size()
torch.Size([32, 288])
"""
__constants__ = ['start_dim', 'end_dim']
start_dim: int
end_dim: int
def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None:
super(Flatten, self).__init__()
self.start_dim = start_dim
self.end_dim = end_dim
def forward(self, input: Tensor) -> Tensor:
return input.flatten(self.start_dim, self.end_dim)
def extra_repr(self) -> str:
return 'start_dim={}, end_dim={}'.format(
self.start_dim, self.end_dim
)
class Unflatten(Module):
r"""
Unflattens a tensor dim expanding it to a desired shape. For use with :class:`~nn.Sequential`.
* :attr:`dim` specifies the dimension of the input tensor to be unflattened, and it can
be either `int` or `str` when `Tensor` or `NamedTensor` is used, respectively.
* :attr:`unflattened_size` is the new shape of the unflattened dimension of the tensor and it can be
a `tuple` of ints or a `list` of ints or `torch.Size` for `Tensor` input; a `NamedShape`
(tuple of `(name, size)` tuples) for `NamedTensor` input.
Shape:
- Input: :math:`(N, *dims)`
- Output: :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})`
Args:
dim (Union[int, str]): Dimension to be unflattened
unflattened_size (Union[torch.Size, Tuple, List, NamedShape]): New shape of the unflattened dimension
Examples:
>>> input = torch.randn(2, 50)
>>> # With tuple of ints
>>> m = nn.Sequential(
>>> nn.Linear(50, 50),
>>> nn.Unflatten(1, (2, 5, 5))
>>> )
>>> output = m(input)
>>> output.size()
torch.Size([2, 2, 5, 5])
>>> # With torch.Size
>>> m = nn.Sequential(
>>> nn.Linear(50, 50),
>>> nn.Unflatten(1, torch.Size([2, 5, 5]))
>>> )
>>> output = m(input)
>>> output.size()
torch.Size([2, 2, 5, 5])
>>> # With namedshape (tuple of tuples)
>>> input = torch.randn(2, 50, names=('N', 'features'))
>>> unflatten = nn.Unflatten('features', (('C', 2), ('H', 5), ('W', 5)))
>>> output = unflatten(input)
>>> output.size()
torch.Size([2, 2, 5, 5])
"""
NamedShape = Tuple[Tuple[str, int]]
__constants__ = ['dim', 'unflattened_size']
dim: Union[int, str]
unflattened_size: Union[_size, NamedShape]
def __init__(self, dim: Union[int, str], unflattened_size: Union[_size, NamedShape]) -> None:
super(Unflatten, self).__init__()
if isinstance(dim, int):
self._require_tuple_int(unflattened_size)
elif isinstance(dim, str):
self._require_tuple_tuple(unflattened_size)
else:
raise TypeError("invalid argument type for dim parameter")
self.dim = dim
self.unflattened_size = unflattened_size
def _require_tuple_tuple(self, input):
if (isinstance(input, tuple)):
for idx, elem in enumerate(input):
if not isinstance(elem, tuple):
raise TypeError("unflattened_size must be tuple of tuples, " +
"but found element of type {} at pos {}".format(type(elem).__name__, idx))
return
raise TypeError("unflattened_size must be a tuple of tuples, " +
"but found type {}".format(type(input).__name__))
def _require_tuple_int(self, input):
if (isinstance(input, (tuple, list))):
for idx, elem in enumerate(input):
if not isinstance(elem, int):
raise TypeError("unflattened_size must be tuple of ints, " +
"but found element of type {} at pos {}".format(type(elem).__name__, idx))
return
raise TypeError("unflattened_size must be a tuple of ints, but found type {}".format(type(input).__name__))
def forward(self, input: Tensor) -> Tensor:
return input.unflatten(self.dim, self.unflattened_size)
def extra_repr(self) -> str:
return 'dim={}, unflattened_size={}'.format(self.dim, self.unflattened_size)