#### **Tutorial on using the PyTorch's PackedSequence object**

Note - This is borrowed from [HarshTrivedi's github demo](https://github.com/HarshTrivedi/packing-unpacking-pytorch-minimal-tutorial). I've recast it as a Jupyter notebook using a GRU.

We want to run a GRU on a batch of 3 character sequences ["long_str", "tiny", "medium"]. Here are the steps. You would be interested in *ed steps only.

* Construct the vocabulary
* Load indexed data (list of instances, where each instance is list of character indices)
* Create the model
* Pad data **\***
* Sort instances **\***
* Embed the instances **\***
* Call pack_padded_sequence with embeded instances and sequence lengths **\***
* Run the model **\***
* Call unpack_padded_sequences **\***
* Summary of Shape Transformations

In [0]:
seqs =  ["long_str", "tiny", "medium"]

#### **1. Construct the vocabulary**

The vocabulary (a set of tokens) will be the characters in the sequences. We add "\<pad\>" to represent the padding character / token.

In [10]:
# make sure that the index for padding character is 0
vocab = ["<pad>"] + sorted(set([tok for seq in seqs for tok in seq]))

print("vocab:", vocab)

vocab: ['<pad>', '_', 'd', 'e', 'g', 'i', 'l', 'm', 'n', 'o', 'r', 's', 't', 'u', 'y']


#### **2. Convert the sequences to indexed data**

The indexes of each sequence is obtained from the vocabulary, i.e., each sequence becomes a list of character indices

In [26]:
seqs_idx = [ [vocab.index(tok) for tok in seq] for seq in seqs]

print("seqs_idxs:", seqs_idx)
# seqs_idxs => [[6, 9, 8, 4, 1, 11, 12, 10],
#               [12, 5, 8, 14],
#               [7, 3, 2, 5, 13, 7]]

# print('\n'.join(' '.join( map(str, lst)) for lst in seqs_idxs))

seqs_idxs: [[6, 9, 8, 4, 1, 11, 12, 10], [12, 5, 8, 14], [7, 3, 2, 5, 13, 7]]


#### **3. Create the model**

Here we use the GRU model from the PyTorch library.

In [0]:
import torch
import torch.nn as nn

# create the embedding, embedding dimension = 4
embed_dim = 4
embed = nn.Embedding(len(vocab), embed_dim)

# create the gated recurrent unit with embedded size of each input (input_size) equal to the embedding dimension and with the size of the hidden state = 5
input_size = embed_dim
hidden_size = 5
gru = nn.GRU(input_size, hidden_size, batch_first = True)

#### **4. Pad data**

We pad the data. The padding character is "\<pad\>", with index 0. Each sequence is padded up to the maximum length, which is the length of the longest sequence.

In [34]:
from torch.autograd import Variable

# get the lengths of all the sequences
seq_lengths = torch.LongTensor(list(map(len, seqs_idx)))
# seq_lengths = [8, 4, 6]
# seq_lengths_max = 8

seqs_tensor = Variable(torch.zeros( (len(seqs_idx), seq_lengths.max()))).long()
# seqs_tensor => [[0 0 0 0 0 0 0 0]
#                [0 0 0 0 0 0 0 0]
#                [0 0 0 0 0 0 0 0]]

for idx, (seq, seq_len) in enumerate(zip(seqs_idx, seq_lengths)):
  seqs_tensor[idx, :seq_len] = torch.LongTensor(seq)

# seqs_tensor => [[ 6  9  8  4  1 11 12 10]         # long_str
#                [12  5  8 14  0  0  0  0]          # tiny
#                [ 7  3  2  5 13  7  0  0]]         # medium
# seqs_tensor.shape : (batch_size X max_seq_len) = (3 X 8)
print("seqs_tensor:", seqs_tensor)
print("seqs_tensor.shape:", seqs_tensor.shape)

seqs_tensor: tensor([[ 6,  9,  8,  4,  1, 11, 12, 10],
        [12,  5,  8, 14,  0,  0,  0,  0],
        [ 7,  3,  2,  5, 13,  7,  0,  0]])
seqs_tensor.shape: torch.Size([3, 8])


#### **5. Sort instances**

We sort the instances by sequence length in descending order.

In [0]:
seq_lengths, perm_idx = seq_lengths.sort(0, descending = True)

seqs_tensor = seqs_tensor[perm_idx]
# seqs_tensor => [[ 6  9  8  4  1 11 12 10]          # long_str
#                [ 7  3  2  5 13  7  0  0]           # medium
#                [12  5  8 14  0  0  0  0]]          # tiny
# seqs_tensor.shape : (batch_size X max_seq_len) = (3 X 8)

#### **6. Embed the instances**

In [40]:
seqs_tensor_embedded = embed(seqs_tensor)
# seqs_tensor_embedded =>
#                       [[[ 0.1114, -0.1977, -0.0224, -1.0467]     l
#                         [ 0.7710,  0.2153, -1.3473,  0.8217]     o
#                         [-1.8285, -2.1818, -1.5927, -1.7026]     n
#                         [-0.8871,  0.2909, -0.0493, -1.0087]     g
#                         [ 0.9886,  0.8594, -0.2939, -0.0761]     _
#                         [ 0.3097, -1.2243, -0.7324, -0.7734]     s
#                         [ 0.0948, -0.1665, -1.0248, -2.0838]     t
#                         [-1.1538,  0.7745, -0.0513, -0.1554]]    r

#                        [[ 0.16031227 -0.08209462 -0.16297023  0.48121014]     m
#                         [-0.7303265  -0.857339    0.58913064 -1.1068314 ]     e
#                         [ 0.48159844 -1.4886451   0.92639893  0.76906884]     d
#                         [ 0.27616557 -1.224429   -1.342848   -0.7495876 ]     i
#                         [ 0.01795524 -0.59048957 -0.53800726 -0.6611691 ]     u
#                         [ 0.16031227 -0.08209462 -0.16297023  0.48121014]     m
#                         [ 0.2691206  -0.43435425  0.87935454 -2.2269666 ]     <pad>
#                         [ 0.2691206  -0.43435425  0.87935454 -2.2269666 ]]    <pad>

#                        [[ 0.64004815  0.45813003  0.3476034  -0.03451729]     t
#                         [ 0.27616557 -1.224429   -1.342848   -0.7495876 ]     i
#                         [-0.6000342   1.1732816   0.19938554 -1.5976517 ]     n
#                         [-1.284392    0.68294704  1.4064184  -0.42879772]     y
#                         [ 0.2691206  -0.43435425  0.87935454 -2.2269666 ]     <pad>
#                         [ 0.2691206  -0.43435425  0.87935454 -2.2269666 ]     <pad>
#                         [ 0.2691206  -0.43435425  0.87935454 -2.2269666 ]     <pad>
#                         [ 0.2691206  -0.43435425  0.87935454 -2.2269666 ]]]   <pad>
# embedded_seq_tensor.shape : (batch_size X max_seq_len X embedding_dim) = (3 X 8 X 4)
# print("seqs_tensor_embedded:", seqs_tensor_embedded)
print("seqs_tensor_embedded.shape:", seqs_tensor_embedded.shape)


seqs_tensor_embedded.shape: torch.Size([3, 8, 4])


#### **7. Call pack_padded_sequence**

We call pack_padded_sequence() with embeded instances and sequence lengths

In [0]:
from torch.nn.utils.rnn import pack_padded_sequence

packed_input = pack_padded_sequence(seqs_tensor_embedded, seq_lengths, batch_first = True)

# packed_input (PackedSequence is NamedTuple with 2 attributes: data and batch_sizes
#
# packed_input.data =>
#                         [[-0.77578706 -1.8080667  -1.1168439   1.1059115 ]     l
#                          [ 0.01795524 -0.59048957 -0.53800726 -0.6611691 ]     m
#                          [-0.6470658  -0.6266589  -1.7463604   1.2675372 ]     t
#                          [ 0.16031227 -0.08209462 -0.16297023  0.48121014]     o
#                          [ 0.40524676  0.98665565 -0.08621677 -1.1728264 ]     e
#                          [-1.284392    0.68294704  1.4064184  -0.42879772]     i
#                          [ 0.64004815  0.45813003  0.3476034  -0.03451729]     n
#                          [ 0.27616557 -1.224429   -1.342848   -0.7495876 ]     d
#                          [ 0.64004815  0.45813003  0.3476034  -0.03451729]     n
#                          [-0.23622951  2.0361056   0.15435742 -0.04513785]     g
#                          [ 0.16031227 -0.08209462 -0.16297023  0.48121014]     i
#                          [-0.22739866 -0.45782727 -0.6643252   0.25129375]]    y
#                          [-0.7303265  -0.857339    0.58913064 -1.1068314 ]     _
#                          [-1.6334635  -0.6100042   1.7509955  -1.931793  ]     u
#                          [ 0.27616557 -1.224429   -1.342848   -0.7495876 ]     s
#                          [-0.6000342   1.1732816   0.19938554 -1.5976517 ]     m
#                          [-0.6000342   1.1732816   0.19938554 -1.5976517 ]     t
#                          [ 0.48159844 -1.4886451   0.92639893  0.76906884]     r
# packed_input.data.shape : (batch_sum_seq_len X embedding_dim) = (18 X 4)
#
# packed_input.batch_sizes => [ 3,  3,  3,  3,  2,  2,  1,  1]
# visualization :
# l  o  n  g  _  s  t  r   #(long_str)
# m  e  d  i  u  m         #(medium)
# t  i  n  y               #(tiny)
# 3  3  3  3  2  2  1  1   (sum = 18 [batch_sum_seq_len])

# print("packed_input.data:", packed_input.data)
# print("packed_input.batch_sizes:", packed_input.batch_sizes)

#### **8. Run the model**

In [0]:
packed_output, h_t = gru(packed_input)

# packed_output (PackedSequence is NamedTuple with 2 attributes: data and batch_sizes
#
# packed_output.data :
#                          [[-0.00947162  0.07743231  0.20343193  0.29611713  0.07992904]   l
#                           [ 0.08596145  0.09205993  0.20892891  0.21788561  0.00624391]   m
#                           [ 0.16861682  0.07807446  0.18812777 -0.01148055 -0.01091915]   t
#                           [ 0.20994528  0.17932937  0.17748171  0.05025435  0.15717036]   o
#                           [ 0.01364102  0.11060348  0.14704391  0.24145307  0.12879576]   e
#                           [ 0.02610307  0.00965587  0.31438383  0.246354    0.08276576]   i
#                           [ 0.09527554  0.14521319  0.1923058  -0.05925677  0.18633027]   n
#                           [ 0.09872741  0.13324396  0.19446367  0.4307988  -0.05149471]   d
#                           [ 0.03895474  0.08449443  0.18839942  0.02205326  0.23149511]   n
#                           [ 0.14620507  0.07822411  0.2849248  -0.22616537  0.15480657]   g
#                           [ 0.00884941  0.05762182  0.30557525  0.373712    0.08834908]   i
#                           [ 0.12460691  0.21189159  0.04823487  0.06384943  0.28563985]   y
#                           [ 0.01368293  0.15872964  0.03759198 -0.13403234  0.23890573]   _
#                           [ 0.00377969  0.05943518  0.2961751   0.35107893  0.15148178]   u
#                           [ 0.00737647  0.17101538  0.28344846  0.18878219  0.20339936]   s
#                           [ 0.0864429   0.11173367  0.3158251   0.37537992  0.11876849]   m
#                           [ 0.17885767  0.12713005  0.28287745  0.05562563  0.10871304]   t
#                           [ 0.09486895  0.12772645  0.34048414  0.25930756  0.12044918]]  r
# packed_output.data.shape : (batch_sum_seq_len X hidden_dim) = (18 X 5)

# packed_output.batch_sizes => [ 3,  3,  3,  3,  2,  2,  1,  1] (same as packed_input.batch_sizes)
# visualization :
# l  o  n  g  _  s  t  r   #(long_str)
# m  e  d  i  u  m         #(medium)
# t  i  n  y               #(tiny)
# 3  3  3  3  2  2  1  1   (sum = 18 [batch_sum_seq_len])

# print("packed_output.data:", packed_output.data)
# print("packed_output.batch_sizes:", packed_output.batch_sizes)

#### **9. Call unpack_padded_sequences**

We could all just pick the last hidden vector

In [54]:
from torch.nn.utils.rnn import pad_packed_sequence

seqs_op, seqs_op_len = pad_packed_sequence(packed_output, batch_first = True)

# output:
# seqs_op =>
#                          [[[-0.00947162  0.07743231  0.20343193  0.29611713  0.07992904]   l
#                            [ 0.20994528  0.17932937  0.17748171  0.05025435  0.15717036]   o
#                            [ 0.09527554  0.14521319  0.1923058  -0.05925677  0.18633027]   n
#                            [ 0.14620507  0.07822411  0.2849248  -0.22616537  0.15480657]   g
#                            [ 0.01368293  0.15872964  0.03759198 -0.13403234  0.23890573]   _
#                            [ 0.00737647  0.17101538  0.28344846  0.18878219  0.20339936]   s
#                            [ 0.17885767  0.12713005  0.28287745  0.05562563  0.10871304]   t
#                            [ 0.09486895  0.12772645  0.34048414  0.25930756  0.12044918]]  r

#                           [[ 0.08596145  0.09205993  0.20892891  0.21788561  0.00624391]   m
#                            [ 0.01364102  0.11060348  0.14704391  0.24145307  0.12879576]   e
#                            [ 0.09872741  0.13324396  0.19446367  0.4307988  -0.05149471]   d
#                            [ 0.00884941  0.05762182  0.30557525  0.373712    0.08834908]   i
#                            [ 0.00377969  0.05943518  0.2961751   0.35107893  0.15148178]   u
#                            [ 0.0864429   0.11173367  0.3158251   0.37537992  0.11876849]   m
#                            [ 0.          0.          0.          0.          0.        ]   <pad>
#                            [ 0.          0.          0.          0.          0.        ]]  <pad>

#                           [[ 0.16861682  0.07807446  0.18812777 -0.01148055 -0.01091915]   t
#                            [ 0.02610307  0.00965587  0.31438383  0.246354    0.08276576]   i
#                            [ 0.03895474  0.08449443  0.18839942  0.02205326  0.23149511]   n
#                            [ 0.12460691  0.21189159  0.04823487  0.06384943  0.28563985]   y
#                            [ 0.          0.          0.          0.          0.        ]   <pad>
#                            [ 0.          0.          0.          0.          0.        ]   <pad>
#                            [ 0.          0.          0.          0.          0.        ]   <pad>
#                            [ 0.          0.          0.          0.          0.        ]]] <pad>
# seqs_op.shape : ( batch_size X max_seq_len X hidden_dim) = (3 X 8 X 5)
# print("seqs_op:", seqs_op)
# print("seqs_op.shape:", seqs_op.shape)

# Or if you just want the final hidden state?
print("h_t:", h_t[-1])

h_t: tensor([[ 0.2999,  0.3441, -0.2094, -0.1904,  0.0267],
        [-0.2112, -0.0114,  0.0353,  0.3477, -0.2878],
        [ 0.1597,  0.1993,  0.0709, -0.6301, -0.1129]],
       grad_fn=<SelectBackward>)


#### **Summary of shape transformations**

#### (batch_size X max_seq_len X embedding_dim) --> SORT (by length) ---> (batch_size X max_seq_len X embedding_dim)
#### (batch_size X max_seq_len X embedding_dim) --->      PACK     ---> (batch_sum_seq_len X embedding_dim)
#### (batch_sum_seq_len X embedding_dim)        --->      GRU     ---> (batch_sum_seq_len X hidden_dim)
#### (batch_sum_seq_len X hidden_dim)           --->    UNPACK     ---> (batch_size X max_seq_len X hidden_dim)