This notebook loads the pretrained weights from OpenAI's GPT2 model into our LLM and then fine-tunes the LLM on classifying text messages as spam or not spam (ham). 

In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import gpt_model
import pretrain_model
import tiktoken
import spam_dataset
import load_pretrained_weights
from torch.utils.data import DataLoader
from gpt_download import download_and_load_gpt2

In [None]:
# OpenAI GPT2 settings and parameters

settings, params = download_and_load_gpt2(model_size="124M", models_dir="gpt2")
print("Settings: ", settings)
print()
print("Params keys: ", params.keys())
print("Params token embedding weights: ", params["wte"])
print("Token embedding weights shape: ", params["wte"].shape)

In [None]:
# Intialize our LLM
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = tiktoken.get_encoding("gpt2")

gpt = gpt_model.GPTModel(
    vocab_size = 50257,
    context_length = 1024,
    emb_dim = 768,
    num_heads = 12,
    num_layers = 12,
    drop_rate = 0.0,
    qkv_bias = True
)
gpt.eval()

In [None]:
# Create dataframe of spam text message dataset
data_file_path = "SMSSpamCollection"
df = pd.read_csv(data_file_path, sep="\t", header=None, names=["Label", "Text"])
df["Label"].value_counts()

In [None]:
# Create a balanced dataset (equal numbers of spam and ham instances)

def create_balanced_dataset(df):
    num_spam = df[df["Label"] == "spam"].shape[0] # num instances of spam
    ham_subset = df[df["Label"] == "ham"].sample(num_spam, random_state=123) # randomly sample ham instances to match number of spam instances
    balanced_df = pd.concat([ham_subset, df[df["Label"]=="spam"]])
    return balanced_df

balanced_df = create_balanced_dataset(df)
print(balanced_df["Label"].value_counts())

In [None]:
# split dataset: 70% for training, 10% for validation, 20% for testing

def random_split(df, train_frac, validation_frac):
    df = df.sample(frac=1, random_state=123).reset_index(drop=True)
    train_end = int(len(df) * train_frac)
    validation_end = train_end + int(len(df) * validation_frac)
    
    train_df = df[:train_end]
    validation_df = df[train_end:validation_end]
    test_df = df[validation_end:]
    
    return train_df, validation_df, test_df

train_df, validation_df, test_df = random_split(balanced_df, 0.7, 0.1)   

train_df.to_csv("train.csv", index=None)
validation_df.to_csv("validation.csv", index=None)
test_df.to_csv("test.csv", index=None) 

train_dataset = spam_dataset.SpamDataset(
    csv_file="train.csv",
    max_length=None,
    tokenizer=tokenizer
)

val_dataset = spam_dataset.SpamDataset(
    csv_file="validation.csv",
    max_length=train_dataset.max_length,
    tokenizer=tokenizer
)

test_dataset = spam_dataset.SpamDataset(
    csv_file="test.csv",
    max_length=train_dataset.max_length,
    tokenizer=tokenizer
)

# print(train_dataset.max_length)

In [None]:
# Create DataLoaders for training, validation, and test datasets
num_workers = 0
batch_size = 8
torch.manual_seed(123)

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    drop_last=True
)

val_loader = DataLoader(
    dataset=val_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    drop_last=False
)

test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    drop_last=False
)

for input_batch, target_batch in train_loader:
    pass

print("Input batch dimensions: ", input_batch.shape)
print("Label batch dimensions: ", target_batch.shape)
print()
print(f"{len(train_loader)} training batches")
print(f"{len(val_loader)} validation batches")
print(f"{len(test_loader)} testing batches")

In [None]:

load_pretrained_weights.load_weights_into_gpt(gpt, params)
gpt.eval()

test_text = "Every effort moves you"
token_ids = pretrain_model.generate(
    model=gpt,
    index=pretrain_model.text_to_token_ids(test_text, tokenizer),
    max_new_tokens=15,
    context_size=1024
)

print(pretrain_model.token_ids_to_text(token_ids, tokenizer))

In [None]:
test_text = (
    "Is the following text 'spam'? Answer with 'yes' or 'no'."
    " 'You are a winner you have been specially"
    " selected to receive $1000 cash or a $2000 award.'"
)
token_ids = pretrain_model.generate(
    model=gpt,
    index=pretrain_model.text_to_token_ids(test_text, tokenizer),
    max_new_tokens=23,
    context_size=1024
)

print(pretrain_model.token_ids_to_text(token_ids, tokenizer))

In [None]:
# freeze the model (i.e., nake all layers nontrainable)
for param in gpt.parameters():
    param.requires_grad = False

In [None]:
# replace output layer (gpt.out_head), which maps the layer inputs to 50,257 dimensions (the size of the vocabulary)
# to 2 dimensions (spam/not spam)
torch.manual_seed(123)
num_classes = 2
gpt.out_head = torch.nn.Linear(
    in_features=768, # embedding dimensions
    out_features=num_classes
)

In [None]:
# make final LayerNorm and last transformer block trainable
for param in gpt.transformer_blocks[-1].parameters():
    param.requires_grad = True
for param in gpt.final_norm.parameters():
    param.requires_grad = True

In [None]:
def calc_accuracy_loader(data_loader, model, device, num_batches=None):
    model.eval()
    correct_predictions, num_examples = 0, 0
    
    if num_batches == None:
        num_batches = len(data_loader)
    else:
        num_batches = min(num_batches, len(data_loader))
        
    for i, (input_batch, target_batch) in enumerate(data_loader):
        if i < num_batches:
            input_batch = input_batch.to(device)
            target_batch = target_batch.to(device)
            
            with torch.no_grad():
                logits = model(input_batch)[:, -1, :] # logits of last output token
            predicted_labels = torch.argmax(logits, dim=-1)
            
            num_examples += predicted_labels.shape[0]
            correct_predictions += ((predicted_labels == target_batch).sum().item())
        else:
            break
        
    return correct_predictions / num_examples

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gpt.to(device)

torch.manual_seed(123)
train_accuracy = calc_accuracy_loader(train_loader, gpt, device, num_batches=10)
val_accuracy = calc_accuracy_loader(val_loader, gpt, device, num_batches=10)
test_accuracy = calc_accuracy_loader(test_loader, gpt, device, num_batches=10)

print(f"Training accuracy: {train_accuracy*100:.2f}%")
print(f"Validation accuracy: {val_accuracy*100:.2f}%")
print(f"Test accuracy: {test_accuracy*100:.2f}%")

In [None]:
def calc_loss_batch(input_batch, target_batch, model, device):
    input_batch = input_batch.to(device)
    target_batch = target_batch.to(device)
    logits = model(input_batch)[:, -1, :] # logits of last output token
    loss = torch.nn.functional.cross_entropy(logits, target_batch)
    return loss

def calc_loss_loader(data_loader, model, device, num_batches=None):
    total_loss = 0.
    if len(data_loader) == 0:
        return float("nan")
    elif num_batches is None:
        num_batches = len(data_loader)
    else:
        num_batches = min(num_batches, len(data_loader))
        
    for i, (input_batch, target_batch) in enumerate(data_loader):
        if i < num_batches:
            loss = calc_loss_batch(input_batch, target_batch, model, device)
            total_loss += loss.item()
        else:
            break
    return total_loss / num_batches

In [None]:
with torch.no_grad(): #disable gradient tracking because we're not training yet
    train_loss = calc_loss_loader(train_loader, gpt, device, num_batches=5)
    val_loss = calc_loss_loader(val_loader, gpt, device, num_batches=5)
    test_loss = calc_loss_loader(test_loader, gpt, device, num_batches=5)
    
print(f"Training loss: {train_loss:.3f}")
print(f"Validation loss: {val_loss:.3f}")
print(f"Test loss: {test_loss:.3f}")

In [None]:
def evaluate_model(
    model,
    train_loader,
    val_loader,
    device,
    eval_iter
):
    model.eval()
    with torch.no_grad():
        train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)
        val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)
        
    model.train()
    return train_loss, val_loss

def train_classifier(
    model,
    train_loader,
    val_loader,
    optimizer,
    device,
    num_epochs,
    eval_freq,
    eval_iter
):
    
    train_losses, val_losses, train_accs, val_accs = [], [], [], []
    examples_seen, global_step = 0, -1
    
    for epoch in range(num_epochs):
        model.train()
        
        for input_batch, target_batch in train_loader:
            optimizer.zero_grad() # reset loss gradients from previous batch
            loss = calc_loss_batch(input_batch, target_batch, model, device)
            loss.backward() # calculate loss gradients
            optimizer.step() # update model with loss gradients
            examples_seen += input_batch.shape[0]
            global_step += 1
            
            if global_step % eval_freq == 0:
                train_loss, val_loss = evaluate_model(model, train_loader, val_loader, device, eval_iter)
                train_losses.append(train_loss)
                val_losses.append(val_loss)
                print(f"Epoch {epoch+1}, Step {global_step:06d}: "
                      f"Train loss: {train_loss:.3f}, "
                      f"Validation loss: {val_loss:.3f}"
                      )
                
        train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter)
        val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter)
        print(f"Train accuracy: {train_accuracy*100:.2f}%")
        print(f"Validation accuracy: {val_accuracy*100:.2f}%")
        train_accs.append(train_accuracy)
        val_accs.append(val_accuracy)
        
    return train_losses, val_losses, train_accs, val_accs, examples_seen

In [None]:
import time
start_time = time.time()
torch.manual_seed(123)
optimizer = torch.optim.AdamW(gpt.parameters(), lr=5e-5, weight_decay=0.1)
num_epochs = 5

train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier(
    gpt, train_loader, val_loader, optimizer, device, num_epochs=num_epochs, eval_freq=50, eval_iter=5
)

end_time = time.time()
execution_time_minutes = (end_time - start_time) / 60
print(f"Training completed in {execution_time_minutes:.2f} minutes.")

In [None]:
def classify(
    text,
    model,
    tokenizer,
    device,
    max_length=None,
    pad_token_id=50256
):
    model.eval()
    
    input_ids = tokenizer.encode(text)
    supported_context_length = model.position_embedding.weight.shape[1]
    input_ids = input_ids[:min(max_length, supported_context_length)] # truncate sentence if too long
    
    input_ids += [pad_token_id] * (max_length - len(input_ids)) # pad to longest sequence
    
    input_tensor = torch.tensor(input_ids, device=device).unsqueeze(0)
    
    with torch.no_grad():
        logits = model(input_tensor)[:, -1, :] # logit of last output token
        
    predicted_label = torch.argmax(logits, dim=-1).item()
    
    return "spam" if predicted_label == 1 else "not spam"

In [None]:
text1 = (
    "You are a winner you have been specially"
    " selected to receive $1000 cash or a $2000 award"
)
print(classify(text1, gpt, tokenizer, device, max_length=train_dataset.max_length))

In [None]:
text2 = (
    "Hey, just wanted to check if we're still on"
    " for dinner tonight? Let me know!"
)
print(classify(text2, gpt, tokenizer, device, max_length=train_dataset.max_length))

In [None]:
torch.save(gpt.state_dict(), "classifer.pth")

To load the model:

model_state_dict = torch.load("classifier.pth, map_location=device")
model.load_state_dict(model_state_dict)