In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

The Sequence-to-Sequence (Seq2Seq) model is a type of neural network architecture widely used in machine learning for tasks that involve mapping one sequence of data to another. It processes an input sequence and generates a corresponding output sequence. Seq2Seq models have had a significant impact in areas such as natural language processing (NLP), machine translation, speech recognition and time-series prediction.

In [2]:
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.rnn = nn.GRU(emb_dim, hidden_dim)

    def forward(self, src):
        embedded = self.embedding(src)
        outputs, hidden = self.rnn(embedded)
        return hidden

In [3]:
class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.rnn = nn.GRU(emb_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, input, hidden):
        input = input.unsqueeze(0)
        embedded = self.embedding(input)
        output, hidden = self.rnn(embedded, hidden)
        prediction = self.fc(output.squeeze(0))
        return prediction, hidden

In [4]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src, trg=None, max_len=10, teacher_forcing_ratio=0.5):
        batch_size = src.shape[1]
        trg_vocab_size = self.decoder.fc.out_features
        outputs = []

        hidden = self.encoder(src)

        input = torch.zeros(batch_size, dtype=torch.long).to(self.device)

        for t in range(max_len):
            output, hidden = self.decoder(input, hidden)
            top1 = output.argmax(1)
            outputs.append(top1.unsqueeze(0))

            if trg is not None and t < trg.shape[0] and torch.rand(1).item() < teacher_forcing_ratio:
                input = trg[t]
            else:
                input = top1

        outputs = torch.cat(outputs, dim=0)
        return outputs

- Batch size & vocab size: extracted from input and decoder.

- Encoding: input sequence → encoder → context vector (hidden).
- Start token: initialize decoder with token 0.
- Loop over max_len:
- Decoder predicts next token.
- top1 → token with max probability.
- Append top1 to outputs.
- Teacher forcing: sometimes feed true target token instead of prediction.
- Return predictions: concatenated sequence of token IDs.

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

VOCAB_SIZE = 10
EMB_DIM = 8
HID_DIM = 16
SEQ_LEN = 5
BATCH_SIZE = 2

enc = Encoder(VOCAB_SIZE, EMB_DIM, HID_DIM)
dec = Decoder(VOCAB_SIZE, EMB_DIM, HID_DIM)
model = Seq2Seq(enc, dec, device).to(device)

src = torch.randint(1, VOCAB_SIZE, (SEQ_LEN, BATCH_SIZE)).to(device)
trg = torch.randint(1, VOCAB_SIZE, (SEQ_LEN, BATCH_SIZE)).to(device)

outputs = model(src, trg, max_len=SEQ_LEN, teacher_forcing_ratio=0.7)

print("Source sequence (input tokens):")
print(src.T)
print("\nTarget sequence (true tokens):")
print(trg.T)
print("\nPredicted sequence (model output tokens):")
print(outputs.T)

Source sequence (input tokens):
tensor([[1, 3, 4, 6, 3],
        [2, 1, 1, 7, 2]])

Target sequence (true tokens):
tensor([[8, 1, 1, 9, 3],
        [5, 2, 1, 3, 3]])

Predicted sequence (model output tokens):
tensor([[3, 4, 6, 6, 4],
        [3, 3, 3, 3, 3]])


# Applications

- Machine Translation: Converts text between languages like English to French.
- Text Summarization: Produces concise summaries of documents or news articles.
- Speech Recognition: Transcribes spoken language into text.
- Image Captioning: Generates captions for images by combining visual features with sequence generation.
- Time-Series Prediction: Predicts future sequences based on past temporal data.