In [5]:
# 这是pad操作，sequences也是list。这个比较好理解，就是给list里的tensor都用padding_value来pad成最长的长度，并组合成一个tensor：
import torch
from torch.nn.utils.rnn import pad_sequence
a = torch.ones(3,8)
b = torch.ones(4,8)
c = torch.ones(5,8)
pad_sequence([a,b,c], batch_first=False, padding_value=0)

tensor([[[1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1.]],

        [[1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1.]],

        [[1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1.]],

        [[0., 0., 0., 0., 0., 0., 0., 0.],
         [1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1.]],

        [[0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [1., 1., 1., 1., 1., 1., 1., 1.]]])

In [4]:
# 这是pack操作，输入的sequences是tensor组成的list，要求按长度从大到小排序。官网的例子：
import torch
from torch.nn.utils.rnn import pack_sequence
a = torch.tensor([1,2,3])
b = torch.tensor([4,5])
c = torch.tensor([6])
pack_sequence([a,b,c])

PackedSequence(data=tensor([1, 4, 6, 2, 5, 3]), batch_sizes=tensor([3, 2, 1]), sorted_indices=None, unsorted_indices=None)

#### 其他两个常用函数

In [9]:
import torch as t
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

In [10]:
a = t.tensor([[1,2,3],[6,0,0],[4,5,0]]) #(batch_size, max_length)
lengths = t.tensor([3,1,2])

In [16]:
# 排序
a_lengths, idx = lengths.sort(0, descending=True)
print(a_lengths, idx)

tensor([3, 2, 1]) tensor([0, 2, 1])


In [18]:
_, un_idx = t.sort(idx, dim=0)
a = a[idx]
print(a)

tensor([[1, 2, 3],
        [4, 5, 0],
        [6, 0, 0]])


In [24]:
emb = t.nn.Embedding(20, 2, padding_idx=0)    # 第一行全为0.
lstm = t.nn.LSTM(input_size=2, hidden_size=4, batch_first=True)
a_input = emb(a)
a_input

tensor([[[-0.1765, -1.0014],
         [ 0.7083, -0.0463],
         [-0.2064, -0.6541]],

        [[-2.4080,  0.3991],
         [ 0.5537, -0.2426],
         [ 0.0000,  0.0000]],

        [[ 0.1955, -0.5838],
         [ 0.0000,  0.0000],
         [ 0.0000,  0.0000]]], grad_fn=<EmbeddingBackward0>)

In [25]:
a_packed_input = t.nn.utils.rnn.pack_padded_sequence(input=a_input, lengths=a_lengths, batch_first=True)
a_packed_input

PackedSequence(data=tensor([[-0.1765, -1.0014],
        [-2.4080,  0.3991],
        [ 0.1955, -0.5838],
        [ 0.7083, -0.0463],
        [ 0.5537, -0.2426],
        [-0.2064, -0.6541]], grad_fn=<PackPaddedSequenceBackward0>), batch_sizes=tensor([3, 2, 1]), sorted_indices=None, unsorted_indices=None)

In [26]:
a_packed_input2 = t.nn.utils.rnn.pack_padded_sequence(input=a_input, lengths=a_lengths) # batch_first=false
a_packed_input2

PackedSequence(data=tensor([[-0.1765, -1.0014],
        [ 0.7083, -0.0463],
        [-0.2064, -0.6541],
        [-2.4080,  0.3991],
        [ 0.5537, -0.2426],
        [ 0.1955, -0.5838]], grad_fn=<PackPaddedSequenceBackward0>), batch_sizes=tensor([3, 2, 1]), sorted_indices=None, unsorted_indices=None)

In [27]:
packed_out, _ = lstm(a_packed_input)
packed_out

PackedSequence(data=tensor([[-0.0748,  0.0699,  0.1413, -0.1009],
        [-0.1208,  0.1010,  0.4545,  0.2097],
        [-0.0854,  0.1046,  0.1260, -0.0984],
        [-0.1234,  0.1776,  0.1904, -0.1299],
        [-0.0959,  0.1811,  0.3077,  0.0080],
        [-0.1293,  0.2048,  0.3234, -0.1136]], grad_fn=<CatBackward0>), batch_sizes=tensor([3, 2, 1]), sorted_indices=None, unsorted_indices=None)

In [28]:
out, _ = pad_packed_sequence(packed_out)
out

tensor([[[-0.0748,  0.0699,  0.1413, -0.1009],
         [-0.1208,  0.1010,  0.4545,  0.2097],
         [-0.0854,  0.1046,  0.1260, -0.0984]],

        [[-0.1234,  0.1776,  0.1904, -0.1299],
         [-0.0959,  0.1811,  0.3077,  0.0080],
         [ 0.0000,  0.0000,  0.0000,  0.0000]],

        [[-0.1293,  0.2048,  0.3234, -0.1136],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000]]], grad_fn=<CopySlices>)

In [29]:
out = t.index_select(out, 0, un_idx)
out

tensor([[[-0.0748,  0.0699,  0.1413, -0.1009],
         [-0.1208,  0.1010,  0.4545,  0.2097],
         [-0.0854,  0.1046,  0.1260, -0.0984]],

        [[-0.1293,  0.2048,  0.3234, -0.1136],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000]],

        [[-0.1234,  0.1776,  0.1904, -0.1299],
         [-0.0959,  0.1811,  0.3077,  0.0080],
         [ 0.0000,  0.0000,  0.0000,  0.0000]]],
       grad_fn=<IndexSelectBackward0>)