# Importing Important Libraries

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoTokenizer

# BART doc: https://huggingface.co/docs/transformers/en/model_doc/bart
from transformers import BartTokenizer, BartForConditionalGeneration

# Simple LSTM-Based Seq2Seq Model

In [None]:
# Config
EMB_DIM = 256
HID_DIM = 512
NUM_LAYERS = 1
LR = 0.001
MAX_LEN = 30
device = "cuda" if torch.cuda.is_available() else "cpu"

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
PAD_IDX = tokenizer.pad_token_id
SOS_IDX = tokenizer.cls_token_id or tokenizer.bos_token_id or 101
EOS_IDX = tokenizer.sep_token_id or tokenizer.eos_token_id or 102
VOCAB_SIZE = tokenizer.vocab_size

# Dummy summarization data
texts = [
    "The cat sat on the mat and looked at the dog.",
    "The stock market crashed due to inflation concerns.",
    "Artificial intelligence is transforming many industries.",
]

summaries = [
    "Cat and dog on mat.",
    "Market crashed from inflation.",
    "AI changing industries.",
]

In [None]:
# Define a function to tokenize a given list of sentences
def encode_batch(sentences):
    tokens = tokenizer(
        sentences,
        padding='max_length',
        truncation=True,
        max_length=MAX_LEN,
        return_tensors='pt'
    )
    return tokens['input_ids'].to(device)

src = encode_batch(texts)
trg = encode_batch(summaries)

In [None]:
# Defining an Encoder
# nn.Module doc: https://docs.pytorch.org/docs/stable/generated/torch.nn.Module.html
class Encoder(nn.Module):
    def __init__(self, vocab_size, emb_dim, hid_dim, num_layers):
        super().__init__()

        # Embedding layer - https://docs.pytorch.org/docs/stable/generated/torch.nn.Embedding.html
        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=PAD_IDX)

        # Define LSTM - https://docs.pytorch.org/docs/stable/generated/torch.nn.LSTM.html
        self.lstm = nn.LSTM(emb_dim, hid_dim, num_layers, batch_first=True)

    def forward(self, src):
        # Get embeddings for input tokens
        embedded = self.embedding(src)

        #Pass embeddings through LSTM to get hidden and cell states
        outputs, (hidden, cell) = self.lstm(embedded)
        return hidden, cell

# Defining a Decoder
class Decoder(nn.Module):
    def __init__(self, vocab_size, emb_dim, hid_dim, num_layers):
        super().__init__()

        # Similar to encoder, except there is a language modeling head (linear projection layer) at the end of it
        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=PAD_IDX)
        self.lstm = nn.LSTM(emb_dim, hid_dim, num_layers, batch_first=True)

        # Language model head - linear layer that projects the hid_dim dimension to vocab_size
        self.fc_out = nn.Linear(hid_dim, vocab_size)

    def forward(self, input, hidden, cell):
        # One forward step of the decoder (similar to encoder)
        input = input.unsqueeze(1)  # [batch_size, 1]
        embedded = self.embedding(input)
        output, (hidden, cell) = self.lstm(embedded, (hidden, cell))

        # Project the LSTM output into logits (from which individual token probabilities can be computed)
        logits = self.fc_out(output.squeeze(1))
        return logits, hidden, cell

# Combining encoder and decoder into a Seq2Seq Model
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, src, trg):
        batch_size, trg_len = trg.shape

        # Initializing a tensor which we will use to store the output vectors computed at each decoding step
        outputs = torch.zeros(batch_size, trg_len, VOCAB_SIZE).to(device)

        # We get the hidden and cell states from the encoder
        hidden, cell = self.encoder(src)

        # Now, we want the decoder to start generating the summary
        # So, we pass the first target token (<bos> token)
        input = trg[:, 0]

        for t in range(1, trg_len):
            # For the current input token, run it through the decoder for one step and get the updated hidden and cell states
            output, hidden, cell = self.decoder(input, hidden, cell)

            # Save the output (logits) for later use
            outputs[:, t] = output

            # The input for the next step of decoding is set to the CURRENT highest probability token
            # Doc: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.argmax.html
            input = output.argmax(1)

        return outputs

In [None]:
# Initializing our Seq2Seq model
enc = Encoder(VOCAB_SIZE, EMB_DIM, HID_DIM, NUM_LAYERS)
dec = Decoder(VOCAB_SIZE, EMB_DIM, HID_DIM, NUM_LAYERS)
model = Seq2Seq(enc, dec).to(device)

# Defining optimizer and loss function for training
# Adam optimizer doc: https://docs.pytorch.org/docs/stable/generated/torch.optim.Adam.html
# CE Loss doc: https://docs.pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
optimizer = optim.Adam(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

# Training loop (tiny demo)
EPOCHS = 10
for epoch in range(EPOCHS):

    # This zeroes-out the gradients computed at the previous step
    # Doc: https://docs.pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html
    optimizer.zero_grad()

    # Getting an output from the model
    output = model(src, trg) # Tensor of size [batch_size, trg_len, vocab_size]

    # Now, we compute the CE-loss between the predicted output logits and the ground-truth target tokens
    output = output.reshape(-1, VOCAB_SIZE)
    trg_y = trg.reshape(-1)
    loss = criterion(output, trg_y)

    # Computes gradients for one step of backward propagation
    # Doc: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.backward.html
    loss.backward()

    # Updates model weights based on computed gradients
    # Doc: https://docs.pytorch.org/docs/stable/generated/torch.optim.Optimizer.step.html
    optimizer.step()

    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {loss.item():.4f}")

# Inference
def summarize(sentence, max_len=MAX_LEN):

    # Sets the model to evaluation model (no weight updates will happen)
    model.eval()
    with torch.no_grad():

        # Tokenize the given sentence and pass it through the model encoder
        src = encode_batch([sentence])
        hidden, cell = model.encoder(src)
        input = torch.tensor([SOS_IDX]).to(device) # Input passed to decoder
        summary = [input.item()] # Tokens of summary

        for _ in range(max_len):

            # Get the decoder's output for the current input
            output, hidden, cell = model.decoder(input, hidden, cell)

            # The predicted token is that which has the highest probability
            pred = output.argmax(1)

            # If the model predicts an End-of-Sentence or Padding Token, we stop generating further
            token_id = pred.item()
            if token_id == EOS_IDX or token_id == PAD_IDX:
                break

            # Otherwise, we append the generated token to our summary
            summary.append(token_id)

            # Input for next step is the token predicted during the current step (autoregressive)
            input = pred

        # Use the tokenizer to decoder ("de-tokenize") the summary
        return tokenizer.decode(summary, skip_special_tokens=True)

# Generating summaries
print("\nSample Summaries:")
for text in texts:
    print(f"TEXT: {text}")
    print(f"SUMMARY: {summarize(text)}")
    print("-" * 50)


Epoch 1/10, Loss: 10.3213
Epoch 2/10, Loss: 10.2706
Epoch 3/10, Loss: 10.1675
Epoch 4/10, Loss: 9.9936
Epoch 5/10, Loss: 9.8127
Epoch 6/10, Loss: 9.4919
Epoch 7/10, Loss: 8.6630
Epoch 8/10, Loss: 7.7821
Epoch 9/10, Loss: 6.8644
Epoch 10/10, Loss: 5.9014

Sample Summaries:
TEXT: The cat sat on the mat and looked at the dog.
SUMMARY: ai and.......
--------------------------------------------------
TEXT: The stock market crashed due to inflation concerns.
SUMMARY: ai and.......
--------------------------------------------------
TEXT: Artificial intelligence is transforming many industries.
SUMMARY: ai and.......
--------------------------------------------------


# Seq2Seq Generation: Summarization as an Example

In [None]:
# By default, all matrix computations utilize the CPU.
# However, you can instead utilize GPUs for the same since they are much faster at performing such operations
# To do so, I am setting the 'device' variable to 'cuda' (NVIDIA GPU) if a GPU is available. Otherwise, it is set to 'cpu'.
# Use '.to(device)' to shift tensors/models to a particular device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# The .from_pretrained() method is used to load a pre-trained model/tokenizer
# doc: https://huggingface.co/docs/transformers/v5.0.0rc1/en/main_classes/model#transformers.PreTrainedModel.from_pretrained
model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn').to(device)
tkr = BartTokenizer.from_pretrained('facebook/bart-large-cnn')

In [None]:
# Some sample article excerpts that we want our model to summarize
# Sources: https://en.wikipedia.org/wiki/Tiger, https://en.wikipedia.org/wiki/Lion, https://en.wikipedia.org/wiki/Snake
text = [
    "The tiger (Panthera tigris) is a large cat and a member of the genus Panthera native to Asia. It has a powerful, muscular body with a large head and paws, a long tail and orange fur with black, mostly vertical stripes. It is traditionally classified into nine recent subspecies, though some recognise only two subspecies, mainland Asian tigers and the island tigers of the Sunda Islands. Throughout the tiger's range, it inhabits mainly forests, from coniferous and temperate broadleaf and mixed forests in the Russian Far East and Northeast China to tropical and subtropical moist broadleaf forests on the Indian subcontinent and Southeast Asia. The tiger is an apex predator and preys mainly on ungulates, which it takes by ambush. It lives a mostly solitary life and occupies home ranges, defending these from individuals of the same sex. The range of a male tiger overlaps with that of multiple females with whom he mates. Females give birth to usually two or three cubs that stay with their mother for about two years. When becoming independent, they leave their mother's home range and establish their own. ",
    "The lion (Panthera leo) is a large cat of the genus Panthera, currently ranging only in Sub-Saharan Africa and India. It has a muscular, broad-chested body; a short, rounded head; round ears; and a dark, hairy tuft at the tip of its tail. It is sexually dimorphic; adult male lions are larger than females and have a more prominent mane that usually obscures the ears and extends to the shoulders. The lion inhabits grasslands, savannahs, and shrublands. It is an apex and keystone predator, preying mostly on medium-sized and large ungulates. It is usually more diurnal than other wild cats, but when persecuted, it adapts to being active at night and at twilight. It is a social species, forming groups called prides. A lion pride consists of related females and cubs, and a few or one adult male who is unrelated to the females. Groups of female lions usually hunt together. Adult males often compete to keep or gain that membership in the pride. ",
    "Snakes are elongated limbless reptiles of the suborder Serpentes (/sɜːrˈpɛntiːz/).[2] Cladistically squamates, snakes are ectothermic, amniote vertebrates covered in overlapping scales much like other members of the group. Many species of snakes have skulls with several more joints than their lizard ancestors and relatives, enabling them to swallow prey much larger than their heads (cranial kinesis). To accommodate their narrow bodies, snakes' paired organs (such as kidneys) appear one in front of the other instead of side by side, and most only have one functional lung. Some species retain a pelvic girdle with a pair of vestigial claws on either side of the cloaca. Lizards have independently evolved elongate bodies without limbs or with greatly reduced limbs at least twenty-five times via convergent evolution, leading to many lineages of legless lizards.[3] These resemble snakes, but several common groups of legless lizards have eyelids and external ears, which snakes lack, although this rule is not universal (see Amphisbaenia, Dibamidae, and Pygopodidae). Living snakes are found on every continent except Antarctica, and on most smaller land masses; exceptions include some large islands, such as Ireland, Iceland, Greenland, and the islands of New Zealand, as well as many small islands of the Atlantic and central Pacific oceans.[4] Additionally, sea snakes are widespread throughout the Indian and Pacific oceans. Around thirty families are currently recognized, comprising about 520 genera and about more than 4,170 species.[5] They range in size from the tiny, 10.4 cm-long (4.1 in) Barbados threadsnake[6] to the reticulated python of 6.95 meters (22.8 ft) in length.[7] The fossil species Titanoboa cerrejonensis was 12.8 meters (42 ft) long.[8] Snakes are thought to have evolved from either burrowing or aquatic lizards, perhaps during the Jurassic period, with the earliest known fossils dating to between 143 and 167 Ma ago.[9][10] The diversity of modern snakes appeared during the Paleocene epoch (c. 66 to 56 Ma ago, after the Cretaceous–Paleogene extinction event). The oldest preserved descriptions of snakes can be found in the Brooklyn Papyrus. "
]

In [None]:
# First, we need to tokenize our text so that the model can actually understand it
# Tokenizer doc: https://huggingface.co/docs/transformers/en/fast_tokenizers
# Padding and truncation guide: https://huggingface.co/docs/transformers/en/pad_truncation
tkr_out = tkr(
    text,
    max_length=256,
    padding="max_length",
    truncation=True,
    return_tensors="pt"
)

#The output of tokenization is a dictionary containing:
# 1. Input IDs (text converted into integer tokens)
# 2. Attention mask that differentiates between padding & non-padding tokens (1 = non-padding token, 0 = padding token)
print(tkr_out)

{'input_ids': tensor([[    0,   133, 23921,    36,   510,   927,  1843,   102,   326,  1023,
          4663,    43,    16,    10,   739,  4758,     8,    10,   919,     9,
             5, 44878, 15148,   102,  3763,     7,  1817,     4,    85,    34,
            10,  2247,     6, 26163,   809,    19,    10,   739,   471,     8,
         40844,     6,    10,   251,  7886,     8,  8978, 15503,    19,   909,
             6,  2260, 12194, 26224,     4,    85,    16, 10341,  8967,    88,
          1117,   485,  2849, 42826,     6,   600,   103, 11865,   129,    80,
          2849, 42826,     6, 11280,  3102, 36054,     8,     5,  2946, 36054,
             9,     5, 12282,   102,  8594,     4, 13231,     5, 23921,    18,
          1186,     6,    24, 42226,  2629,  4412, 14275,     6,    31,  2764,
         14087,  1827,     8, 18586,   877,  4007, 24999,     8,  4281, 14275,
            11,     5,  1083,  4256,   953,     8,  9564,   436,     7, 10602,
             8, 30757,  6884,  3569, 3

In [None]:
# Make sure that the model and all of its inputs are on the same device.
# This is necessary since PyTorch cannot perform operations between matrices that are different devices
input_ids = tkr_out["input_ids"].to(device)
attn_mask = tkr_out["attention_mask"].to(device)

In [None]:
# Once we have processed the raw input, we can query the model to generate a summary by using the '.generate()' method
# Generation doc: https://huggingface.co/docs/transformers/v5.0.0rc1/en/main_classes/text_generation#transformers.GenerationMixin.generate
output = model.generate(
    input_ids=input_ids,
    attention_mask=attn_mask,
    max_new_tokens=100
)

In [None]:
# The .generate() method returns a tokenized output which needs to be decoded back to text
print(output)

tensor([[    2,     0,   133, 23921,    36,   510,   927,  1843,   102,   326,
          1023,  4663,    43,    16,    10,   739,  4758,     8,    10,   919,
             9,     5, 44878, 15148,   102,  3763,     7,  1817,     4,    85,
            34,    10,  2247,     6, 26163,   809,    19,    10,   739,   471,
             8, 40844,     6,    10,   251,  7886,     8,  8978, 15503,    19,
           909,     6,  2260, 12194, 26224,     4,    85,    16, 10341,  8967,
            88,  1117,   485,  2849, 42826,     6,   600,   103, 11865,   129,
            80,     4,     2,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1],
        [    2,     0,   133, 15587,    36,   510,   927,  1843,   102,  2084,
           139,    43,    16,    10,   739,  4758,     9,     5, 44878, 15148,
           102,     6,   855,  6272,   129,    11,  4052,    12, 27692,  1327,
             8,   666,     4,    85,    34,    10, 

In [None]:
# We can use the tokenizer for this decoding process
# Important: When decoding a batched output, you must use .batch_decode(). Otherwise, use .decode()
tkr.batch_decode(output, skip_special_tokens=True)

['The tiger (Panthera tigris) is a large cat and a member of the genus Panthera native to Asia. It has a powerful, muscular body with a large head and paws, a long tail and orange fur with black, mostly vertical stripes. It is traditionally classified into nine recent subspecies, though some recognise only two.',
 'The lion (Panthera leo) is a large cat of the genus Panthera, currently ranging only in Sub-Saharan Africa and India. It has a muscular, broad-chested body; a short, rounded head; round ears; and a dark, hairy tuft at the tip of its tail. It is sexually dimorphic; adult male lions are larger than females and have a more prominent mane.',
 "Snakes are ectothermic, amniote vertebrates covered in overlapping scales. Many species of snakes have skulls with several more joints than their lizard ancestors and relatives. To accommodate their narrow bodies, snakes' paired organs (such as kidneys) appear one in front of the other instead of side by side."]

# Generation/Decoding Strategies
Decoding strategies doc: https://huggingface.co/docs/transformers/en/generation_strategies

Sampling strategies doc: https://huggingface.co/docs/transformers/v5.0.0rc1/en/main_classes/text_generation#transformers.GenerationConfig
- Greedy Decoding
- Beam Search
- Top-K Sampling
- Nucleus Sampling (AKA: Top-P Sampling)


In [None]:
# Greedy Decoding (Default) - At each step, choose the token with the highest probability as the generated token
greedy_output = model.generate(
    input_ids=input_ids,
    attention_mask=attn_mask,
    max_new_tokens=100
)
tkr.batch_decode(greedy_output, skip_special_tokens=True)

['The tiger (Panthera tigris) is a large cat and a member of the genus Panthera native to Asia. It has a powerful, muscular body with a large head and paws, a long tail and orange fur with black, mostly vertical stripes. It is traditionally classified into nine recent subspecies, though some recognise only two.',
 'The lion (Panthera leo) is a large cat of the genus Panthera, currently ranging only in Sub-Saharan Africa and India. It has a muscular, broad-chested body; a short, rounded head; round ears; and a dark, hairy tuft at the tip of its tail. It is sexually dimorphic; adult male lions are larger than females and have a more prominent mane.',
 "Snakes are ectothermic, amniote vertebrates covered in overlapping scales. Many species of snakes have skulls with several more joints than their lizard ancestors and relatives. To accommodate their narrow bodies, snakes' paired organs (such as kidneys) appear one in front of the other instead of side by side."]

In [None]:
# Beam search - Several sequences (beams) are generated and the sequence with the maximum overall probability is selected
beam_output = model.generate(
    input_ids=input_ids,
    attention_mask=attn_mask,
    max_new_tokens=100,
    num_beams=5 #By default, this is set to 1, making it the same as greedy decoding
)
tkr.batch_decode(beam_output, skip_special_tokens=True)

['The tiger (Panthera tigris) is a large cat and a member of the genus Panthera native to Asia. It has a powerful, muscular body with a large head and paws, a long tail and orange fur with black, mostly vertical stripes. It is traditionally classified into nine recent subspecies, though some recognise only two subspecies.',
 'The lion (Panthera leo) is a large cat of the genus Panthera. It has a muscular, broad-chested body; a short, rounded head; round ears; and a dark, hairy tuft at the tip of its tail. Adult male lions are larger than females and have a more prominent mane that usually obscures the ears and extends to the shoulders.',
 "Snakes are ectothermic, amniote vertebrates covered in overlapping scales. Many species of snakes have skulls with several more joints than their lizard ancestors and relatives. To accommodate their narrow bodies, snakes' paired organs (such as kidneys) appear one in front of the other instead of side by side."]

In [None]:
# Top-K Sampling - K highest probability tokens are considered at each step
topk_output = model.generate(
    input_ids=input_ids,
    attention_mask=attn_mask,
    max_new_tokens=100,
    do_sample=True, #Whenever a sampling technique is used, do_sample must be set to True
    top_k=50 # Number of highest probability tokens to consider
)
tkr.batch_decode(topk_output, skip_special_tokens=True)

['The tiger (Panthera tigris) is a large cat and a member of the genus Panthera native to Asia. It has a powerful, muscular body with a large head and paws, a long tail and orange fur with black, mostly vertical stripes. It is traditionally classified into nine recent subspecies, though some recognise only two subspecies.',
 'The lion (Panthera leo) is a large cat of the genus Panthera. It has a muscular, broad-chested body; a short, rounded head; round ears; and a dark, hairy tuft at the tip of its tail. Adult male lions are larger than females and have a more prominent mane that usually obscures the ears and extends to the shoulders.',
 "Snakes are ectothermic, amniote vertebrates covered in overlapping scales. Many species of snakes have skulls with several more joints than their lizard ancestors and relatives, enabling them to swallow prey much larger than their heads. To accommodate their narrow bodies, snakes' paired organs appear one in front of the other instead of side by side

In [None]:
# Nucleus Sampling (AKA: Top-p Sampling) - The smallest set of tokens whose probabilities add up to p (or more) is considered
nucleus_output = model.generate(
    input_ids=input_ids,
    attention_mask=attn_mask,
    max_new_tokens=100,
    do_sample=True,
    top_p=0.8
)
tkr.batch_decode(nucleus_output, skip_special_tokens=True)

['The tiger (Panthera tigris) is a large cat and a member of the genus Panthera native to Asia. It has a powerful, muscular body with a large head and paws, a long tail and orange fur with black, mostly vertical stripes. It is traditionally classified into nine recent subspecies, though some recognise only two subspecies.',
 'The lion (Panthera leo) is a large cat of the genus Panthera, currently ranging only in Sub-Saharan Africa and India. It has a muscular, broad-chested body; a short, rounded head; round ears; and a dark, hairy tuft at the tip of its tail. It is sexually dimorphic; adult male lions are larger than females and have a more prominent mane.',
 "Snakes are ectothermic, amniote vertebrates covered in overlapping scales. Many species of snakes have skulls with several more joints than their lizard ancestors and relatives. To accommodate their narrow bodies, snakes' paired organs (such as kidneys) appear one in front of the other instead of side by side."]

# Bonus 1: Seq2Seq with Attention

In [None]:
# Define a function to tokenize a given list of sentences
def encode_batch(sentences):
    tokens = tokenizer(
        sentences,
        padding='max_length',
        truncation=True,
        max_length=MAX_LEN,
        return_tensors='pt'
    )
    return tokens['input_ids'].to(device)

src = encode_batch(texts)
trg = encode_batch(summaries)

In [None]:
# Defining an Encoder
# nn.Module doc: https://docs.pytorch.org/docs/stable/generated/torch.nn.Module.html
class Encoder(nn.Module):
    def __init__(self, vocab_size, emb_dim, hid_dim, num_layers):
        super().__init__()

        # Embedding layer - https://docs.pytorch.org/docs/stable/generated/torch.nn.Embedding.html
        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=PAD_IDX)

        # Define LSTM - https://docs.pytorch.org/docs/stable/generated/torch.nn.LSTM.html
        self.lstm = nn.LSTM(emb_dim, hid_dim, num_layers, batch_first=True)

    def forward(self, src):
        # Get embeddings for input tokens
        embedded = self.embedding(src)

        #Pass embeddings through LSTM to get hidden and cell states
        outputs, (hidden, cell) = self.lstm(embedded)
        return outputs, hidden, cell

# Defining a Decoder
class Decoder(nn.Module):
    def __init__(self, vocab_size, emb_dim, hid_dim, num_layers):
        super().__init__()

        # Similar to encoder, except there is a language modeling head (linear projection layer) at the end of it
        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=PAD_IDX)
        self.lstm = nn.LSTM(emb_dim, hid_dim, num_layers, batch_first=True)

        # Attention layer
        # Doc: (https://docs.pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.htmlhttps://docs.pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html)
        self.attention = nn.MultiheadAttention(embed_dim=hid_dim, num_heads=4, batch_first=True)

        # Language model head - linear layer that projects the hid_dim dimension to vocab_size
        self.fc_out = nn.Linear(2*hid_dim, vocab_size)

    def forward(self, input, hidden, cell, encoder_outputs):
        # One forward step of the decoder (similar to encoder)
        input = input.unsqueeze(1)  # [batch_size, 1]
        embedded = self.embedding(input)
        lstm_out, (hidden, cell) = self.lstm(embedded, (hidden, cell))

        # Apply attention (query=lstm_out, key=value=encoder_outputs)
        # Query -> What you're currently trying to understand (current decoder state)
        # Key -> Every encoder state to match the current decoder state to
        # Value -> Every encoder state using which a weighted sum is computed
        attn_out, attn_weights = self.attention(
            query=lstm_out, key=encoder_outputs, value=encoder_outputs
        )

        # Concatenate LSTM output and attention output
        combined = torch.cat((lstm_out, attn_out), dim=-1)

        # Project the LSTM output into logits (from which individual token probabilities can be computed)
        logits = self.fc_out(combined.squeeze(1))
        return logits, hidden, cell, attn_weights

# Combining encoder and decoder into a Seq2Seq Model
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, src, trg):
        batch_size, trg_len = trg.shape

        # Initializing a tensor which we will use to store the output vectors computed at each decoding step
        outputs = torch.zeros(batch_size, trg_len, VOCAB_SIZE).to(device)

        # We get the hidden and cell states from the encoder
        encoder_outputs, hidden, cell = self.encoder(src)

        # Now, we want the decoder to start generating the summary
        # So, we pass the first target token (<bos> token)
        input = trg[:, 0]

        for t in range(1, trg_len):
            # For the current input token, run it through the decoder for one step and get the updated hidden and cell states
            output, hidden, cell, attn_weights = self.decoder(input, hidden, cell, encoder_outputs)

            # Save the output (logits) for later use
            outputs[:, t] = output

            # The input for the next step of decoding is set to the CURRENT highest probability token
            # Doc: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.argmax.html
            input = output.argmax(1)

        return outputs

In [None]:
# Initializing our Seq2Seq model
enc = Encoder(VOCAB_SIZE, EMB_DIM, HID_DIM, NUM_LAYERS)
dec = Decoder(VOCAB_SIZE, EMB_DIM, HID_DIM, NUM_LAYERS)
model = Seq2Seq(enc, dec).to(device)

# Defining optimizer and loss function for training
# Adam optimizer doc: https://docs.pytorch.org/docs/stable/generated/torch.optim.Adam.html
# CE Loss doc: https://docs.pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
optimizer = optim.Adam(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

# Training loop (tiny demo)
EPOCHS = 10
for epoch in range(EPOCHS):

    # This zeroes-out the gradients computed at the previous step
    # Doc: https://docs.pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html
    optimizer.zero_grad()

    # Getting an output from the model
    output = model(src, trg) # Tensor of size [batch_size, trg_len, vocab_size]

    # Now, we compute the CE-loss between the predicted output logits and the ground-truth target tokens
    output = output.reshape(-1, VOCAB_SIZE)
    trg_y = trg.reshape(-1)
    loss = criterion(output, trg_y)

    # Computes gradients for one step of backward propagation
    # Doc: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.backward.html
    loss.backward()

    # Updates model weights based on computed gradients
    # Doc: https://docs.pytorch.org/docs/stable/generated/torch.optim.Optimizer.step.html
    optimizer.step()

    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {loss.item():.4f}")

# Inference
def summarize(sentence, max_len=MAX_LEN):

    # Sets the model to evaluation model (no weight updates will happen)
    model.eval()
    with torch.no_grad():

        # Tokenize the given sentence and pass it through the model encoder
        src = encode_batch([sentence])
        encoder_outputs, hidden, cell = model.encoder(src)
        input = torch.tensor([SOS_IDX]).to(device) # Input passed to decoder
        summary = [input.item()] # Tokens of summary

        for _ in range(max_len):

            # Get the decoder's output for the current input
            output, hidden, cell, attn_weights = model.decoder(input, hidden, cell, encoder_outputs)

            # The predicted token is that which has the highest probability
            pred = output.argmax(1)

            # If the model predicts an End-of-Sentence or Padding Token, we stop generating further
            token_id = pred.item()
            if token_id == EOS_IDX or token_id == PAD_IDX:
                break

            # Otherwise, we append the generated token to our summary
            summary.append(token_id)

            # Input for next step is the token predicted during the current step (autoregressive)
            input = pred

        # Use the tokenizer to decoder ("de-tokenize") the summary
        return tokenizer.decode(summary, skip_special_tokens=True)

# Generating summaries
print("\nSample Summaries:")
for text in texts:
    print(f"TEXT: {text}")
    print(f"SUMMARY: {summarize(text)}")
    print("-" * 50)


Epoch 1/10, Loss: 10.3361
Epoch 2/10, Loss: 10.2277
Epoch 3/10, Loss: 9.9682
Epoch 4/10, Loss: 9.2848
Epoch 5/10, Loss: 7.4120
Epoch 6/10, Loss: 5.2189
Epoch 7/10, Loss: 4.4365
Epoch 8/10, Loss: 3.7568
Epoch 9/10, Loss: 3.7476
Epoch 10/10, Loss: 3.6522

Sample Summaries:
TEXT: The cat sat on the mat and looked at the dog.
SUMMARY: cat and and mat mat mat mat mat mat mat mat mat mat mat mat mat mat mat mat mat mat mat mat mat mat mat mat mat mat mat
--------------------------------------------------
TEXT: The stock market crashed due to inflation concerns.
SUMMARY: market crashed from from from from from from from from from from from from from from from from from from from from from from from from from from from from
--------------------------------------------------
TEXT: Artificial intelligence is transforming many industries.
SUMMARY: market changing industries industries industries industries industries industries industries industries industries industries industries industries ind

# Bonus 2: Writing/Reading Text To/From A File  

In [None]:
# Here, I am writing some articles to a file called 'articles.txt'
# Doc: https://docs.python.org/3/library/functions.html#open
with open("articles.txt", "w") as f:
  articles = [
      "Owls are birds from the order Strigiformes[1] (/ˈstrɪdʒəfɔːrmiːz/), which includes over 200 species of mostly solitary and nocturnal birds of prey typified by an upright stance, a large, broad head, binocular vision, binaural hearing, sharp talons, and feathers adapted for silent flight. Exceptions include the diurnal northern hawk-owl and the gregarious burrowing owl. Owls are divided into two families: the true (or typical) owl family, Strigidae, and the barn owl and bay owl family, Tytonidae.[2] Owls hunt mostly small mammals, insects, and other birds, although a few species specialize in hunting fish. They are found in all regions of the Earth except the polar ice caps and some remote islands. A group of owls is called a 'parliament'.[3]",
      "A frog is any member of a diverse and largely semiaquatic group of short-bodied, tailless amphibian vertebrates composing the order Anura[1] (coming from the Ancient Greek ἀνούρα, literally 'without tail'). Frog species with rough skin texture due to wart-like parotoid glands tend to be called toads, but the distinction between frogs and toads is informal and purely cosmetic, not from taxonomy or evolutionary history. Frogs are widely distributed, ranging from the tropics to subarctic regions, but the greatest concentration of species diversity is in tropical rainforest and associated wetlands. They account for around 88% of extant amphibian species, and are one of the five most diverse vertebrate orders. The oldest fossil \"proto-frog\" Triadobatrachus is known from the Early Triassic of Madagascar (250 million years ago), but molecular clock dating suggests their divergence from other amphibians may extend further back to the Permian, 265 million years ago. Adult frogs have a stout body, protruding eyes, anteriorly-attached tongue, limbs folded underneath, and no tail (the \"tail\" of tailed frogs is an extension of the male cloaca). Frogs have glandular skin, with secretions ranging from distasteful to toxic. Their skin varies in colour from well-camouflaged dappled brown, grey and green, to vivid patterns of bright red or yellow and black to show toxicity and ward off predators. Adult frogs live in both fresh water and on dry land; some species are adapted for living underground or in trees. As their skin is semi-permeable, making them susceptible to dehydration, they either live in moist niches or have special adaptations to deal with drier habitats. Frogs produce a wide range of vocalisations, particularly in their breeding season, and exhibit many different kinds of complex behaviors to attract mates, to fend off predators and to generally survive. Being oviparous anamniotes, frogs typically spawn their eggs in bodies of water. The eggs then hatch into fully aquatic larvae called tadpoles, which have tails and internal gills. A few species lay eggs on land or bypass the tadpole stage altogether. Tadpoles have highly specialised rasping mouth parts suitable for herbivorous, omnivorous or planktivorous diets. The life cycle is completed when they metamorphose into semiaquatic adults capable of terrestrial locomotion and hybrid respiration using both lungs aided by buccal pumping and gas exchange across the skin, and the larval tail regresses into an internal urostyle. Adult frogs generally have a carnivorous diet consisting of small invertebrates, especially insects, but omnivorous species exist and a few feed on plant matter. Frogs generally seize and ingest food by protruding their adhesive tongue and then swallow the item whole, often using their eyeballs and extraocular muscles to help pushing down the throat, and their digestive system is extremely efficient at converting what they eat into body mass. Being low-level consumers, both tadpoles and adult frogs are an important food source for other predators and a vital part of the food web dynamics of many of the world's ecosystems.",
      "The meerkat (Suricata suricatta) or suricate is a small mongoose found in southern Africa. It is characterised by a broad head, large eyes, a pointed snout, long legs, a thin tapering tail, and a brindled coat pattern. The head-and-body length is around 24–35 cm (9.4–13.8 in), and the weight is typically between 0.62 and 0.97 kg (1.4 and 2.1 lb). The coat is light grey to yellowish-brown with alternate, poorly defined light and dark bands on the back. Meerkats have foreclaws adapted for digging and have the ability to thermoregulate to survive in their harsh, dry habitat. Three subspecies are recognised. Meerkats are highly social, and form packs of two to 30 individuals each that occupy home ranges around 5 km2 (1.9 sq mi) in area. There is a social hierarchy—generally dominant individuals in a pack breed and produce offspring, and the nonbreeding, subordinate members provide altruistic care to the pups. Breeding occurs around the year, with peaks during heavy rainfall; after a gestation of 60 to 70 days, a litter of three to seven pups is born. They live in rock crevices in stony, often calcareous areas, and in large burrow systems in plains. The burrow systems, typically 5 m (16 ft) in diameter with around 15 openings, are large underground networks consisting of two to three levels of tunnels. These tunnels are around 7.5 cm (3.0 in) high at the top and wider below, and extend up to 1.5 m (5 ft) into the ground. Burrows have moderated internal temperatures and provide a comfortable microclimate that protects meerkats in harsh weather and at extreme temperatures. "
  ]

  for a in articles:
    # Write article to the file followed by a newline (\n) so that each line in the file corresponds to a unique article
    f.write(a)
    f.write("\n")

In [None]:
with open("articles.txt", "r") as f:
  file_txt = f.readlines()

file_txt[0]

"Owls are birds from the order Strigiformes[1] (/ˈstrɪdʒəfɔːrmiːz/), which includes over 200 species of mostly solitary and nocturnal birds of prey typified by an upright stance, a large, broad head, binocular vision, binaural hearing, sharp talons, and feathers adapted for silent flight. Exceptions include the diurnal northern hawk-owl and the gregarious burrowing owl. Owls are divided into two families: the true (or typical) owl family, Strigidae, and the barn owl and bay owl family, Tytonidae.[2] Owls hunt mostly small mammals, insects, and other birds, although a few species specialize in hunting fish. They are found in all regions of the Earth except the polar ice caps and some remote islands. A group of owls is called a 'parliament'.[3]\n"