# Use a foundation model to build a spam email classifier
A foundatio model serves as a fundamental building block for potentially enedless applications. One is the ddevelopment of a spam email classifier using only the prompt. By leveraging the capabilities of a foundatio model, this project aims to accurately identify and filter out unwanted emails.

Tasks:
1. Identify and gather relevant data
2. Build and evaluate the spam email classifier
3. Build an improved classifier?

# Step 1: Identify and gather relevant data
To train and test the spam email classifier, you will need a dataset of emails that are labeled as spam or not. It's important to identify and gather a suitable dataset that represents a wide range of spam and non-spam emails

In [None]:
# find a spam dataset from huggingface at https://huggingface.co/datasets and load it using the datasets library

from datasets import load_dataset

dataset = load_dataset("sms_spam", split=["train"])[0]

for entry in dataset.select(range(3)):
    sms = entry["sms"]
    label = entry["label"]
    print(f"label={label}, sms={sms}")

Create functions to convert numerical label IDs to labels

In [None]:
# Dictionaries to convert between labels and IDs
id2label = {0: "NOT SPAM", 1: "SPAM"}
lable2id={"NOT SPAM": 0, "SPAM": 1}

for entry in dataset.select(range(3)):
    sms = entry["sms"]
    label = entry["label"]
    print(f"label={id2label[label]}, sms={sms}")

# Step 2: Build and evaluate the spam email classifier
Using the foundation model and the prepared dataset, you can create a spam email classifier
Write a prompt that will ask the model to classify 15 messgaes as either "spam" or "not spam". For easier parsing, we can ask the LLM to respond in JSON.

In [None]:
# Start with a helper function to format SMS messages for the LLM

def get_sms_messages_string(dataset, item_numbers, include_labels =False): #include_lables toggles whether the output includes the spam/not spam label
    sms_message_string = ""
    for item_number, entry in zip(item_numbers, dataset.select(item_numbers)):
        sms = entry["sms"]
        label_id = entry["label"]

        if include_labels:
            sms_message_string += (
                f"{item_number} (label={id2label[label_id]}) -> {sms}\n"
            )
        else:
            sms_message_string += f"{item_number} -> {sms}\n"
    return sms_message_string

print(get_sms_messages_string(dataset, range(3), include_labels=True))

Now write code that will produce the prompt. The prompt should include a few SMS messages to be labeled as well as instructions for the LLM.

Some LLMs will also format the output as JSON if request, e.g. "Respond in JSON format"

In [None]:
data_range = range(7,15)

# Get a few messages and format them as a string
sms_messages_string = get_sms_messages_string(dataset, data_range) # By default, include_labels is False

# Construct a query to send to the LLM including the sms messages.
# Ask that it respons in JSON format

query =f"""
Look at the following messages and determine whether they are spam or not. Respond in JSON format using the following format{{"0": NOT SPAM", "1", "SPAM"}}. 
\nMessages:\n
{sms_messages_string}
"""
print(query)

In [None]:
# Use the above prompt in an LLM to get a response

# Accurate predictions
# response = {
#   "7": "NOT SPAM",
#   "8": "SPAM",
#   "9": "SPAM",
#   "10": "NOT SPAM",
#   "11": "SPAM",
#   "12": "SPAM",
#   "13": "NOT SPAM",
#   "14": "NOT SPAM"
# }

# Less accurate predictions
response = {
    "7": "NOT SPAM",
    "8": "SPAM",
    "9": "SPAM",
    "10": "NOT SPAM",
    "11": "SPAM",
    "12": "SPAM",
    "13": "NOT SPAM",
    "14": "SPAM",
}

# Estimate the accuracy of the classifier by comparing the response to the lables in the dataset
def get_accuracy(response, dataset, original_indices):
    correct =0
    total =0

    # iterate over the responses and store the entry number and prediction label in variables
    # if the entry number is not in the range of SMS messages passed into the LLM, skip 
    for entry_number, prediction in response.items():
        if int(entry_number) not in original_indices:
            continue

        # Get the ID from the entry's label, convert string to an integer
        label_id=dataset[int(entry_number)]["label"]
        # Get the label for the ID
        label = id2label[label_id]

        if prediction.lower() == label.lower():
            print(f"prediction: {prediction.lower()} || Actual: {label.lower()}")
            correct += 1
        else:
            print(f"prediction: {prediction.lower()} || Actual: {label.lower()}")

        total += 1

    try:
        accuracy = correct/total
    except ZeroDivisionError:
        print("No matching results found!")
        return
    
    return round(accuracy, 2)

print(f"Accuracy: {get_accuracy(response, dataset, data_range)}")

# Step 3: Build an improved classifier?
If you provide the LLM with some examples for how to complete a task, it will sometimes improve its performance. Let's try that

In [None]:
# Get a few labeled messages and format them as a string
labeled_range = range(54,60)

sms_messages_string_w_labels = get_sms_messages_string(dataset, labeled_range, include_labels=True)

# Get a few unlabeled messages and format them as a string
unlabeled_range = range(7,15)
sms_message_string_no_lables = get_sms_messages_string(dataset, unlabeled_range) # include_lables is False by default

# Construct a query to send to the LLM including the labeled messages as well as the unlabled.
# AS it to respond in a JSON format
query = f"""
Use the labeled message samples to determine if the unlabled messages are spam or not. Respond in JSON following format{{"0": NOT SPAM", "1", "SPAM"}}. 
\nUnabeled Messages:\n
{sms_message_string_no_lables}

\nLabeled Messages:\n
{sms_messages_string_w_labels}
"""

print(query)

In [None]:
# Accurate response
# {
#   "7": "NOT SPAM",
#   "8": "SPAM",
#   "9": "SPAM",
#   "10": "NOT SPAM",
#   "11": "SPAM",
#   "12": "SPAM",
#   "13": "NOT SPAM",
#   "14": "NOT SPAM"
# }

# Less accurate response
response = {
  "7": "SPAM",
  "8": "SPAM",
  "9": "SPAM",
  "10": "NOT SPAM",
  "11": "SPAM",
  "12": "SPAM",
  "13": "SPAM",
  "14": "NOT SPAM"
}

# Check the accuracy for the unlabeled messages
print(f"Accuracy: {get_accuracy(response, dataset, unlabeled_range)}")

In [None]:
# Show the misclassified items
def print_misclassified_messages(response, dataset):
    for entry_number, prediction in response.items():
        # Get the ID from the entry's label, convert string to an integer
        label_id=dataset[int(entry_number)]["label"]
        # Get the label for the ID
        label = id2label[label_id]

        if prediction.lower() != label.lower():
            sms = dataset[int(entry_number)]["sms"] # get the message
            print(f"---\nMessage: {sms} \nLabel: {label} || Prediciton: {prediction}\n---\n")

print(print_misclassified_messages(response, dataset))