In [40]:
import torch
from torch import nn

In [86]:
X = torch.LongTensor(
    [
        [[ 1.,  2.,  3.,  4.,  5.],
        [ 6.,  8.,  0.,  0.,  0.],
        [ 7.,  7.,  7.,  0.,  0.]]
        ,[[ 1.,  2.,  3.,  4.,  5.],
        [ 6.,  8.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.]]
    ]
)

In [89]:
batch_size, seq_len, _ = X.size()
print(X.size())

torch.Size([2, 3, 5])


In [80]:
word_embedding = nn.Embedding(
    num_embeddings=10,
    embedding_dim=5,
    padding_idx=0
)

In [81]:
lstm = nn.LSTM(
    input_size=5,
    hidden_size=10,
    batch_first=True,
)

In [82]:
hidden_to_tag = nn.Linear(10, 2)

In [94]:
X = X.view(batch_size * seq_len, -1)
X

tensor([[1, 2, 3, 4, 5],
        [6, 8, 0, 0, 0],
        [7, 7, 7, 0, 0],
        [1, 2, 3, 4, 5],
        [6, 8, 0, 0, 0],
        [0, 0, 0, 0, 0]])

In [98]:
X_lengths = (X > 0).sum(-1)
X_lengths, sorted_index = X_lengths.sort(0, descending=True)
X_lengths

tensor([5, 5, 3, 2, 2, 0])

In [175]:
X1 = word_embedding(X[sorted_index][X_lengths > 0])
X1

tensor([[[ 0.8975,  1.1562, -0.4889,  0.2837,  1.2762],
         [ 1.4487, -0.2751,  1.2067, -0.7646,  1.5479],
         [ 0.5451, -0.6426,  1.1354, -0.0422,  0.9274],
         [ 0.3077, -1.6722,  0.8468, -0.6089, -1.0505],
         [-0.1483,  1.9102,  0.0642, -0.0661, -1.3715]],

        [[ 0.8975,  1.1562, -0.4889,  0.2837,  1.2762],
         [ 1.4487, -0.2751,  1.2067, -0.7646,  1.5479],
         [ 0.5451, -0.6426,  1.1354, -0.0422,  0.9274],
         [ 0.3077, -1.6722,  0.8468, -0.6089, -1.0505],
         [-0.1483,  1.9102,  0.0642, -0.0661, -1.3715]],

        [[-2.0011,  1.8204,  0.2832, -0.5354, -1.6312],
         [-2.0011,  1.8204,  0.2832, -0.5354, -1.6312],
         [-2.0011,  1.8204,  0.2832, -0.5354, -1.6312],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],

        [[-1.1873,  0.2643, -1.1364, -1.7089,  0.0416],
         [-0.6830, -0.5823,  0.1573,  0.1522,  1.5413],
         [ 0.0000,  0.0000,  0.0000,  0.00

In [140]:
X_pad = X_lengths > 0

In [176]:
X2 = torch.nn.utils.rnn.pack_padded_sequence(X1, X_lengths[X_lengths > 0], batch_first=True)
X2

PackedSequence(data=tensor([[ 0.8975,  1.1562, -0.4889,  0.2837,  1.2762],
        [ 0.8975,  1.1562, -0.4889,  0.2837,  1.2762],
        [-2.0011,  1.8204,  0.2832, -0.5354, -1.6312],
        [-1.1873,  0.2643, -1.1364, -1.7089,  0.0416],
        [-1.1873,  0.2643, -1.1364, -1.7089,  0.0416],
        [ 1.4487, -0.2751,  1.2067, -0.7646,  1.5479],
        [ 1.4487, -0.2751,  1.2067, -0.7646,  1.5479],
        [-2.0011,  1.8204,  0.2832, -0.5354, -1.6312],
        [-0.6830, -0.5823,  0.1573,  0.1522,  1.5413],
        [-0.6830, -0.5823,  0.1573,  0.1522,  1.5413],
        [ 0.5451, -0.6426,  1.1354, -0.0422,  0.9274],
        [ 0.5451, -0.6426,  1.1354, -0.0422,  0.9274],
        [-2.0011,  1.8204,  0.2832, -0.5354, -1.6312],
        [ 0.3077, -1.6722,  0.8468, -0.6089, -1.0505],
        [ 0.3077, -1.6722,  0.8468, -0.6089, -1.0505],
        [-0.1483,  1.9102,  0.0642, -0.0661, -1.3715],
        [-0.1483,  1.9102,  0.0642, -0.0661, -1.3715]],
       grad_fn=<PackPaddedSequenceBackward>)

In [177]:
X3, _ = lstm(X2)
X3

PackedSequence(data=tensor([[-0.1297,  0.0264,  0.1290,  0.0980, -0.0395, -0.0021, -0.2415,  0.0254,
          0.0157,  0.0020],
        [-0.1297,  0.0264,  0.1290,  0.0980, -0.0395, -0.0021, -0.2415,  0.0254,
          0.0157,  0.0020],
        [-0.0828, -0.2182,  0.1457, -0.0035,  0.2346,  0.1295,  0.0822,  0.0652,
          0.1894, -0.2764],
        [-0.0393, -0.0736,  0.0674, -0.0145,  0.1641,  0.1616, -0.0362,  0.0855,
         -0.0150, -0.0988],
        [-0.0393, -0.0736,  0.0674, -0.0145,  0.1641,  0.1616, -0.0362,  0.0855,
         -0.0150, -0.0988],
        [-0.2667,  0.1126,  0.0945,  0.0630, -0.0982, -0.1466, -0.2590,  0.1123,
         -0.0125,  0.0628],
        [-0.2667,  0.1126,  0.0945,  0.0630, -0.0982, -0.1466, -0.2590,  0.1123,
         -0.0125,  0.0628],
        [-0.1101, -0.3333,  0.2437,  0.0088,  0.2958,  0.1664,  0.1547,  0.0564,
          0.2440, -0.3665],
        [-0.1667,  0.0448,  0.0783,  0.0559,  0.0524, -0.0742, -0.2005, -0.0718,
          0.1102, -0.0366],

In [183]:
X4, _ = torch.nn.utils.rnn.pad_packed_sequence(X3, batch_first=True)
X4

tensor([[[-0.1297,  0.0264,  0.1290,  0.0980, -0.0395, -0.0021, -0.2415,
           0.0254,  0.0157,  0.0020],
         [-0.2667,  0.1126,  0.0945,  0.0630, -0.0982, -0.1466, -0.2590,
           0.1123, -0.0125,  0.0628],
         [-0.2975,  0.1455,  0.1336,  0.0276, -0.1319, -0.2307, -0.2275,
           0.0975,  0.0986,  0.0998],
         [-0.1589,  0.0867,  0.1411, -0.1722, -0.2041, -0.2370, -0.0277,
           0.1874,  0.1327,  0.1257],
         [-0.1442, -0.1201,  0.3393, -0.0096,  0.0430, -0.0471, -0.0848,
           0.2032,  0.1596, -0.1078]],

        [[-0.1297,  0.0264,  0.1290,  0.0980, -0.0395, -0.0021, -0.2415,
           0.0254,  0.0157,  0.0020],
         [-0.2667,  0.1126,  0.0945,  0.0630, -0.0982, -0.1466, -0.2590,
           0.1123, -0.0125,  0.0628],
         [-0.2975,  0.1455,  0.1336,  0.0276, -0.1319, -0.2307, -0.2275,
           0.0975,  0.0986,  0.0998],
         [-0.1589,  0.0867,  0.1411, -0.1722, -0.2041, -0.2370, -0.0277,
           0.1874,  0.1327,  0.1257],

In [182]:
zero = torch.zeros(((X_lengths == 0).sum(), 5, 10))

In [184]:
X4 = torch.cat((X4, zero), 0)

In [216]:
X5 = torch.zeros(X4.size())
X5[sorted_index] = X4
X5

tensor([[[-0.1297,  0.0264,  0.1290,  0.0980, -0.0395, -0.0021, -0.2415,
           0.0254,  0.0157,  0.0020],
         [-0.2667,  0.1126,  0.0945,  0.0630, -0.0982, -0.1466, -0.2590,
           0.1123, -0.0125,  0.0628],
         [-0.2975,  0.1455,  0.1336,  0.0276, -0.1319, -0.2307, -0.2275,
           0.0975,  0.0986,  0.0998],
         [-0.1589,  0.0867,  0.1411, -0.1722, -0.2041, -0.2370, -0.0277,
           0.1874,  0.1327,  0.1257],
         [-0.1442, -0.1201,  0.3393, -0.0096,  0.0430, -0.0471, -0.0848,
           0.2032,  0.1596, -0.1078]],

        [[-0.0393, -0.0736,  0.0674, -0.0145,  0.1641,  0.1616, -0.0362,
           0.0855, -0.0150, -0.0988],
         [-0.1667,  0.0448,  0.0783,  0.0559,  0.0524, -0.0742, -0.2005,
          -0.0718,  0.1102, -0.0366],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000],

In [186]:
X5 = X5.view(-1, X5.shape[2])
X5

tensor([[-0.1297,  0.0264,  0.1290,  0.0980, -0.0395, -0.0021, -0.2415,  0.0254,
          0.0157,  0.0020],
        [-0.2667,  0.1126,  0.0945,  0.0630, -0.0982, -0.1466, -0.2590,  0.1123,
         -0.0125,  0.0628],
        [-0.2975,  0.1455,  0.1336,  0.0276, -0.1319, -0.2307, -0.2275,  0.0975,
          0.0986,  0.0998],
        [-0.1589,  0.0867,  0.1411, -0.1722, -0.2041, -0.2370, -0.0277,  0.1874,
          0.1327,  0.1257],
        [-0.1442, -0.1201,  0.3393, -0.0096,  0.0430, -0.0471, -0.0848,  0.2032,
          0.1596, -0.1078],
        [-0.0393, -0.0736,  0.0674, -0.0145,  0.1641,  0.1616, -0.0362,  0.0855,
         -0.0150, -0.0988],
        [-0.1667,  0.0448,  0.0783,  0.0559,  0.0524, -0.0742, -0.2005, -0.0718,
          0.1102, -0.0366],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  

In [201]:
hidden_to_tag(X5).view(batch_size, 3, 5, -1) 

tensor([[[[ 0.3575,  0.2728],
          [ 0.3814,  0.3348],
          [ 0.4167,  0.3077],
          [ 0.3565,  0.1878],
          [ 0.3819,  0.0784]],

         [[ 0.2913,  0.1325],
          [ 0.3766,  0.2446],
          [ 0.2939,  0.2013],
          [ 0.2939,  0.2013],
          [ 0.2939,  0.2013]],

         [[ 0.3074,  0.0279],
          [ 0.3041, -0.0462],
          [ 0.2984, -0.0824],
          [ 0.2939,  0.2013],
          [ 0.2939,  0.2013]]],


        [[[ 0.3575,  0.2728],
          [ 0.3814,  0.3348],
          [ 0.4167,  0.3077],
          [ 0.3565,  0.1878],
          [ 0.3819,  0.0784]],

         [[ 0.2913,  0.1325],
          [ 0.3766,  0.2446],
          [ 0.2939,  0.2013],
          [ 0.2939,  0.2013],
          [ 0.2939,  0.2013]],

         [[ 0.2939,  0.2013],
          [ 0.2939,  0.2013],
          [ 0.2939,  0.2013],
          [ 0.2939,  0.2013],
          [ 0.2939,  0.2013]]]], grad_fn=<ViewBackward>)

In [228]:
def _cat_lstm_last(output, length):
    return torch.cat((output[length - 1, 5:], output[0, :5]))

In [217]:
X5.shape

torch.Size([6, 5, 10])

In [219]:
embs = torch.Tensor(torch.zeros((batch_size, seq_len, X5.size(2))))
embs

tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]])

In [229]:
embs[sorted_index] = 
torch.Tensor([_cat_lstm_last(word_feat, length) for word_feat, length in zip(X5, X_lengths)])

ValueError: only one element tensors can be converted to Python scalars

In [247]:
a = [_cat_lstm_last(word_feat, length) for word_feat, length in zip(X5, X_lengths)]

a

[tensor([-0.0471, -0.0848,  0.2032,  0.1596, -0.1078, -0.1297,  0.0264,  0.1290,
          0.0980, -0.0395], grad_fn=<CatBackward>),
 tensor([ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.0393, -0.0736,  0.0674,
         -0.0145,  0.1641], grad_fn=<CatBackward>),
 tensor([ 0.1737,  0.2137,  0.0389,  0.2571, -0.3928, -0.0828, -0.2182,  0.1457,
         -0.0035,  0.2346], grad_fn=<CatBackward>),
 tensor([-0.1466, -0.2590,  0.1123, -0.0125,  0.0628, -0.1297,  0.0264,  0.1290,
          0.0980, -0.0395], grad_fn=<CatBackward>),
 tensor([-0.0742, -0.2005, -0.0718,  0.1102, -0.0366, -0.0393, -0.0736,  0.0674,
         -0.0145,  0.1641], grad_fn=<CatBackward>),
 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], grad_fn=<CatBackward>)]

In [245]:
torch.stack(a)

tensor([[-0.0471, -0.0848,  0.2032,  0.1596, -0.1078, -0.1297,  0.0264,  0.1290,
          0.0980, -0.0395],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.0393, -0.0736,  0.0674,
         -0.0145,  0.1641],
        [ 0.1737,  0.2137,  0.0389,  0.2571, -0.3928, -0.0828, -0.2182,  0.1457,
         -0.0035,  0.2346],
        [-0.1466, -0.2590,  0.1123, -0.0125,  0.0628, -0.1297,  0.0264,  0.1290,
          0.0980, -0.0395],
        [-0.0742, -0.2005, -0.0718,  0.1102, -0.0366, -0.0393, -0.0736,  0.0674,
         -0.0145,  0.1641],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000]], grad_fn=<StackBackward>)