In [1]:
import torch
import torch.nn as nn

# MaxPool1d

In [2]:
# torch.nn.MaxPool1d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False)

Applies a 1D max pooling over an input signal composed of several input planes.

- Input: $(N, C, L_{in})$ or $(C, L_{in})$, where N is the batch size, C is the number of features or channels, and L is the sequence length.

- Output: $(N, C, L_{out})$ or $(C, L_{out})$, where $L_{out}$ depends on $L_{in}$, padding, dilation, kernel_size and stride.

## 2D input data

In [3]:
number_feature = 4
sequence_length = 6

x  = torch.arange(number_feature * sequence_length, dtype=torch.double).reshape(number_feature, sequence_length)
print(x)

tensor([[ 0.,  1.,  2.,  3.,  4.,  5.],
        [ 6.,  7.,  8.,  9., 10., 11.],
        [12., 13., 14., 15., 16., 17.],
        [18., 19., 20., 21., 22., 23.]], dtype=torch.float64)


In [4]:
# kernel_size=3, stride=1, padding=0, dilation=1
max_pool_1d = nn.MaxPool1d(kernel_size=3, stride=1, padding=0)
output = max_pool_1d(x)
print(output)  # for feature 0: [0,1,2] -> 2, [1,2,3] -> 3, [2,3,4] -> 4, [3,4,5] -> 5

tensor([[ 2.,  3.,  4.,  5.],
        [ 8.,  9., 10., 11.],
        [14., 15., 16., 17.],
        [20., 21., 22., 23.]], dtype=torch.float64)


In [5]:
# kernel_size=3, stride=2, padding=0, dilation=1
max_pool_1d_with_stride = nn.MaxPool1d(kernel_size=3, stride=2, padding=0)
output_with_stride = max_pool_1d_with_stride(x)
print(output_with_stride)  # for feature 0: [0,1,2] -> 2, [2,3,4] -> 4

tensor([[ 2.,  4.],
        [ 8., 10.],
        [14., 16.],
        [20., 22.]], dtype=torch.float64)


In [6]:
# kernel_size=3, stride=1, padding=1, dilation=1
max_pool_1d_with_padding = nn.MaxPool1d(kernel_size=3, stride=1, padding=1)
output_with_padding = max_pool_1d_with_padding(x)
print(output_with_padding)  # for feature 0: [0,0,1] -> 1, [0,1,2] -> 2, [1,2,3] -> 3, [2,3,4] -> 4, [3,4,5] -> 5, [4,5,0] -> 5

tensor([[ 1.,  2.,  3.,  4.,  5.,  5.],
        [ 7.,  8.,  9., 10., 11., 11.],
        [13., 14., 15., 16., 17., 17.],
        [19., 20., 21., 22., 23., 23.]], dtype=torch.float64)


In [7]:
# kernel_size=3, stride=1, padding=0, dilation=2
max_pool_1d_with_dilation = nn.MaxPool1d(kernel_size=3, stride=1, padding=0, dilation=2)
output_with_dilation = max_pool_1d_with_dilation(x)
print(output_with_dilation)  # for feature 0: [0,2,4] -> 4, [1,3,5] -> 5

tensor([[ 4.,  5.],
        [10., 11.],
        [16., 17.],
        [22., 23.]], dtype=torch.float64)


In [8]:
# kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True
max_pool_1d_ceil_mode = nn.MaxPool1d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
output_ceil_mode = max_pool_1d_ceil_mode(x)
print(output_ceil_mode)  # for feature 0: [0,1,2] -> 2, [2,3,4] -> 4, [4,5] -> 5

tensor([[ 2.,  4.,  5.],
        [ 8., 10., 11.],
        [14., 16., 17.],
        [20., 22., 23.]], dtype=torch.float64)




## 3D input data

use nn.MaxPool1d on each 2D tensor.

In [9]:
sample_size = 3

x_3d  = torch.arange(sample_size * number_feature * sequence_length, 
                  dtype=torch.double).reshape(sample_size, number_feature, sequence_length)
print(x_3d)

tensor([[[ 0.,  1.,  2.,  3.,  4.,  5.],
         [ 6.,  7.,  8.,  9., 10., 11.],
         [12., 13., 14., 15., 16., 17.],
         [18., 19., 20., 21., 22., 23.]],

        [[24., 25., 26., 27., 28., 29.],
         [30., 31., 32., 33., 34., 35.],
         [36., 37., 38., 39., 40., 41.],
         [42., 43., 44., 45., 46., 47.]],

        [[48., 49., 50., 51., 52., 53.],
         [54., 55., 56., 57., 58., 59.],
         [60., 61., 62., 63., 64., 65.],
         [66., 67., 68., 69., 70., 71.]]], dtype=torch.float64)


In [10]:
# kernel_size=3, stride=1, padding=0, dilation=1
output_3d = max_pool_1d(x_3d)
print(output_3d)

tensor([[[ 2.,  3.,  4.,  5.],
         [ 8.,  9., 10., 11.],
         [14., 15., 16., 17.],
         [20., 21., 22., 23.]],

        [[26., 27., 28., 29.],
         [32., 33., 34., 35.],
         [38., 39., 40., 41.],
         [44., 45., 46., 47.]],

        [[50., 51., 52., 53.],
         [56., 57., 58., 59.],
         [62., 63., 64., 65.],
         [68., 69., 70., 71.]]], dtype=torch.float64)


In [11]:
# kernel_size=3, stride=2, padding=0, dilation=1
output_with_stride_3d = max_pool_1d_with_stride(x_3d)
print(output_with_stride_3d)

tensor([[[ 2.,  4.],
         [ 8., 10.],
         [14., 16.],
         [20., 22.]],

        [[26., 28.],
         [32., 34.],
         [38., 40.],
         [44., 46.]],

        [[50., 52.],
         [56., 58.],
         [62., 64.],
         [68., 70.]]], dtype=torch.float64)


In [12]:
# kernel_size=3, stride=1, padding=1, dilation=1
output_with_padding_3d = max_pool_1d_with_padding(x_3d)
print(output_with_padding_3d)

tensor([[[ 1.,  2.,  3.,  4.,  5.,  5.],
         [ 7.,  8.,  9., 10., 11., 11.],
         [13., 14., 15., 16., 17., 17.],
         [19., 20., 21., 22., 23., 23.]],

        [[25., 26., 27., 28., 29., 29.],
         [31., 32., 33., 34., 35., 35.],
         [37., 38., 39., 40., 41., 41.],
         [43., 44., 45., 46., 47., 47.]],

        [[49., 50., 51., 52., 53., 53.],
         [55., 56., 57., 58., 59., 59.],
         [61., 62., 63., 64., 65., 65.],
         [67., 68., 69., 70., 71., 71.]]], dtype=torch.float64)


In [13]:
# kernel_size=3, stride=1, padding=0, dilation=2
output_with_dilation_3d = max_pool_1d_with_dilation(x_3d)
print(output_with_dilation_3d)

tensor([[[ 4.,  5.],
         [10., 11.],
         [16., 17.],
         [22., 23.]],

        [[28., 29.],
         [34., 35.],
         [40., 41.],
         [46., 47.]],

        [[52., 53.],
         [58., 59.],
         [64., 65.],
         [70., 71.]]], dtype=torch.float64)


In [14]:
# kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True
output_ceil_mode_3d = max_pool_1d_ceil_mode(x_3d)
print(output_ceil_mode_3d)

tensor([[[ 2.,  4.,  5.],
         [ 8., 10., 11.],
         [14., 16., 17.],
         [20., 22., 23.]],

        [[26., 28., 29.],
         [32., 34., 35.],
         [38., 40., 41.],
         [44., 46., 47.]],

        [[50., 52., 53.],
         [56., 58., 59.],
         [62., 64., 65.],
         [68., 70., 71.]]], dtype=torch.float64)
