In [1]:
import torch

import torch.nn as nn

from  torch.nn.utils.rnn import pad_sequence,pad_packed_sequence, pack_padded_sequence

In [2]:
# Suppose our vocab: {0: PAD, 1: deep, 2: learning, 3: is, 4: fun, 5: powerful, 6: and}

# Variable-length sequences (already tokenized)

seqs = [
    torch.tensor([1, 2, 3]),          # "deep learning is"     → length 3
    torch.tensor([4, 6, 5]),          # "fun and powerful"     → length 3
    torch.tensor([1, 2])              # "deep learning"        → length 2
]

# Lengths before padding
seq_lengths = torch.tensor([3, 3, 2])

Padding and Packing

In [3]:
# Pad sequences (batch_first = False for LSTM default shape)

padded_seqs = pad_sequence(seqs, batch_first=False, padding_value=0)

print("Padded:\n",padded_seqs)


# Pack them (enables LSTM to skip PADs)

packed_input = pack_padded_sequence(padded_seqs, seq_lengths, enforce_sorted=False)

Padded:
 tensor([[1, 4, 1],
        [2, 6, 2],
        [3, 5, 0]])


LSTM Definition

In [4]:
class PackedLSTM(nn.Module):
    def __init__(self,vocab_size, embedding_dim, hidden_dim):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size,embedding_dim, padding_idx=0)

        self.lstm = nn.LSTM(embedding_dim, hidden_dim)

    
    def forward(self, packed_input):

        embedded = self.embedding(packed_input.data)   # pack.data is shape [sum(lengths), emb_dim]

        packed_embedded = torch.nn.utils.rnn.PackedSequence(embedded, packed_input.batch_sizes)

        packed_output,(h_n, c_n) = self.lstm(packed_embedded)

        return packed_output, (h_n, c_n)

Running the Model

In [5]:
vocab_size = 7

embedding_dim = 8

hidden_dim = 16


model = PackedLSTM(vocab_size, embedding_dim, hidden_dim)

packed_output, (h_n,c_n) = model(packed_input)


# To view the output normally (if needed):

padded_output, output_lengths = pad_packed_sequence(packed_output,batch_first=False)

print("LSTM output shape (padded):", padded_output.shape)

LSTM output shape (padded): torch.Size([3, 3, 16])


| Step                     | Action                                     |
| ------------------------ | ------------------------------------------ |
| `pad_sequence()`         | Pads all input sequences to match longest. |
| `pack_padded_sequence()` | Converts padded input into packed form.    |
| `lstm(packed)`           | LSTM processes only the non-padded tokens. |
| `pad_packed_sequence()`  | Unpacks the output to padded form again.   |