In [None]:
from datasets import load_dataset
dataset = load_dataset("cnn_dailymail", "3.0.0")


In [None]:
import re
import torch
from transformers import AutoTokenizer


In [None]:
def clean_text(text):
    """Cleans the text by removing unnecessary characters and formatting issues."""
    if not isinstance(text, str) or text.strip() == "":  # Remove empty or corrupted data
        return ""

    text = text.lower()  # Lowercasing
    text = re.sub(r"[^a-zA-Z0-9.,!?'\"\s]", "", text)  # Remove non-ASCII characters
    text = re.sub(r"\s+", " ", text).strip()  # Remove extra whitespaces

    return text

In [None]:
tokenizer = AutoTokenizer.from_pretrained("t5-small")

def preprocess_data(examples):
    """Cleans and tokenizes the input and target text, then converts to PyTorch tensors."""

    # Clean the text before tokenization
    cleaned_articles = [clean_text(article) for article in examples["article"]]
    cleaned_summaries = [clean_text(summary) for summary in examples["highlights"]]

    # Tokenizing articles
    inputs = tokenizer(
        cleaned_articles, max_length=512, truncation=True, padding="max_length"
    )

    # Tokenizing summaries
    labels = tokenizer(
        cleaned_summaries, max_length=128, truncation=True, padding="max_length"
    )

    return {
        "input_ids": torch.tensor(inputs["input_ids"], dtype=torch.long),
        "attention_mask": torch.tensor(inputs["attention_mask"], dtype=torch.long),
        "labels": torch.tensor(labels["input_ids"], dtype=torch.long),
    }


In [None]:
tokenized_dataset = dataset.map(preprocess_data, batched=True)


In [None]:
from torch.utils.data import DataLoader

def collate_fn(batch):
    """Collates batch and ensures all data is converted to tensors before stacking."""

    # Convert lists to tensors directly within the collate function
    input_ids = torch.tensor([item["input_ids"] for item in batch], dtype=torch.long)
    attention_mask = torch.tensor([item["attention_mask"] for item in batch], dtype=torch.long)
    labels = torch.tensor([item["labels"] for item in batch], dtype=torch.long)

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

# DataLoader with the collate function
train_dataloader = DataLoader(tokenized_dataset["train"], batch_size=8, shuffle=True, collate_fn=collate_fn)
val_dataloader = DataLoader(tokenized_dataset["validation"], batch_size=4, collate_fn=collate_fn)


In [None]:
for batch in train_dataloader:
    print("Batch structure:", batch)
    break


Encoder Decoder


In [None]:
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, n_layers, bidirectional=True, dropout=0.5):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(
            embedding_dim, hidden_dim, n_layers,
            bidirectional=bidirectional, batch_first=True, dropout=dropout
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
        embedded = self.dropout(self.embedding(src))
        outputs, (hidden, cell) = self.lstm(embedded)
        return outputs, hidden, cell  # Encoder outputs, hidden state, and cell state


In [None]:
class Attention(nn.Module):
    def __init__(self, hidden_dim):
        super(Attention, self).__init__()
        self.attn = nn.Linear(hidden_dim * 2 + hidden_dim, hidden_dim)
        self.v = nn.Linear(hidden_dim, 1, bias=False)

    def forward(self, hidden, encoder_outputs):
        # hidden: Decoder's previous hidden state
        # encoder_outputs: All hidden states from the encoder
        src_len = encoder_outputs.shape[1]
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)  # Align dimensions
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
        attention = torch.softmax(self.v(energy).squeeze(2), dim=1)  # Softmax over source length
        return attention

In [None]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, dropout=0.5):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(
            embedding_dim + hidden_dim * 2, hidden_dim,
            n_layers, batch_first=True, dropout=dropout
        )
        self.fc_out = nn.Linear(hidden_dim * 2 + hidden_dim + embedding_dim, output_dim)
        self.dropout = nn.Dropout(dropout)
        self.attention = Attention(hidden_dim)

    def forward(self, input, hidden, cell, encoder_outputs):
        input = input.unsqueeze(1)  # Add time dimension
        embedded = self.dropout(self.embedding(input))

        attn_weights = self.attention(hidden[-1], encoder_outputs)
        attn_applied = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs).squeeze(1)

        rnn_input = torch.cat((embedded, attn_applied.unsqueeze(1)), dim=2)
        outputs, (hidden, cell) = self.lstm(rnn_input, (hidden, cell))

        prediction = self.fc_out(torch.cat((outputs.squeeze(1), attn_applied, embedded.squeeze(1)), dim=1))
        return prediction, hidden, cell, attn_weights

In [None]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, dropout=0.5):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(
            embedding_dim + hidden_dim * 2, hidden_dim,
            n_layers, batch_first=True, dropout=dropout
        )
        self.fc_out = nn.Linear(hidden_dim * 2 + hidden_dim + embedding_dim, output_dim)
        self.dropout = nn.Dropout(dropout)
        self.attention = Attention(hidden_dim)

    def forward(self, input, hidden, cell, encoder_outputs):
        input = input.unsqueeze(1)  # Add time dimension
        embedded = self.dropout(self.embedding(input))

        attn_weights = self.attention(hidden[-1], encoder_outputs)
        attn_applied = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs).squeeze(1)

        rnn_input = torch.cat((embedded, attn_applied.unsqueeze(1)), dim=2)
        outputs, (hidden, cell) = self.lstm(rnn_input, (hidden, cell))

        prediction = self.fc_out(torch.cat((outputs.squeeze(1), attn_applied, embedded.squeeze(1)), dim=1))
        return prediction, hidden, cell, attn_weights

In [None]:
import random

# Define hyperparameters
input_dim = len(tokenizer.get_vocab())  # Adjust based on your tokenizer
output_dim = len(tokenizer.get_vocab())  # Same as input_dim for Seq2Seq tasks
emb_dim = 256
enc_hid_dim = 512
dec_hid_dim = 512
dropout = 0.5
device = torch.device('cpu')

# Define optimizer, loss function, etc.
encoder = Encoder(input_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout)
decoder = Decoder(
    vocab_size = input_dim,
    embedding_dim=emb_dim,
    hidden_dim=dec_hid_dim,
    output_dim=output_dim,
    n_layers=2,
)

# Instantiate Seq2Seq model
seq2seq = Seq2Seq(encoder, decoder, device)

num_epochs = 100
pad_idx = tokenizer.pad_token_id
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)
optimizer = torch.optim.Adam(seq2seq.parameters())

# Training loop
for epoch in range(num_epochs):
    seq2seq.train()  # Set model to training mode
    epoch_loss = 0

    for batch in train_dataloader:
        # Move tensors to the appropriate device
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        optimizer.zero_grad()  # Zero gradients

        # Forward pass
        outputs = seq2seq(input_ids, labels)
        outputs = outputs.view(-1, outputs.size(-1))  # Reshape for loss calculation

        labels = labels.view(-1)  # Reshape labels
        loss = criterion(outputs, labels)

        loss.backward()  # Backpropagation
        torch.nn.utils.clip_grad_norm_(seq2seq.parameters(), max_norm=1)  # Gradient clipping
        optimizer.step()

        # Accumulate batch loss
        epoch_loss += loss.item()

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss / len(train_dataloader):.4f}")
