# Build neural network model from scratch

Predict summaries using LSTM build from scratch using PyTorch

In [1]:
# %% Import necessary libraries
import pandas as pd
import numpy as np
import time
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


## Load the dataset

In [2]:
dataset_dir = 'datasets'
df = pd.read_csv(f'{dataset_dir}/podcast_with_summary.csv')

In [3]:
# the lengths of input and output must be the same for out model
max_length = 1024

# Use Hugging Face tokenizer
model_name = "bert-base-uncased"  # Replace with a model suitable for tokenization
tokenizer = AutoTokenizer.from_pretrained(model_name)

def tokenize_text(texts, max_length):
    return tokenizer(
        texts,
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt"
    )

# Tokenize inputs and summaries
input_tokens = tokenize_text(df['text_short'].tolist(), max_length)
summary_tokens = tokenize_text(df['summary'].tolist(), max_length)

X = input_tokens['input_ids']
Y = summary_tokens['input_ids']

In [4]:
from sklearn.model_selection import train_test_split

X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.1, random_state=42)

In [5]:
#Create PyTorch Dataset to load data
class TextSummaryDataset(Dataset):
    def __init__(self, inputs, targets):
        self.inputs = inputs
        self.targets = targets

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx], self.targets[idx]

train_dataset = TextSummaryDataset(X_train, Y_train)
test_dataset = TextSummaryDataset(X_test, Y_test)

# I experimented with different batch sizes and found 32 to be the best
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)

In [6]:
X = input_tokens['input_ids']
Y = summary_tokens['input_ids']

# Verify dimensions
print(f"Input shape: {X.shape}, Target shape: {Y.shape}")

Input shape: torch.Size([319, 1024]), Target shape: torch.Size([319, 1024])


In [7]:
print(tokenizer.decode(X[0]))
print(tokenizer.decode(Y[0]))

[CLS] as part of mit course 6s099, artificial general intelligence, i ' ve gotten the chance to sit down with max tegmark. he is a professor here at mit. he ' s a physicist, spent a large part of his career studying the mysteries of our cosmological universe. but he ' s also studied and delved into the beneficial possibilities and the existential risks of artificial intelligence. amongst many other things, he is the cofounder of the future of life institute, author of two books, both of which i highly recommend. first, our mathematical universe. second is life 3. 0. he ' s truly an out of the box thinker and a fun personality, so i really enjoy talking to him. if you ' d like to see more of these videos in the future, please subscribe and also click the little bell icon to make sure you don ' t miss any videos. also, twitter, linkedin, agi. mit. edu if you wanna watch other lectures or conversations like this one. better yet, go read max ' s book, life 3. 0. chapter seven on goals is m

## Build the LSTM

In [8]:

class LSTMSummarizer(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_units):
        super(LSTMSummarizer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=tokenizer.pad_token_id)
        self.lstm = nn.LSTM(embedding_dim, hidden_units, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_units * 2, vocab_size)

    def forward(self, x):
        embedded = self.embedding(x)
        lstm_out, _ = self.lstm(embedded)
        out = self.fc(lstm_out)  # Predict for each time step
        return out

embedding_dim = 128
hidden_units = 256
vocab_size = tokenizer.vocab_size
model = LSTMSummarizer(vocab_size, embedding_dim, hidden_units).to('cuda' if torch.cuda.is_available() else 'cpu')

In [9]:
from torch.optim.lr_scheduler import StepLR
# I was trying to use mixed precision training but was unable to get it to work
from torch.cuda.amp import autocast, GradScaler

# Define optimizer and loss function
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Training loop
device = 'cuda' if torch.cuda.is_available() else 'cpu'

scheduler = StepLR(optimizer, step_size=5, gamma=0.1)

def train_model(model, train_loader, criterion, optimizer, scheduler, epochs=10):
    model.train()
    for epoch in range(epochs):
        epoch_loss = 0
        for inputs, targets in tqdm(train_loader):
            # Move inputs and targets to gpu
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()
            
            outputs = model(inputs)
            
            outputs = outputs.view(-1, outputs.size(-1))
            targets = targets.view(-1)
            
            loss = criterion(outputs, targets)
            loss.backward()

            optimizer.step()

            epoch_loss += loss.item()

        scheduler.step()

        print(f"Epoch {epoch + 1}/{epochs}, Loss: {epoch_loss / len(train_loader):.4f}")



In [10]:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
epochs = 100
train_model(model, train_loader, criterion, optimizer, scheduler, epochs=epochs)

100%|██████████| 9/9 [00:01<00:00,  4.73it/s]


Epoch 1/100, Loss: 10.3128


100%|██████████| 9/9 [00:01<00:00,  5.86it/s]


Epoch 2/100, Loss: 10.2691


100%|██████████| 9/9 [00:01<00:00,  6.17it/s]


Epoch 3/100, Loss: 10.2206


100%|██████████| 9/9 [00:01<00:00,  6.67it/s]


Epoch 4/100, Loss: 10.1567


100%|██████████| 9/9 [00:01<00:00,  6.66it/s]


Epoch 5/100, Loss: 10.0589


100%|██████████| 9/9 [00:01<00:00,  6.84it/s]


Epoch 6/100, Loss: 9.9679


100%|██████████| 9/9 [00:01<00:00,  6.50it/s]


Epoch 7/100, Loss: 9.9470


100%|██████████| 9/9 [00:01<00:00,  6.77it/s]


Epoch 8/100, Loss: 9.9246


100%|██████████| 9/9 [00:01<00:00,  6.60it/s]


Epoch 9/100, Loss: 9.9002


100%|██████████| 9/9 [00:01<00:00,  6.94it/s]


Epoch 10/100, Loss: 9.8742


100%|██████████| 9/9 [00:01<00:00,  6.61it/s]


Epoch 11/100, Loss: 9.8568


100%|██████████| 9/9 [00:01<00:00,  6.48it/s]


Epoch 12/100, Loss: 9.8542


100%|██████████| 9/9 [00:01<00:00,  6.58it/s]


Epoch 13/100, Loss: 9.8512


100%|██████████| 9/9 [00:01<00:00,  6.94it/s]


Epoch 14/100, Loss: 9.8485


100%|██████████| 9/9 [00:01<00:00,  6.73it/s]


Epoch 15/100, Loss: 9.8456


100%|██████████| 9/9 [00:01<00:00,  6.72it/s]


Epoch 16/100, Loss: 9.8439


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 17/100, Loss: 9.8435


100%|██████████| 9/9 [00:01<00:00,  6.99it/s]


Epoch 18/100, Loss: 9.8432


100%|██████████| 9/9 [00:01<00:00,  6.96it/s]


Epoch 19/100, Loss: 9.8432


100%|██████████| 9/9 [00:01<00:00,  6.77it/s]


Epoch 20/100, Loss: 9.8428


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 21/100, Loss: 9.8429


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 22/100, Loss: 9.8427


100%|██████████| 9/9 [00:01<00:00,  7.06it/s]


Epoch 23/100, Loss: 9.8425


100%|██████████| 9/9 [00:01<00:00,  7.02it/s]


Epoch 24/100, Loss: 9.8427


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 25/100, Loss: 9.8428


100%|██████████| 9/9 [00:01<00:00,  7.14it/s]


Epoch 26/100, Loss: 9.8426


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 27/100, Loss: 9.8426


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 28/100, Loss: 9.8427


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 29/100, Loss: 9.8425


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


Epoch 30/100, Loss: 9.8426


100%|██████████| 9/9 [00:01<00:00,  7.01it/s]


Epoch 31/100, Loss: 9.8427


100%|██████████| 9/9 [00:01<00:00,  7.07it/s]


Epoch 32/100, Loss: 9.8428


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


Epoch 33/100, Loss: 9.8424


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 34/100, Loss: 9.8426


100%|██████████| 9/9 [00:01<00:00,  6.95it/s]


Epoch 35/100, Loss: 9.8425


100%|██████████| 9/9 [00:01<00:00,  7.13it/s]


Epoch 36/100, Loss: 9.8426


100%|██████████| 9/9 [00:01<00:00,  7.17it/s]


Epoch 37/100, Loss: 9.8424


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 38/100, Loss: 9.8424


100%|██████████| 9/9 [00:01<00:00,  6.93it/s]


Epoch 39/100, Loss: 9.8426


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 40/100, Loss: 9.8424


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


Epoch 41/100, Loss: 9.8427


100%|██████████| 9/9 [00:01<00:00,  7.14it/s]


Epoch 42/100, Loss: 9.8424


100%|██████████| 9/9 [00:01<00:00,  6.93it/s]


Epoch 43/100, Loss: 9.8427


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


Epoch 44/100, Loss: 9.8425


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


Epoch 45/100, Loss: 9.8429


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


Epoch 46/100, Loss: 9.8427


100%|██████████| 9/9 [00:01<00:00,  6.92it/s]


Epoch 47/100, Loss: 9.8428


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


Epoch 48/100, Loss: 9.8426


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 49/100, Loss: 9.8426


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 50/100, Loss: 9.8428


100%|██████████| 9/9 [00:01<00:00,  6.95it/s]


Epoch 51/100, Loss: 9.8427


100%|██████████| 9/9 [00:01<00:00,  7.13it/s]


Epoch 52/100, Loss: 9.8428


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


Epoch 53/100, Loss: 9.8427


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


Epoch 54/100, Loss: 9.8426


100%|██████████| 9/9 [00:01<00:00,  6.92it/s]


Epoch 55/100, Loss: 9.8426


100%|██████████| 9/9 [00:01<00:00,  7.14it/s]


Epoch 56/100, Loss: 9.8426


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 57/100, Loss: 9.8426


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 58/100, Loss: 9.8426


100%|██████████| 9/9 [00:01<00:00,  6.93it/s]


Epoch 59/100, Loss: 9.8427


100%|██████████| 9/9 [00:01<00:00,  7.18it/s]


Epoch 60/100, Loss: 9.8426


100%|██████████| 9/9 [00:01<00:00,  7.14it/s]


Epoch 61/100, Loss: 9.8427


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 62/100, Loss: 9.8427


100%|██████████| 9/9 [00:01<00:00,  6.93it/s]


Epoch 63/100, Loss: 9.8425


100%|██████████| 9/9 [00:01<00:00,  7.14it/s]


Epoch 64/100, Loss: 9.8426


100%|██████████| 9/9 [00:01<00:00,  6.90it/s]


Epoch 65/100, Loss: 9.8428


100%|██████████| 9/9 [00:01<00:00,  6.78it/s]


Epoch 66/100, Loss: 9.8425


100%|██████████| 9/9 [00:01<00:00,  6.92it/s]


Epoch 67/100, Loss: 9.8427


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 68/100, Loss: 9.8426


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 69/100, Loss: 9.8424


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


Epoch 70/100, Loss: 9.8426


100%|██████████| 9/9 [00:01<00:00,  6.84it/s]


Epoch 71/100, Loss: 9.8426


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 72/100, Loss: 9.8425


100%|██████████| 9/9 [00:01<00:00,  7.14it/s]


Epoch 73/100, Loss: 9.8427


100%|██████████| 9/9 [00:01<00:00,  7.09it/s]


Epoch 74/100, Loss: 9.8426


100%|██████████| 9/9 [00:01<00:00,  7.12it/s]


Epoch 75/100, Loss: 9.8425


100%|██████████| 9/9 [00:01<00:00,  7.14it/s]


Epoch 76/100, Loss: 9.8425


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


Epoch 77/100, Loss: 9.8428


100%|██████████| 9/9 [00:01<00:00,  7.10it/s]


Epoch 78/100, Loss: 9.8425


100%|██████████| 9/9 [00:01<00:00,  7.08it/s]


Epoch 79/100, Loss: 9.8425


100%|██████████| 9/9 [00:01<00:00,  7.17it/s]


Epoch 80/100, Loss: 9.8424


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


Epoch 81/100, Loss: 9.8425


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 82/100, Loss: 9.8425


100%|██████████| 9/9 [00:01<00:00,  7.12it/s]


Epoch 83/100, Loss: 9.8426


100%|██████████| 9/9 [00:01<00:00,  7.18it/s]


Epoch 84/100, Loss: 9.8426


100%|██████████| 9/9 [00:01<00:00,  7.13it/s]


Epoch 85/100, Loss: 9.8428


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 86/100, Loss: 9.8427


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 87/100, Loss: 9.8426


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 88/100, Loss: 9.8428


100%|██████████| 9/9 [00:01<00:00,  7.14it/s]


Epoch 89/100, Loss: 9.8427


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 90/100, Loss: 9.8426


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 91/100, Loss: 9.8427


100%|██████████| 9/9 [00:01<00:00,  7.18it/s]


Epoch 92/100, Loss: 9.8425


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 93/100, Loss: 9.8425


100%|██████████| 9/9 [00:01<00:00,  6.96it/s]


Epoch 94/100, Loss: 9.8423


100%|██████████| 9/9 [00:01<00:00,  7.12it/s]


Epoch 95/100, Loss: 9.8426


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 96/100, Loss: 9.8427


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 97/100, Loss: 9.8424


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 98/100, Loss: 9.8423


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


Epoch 99/100, Loss: 9.8427


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]

Epoch 100/100, Loss: 9.8427





## Run inference

In [11]:
import os

# Ensure the directory exists
output_dir = "./results/pytorch"
os.makedirs(output_dir, exist_ok=True)

output_length = 200

def run_inference(text):
    model.eval()
    input_tokens = tokenize_text([text], max_length)
    input_ids = input_tokens['input_ids'].to(device)
    attention_mask = input_tokens['attention_mask'].to(device)

    generated_tokens = []
    current_input = input_ids

    with torch.no_grad():
        for _ in range(output_length):
            logits = model(current_input)  # Shape: (batch_size, seq_len, vocab_size)
            next_token_logits = logits[:, -1, :]  # Get logits for the last token

            # Using argmax causes the model to generate the same text repeatedly
            # next_token_id = torch.argmax(next_token_logits, dim=-1).item()            
            
            # Use multinomial sampling to generate a diverse set of outputs
            next_token_probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
            next_token_id = torch.multinomial(next_token_probs, num_samples=1).item()

            if next_token_id == tokenizer.pad_token_id:  # Stop at padding token
                break

            generated_tokens.append(next_token_id)

            # Prepare input for the next iteration
            next_token = torch.tensor([[next_token_id]], device=device)
            current_input = torch.cat((current_input, next_token), dim=1)

    return tokenizer.decode(generated_tokens, skip_special_tokens=True)

In [12]:
test_df = pd.read_csv(f'{dataset_dir}/podcast_with_summary_test.csv')

# print some of the inference results
for i in range(5):
    test_text = test_df['text_short'][i]
    reference_summary = test_df['summary'][i]
    predicted_summary = run_inference(test_text)
    print(f"Test Text: {test_text}")
    print(f"Reference Summary: {reference_summary}")
    print(f"Predicted Summary: {predicted_summary}")
    print("")

Test Text: The following is a conversation with Andrew Ng, one of the most impactful educators, researchers, innovators, and leaders in artificial intelligence and technology space in general. He cofounded Coursera and Google Brain, launched Deep Learning AI, Landing AI, and the AI Fund, and was the chief scientist at Baidu. As a Stanford professor and with Coursera and Deep Learning AI, he has helped educate and inspire millions of students, including me. This is the Artificial Intelligence Podcast. If you enjoy it, subscribe on YouTube, give it five stars on Apple Podcast, support it on Patreon, or simply connect with me on Twitter at Lex Friedman, spelled F R I D M A N. As usual, I'll do one or two minutes of ads now and never any ads in the middle that can break the flow of the conversation. I hope that works for you and doesn't hurt the listening experience. This show is presented by Cash App, the number one finance app in the App Store. When you get it, use code LEXPODCAST.
Refer

## Generate metrics

In [13]:
from SharedUtils import evaluate_and_save_metrics

# Evaluate the model
def evaluate_model(df, model, name):
    model.eval()
    reference_summaries = []
    predicted_summaries = []
    total_time = 0

    for _, row in df.iterrows():
        test_text = row['text_short']
        reference_summary = row['summary']

        start_time = time.time()
        predicted_summary = run_inference(test_text)
        end_time = time.time()
        elapsed_time = end_time - start_time
        total_time += elapsed_time

        reference_summaries.append(reference_summary)
        predicted_summaries.append(predicted_summary)

    model_name = "pytorch"
    filename = f"{epochs}_epochs"
    rouge_results, bleu_results = evaluate_and_save_metrics(
        model_name,
        name,
        filename,
        reference_summaries,
        predicted_summaries,
        total_time
    )

    print(rouge_results)
    print(bleu_results)

    results_df = pd.DataFrame({
        'summary': reference_summaries,
        'summary_tuned': predicted_summaries
    })
    results_df.to_csv(f"./results/{model_name}/{name}/{filename}_summaries.csv")

    print(f"Evaluation completed for {name}.")
    print(f"Total time (seconds): {total_time}")
    print(f"Total time (minutes): {total_time / 60}")

In [14]:
# Evaluate on test set
evaluate_model(test_df, model, "test_dataset")

{'rouge1': 0.0008847516330128638, 'rouge2': 0.0, 'rougeL': 0.0008846137339065631, 'rougeLsum': 0.0008834528771283641}
{'bleu': 0.0, 'precisions': [0.00070577856197618, 0.0, 0.0, 0.0], 'brevity_penalty': 1.0, 'length_ratio': 8.433779761904763, 'translation_length': 11335, 'reference_length': 1344}
Evaluation completed for test_dataset.
Total time (seconds): 138.99433064460754
Total time (minutes): 2.3165721774101256


In [15]:
# Evaluate on whole dataset
whole_df = pd.read_csv(f'{dataset_dir}/podcast_with_summary.csv')
evaluate_model(whole_df, model, "whole_dataset")

{'rouge1': 0.0010012046775948248, 'rouge2': 0.0, 'rougeL': 0.0009987140807244028, 'rougeLsum': 0.001001297924903576}
{'bleu': 0.0, 'precisions': [0.0004081053266616984, 0.0, 0.0, 0.0], 'brevity_penalty': 1.0, 'length_ratio': 8.35180794309425, 'translation_length': 56358, 'reference_length': 6748}
Evaluation completed for whole_dataset.
Total time (seconds): 695.0681042671204
Total time (minutes): 11.584468404452005
