In [1]:
import urllib.request
import zipfile
from pathlib import Path

url = "https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip"

zip_path = Path("sms_spam_collection.zip")
extract_path = Path("sms_spam_collection")
data_file = extract_path / "SMSSpamCollection.tsv"


def download_and_extract_data(url, zip_path, extract_path):
    if not zip_path.exists():
        print("Downloading dataset...")
        urllib.request.urlretrieve(url, zip_path)
        print("Download complete.")
    else:
        print("Dataset already downloaded.")

    if not extract_path.exists():
        print("Extracting dataset...")
        with zipfile.ZipFile(zip_path, "r") as zip_ref:
            zip_ref.extractall(extract_path)
        print("Extraction complete.")
    else:
        print("Dataset already extracted.")


download_and_extract_data(url, zip_path, extract_path)

Dataset already downloaded.
Dataset already extracted.


# Use Spam data to fine-tune a GPT-2 model for text classification


In [2]:
import tiktoken

tokenizer = tiktoken.get_encoding("gpt2")

In [3]:
import pandas as pd

data_file = "./sms_spam_collection/SMSSpamCollection"

df = pd.read_csv(data_file, sep="\t", header=None, names=["label", "text"])
df.head()

Unnamed: 0,label,text
0,ham,"Go until jurong point, crazy.. Available only ..."
1,ham,Ok lar... Joking wif u oni...
2,spam,Free entry in 2 a wkly comp to win FA Cup fina...
3,ham,U dun say so early hor... U c already then say...
4,ham,"Nah I don't think he goes to usf, he lives aro..."


In [4]:
# we can see that the dataset is imbalanced, with more "ham" than "spam" messages

def balance_dataset(df):
    spam = df[df["label"] == "spam"]
    ham = df[df["label"] == "ham"].sample(n=len(spam), random_state=42)
    balanced_df = pd.concat([spam, ham]).sample(frac=1, random_state=42).reset_index(drop=True)
    return balanced_df

balanced_df = balance_dataset(df)

In [5]:
balanced_df["label"] = balanced_df["label"].map({"ham": 0, "spam": 1})
balanced_df.head()

Unnamed: 0,label,text
0,0,The evo. I just had to download flash. Jealous?
1,0,Hi Dear Call me its urgnt. I don't know whats ...
2,0,Full heat pa:-) i have applyed oil pa.
3,0,Gokila is talking with you aha:)
4,0,"Dude u knw also telugu..thts gud..k, gud nyt.."


In [6]:
def splt_dataset(df, train_frac=0.7, val_frac=0.1):
    df = df.sample(frac=1, random_state=42).reset_index(drop=True)  # Shuffle the dataset
    train_end = int(len(df) * train_frac)
    val_end = int(len(df) * (train_frac + val_frac))
    train_df = df[:train_end]
    val_df = df[train_end:val_end]
    test_df = df[val_end:]
    return train_df, val_df, test_df

train_df, val_df, test_df = splt_dataset(balanced_df)

print(f"Train size: {len(train_df)}, Val size: {len(val_df)}, Test size: {len(test_df)}")
    

Train size: 1045, Val size: 150, Test size: 299


In [7]:
from pathlib import Path

dataset_path = Path("../data/sms_spam")
dataset_path.mkdir(parents=True, exist_ok=True)
train_df.to_csv(dataset_path / "train.csv", index=False)
val_df.to_csv(dataset_path / "val.csv", index=False)
test_df.to_csv(dataset_path / "test.csv", index=False)

In [8]:
import torch
from torch.utils.data import Dataset, DataLoader

class SpamDataset(Dataset):
    def __init__(self, dataset_path, tokenizer, max_length=None, pad_token_id=50_256):
        super().__init__()
        self.data = pd.read_csv(dataset_path)
        self.tokenizer = tokenizer
        self.encoded_texts = [
            tokenizer.encode(text) for text in self.data["text"].values
        ]

        self.labels = self.data["label"].values

        if max_length is None:
            self.max_length = max(len(ids) for ids in self.encoded_texts)
        else:
            self.max_length = max_length
            self.encoded_texts = [
                ids[:max_length] for ids in self.encoded_texts
            ]

        self.encoded_texts = [
            ids + [pad_token_id] * (self.max_length - len(ids))
            for ids in self.encoded_texts
        ]

    def __len__(self):
        return len(self.encoded_texts)
    
    def __getitem__(self, idx):
        encoded_text = torch.tensor(self.encoded_texts[idx], dtype=torch.long)
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return encoded_text, label
        

In [9]:
train_dataset = SpamDataset(dataset_path / "train.csv", tokenizer)
val_dataset = SpamDataset(dataset_path / "val.csv", tokenizer, max_length=train_dataset.max_length)
test_dataset = SpamDataset(dataset_path / "test.csv", tokenizer, max_length=train_dataset.max_length)

len(train_dataset), len(val_dataset), len(test_dataset)

(1045, 150, 299)

In [10]:
dl_train = DataLoader(train_dataset, batch_size=8, shuffle=True, drop_last=True)
dl_val = DataLoader(val_dataset, batch_size=8, shuffle=False)
dl_test = DataLoader(test_dataset, batch_size=8, shuffle=False)



# Load the model

In [11]:
from llm_from_papers.models import GPT2Model

torch.manual_seed(42)

config = {
    "vocab_size": 50257,
    "context_size": 1024,
    "embed_dim": 768,
    "num_heads": 12,
    "num_layers": 12,
    "dropout": 0.1,
    "qkv_bias": True,
}

model = GPT2Model(config)
model.load_state_dict(torch.load("gpt2_124M.pth")['model_weights'])

for param in model.parameters():
    param.requires_grad = False


In [12]:
model.eval()

model.lm_head = torch.nn.Linear(config["embed_dim"], 2)  # 2 classes: ham and spam

In [13]:
for param in model.final_norm.parameters():
    param.requires_grad = True

for param in model.transfomer_blocks[-1].parameters():
    param.requires_grad = True


In [14]:
# count non trainable parameters
non_trainable_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Non-trainable parameters: {non_trainable_params / 1e6:.2f}M")
print(f"Trainable parameters: {trainable_params/1e6:.2f}M")

Non-trainable parameters: 117.35M
Trainable parameters: 7.09M


In [None]:
from tqdm import tqdm

def train_model(model, optimizer, dl_train, dl_val, device, nb_epoch):
    train_res, val_res = [], []
    criterion = torch.nn.CrossEntropyLoss()
    model.to(device)
    nb_examples, global_step = 0, 0

    for epoch in range(nb_epoch):
        model.train()
        
        for batch in tqdm(dl_train):
            inputs, labels = batch
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            pred = model(inputs)  # (B, C)

            # We use the logits of the last token for classification
            logits = pred[:, -1, :]

            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()

            nb_examples += inputs.size(0)
            global_step += 1

            if global_step % 100 == 0:
                print(
                    f"Epoch {epoch+1}/{nb_epoch} - Step {global_step} - Loss: {loss.item():.4f} - Examples seen: {nb_examples}"
                )

        model.eval()
        correct, total = 0, 0
        val_loss = 0.0
        with torch.no_grad():
            for batch in tqdm(dl_val):
                inputs, labels = batch
                inputs, labels = inputs.to(device), labels.to(device)

                pred = model(inputs)  # (B, C)
                logits = pred[:, -1, :]

                loss = criterion(logits, labels)
                val_loss += loss.item()

                _, predicted = torch.max(logits, dim=1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_loss /= len(dl_val)
        val_accuracy = correct / total
        val_res.append((val_loss, val_accuracy))

        print(
            f"Epoch {epoch+1}/{nb_epoch} - Validation Loss: {val_loss:.4f} - Validation Accuracy: {val_accuracy:.4f}"
        )

    return train_res, val_res

import torch.optim as optim

optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4, weight_decay=0.1)

device = "cpu"

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"

train_res, val_res = train_model(model, optimizer, dl_train, dl_val, device, nb_epoch=5)


 78%|███████▊  | 101/130 [00:09<00:02, 10.07it/s]

Epoch 1/5 - Step 100 - Loss: 0.5670 - Examples seen: 800


100%|██████████| 130/130 [00:12<00:00, 10.12it/s]


Epoch 1/5 - Validation Loss: 0.7231 - Validation Accuracy: 0.4533


 55%|█████▌    | 72/130 [00:07<00:05, 10.68it/s]

Epoch 2/5 - Step 200 - Loss: 0.2266 - Examples seen: 1600


100%|██████████| 130/130 [00:12<00:00, 10.20it/s]


Epoch 2/5 - Validation Loss: 0.5900 - Validation Accuracy: 0.7200


 32%|███▏      | 42/130 [00:04<00:08, 10.47it/s]

Epoch 3/5 - Step 300 - Loss: 0.6567 - Examples seen: 2400


100%|██████████| 130/130 [00:12<00:00, 10.22it/s]


Epoch 3/5 - Validation Loss: 0.3688 - Validation Accuracy: 0.8333


  9%|▉         | 12/130 [00:01<00:10, 10.90it/s]

Epoch 4/5 - Step 400 - Loss: 0.1018 - Examples seen: 3200


 86%|████████▌ | 112/130 [00:10<00:01, 10.68it/s]

Epoch 4/5 - Step 500 - Loss: 0.3421 - Examples seen: 4000


100%|██████████| 130/130 [00:12<00:00, 10.26it/s]


Epoch 4/5 - Validation Loss: 0.4604 - Validation Accuracy: 0.8067


 63%|██████▎   | 82/130 [00:07<00:04, 10.39it/s]

Epoch 5/5 - Step 600 - Loss: 0.5139 - Examples seen: 4800


100%|██████████| 130/130 [00:12<00:00, 10.26it/s]


Epoch 5/5 - Validation Loss: 0.3225 - Validation Accuracy: 0.8467


In [16]:
# look at test accuracy
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for batch in dl_test:
        inputs, labels = batch
        inputs, labels = inputs.to(device), labels.to(device)

        pred = model(inputs)  # (B, C)
        logits = pred[:, -1, :]

        _, predicted = torch.max(logits, dim=1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
test_accuracy = correct / total
print(f"Test Accuracy: {test_accuracy:.4f}")

Test Accuracy: 0.8696


In [17]:
def predict_spam(text, model=model, tokenizer=tokenizer, device=device):

    inputs = tokenizer.encode(text)
    inputs = torch.tensor(inputs, dtype=torch.long).unsqueeze(0).to(device)
    with torch.no_grad():
        pred = model(inputs)
        logits = pred[:, -1, :]
        predicted_class = torch.argmax(logits, dim=1).item()
        print(f"Message: '{text}' - Predicted class: {'spam' if predicted_class == 1 else 'ham'}")

In [18]:
not_a_spam = "Hey, are we still on for lunch tomorrow?"
spam = "Congratulations! You've won a $1000 Walmart gift card. Click here to claim now."

predict_spam(not_a_spam)
predict_spam(spam)



Message: 'Hey, are we still on for lunch tomorrow?' - Predicted class: ham
Message: 'Congratulations! You've won a $1000 Walmart gift card. Click here to claim now.' - Predicted class: spam


In [19]:
# is it working in french?
not_a_spam_fr = "Salut, on est toujours bon pour le déjeuner demain ?"
spam_fr = "Félicitations ! Vous avez gagné une carte cadeau Walmart de 1000 $. Cliquez ici pour la réclamer maintenant."
predict_spam(not_a_spam_fr)
predict_spam(spam_fr)

# no it is not working in french. We would need to add french examples to the training dataset.

Message: 'Salut, on est toujours bon pour le déjeuner demain ?' - Predicted class: spam
Message: 'Félicitations ! Vous avez gagné une carte cadeau Walmart de 1000 $. Cliquez ici pour la réclamer maintenant.' - Predicted class: spam
