## Importing Libraries


In [None]:
import torch
import torch.nn as nn
import math
import pandas as pd

from typing import List, Dict, Tuple
from collections import Counter
import re

from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

import html
import unicodedata

from tqdm import tqdm


## Creating Input Embeddings

In [None]:
class InputEmbeddings(nn.Module):
    def __init__(self, dim_model: int, vocab_size: int):
        super().__init__()
        # Dimension of Our Model
        self.dim_model = dim_model
        # Vocabulary Size of Our Model
        self.vocab_size = vocab_size
        # Creating Embedding For Each Vocab Word
        self.embedding = nn.Embedding(vocab_size, dim_model)

    def forward(self, x):
        # Multipied to Normalize as random values my be too small
        return self.embedding(x) * math.sqrt(self.dim_model)

## Positional Encoding

In [None]:
class PositionalEncoding(nn.Module):

    def __init__(self, dim_model: int, seq_len: int, dropout: float):
        super().__init__()
        # Dimension of Our Model
        self.dim_model = dim_model
        # Sequence Length of Our Model
        self.seq_len = seq_len
        # Dropout to prevent overfitting
        self.dropout = nn.Dropout(dropout)

        # using unsqueeze to add extra dimension to match shape for multiplication: [seq_len,1]
        token_positions = torch.arange(0, seq_len, dtype = torch.float).unsqueeze(1)
        # creating division term for positional encoding formula for pairs (sin=2i,cos=2i+1)
        divider = torch.exp(torch.arange(0, dim_model, 2).float() * (-math.log(10000.0) / dim_model))
        # Creating matrix intialized with zeros
        positional_encoding = torch.zeros(seq_len, dim_model)
        # Applying sine to even indices (0,2,4,...) in positional encoding
        positional_encoding[:, 0::2] = torch.sin(token_positions * divider)
        # Applying cosine to odd indices (1,3,5,...) in positional encoding
        positional_encoding[:, 1::2] = torch.cos(token_positions * divider)

        # Adding an extra dimension at the beginning of positional encoding matrix for batch handling
        positional_encoding = positional_encoding.unsqueeze(0)

        # Register the positional encoding as a buffer so it moves with the model- either cpu or gpu
        self.register_buffer('positional_encoding', positional_encoding)


    def forward(self,x):

        # extracting the sequence length of current input
        seq_len = x.shape[1]
        # getting out positional encodings up to the required sequence length
        pos_encoding_for_seq = self.positional_encoding[:, :seq_len, :].to(x.device)
        # adding positional encoding to the input x
        x = x + pos_encoding_for_seq
        # Applying Dropout
        x = self.dropout(x)
        return x

## Layer Normalization

In [None]:
class LayerNormalization(nn.Module):
    def __init__(self):
        super().__init__()
        # Defining epsilon as 0.000001 to avoid division by zero
        self.epsilon = 10**-6
        # defining trainable parameter alpha
        self.alpha = nn.Parameter(torch.ones(1))
        # defining trainable parameter bias
        self.bias = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        # computing mean and std dev accross the embdedding dim
        mean = x.mean(dim = -1, keepdim = True)
        std = x.std(dim = -1, keepdim = True)
        normalized_input = self.alpha * (x - mean)/(std + self.epsilon) + self.bias
        return normalized_input

## Feed Forward

In [None]:
class FeedForward(nn.Module):

    def __init__(self, dim_model: int, hidden_dim: int, dropout: float):
        super().__init__()
        # projecting to a larger dimension/expansion layer
        self.input_projection = nn.Linear(dim_model, hidden_dim)
        # applying dropout
        self.dropout = nn.Dropout(dropout)
        # transforming back to original
        self.output_projection = nn.Linear(hidden_dim, dim_model)

    def forward(self, x):
        input=torch.relu(self.input_projection(x))
        output=self.output_projection(self.dropout(input))
        return output

## Multihead Attention

In [None]:
class MultiHeadAttention(nn.Module):

    def __init__(self, dim_model: int, num_heads: int, dropout: float):
        super().__init__()
        self.dim_model = dim_model
        self.num_heads = num_heads

        # checking model dimension is divisble by number of heads or not
        assert dim_model % num_heads == 0, 'Model Dimensions are not divisible by number of heads'

       # dimension of each heads key,query and value matrix
        self.dim_k = dim_model // num_heads

        # weight matrices for query,key,value and output
        self.w_q = nn.Linear(dim_model, dim_model)
        self.w_k = nn.Linear(dim_model, dim_model)
        self.w_v = nn.Linear(dim_model, dim_model)
        self.w_o = nn.Linear(dim_model, dim_model)

        self.dropout = nn.Dropout(dropout)


    def forward(self, q, k, v, mask):
        # performing linear tranformation/multiplying q with w_q to get query and others are same
        query = self.w_q(q)
        key = self.w_k(k)
        value = self.w_v(v)


        # split matrices for different heads/splitting the dimensions
        # shape[0]=batch size,shape[1]=seq_len and swapping 2nd and 3rd dim by transpose
        query = query.view(query.shape[0], query.shape[1], self.num_heads, self.dim_k).transpose(1,2)
        key = key.view(key.shape[0], key.shape[1], self.num_heads, self.dim_k).transpose(1,2)
        value = value.view(value.shape[0], value.shape[1], self.num_heads, self.dim_k).transpose(1,2)

        # Attention calculation
        dim_k = query.shape[-1]
        # Qxk^t / sqrt(dim_k)
        attention_scores = torch.matmul(query, key.transpose(-2,-1)) / math.sqrt(dim_k)


        if mask is not None:
            # Handle different mask types:
            # 1. Padding mask: [batch_size, 1, seq_len] -> needs reshape
            # 2. Attention mask: [batch_size, 1, seq_len, seq_len] -> already correct shape
            if mask.dim() == 3:  # For padding mask
                batch_size, _, seq_len = mask.shape
                # Reshape to [batch_size, 1, 1, seq_len] for broadcasting
                mask = mask.view(batch_size, 1, 1, seq_len)

            # Apply mask
            attention_scores.masked_fill_(mask == 0, -1e9)
        if self.dropout is not None:
            # apply dropout
            attention_scores = self.dropout(attention_scores)

        # applying softmax
        attention_scores = attention_scores.softmax(dim = -1)
        # multiplying with value matrix
        x = torch.matmul(attention_scores, value)
        self.attention_scores = attention_scores

        # converting back to same shape to get the head concatenated
        h = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.num_heads * self.dim_k)
        # multiplying head matrix with the weight matrix
        result=self.w_o(h)
        return result



## Single Encoder Block

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, self_attention: MultiHeadAttention, feed_forward: FeedForward, dropout: float):
        super().__init__()
        self.self_attention = self_attention
        self.feed_forward = feed_forward
        self.norm1 = LayerNormalization()
        self.norm2 = LayerNormalization()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        # Applying Multihead self attention with residual connection
        attn_output = self.self_attention(self.norm1(x), self.norm1(x), self.norm1(x), mask)      # Passing x as query,key and value along mask
        x = x + self.dropout(attn_output)

        # Applying Feed forward Network with residual connection
        ff_output = self.feed_forward(self.norm2(x))
        x = x + self.dropout(ff_output)
        return x

## Complete Encoder

In [None]:
class Encoder(nn.Module):
    def __init__(self, num_layers: int,self_attention: MultiHeadAttention,feed_forward: FeedForward,dropout: float):
        super().__init__()
        # Creating N identical encoder blocks
        self.layers = nn.ModuleList([
            EncoderBlock(
                self_attention=self_attention,
                feed_forward=feed_forward,
                dropout=dropout
            ) for _ in range(num_layers)
        ])

        self.norm = LayerNormalization()

    def forward(self, x, mask):
        # Passing input through each encoder block in sequence
        for layer in self.layers:
            x = layer(x, mask)

        # Final layer normalization
        normalized_output = self.norm(x)
        return normalized_output

## Single Decoder Block

In [None]:
class DecoderBlock(nn.Module):
    def __init__(self,self_attention_block: MultiHeadAttention,cross_attention_block: MultiHeadAttention,feed_forward_block: FeedForward,dropout: float):
        super().__init__()
        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward_block = feed_forward_block
        self.norm1 = LayerNormalization()
        self.norm2 = LayerNormalization()
        self.norm3 = LayerNormalization()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        # Self attention with residual connection
        self_attn = self.self_attention_block(self.norm1(x), self.norm1(x), self.norm1(x), tgt_mask)
        x = x + self.dropout(self_attn)

        # Cross attention with residual connection
        cross_attn = self.cross_attention_block(self.norm2(x), encoder_output, encoder_output, src_mask)
        x = x + self.dropout(cross_attn)

        # Feed forward with residual connection
        ff_output = self.feed_forward_block(self.norm3(x))
        x = x + self.dropout(ff_output)

        return x

## Complete Decoder

In [None]:
class Decoder(nn.Module):
    def __init__(self,num_layers: int,self_attention_block: MultiHeadAttention,cross_attention_block: MultiHeadAttention,feed_forward_block: FeedForward,dropout: float):
        super().__init__()

        # Create N identical decoder blocks
        self.layers = nn.ModuleList([
            DecoderBlock(
                self_attention_block=self_attention_block,
                cross_attention_block=cross_attention_block,
                feed_forward_block=feed_forward_block,
                dropout=dropout
            ) for _ in range(num_layers)
        ])

        self.norm = LayerNormalization()

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        # Pass through each decoder block in sequence
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        # Final layer normalization
        normalized_output = self.norm(x)

        return normalized_output


## Linear Layer with Softmax

In [None]:
class Linearprojectionlayer(nn.Module):
    def __init__(self, dim_model: int, vocab_size: int):
        super().__init__()
        self.proj = nn.Linear(dim_model, vocab_size)
    def forward(self, x):
        return torch.log_softmax(self.proj(x), dim = -1)


## Complete Transformer Class

In [None]:
class Transformer(nn.Module):
    def __init__(
            self,
            encoder_vocab_size: int,
            decoder_vocab_size: int,
            dim_model: int,
            num_heads: int,
            num_encoder_layers: int,
            num_decoder_layers: int,
            hidden_dim: int,
            dropout: float,
            max_seq_length: int
        ):


        super().__init__()
        self.encoder_vocab_size = encoder_vocab_size
        self.decoder_vocab_size = decoder_vocab_size
        # Creating the embedding layers
        self.encoder_embedding = InputEmbeddings(dim_model, encoder_vocab_size)
        self.decoder_embedding = InputEmbeddings(dim_model, decoder_vocab_size)

        # Creating the positional encoding layers
        self.encoder_position = PositionalEncoding(dim_model, max_seq_length, dropout)
        self.decoder_position = PositionalEncoding(dim_model, max_seq_length, dropout)

        # Creating attention and feed-forward instances for encoder
        encoder_self_attention = MultiHeadAttention(dim_model, num_heads, dropout)
        encoder_feed_forward = FeedForward(dim_model, hidden_dim, dropout)

        # Creating attention and feed-forward instances for decoder
        decoder_self_attention = MultiHeadAttention(dim_model, num_heads, dropout)
        decoder_cross_attention = MultiHeadAttention(dim_model, num_heads, dropout)
        decoder_feed_forward = FeedForward(dim_model, hidden_dim, dropout)

        # Creating the encoder and decoder
        self.encoder = Encoder(num_encoder_layers,encoder_self_attention,encoder_feed_forward,dropout)

        self.decoder = Decoder(num_decoder_layers,decoder_self_attention,decoder_cross_attention,decoder_feed_forward,dropout)

        # Creating the final linear projection layer
        self.linear = Linearprojectionlayer(dim_model, decoder_vocab_size)

    def encode(self, src, src_mask):
        src = self.encoder_embedding(src)
        src = self.encoder_position(src)
        return self.encoder(src, src_mask)

    def decode(self, tgt, encoder_output, src_mask, tgt_mask):
        tgt = self.decoder_embedding(tgt)
        tgt = self.decoder_position(tgt)
        return self.decoder(tgt, encoder_output, src_mask, tgt_mask)

    def forward(self, src, tgt, src_mask, tgt_mask):
        encoder_output = self.encode(src, src_mask)
        decoder_output = self.decode(tgt, encoder_output, src_mask, tgt_mask)
        return self.linear(decoder_output)


    def generate(self, src, src_mask, max_length=100, temperature=1.0):
        self.eval()
        encoder_output = self.encode(src, src_mask)

        # Start with START token (2)
        tgt = torch.tensor([[2]], device=src.device)

        for _ in range(max_length):
            # Creating causal mask for target
            tgt_mask = torch.zeros((1, 1, tgt.size(1), tgt.size(1)), device=src.device)
            tgt_mask = torch.triu(tgt_mask.fill_(float('-inf')), diagonal=1)

            decoder_output = self.decode(tgt, encoder_output, src_mask, tgt_mask)
            logits = self.linear(decoder_output[:, -1:])

            # Apply temperature and sample
            probs = (logits / temperature).softmax(dim=-1)
            next_token = torch.multinomial(probs.squeeze(1), 1)

            tgt = torch.cat([tgt, next_token], dim=1)

            # Check for END token 3
            if next_token.item() == 3:
                break

        return tgt

## Testing the Transformer

In [None]:
# # Model hyperparameters
# BATCH_SIZE = 2
# SEQ_LENGTH = 10
# SRC_VOCAB_SIZE = 5000
# TGT_VOCAB_SIZE = 5000
# DIM_MODEL = 512
# NUM_HEADS = 8
# NUM_ENCODER_LAYERS = 3
# NUM_DECODER_LAYERS = 3
# HIDDEN_DIM = 2048
# DROPOUT = 0.1
# MAX_SEQ_LENGTH = 100

# # Create model
# model = Transformer(
#     encoder_vocab_size=SRC_VOCAB_SIZE,
#     decoder_vocab_size=TGT_VOCAB_SIZE,
#     dim_model=DIM_MODEL,
#     num_heads=NUM_HEADS,
#     num_encoder_layers=NUM_ENCODER_LAYERS,
#     num_decoder_layers=NUM_DECODER_LAYERS,
#     hidden_dim=HIDDEN_DIM,
#     dropout=DROPOUT,
#     max_seq_length=MAX_SEQ_LENGTH
# )

# # Create sample input data
# src = torch.randint(1, SRC_VOCAB_SIZE, (BATCH_SIZE, SEQ_LENGTH))  # Random source sequences
# tgt = torch.randint(1, TGT_VOCAB_SIZE, (BATCH_SIZE, SEQ_LENGTH))  # Random target sequences

# # Create masks with correct shapes
# src_mask = torch.ones(BATCH_SIZE, 1, 1, SEQ_LENGTH)  # Shape: [batch_size, 1, 1, seq_len]

# # Create causal mask for decoder
# tgt_mask = torch.triu(torch.ones((SEQ_LENGTH, SEQ_LENGTH)) * float('-inf'), diagonal=1)
# tgt_mask = tgt_mask.expand(BATCH_SIZE, 1, SEQ_LENGTH, SEQ_LENGTH)


# # Forward pass
# print("Input shapes:")
# print(f"Source: {src.shape}")
# print(f"Target: {tgt.shape}")
# print(f"Source mask: {src_mask.shape}")
# print(f"Target mask: {tgt_mask.shape}")

# output = model(src, tgt, src_mask, tgt_mask)

# print("\nOutput shape:", output.shape)
# print("\nExpected output shape:", (BATCH_SIZE, SEQ_LENGTH, TGT_VOCAB_SIZE))

## Tokenizer

In [None]:

class Tokenizer:
    def __init__(self, max_vocab_size: int = 50000):
        self.max_vocab_size = max_vocab_size
        # Add special token for speaker separator
        self.word2idx = {
            '<PAD>': 0,
            '<UNK>': 1,
            '<START>': 2,
            '<END>': 3,
            '<SPEAKER>': 4,  # New special token for speakers
            ':': 5  # Preserve colon as special token
        }
        self.idx2word = {
            0: '<PAD>',
            1: '<UNK>',
            2: '<START>',
            3: '<END>',
            4: '<SPEAKER>',
            5: ':'
        }
        self.vocab = Counter()

    def tokenize(self, text: str):
        tokens = []
        segments = text.split()

        for segment in segments:
            if ':' in segment:
                speaker, rest = segment.split(':', 1)
                # Add speaker tokens in sequence
                tokens.append('<SPEAKER>')
                tokens.extend(re.findall(r'\b\w+\b', speaker.lower()))
                tokens.append(':')
                # Add message tokens
                if rest:
                    tokens.extend(re.findall(r'\b\w+\b', rest.lower()))
            else:
                tokens.extend(re.findall(r'\b\w+\b', segment.lower()))

        return tokens

    def build_vocab(self, texts: List[str]):
        # Count words in all texts
        for text in texts:
            # converting to tokens
            tokens = self.tokenize(text)
            self.vocab.update(tokens)

        # Get most common words
        vocab_size = self.max_vocab_size - len(self.word2idx)
        # getting most common words first and then adding them to vocab
        most_common = self.vocab.most_common(vocab_size)

        # starting idx after special token
        current_idx = len(self.word2idx)
          # Add words to vocabulary
        for word, _ in most_common:
            self.word2idx[word] = current_idx
            self.idx2word[current_idx] = word
            current_idx += 1

    def text_to_sequence(self, text: str):
        tokens = self.tokenize(text)
        sequence = []

        # Add start token
        sequence.append(self.word2idx['<START>'])

        # Converting each token to its corresponding index
        # If token not in vocabulary, using UNK token
        for token in tokens:
            if token in self.word2idx:
                sequence.append(self.word2idx[token])
            else:
                sequence.append(self.word2idx['<UNK>'])

        # Add end token
        sequence.append(self.word2idx['<END>'])
        return sequence

    def sequence_to_text(self, sequence: List[int]):
        tokens = []
        is_speaker = False
        current_speaker = []

        for idx in sequence:
            token = self.idx2word[idx]

            if token == '<SPEAKER>':
                is_speaker = True
                if current_speaker:  # Add previous speaker's text
                    tokens.append(' '.join(current_speaker))
                current_speaker = []
            elif token == ':':
                is_speaker = False
                speaker_name = ' '.join(current_speaker)
                tokens.append(f"{speaker_name}:")
                current_speaker = []
            elif token not in ['<START>', '<END>', '<PAD>', '<UNK>']:
                if is_speaker:
                    current_speaker.append(token)
                else:
                    tokens.append(token)

        if current_speaker:  # Add any remaining speaker text
            tokens.append(' '.join(current_speaker))

        return ' '.join(tokens)


## Testing My Tokenizer

In [None]:
# Create sample texts
texts = [
    "hello world this is a test",
    "another test with more words",
    "hello again world testing tokenizer",
    "this is the final test sentence"
]

# Initialize tokenizer
tokenizer = Tokenizer(max_vocab_size=20)

# Build vocabulary
tokenizer.build_vocab(texts)

# Print vocabulary statistics
print(f"\nVocabulary size: {len(tokenizer.word2idx)}")
print("\nWord to index mapping:")
for word, idx in tokenizer.word2idx.items():
    print(f"{word}: {idx}")

# Test text to sequence conversion
test_text = "hello world test"
sequence = tokenizer.text_to_sequence(test_text)
print(f"\nConverting '{test_text}' to sequence:")
print(f"Sequence: {sequence}")

# Test sequence to text conversion
recovered_text = tokenizer.sequence_to_text(sequence)
print(f"Recovered text: '{recovered_text}'")

# Test unknown word handling
unknown_text = "hello unknown word test"
unknown_sequence = tokenizer.text_to_sequence(unknown_text)
print(f"\nTesting unknown word handling:")
print(f"Input text: '{unknown_text}'")
print(f"Sequence with unknown: {unknown_sequence}")
print(f"Recovered text: '{tokenizer.sequence_to_text(unknown_sequence)}'")


Vocabulary size: 20

Word to index mapping:
<PAD>: 0
<UNK>: 1
<START>: 2
<END>: 3
<SPEAKER>: 4
:: 5
test: 6
hello: 7
world: 8
this: 9
is: 10
a: 11
another: 12
with: 13
more: 14
words: 15
again: 16
testing: 17
tokenizer: 18
the: 19

Converting 'hello world test' to sequence:
Sequence: [2, 7, 8, 6, 3]
Recovered text: 'hello world test'

Testing unknown word handling:
Input text: 'hello unknown word test'
Sequence with unknown: [2, 7, 1, 1, 6, 3]
Recovered text: 'hello test'


## Data Prepration

## Data Cleaning

In [None]:
def clean_text(text: str) -> str:
    # Convert html entities
    text = html.unescape(text)

    # Normalize unicode characters
    text = unicodedata.normalize('NFKC', text)

    # Remove URLs but preserve speaker patterns
    text = re.sub(r'http\S+|www\S+|https\S+', '', text, flags=re.MULTILINE)

    # Remove file attachments
    text = re.sub(r'<file_\w+>', '', text)

    text = ' '.join(text.split())

    return text.strip()

### Dataset Class

In [None]:
class SamsumDataset(Dataset):
    def __init__(self, dialogues, summaries, tokenizer):

        self.dialogues = []
        for d in dialogues:
            cleaned = clean_text(d)
            self.dialogues.append(cleaned)


        self.summaries = []
        for s in summaries:
            cleaned = clean_text(s)
            self.summaries.append(cleaned)

        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.dialogues)

    def __getitem__(self, idx):
        dialogue = self.dialogues[idx]
        summary = self.summaries[idx]

        # Convert to sequences of indices
        dialogue_seq = self.tokenizer.text_to_sequence(dialogue)
        summary_seq = self.tokenizer.text_to_sequence(summary)

        return {
            'dialogue': torch.tensor(dialogue_seq),
            'summary': torch.tensor(summary_seq)
        }


In [None]:
MAX_SEQ_LENGTH = 512

In [None]:
def prepare_batch_with_padding(batch, max_seq_length=None):
    # Separate dialogues and summaries
    dialogues = []
    summaries = []

    # Ensure sequences don't exceed max_seq_length
    for item in batch:
        if max_seq_length:
            dialogue = item['dialogue'][:max_seq_length]
            summary = item['summary'][:max_seq_length]
        else:
            dialogue = item['dialogue']
            summary = item['summary']
        dialogues.append(dialogue)
        summaries.append(summary)

    # Pad sequences
    dialogues_padded = pad_sequence(dialogues, batch_first=True, padding_value=0)
    summaries_padded = pad_sequence(summaries, batch_first=True, padding_value=0)

    return {
        'dialogues': dialogues_padded,
        'summaries': summaries_padded
    }


In [None]:
def load_and_prepare_data(train_path, val_path, test_path, tokenizer, batch_size=32, max_seq_length=512):
    """
    Load and prepare dialogue data for training
    """
    # Load datasets
    train_df = pd.read_csv(train_path)
    val_df = pd.read_csv(val_path)
    test_df = pd.read_csv(test_path)

    # Clean and convert data to strings
    for df in [train_df, val_df, test_df]:
        df['dialogue'] = df['dialogue'].fillna('').astype(str)
        df['summary'] = df['summary'].fillna('').astype(str)

    # Build vocabulary from cleaned training data
    print("Building vocabulary...")
    tokenizer.build_vocab(
        list(train_df['dialogue']) + list(train_df['summary'])
    )
    print(f"Vocabulary size: {len(tokenizer.word2idx)}")

    # Create datasets
    train_dataset = SamsumDataset(
        train_df['dialogue'].tolist(),
        train_df['summary'].tolist(),
        tokenizer
    )

    val_dataset = SamsumDataset(
        val_df['dialogue'].tolist(),
        val_df['summary'].tolist(),
        tokenizer
    )

    test_dataset = SamsumDataset(
        test_df['dialogue'].tolist(),
        test_df['summary'].tolist(),
        tokenizer
    )

    # Create dataloaders with max sequence length
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=lambda x: prepare_batch_with_padding(x, max_seq_length)
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=lambda x: prepare_batch_with_padding(x, max_seq_length)
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=lambda x: prepare_batch_with_padding(x, max_seq_length)
    )

    return train_loader, val_loader, test_loader

## Testing Data prepration Pipeline

In [None]:
# Add data loading and preparation
train_path = "/content/samsum-train.csv"
val_path = "/content/samsum-validation.csv"
test_path = "/content/samsum-test.csv"

# Initialize tokenizer
tokenizer = Tokenizer(max_vocab_size=30000)

# Load and prepare data
train_loader, val_loader, test_loader = load_and_prepare_data(
    train_path,
    val_path,
    test_path,
    tokenizer,
    batch_size=32

)

# Print sample batch
# for batch in train_loader:
#     print("\nSample batch shapes:")
#     print(f"Dialogues: {batch['dialogues'].shape}")
#     print(f"Summaries: {batch['summaries'].shape}")

#     # Print first dialogue and summary from batch
#     print("\nSample dialogue and summary:")
#     dialogue_seq = batch['dialogues'][0].tolist()  # Get first dialogue
#     summary_seq = batch['summaries'][0].tolist()  # Get first summary

#     print("\nDialogue:")
#     print(tokenizer.sequence_to_text(dialogue_seq))
#     print("\nSummary:")
#     print(tokenizer.sequence_to_text(summary_seq))
#     break

Building vocabulary...
Vocabulary size: 29998


## Training Pipeline

### Creating Masks

In [None]:
def create_masks(src, tgt):
    # Source mask (for padding)
    src_padding_mask = (src != 0).unsqueeze(1).unsqueeze(2)

    # Target mask for padding
    tgt_padding_mask = (tgt != 0).unsqueeze(1).unsqueeze(2)

    # Target mask for causal attention
    seq_length = tgt.size(1)
    causal_mask = torch.triu(torch.ones((seq_length, seq_length)) * float('-inf'), diagonal=1)
    causal_mask = causal_mask.to(tgt.device)

    # Combine padding and causal masks for target
    tgt_mask = causal_mask.masked_fill(tgt_padding_mask == 0, float('-inf'))

    return src_padding_mask, tgt_mask

### Training Functions

In [None]:
def calculate_accuracy(output, target, pad_idx=0):
    # Get predictions
    _, predictions = output.max(1)

    # Create mask to ignore padding tokens
    mask = (target != pad_idx)

    # Calculate accuracy only on non-padded tokens
    correct = ((predictions == target) * mask).sum().float()
    total = mask.sum().float()

    # Avoid division by zero
    accuracy = correct / total if total > 0 else torch.tensor(0.0)

    return accuracy.item()

In [None]:
def train_epoch(model, train_loader, optimizer, criterion, device, scheduler):
    model.train()
    total_loss = 0
    total_acc = 0
    num_batches = len(train_loader)

    progress_bar = tqdm(train_loader, desc=f'Training')

    for batch_idx, batch in enumerate(progress_bar):
        src = batch['dialogues'].to(device)
        tgt = batch['summaries'].to(device)

        # Create target input (remove last token) and target output (remove first token)
        tgt_input = tgt[:, :-1]
        tgt_output = tgt[:, 1:]

        # Create masks
        src_mask, tgt_mask = create_masks(src, tgt_input)
        src_mask = src_mask.to(device)
        tgt_mask = tgt_mask.to(device)

        # Forward pass
        optimizer.zero_grad()
        output = model(src, tgt_input, src_mask, tgt_mask)

        # Reshape output and target for loss calculation
        output = output.view(-1, output.size(-1))
        tgt_output = tgt_output.reshape(-1)

        # Calculate loss
        loss = criterion(output, tgt_output)

        # Calculate accuracy with padding handling
        acc = calculate_accuracy(output, tgt_output, pad_idx=0)

        # Backward pass and optimization
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()  # Step the scheduler after each batch
        optimizer.zero_grad()

        total_loss += loss.item()
        total_acc += acc

        # Update progress bar with current learning rate
        avg_loss = total_loss / (batch_idx + 1)
        avg_acc = total_acc / (batch_idx + 1)
        current_lr = scheduler.get_last_lr()[0]
        progress_bar.set_description(
            f'Training: loss={avg_loss:.4f}, acc={avg_acc:.4f}, lr={current_lr:.6f}'
        )

    return total_loss / num_batches, total_acc / num_batches

In [None]:
def validate(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0
    total_acc = 0
    num_batches = len(val_loader)

    # Create progress bar with loss and accuracy tracking
    progress_bar = tqdm(val_loader, desc=f'Validating')

    with torch.no_grad():
        for batch_idx, batch in enumerate(progress_bar):
            src = batch['dialogues'].to(device)
            tgt = batch['summaries'].to(device)

            tgt_input = tgt[:, :-1]
            tgt_output = tgt[:, 1:]

            src_mask, tgt_mask = create_masks(src, tgt_input)
            src_mask = src_mask.to(device)
            tgt_mask = tgt_mask.to(device)

            output = model(src, tgt_input, src_mask, tgt_mask)

            output = output.view(-1, output.size(-1))
            tgt_output = tgt_output.reshape(-1)

            # Calculate loss
            loss = criterion(output, tgt_output)

            # Calculate accuracy with padding handling
            acc = calculate_accuracy(output, tgt_output, pad_idx=0)

            total_loss += loss.item()
            total_acc += acc

            # Update progress bar description with current loss and accuracy
            avg_loss = total_loss / (batch_idx + 1)
            avg_acc = total_acc / (batch_idx + 1)
            progress_bar.set_description(f'Validating: loss={avg_loss:.4f}, acc={avg_acc:.4f}')

    return total_loss / num_batches, total_acc / num_batches

## Training

In [None]:
# Model hyperparameters
BATCH_SIZE = 32
SRC_VOCAB_SIZE = 30000
TGT_VOCAB_SIZE = 30000
DIM_MODEL = 256
NUM_HEADS = 4
NUM_ENCODER_LAYERS = 4
NUM_DECODER_LAYERS = 4
HIDDEN_DIM = 1024
DROPOUT = 0.1
MAX_LR = 0.01  # Maximum learning rate for OneCycleLR
NUM_EPOCHS = 5

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [None]:
# Initialize model
model = Transformer(encoder_vocab_size=SRC_VOCAB_SIZE,decoder_vocab_size=TGT_VOCAB_SIZE,dim_model=DIM_MODEL,num_heads=NUM_HEADS,num_encoder_layers=NUM_ENCODER_LAYERS,num_decoder_layers=NUM_DECODER_LAYERS,hidden_dim=HIDDEN_DIM,dropout=DROPOUT,max_seq_length=MAX_SEQ_LENGTH).to(device)

# Initialize optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=MAX_LR)

# Calculate total steps for OneCycleLR
total_steps = len(train_loader) * NUM_EPOCHS

# Initialize the OneCycleLR scheduler
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=MAX_LR,
    total_steps=total_steps,
    pct_start=0.3,
    div_factor=25,
    final_div_factor=1e4,
    anneal_strategy='cos'  # Cosine annealing
)

criterion = nn.CrossEntropyLoss(ignore_index=0)

In [None]:
best_val_loss = float('inf')

for epoch in range(NUM_EPOCHS):
    print(f'\nEpoch {epoch + 1}/{NUM_EPOCHS}')

    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device, scheduler)
    val_loss, val_acc = validate(model, val_loader, criterion, device)

    print(f'Epoch {epoch + 1}/{NUM_EPOCHS} - '
          f'train_loss: {train_loss:.4f}, train_acc: {train_acc:.4f}, '
          f'val_loss: {val_loss:.4f}, val_acc: {val_acc:.4f}, '
          f'lr: {scheduler.get_last_lr()[0]:.6f}')

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'val_loss': val_loss,
        }, 'best_transformer_model.pt')
        print('Model saved!')


In [None]:
def generate_summary(model, tokenizer, dialogue, device, max_length=100, temperature=1.0):
    model.eval()
    with torch.no_grad():
        try:
            # Tokenize and prepare input
            dialogue_seq = tokenizer.text_to_sequence(dialogue)
            src = torch.tensor([dialogue_seq], device=device)

            # Create source mask
            src_mask = (src != 0).unsqueeze(1).unsqueeze(2)

            # Generate sequence
            generated_seq = model.generate(
                src,
                src_mask,
                max_length=max_length,
                temperature=temperature
            )

            return tokenizer.sequence_to_text(generated_seq.squeeze().tolist())
        except Exception as e:
            print(f"Error in generate_summary: {str(e)}")
            return ""

def test_model_generation():
    # Load the trained model

      checkpoint = torch.load('best_transformer_model.pt', map_location=device)
      model.load_state_dict(checkpoint['model_state_dict'])
      model.eval()

      # Test dialogue
      test_dialogue = "John: Hi Mary, how are you? Mary: I'm good, thanks! How about you? John: I'm doing great, thanks for asking!"

      # Generate summary with different temperatures
      for temp in [0.7, 1.0, 1.3]:
          summary = generate_summary(model, tokenizer, test_dialogue, device, temperature=temp)
          print(f"\nTemperature: {temp}")
          print(f"Dialogue:\n{test_dialogue}\n")
          print(f"Generated Summary:\n{summary}")

