In [1]:
# %% Import necessary libraries
import pandas as pd
import numpy as np
import time
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer
from tqdm import tqdm

# %% Read dataset
dataset_dir = 'datasets'
df = pd.read_csv(f'{dataset_dir}/podcast_with_summary.csv')


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# %% Preprocess text data
max_length = 1024
# the lengths of input and output must be the same for out model

# Use Hugging Face tokenizer
model_name = "bert-base-uncased"  # Replace with a model suitable for tokenization
tokenizer = AutoTokenizer.from_pretrained(model_name)

def tokenize_text(texts, max_length):
    return tokenizer(
        texts,
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt"
    )

# Tokenize inputs and summaries
input_tokens = tokenize_text(df['text_short'].tolist(), max_length)
summary_tokens = tokenize_text(df['summary'].tolist(), max_length)

X = input_tokens['input_ids']
Y = summary_tokens['input_ids']

# %% Train-test split
from sklearn.model_selection import train_test_split

X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.1, random_state=42)


In [3]:

# %% Create PyTorch Dataset
class TextSummaryDataset(Dataset):
    def __init__(self, inputs, targets):
        self.inputs = inputs
        self.targets = targets

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx], self.targets[idx]

train_dataset = TextSummaryDataset(X_train, Y_train)
test_dataset = TextSummaryDataset(X_test, Y_test)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)

In [4]:
X = input_tokens['input_ids']
Y = summary_tokens['input_ids']

# Verify dimensions
print(f"Input shape: {X.shape}, Target shape: {Y.shape}")

Input shape: torch.Size([319, 1024]), Target shape: torch.Size([319, 1024])


In [5]:
print(tokenizer.decode(X[0]))
print(tokenizer.decode(Y[0]))

[CLS] as part of mit course 6s099, artificial general intelligence, i ' ve gotten the chance to sit down with max tegmark. he is a professor here at mit. he ' s a physicist, spent a large part of his career studying the mysteries of our cosmological universe. but he ' s also studied and delved into the beneficial possibilities and the existential risks of artificial intelligence. amongst many other things, he is the cofounder of the future of life institute, author of two books, both of which i highly recommend. first, our mathematical universe. second is life 3. 0. he ' s truly an out of the box thinker and a fun personality, so i really enjoy talking to him. if you ' d like to see more of these videos in the future, please subscribe and also click the little bell icon to make sure you don ' t miss any videos. also, twitter, linkedin, agi. mit. edu if you wanna watch other lectures or conversations like this one. better yet, go read max ' s book, life 3. 0. chapter seven on goals is m

In [6]:

class LSTMSummarizer(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_units):
        super(LSTMSummarizer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=tokenizer.pad_token_id)
        self.lstm = nn.LSTM(embedding_dim, hidden_units, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_units * 2, vocab_size)

    def forward(self, x):
        embedded = self.embedding(x)
        lstm_out, _ = self.lstm(embedded)
        out = self.fc(lstm_out)  # Predict for each time step
        return out

embedding_dim = 128
hidden_units = 256
vocab_size = tokenizer.vocab_size
model = LSTMSummarizer(vocab_size, embedding_dim, hidden_units).to('cuda' if torch.cuda.is_available() else 'cpu')


In [7]:
from torch.optim.lr_scheduler import StepLR
from torch.cuda.amp import autocast, GradScaler

# %% Define optimizer and loss function
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# %% Training loop
device = 'cuda' if torch.cuda.is_available() else 'cpu'

scheduler = StepLR(optimizer, step_size=5, gamma=0.1)

def train_model(model, train_loader, criterion, optimizer, scheduler, epochs=10):
    model.train()
    for epoch in range(epochs):
        epoch_loss = 0
        for inputs, targets in tqdm(train_loader):
            # Move inputs and targets to device
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()
            
            outputs = model(inputs)  # Shape: (batch_size, seq_len, vocab_size)
            
            # Reshape outputs and targets for loss computation
            outputs = outputs.view(-1, outputs.size(-1))  # (batch_size * seq_len, vocab_size)
            targets = targets.view(-1)  # (batch_size * seq_len)
            
            # Compute loss
            loss = criterion(outputs, targets)
            loss.backward()

            optimizer.step()

            epoch_loss += loss.item()

        # Scheduler step
        scheduler.step()

        print(f"Epoch {epoch + 1}/{epochs}, Loss: {epoch_loss / len(train_loader):.4f}")



In [8]:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
train_model(model, train_loader, criterion, optimizer, scheduler, epochs=500)

100%|██████████| 9/9 [00:01<00:00,  5.41it/s]


Epoch 1/500, Loss: 10.3117


100%|██████████| 9/9 [00:01<00:00,  6.65it/s]


Epoch 2/500, Loss: 10.2650


100%|██████████| 9/9 [00:01<00:00,  7.02it/s]


Epoch 3/500, Loss: 10.2122


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 4/500, Loss: 10.1405


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


Epoch 5/500, Loss: 10.0263


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 6/500, Loss: 9.9193


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


Epoch 7/500, Loss: 9.8940


100%|██████████| 9/9 [00:01<00:00,  7.02it/s]


Epoch 8/500, Loss: 9.8660


100%|██████████| 9/9 [00:01<00:00,  7.06it/s]


Epoch 9/500, Loss: 9.8353


100%|██████████| 9/9 [00:01<00:00,  7.02it/s]


Epoch 10/500, Loss: 9.8021


100%|██████████| 9/9 [00:01<00:00,  6.96it/s]


Epoch 11/500, Loss: 9.7802


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 12/500, Loss: 9.7762


100%|██████████| 9/9 [00:01<00:00,  7.06it/s]


Epoch 13/500, Loss: 9.7725


100%|██████████| 9/9 [00:01<00:00,  7.13it/s]


Epoch 14/500, Loss: 9.7690


100%|██████████| 9/9 [00:01<00:00,  7.02it/s]


Epoch 15/500, Loss: 9.7651


100%|██████████| 9/9 [00:01<00:00,  7.07it/s]


Epoch 16/500, Loss: 9.7630


100%|██████████| 9/9 [00:01<00:00,  7.14it/s]


Epoch 17/500, Loss: 9.7626


100%|██████████| 9/9 [00:01<00:00,  7.14it/s]


Epoch 18/500, Loss: 9.7624


100%|██████████| 9/9 [00:01<00:00,  6.94it/s]


Epoch 19/500, Loss: 9.7621


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 20/500, Loss: 9.7619


100%|██████████| 9/9 [00:01<00:00,  6.99it/s]


Epoch 21/500, Loss: 9.7617


100%|██████████| 9/9 [00:01<00:00,  7.11it/s]


Epoch 22/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.13it/s]


Epoch 23/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 24/500, Loss: 9.7611


100%|██████████| 9/9 [00:01<00:00,  7.12it/s]


Epoch 25/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.06it/s]


Epoch 26/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  6.97it/s]


Epoch 27/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 28/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.00it/s]


Epoch 29/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 30/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 31/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


Epoch 32/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 33/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 34/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.13it/s]


Epoch 35/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 36/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.07it/s]


Epoch 37/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 38/500, Loss: 9.7616


100%|██████████| 9/9 [00:01<00:00,  7.02it/s]


Epoch 39/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  6.95it/s]


Epoch 40/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 41/500, Loss: 9.7611


100%|██████████| 9/9 [00:01<00:00,  7.14it/s]


Epoch 42/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.18it/s]


Epoch 43/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.12it/s]


Epoch 44/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  6.91it/s]


Epoch 45/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.02it/s]


Epoch 46/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.10it/s]


Epoch 47/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  6.89it/s]


Epoch 48/500, Loss: 9.7611


100%|██████████| 9/9 [00:01<00:00,  7.09it/s]


Epoch 49/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


Epoch 50/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 51/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.01it/s]


Epoch 52/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 53/500, Loss: 9.7611


100%|██████████| 9/9 [00:01<00:00,  7.06it/s]


Epoch 54/500, Loss: 9.7610


100%|██████████| 9/9 [00:01<00:00,  7.06it/s]


Epoch 55/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 56/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.01it/s]


Epoch 57/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.19it/s]


Epoch 58/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.13it/s]


Epoch 59/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.17it/s]


Epoch 60/500, Loss: 9.7618


100%|██████████| 9/9 [00:01<00:00,  7.13it/s]


Epoch 61/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.17it/s]


Epoch 62/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.12it/s]


Epoch 63/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


Epoch 64/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 65/500, Loss: 9.7616


100%|██████████| 9/9 [00:01<00:00,  7.13it/s]


Epoch 66/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  6.94it/s]


Epoch 67/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 68/500, Loss: 9.7616


100%|██████████| 9/9 [00:01<00:00,  6.93it/s]


Epoch 69/500, Loss: 9.7610


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 70/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  6.96it/s]


Epoch 71/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.01it/s]


Epoch 72/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 73/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.01it/s]


Epoch 74/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.00it/s]


Epoch 75/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 76/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.00it/s]


Epoch 77/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 78/500, Loss: 9.7611


100%|██████████| 9/9 [00:01<00:00,  6.94it/s]


Epoch 79/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 80/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 81/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 82/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 83/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 84/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  6.95it/s]


Epoch 85/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.01it/s]


Epoch 86/500, Loss: 9.7611


100%|██████████| 9/9 [00:01<00:00,  6.95it/s]


Epoch 87/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  6.95it/s]


Epoch 88/500, Loss: 9.7616


100%|██████████| 9/9 [00:01<00:00,  7.07it/s]


Epoch 89/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  6.95it/s]


Epoch 90/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  6.96it/s]


Epoch 91/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 92/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 93/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


Epoch 94/500, Loss: 9.7616


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 95/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 96/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.08it/s]


Epoch 97/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  6.99it/s]


Epoch 98/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  6.94it/s]


Epoch 99/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.02it/s]


Epoch 100/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 101/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 102/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 103/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  6.96it/s]


Epoch 104/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.02it/s]


Epoch 105/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 106/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  6.94it/s]


Epoch 107/500, Loss: 9.7616


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 108/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.14it/s]


Epoch 109/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 110/500, Loss: 9.7617


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


Epoch 111/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.13it/s]


Epoch 112/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.14it/s]


Epoch 113/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.19it/s]


Epoch 114/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.14it/s]


Epoch 115/500, Loss: 9.7611


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 116/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 117/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 118/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.07it/s]


Epoch 119/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.06it/s]


Epoch 120/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.01it/s]


Epoch 121/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.02it/s]


Epoch 122/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.06it/s]


Epoch 123/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.01it/s]


Epoch 124/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.07it/s]


Epoch 125/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  6.92it/s]


Epoch 126/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.06it/s]


Epoch 127/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  6.99it/s]


Epoch 128/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  6.94it/s]


Epoch 129/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  6.95it/s]


Epoch 130/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 131/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  6.96it/s]


Epoch 132/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.02it/s]


Epoch 133/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 134/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.14it/s]


Epoch 135/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 136/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  6.97it/s]


Epoch 137/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 138/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.11it/s]


Epoch 139/500, Loss: 9.7611


100%|██████████| 9/9 [00:01<00:00,  6.95it/s]


Epoch 140/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.01it/s]


Epoch 141/500, Loss: 9.7617


100%|██████████| 9/9 [00:01<00:00,  7.09it/s]


Epoch 142/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.13it/s]


Epoch 143/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.17it/s]


Epoch 144/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


Epoch 145/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.06it/s]


Epoch 146/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.01it/s]


Epoch 147/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 148/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 149/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  6.95it/s]


Epoch 150/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 151/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.02it/s]


Epoch 152/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  6.98it/s]


Epoch 153/500, Loss: 9.7616


100%|██████████| 9/9 [00:01<00:00,  7.00it/s]


Epoch 154/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 155/500, Loss: 9.7616


100%|██████████| 9/9 [00:01<00:00,  7.01it/s]


Epoch 156/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  6.96it/s]


Epoch 157/500, Loss: 9.7617


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 158/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 159/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 160/500, Loss: 9.7611


100%|██████████| 9/9 [00:01<00:00,  6.90it/s]


Epoch 161/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.10it/s]


Epoch 162/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 163/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 164/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


Epoch 165/500, Loss: 9.7610


100%|██████████| 9/9 [00:01<00:00,  7.14it/s]


Epoch 166/500, Loss: 9.7616


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 167/500, Loss: 9.7616


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 168/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  6.94it/s]


Epoch 169/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 170/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 171/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 172/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.18it/s]


Epoch 173/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.14it/s]


Epoch 174/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 175/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 176/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.07it/s]


Epoch 177/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.02it/s]


Epoch 178/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.14it/s]


Epoch 179/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  6.99it/s]


Epoch 180/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  6.98it/s]


Epoch 181/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 182/500, Loss: 9.7618


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 183/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.14it/s]


Epoch 184/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 185/500, Loss: 9.7616


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 186/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 187/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.07it/s]


Epoch 188/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.11it/s]


Epoch 189/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.07it/s]


Epoch 190/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.02it/s]


Epoch 191/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 192/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 193/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.02it/s]


Epoch 194/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  6.98it/s]


Epoch 195/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.01it/s]


Epoch 196/500, Loss: 9.7617


100%|██████████| 9/9 [00:01<00:00,  7.02it/s]


Epoch 197/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  6.96it/s]


Epoch 198/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  6.91it/s]


Epoch 199/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 200/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 201/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 202/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 203/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 204/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.07it/s]


Epoch 205/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 206/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.13it/s]


Epoch 207/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.18it/s]


Epoch 208/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.02it/s]


Epoch 209/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 210/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 211/500, Loss: 9.7617


100%|██████████| 9/9 [00:01<00:00,  7.02it/s]


Epoch 212/500, Loss: 9.7616


100%|██████████| 9/9 [00:01<00:00,  7.06it/s]


Epoch 213/500, Loss: 9.7611


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


Epoch 214/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


Epoch 215/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.14it/s]


Epoch 216/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 217/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 218/500, Loss: 9.7616


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


Epoch 219/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.13it/s]


Epoch 220/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


Epoch 221/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.17it/s]


Epoch 222/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.13it/s]


Epoch 223/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 224/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 225/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 226/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  6.93it/s]


Epoch 227/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.19it/s]


Epoch 228/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.12it/s]


Epoch 229/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 230/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 231/500, Loss: 9.7611


100%|██████████| 9/9 [00:01<00:00,  7.14it/s]


Epoch 232/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 233/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 234/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 235/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.07it/s]


Epoch 236/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 237/500, Loss: 9.7616


100%|██████████| 9/9 [00:01<00:00,  7.13it/s]


Epoch 238/500, Loss: 9.7616


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


Epoch 239/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 240/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.07it/s]


Epoch 241/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.06it/s]


Epoch 242/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.01it/s]


Epoch 243/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.06it/s]


Epoch 244/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.01it/s]


Epoch 245/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.07it/s]


Epoch 246/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.13it/s]


Epoch 247/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 248/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.02it/s]


Epoch 249/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 250/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 251/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  6.94it/s]


Epoch 252/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 253/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 254/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 255/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 256/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 257/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  6.97it/s]


Epoch 258/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  6.99it/s]


Epoch 259/500, Loss: 9.7616


100%|██████████| 9/9 [00:01<00:00,  6.98it/s]


Epoch 260/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  6.99it/s]


Epoch 261/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 262/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 263/500, Loss: 9.7616


100%|██████████| 9/9 [00:01<00:00,  6.98it/s]


Epoch 264/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 265/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.00it/s]


Epoch 266/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  6.97it/s]


Epoch 267/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.09it/s]


Epoch 268/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.08it/s]


Epoch 269/500, Loss: 9.7611


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 270/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 271/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.01it/s]


Epoch 272/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 273/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 274/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.17it/s]


Epoch 275/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.14it/s]


Epoch 276/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.17it/s]


Epoch 277/500, Loss: 9.7611


100%|██████████| 9/9 [00:01<00:00,  7.13it/s]


Epoch 278/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.19it/s]


Epoch 279/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.13it/s]


Epoch 280/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.14it/s]


Epoch 281/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.07it/s]


Epoch 282/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.07it/s]


Epoch 283/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  6.97it/s]


Epoch 284/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 285/500, Loss: 9.7610


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 286/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 287/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 288/500, Loss: 9.7616


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


Epoch 289/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.18it/s]


Epoch 290/500, Loss: 9.7611


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 291/500, Loss: 9.7616


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 292/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 293/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


Epoch 294/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.14it/s]


Epoch 295/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.17it/s]


Epoch 296/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.09it/s]


Epoch 297/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.08it/s]


Epoch 298/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 299/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  6.97it/s]


Epoch 300/500, Loss: 9.7616


100%|██████████| 9/9 [00:01<00:00,  7.01it/s]


Epoch 301/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 302/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 303/500, Loss: 9.7616


100%|██████████| 9/9 [00:01<00:00,  7.17it/s]


Epoch 304/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.17it/s]


Epoch 305/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  6.89it/s]


Epoch 306/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 307/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  6.96it/s]


Epoch 308/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 309/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.11it/s]


Epoch 310/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  6.94it/s]


Epoch 311/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 312/500, Loss: 9.7609


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 313/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.14it/s]


Epoch 314/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  6.98it/s]


Epoch 315/500, Loss: 9.7618


100%|██████████| 9/9 [00:01<00:00,  7.01it/s]


Epoch 316/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 317/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  6.94it/s]


Epoch 318/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 319/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.12it/s]


Epoch 320/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.00it/s]


Epoch 321/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.00it/s]


Epoch 322/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.07it/s]


Epoch 323/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 324/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.00it/s]


Epoch 325/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 326/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 327/500, Loss: 9.7610


100%|██████████| 9/9 [00:01<00:00,  6.99it/s]


Epoch 328/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  6.97it/s]


Epoch 329/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.17it/s]


Epoch 330/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 331/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 332/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 333/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 334/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.02it/s]


Epoch 335/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  6.97it/s]


Epoch 336/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.12it/s]


Epoch 337/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.17it/s]


Epoch 338/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 339/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.17it/s]


Epoch 340/500, Loss: 9.7616


100%|██████████| 9/9 [00:01<00:00,  7.13it/s]


Epoch 341/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


Epoch 342/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 343/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.14it/s]


Epoch 344/500, Loss: 9.7616


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


Epoch 345/500, Loss: 9.7611


100%|██████████| 9/9 [00:01<00:00,  7.14it/s]


Epoch 346/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


Epoch 347/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.06it/s]


Epoch 348/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.01it/s]


Epoch 349/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 350/500, Loss: 9.7610


100%|██████████| 9/9 [00:01<00:00,  6.94it/s]


Epoch 351/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 352/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.06it/s]


Epoch 353/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.13it/s]


Epoch 354/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.18it/s]


Epoch 355/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 356/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.14it/s]


Epoch 357/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.14it/s]


Epoch 358/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.17it/s]


Epoch 359/500, Loss: 9.7611


100%|██████████| 9/9 [00:01<00:00,  7.13it/s]


Epoch 360/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


Epoch 361/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.13it/s]


Epoch 362/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.00it/s]


Epoch 363/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  6.98it/s]


Epoch 364/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 365/500, Loss: 9.7616


100%|██████████| 9/9 [00:01<00:00,  7.02it/s]


Epoch 366/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  6.94it/s]


Epoch 367/500, Loss: 9.7611


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


Epoch 368/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


Epoch 369/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 370/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 371/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.06it/s]


Epoch 372/500, Loss: 9.7611


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 373/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 374/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  6.93it/s]


Epoch 375/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.06it/s]


Epoch 376/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 377/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 378/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 379/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.00it/s]


Epoch 380/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  6.99it/s]


Epoch 381/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  6.91it/s]


Epoch 382/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 383/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  6.92it/s]


Epoch 384/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.07it/s]


Epoch 385/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.02it/s]


Epoch 386/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  6.94it/s]


Epoch 387/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.14it/s]


Epoch 388/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.07it/s]


Epoch 389/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 390/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.02it/s]


Epoch 391/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 392/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.06it/s]


Epoch 393/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.06it/s]


Epoch 394/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  6.99it/s]


Epoch 395/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.08it/s]


Epoch 396/500, Loss: 9.7616


100%|██████████| 9/9 [00:01<00:00,  7.13it/s]


Epoch 397/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.14it/s]


Epoch 398/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  6.96it/s]


Epoch 399/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.13it/s]


Epoch 400/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 401/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.10it/s]


Epoch 402/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.09it/s]


Epoch 403/500, Loss: 9.7611


100%|██████████| 9/9 [00:01<00:00,  6.97it/s]


Epoch 404/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 405/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.13it/s]


Epoch 406/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 407/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


Epoch 408/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.14it/s]


Epoch 409/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 410/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  6.94it/s]


Epoch 411/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.06it/s]


Epoch 412/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.11it/s]


Epoch 413/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 414/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.02it/s]


Epoch 415/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 416/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.18it/s]


Epoch 417/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.13it/s]


Epoch 418/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  6.94it/s]


Epoch 419/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 420/500, Loss: 9.7618


100%|██████████| 9/9 [00:01<00:00,  7.00it/s]


Epoch 421/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  6.96it/s]


Epoch 422/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 423/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 424/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 425/500, Loss: 9.7616


100%|██████████| 9/9 [00:01<00:00,  7.02it/s]


Epoch 426/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 427/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  6.96it/s]


Epoch 428/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 429/500, Loss: 9.7611


100%|██████████| 9/9 [00:01<00:00,  6.94it/s]


Epoch 430/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.11it/s]


Epoch 431/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 432/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 433/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.02it/s]


Epoch 434/500, Loss: 9.7611


100%|██████████| 9/9 [00:01<00:00,  6.96it/s]


Epoch 435/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  6.93it/s]


Epoch 436/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


Epoch 437/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 438/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


Epoch 439/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.07it/s]


Epoch 440/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.00it/s]


Epoch 441/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 442/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 443/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  6.99it/s]


Epoch 444/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.01it/s]


Epoch 445/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.13it/s]


Epoch 446/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.02it/s]


Epoch 447/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 448/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 449/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  6.94it/s]


Epoch 450/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 451/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 452/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


Epoch 453/500, Loss: 9.7611


100%|██████████| 9/9 [00:01<00:00,  7.19it/s]


Epoch 454/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 455/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.14it/s]


Epoch 456/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.14it/s]


Epoch 457/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 458/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 459/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.17it/s]


Epoch 460/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.13it/s]


Epoch 461/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 462/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 463/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.06it/s]


Epoch 464/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.11it/s]


Epoch 465/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.06it/s]


Epoch 466/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.02it/s]


Epoch 467/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.14it/s]


Epoch 468/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.00it/s]


Epoch 469/500, Loss: 9.7611


100%|██████████| 9/9 [00:01<00:00,  7.12it/s]


Epoch 470/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 471/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.07it/s]


Epoch 472/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.12it/s]


Epoch 473/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.04it/s]


Epoch 474/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 475/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.13it/s]


Epoch 476/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  6.97it/s]


Epoch 477/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.10it/s]


Epoch 478/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.03it/s]


Epoch 479/500, Loss: 9.7615


100%|██████████| 9/9 [00:01<00:00,  7.08it/s]


Epoch 480/500, Loss: 9.7611


100%|██████████| 9/9 [00:01<00:00,  7.05it/s]


Epoch 481/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  6.91it/s]


Epoch 482/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


Epoch 483/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.09it/s]


Epoch 484/500, Loss: 9.7616


100%|██████████| 9/9 [00:01<00:00,  6.99it/s]


Epoch 485/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.01it/s]


Epoch 486/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.08it/s]


Epoch 487/500, Loss: 9.7610


100%|██████████| 9/9 [00:01<00:00,  7.17it/s]


Epoch 488/500, Loss: 9.7612


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


Epoch 489/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.13it/s]


Epoch 490/500, Loss: 9.7616


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


Epoch 491/500, Loss: 9.7616


100%|██████████| 9/9 [00:01<00:00,  7.14it/s]


Epoch 492/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.15it/s]


Epoch 493/500, Loss: 9.7611


100%|██████████| 9/9 [00:01<00:00,  7.17it/s]


Epoch 494/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.14it/s]


Epoch 495/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.17it/s]


Epoch 496/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.14it/s]


Epoch 497/500, Loss: 9.7613


100%|██████████| 9/9 [00:01<00:00,  7.17it/s]


Epoch 498/500, Loss: 9.7611


100%|██████████| 9/9 [00:01<00:00,  7.13it/s]


Epoch 499/500, Loss: 9.7614


100%|██████████| 9/9 [00:01<00:00,  7.16it/s]

Epoch 500/500, Loss: 9.7615





In [9]:
import os

# Ensure the directory exists
output_dir = "./results/pytorch"
os.makedirs(output_dir, exist_ok=True)

output_length = 200

def run_inference(text):
    model.eval()
    input_tokens = tokenize_text([text], max_length)
    input_ids = input_tokens['input_ids'].to(device)
    attention_mask = input_tokens['attention_mask'].to(device)

    generated_tokens = []
    current_input = input_ids

    with torch.no_grad():
        for _ in range(output_length):
            logits = model(current_input)  # Shape: (batch_size, seq_len, vocab_size)
            next_token_logits = logits[:, -1, :]  # Get logits for the last token
            # next_token_id = torch.argmax(next_token_logits, dim=-1).item()            
            next_token_probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
            next_token_id = torch.multinomial(next_token_probs, num_samples=1).item()

            if next_token_id == tokenizer.pad_token_id:  # Stop at padding token
                break

            generated_tokens.append(next_token_id)

            # Prepare input for the next iteration
            next_token = torch.tensor([[next_token_id]], device=device)
            current_input = torch.cat((current_input, next_token), dim=1)

    return tokenizer.decode(generated_tokens, skip_special_tokens=True)


# %% Evaluate the model
def evaluate_model(df, model, name):
    model.eval()
    reference_summaries = []
    predicted_summaries = []
    total_time = 0

    for _, row in df.iterrows():
        test_text = row['text_short']
        reference_summary = row['summary']

        start_time = time.time()
        predicted_summary = run_inference(test_text)
        end_time = time.time()
        elapsed_time = end_time - start_time
        total_time += elapsed_time

        reference_summaries.append(reference_summary)
        predicted_summaries.append(predicted_summary)

    results_df = pd.DataFrame({
        'summary': reference_summaries,
        'summary_tuned': predicted_summaries
    })
    results_df.to_csv(f"./results/pytorch/{name}_summaries.csv")

    print(f"Evaluation completed for {name}.")
    print(f"Total time (seconds): {total_time}")
    print(f"Total time (minutes): {total_time / 60}")


In [10]:
test_df = pd.read_csv(f'{dataset_dir}/podcast_with_summary_test.csv')

# print some of the inference results
for i in range(5):
    test_text = test_df['text_short'][i]
    reference_summary = test_df['summary'][i]
    predicted_summary = run_inference(test_text)
    print(f"Test Text: {test_text}")
    print(f"Reference Summary: {reference_summary}")
    print(f"Predicted Summary: {predicted_summary}")
    print("")

Test Text: The following is a conversation with Andrew Ng, one of the most impactful educators, researchers, innovators, and leaders in artificial intelligence and technology space in general. He cofounded Coursera and Google Brain, launched Deep Learning AI, Landing AI, and the AI Fund, and was the chief scientist at Baidu. As a Stanford professor and with Coursera and Deep Learning AI, he has helped educate and inspire millions of students, including me. This is the Artificial Intelligence Podcast. If you enjoy it, subscribe on YouTube, give it five stars on Apple Podcast, support it on Patreon, or simply connect with me on Twitter at Lex Friedman, spelled F R I D M A N. As usual, I'll do one or two minutes of ads now and never any ads in the middle that can break the flow of the conversation. I hope that works for you and doesn't hurt the listening experience. This show is presented by Cash App, the number one finance app in the App Store. When you get it, use code LEXPODCAST.
Refer

In [11]:
# %% Evaluate on test set
evaluate_model(test_df, model, "test_dataset")

Evaluation completed for test_dataset.
Total time (seconds): 139.41584014892578
Total time (minutes): 2.3235973358154296
