In [None]:
import numpy as np
import pandas as pd
import sqlite3

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline

# Google Colab BUG - after installing, restart the session but do not reinstall the packages
# %pip install -q torch transformers accelerate bitsandbytes transformers sentence-transformers
# %pip install -q --upgrade transformers

In [None]:
from google.colab import drive
drive.mount('/content/drive')

from google.colab import userdata
HF_TOKEN = userdata.get('HF_TOKEN')
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"

In [None]:
# Quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    token=HF_TOKEN,
    low_cpu_mem_usage=True
)
tokenizer = AutoTokenizer.from_pretrained(
    model_id,
    token=HF_TOKEN
)

In [None]:
from transformers import pipeline

# HF_TOKEN is stored as a colab secret, HF imports it automatically
pipe = pipeline(
    "text-generation",
    model=model_id,
    model_kwargs={"torch_dtype": torch.bfloat16},
    device_map="auto",
    # token=HF_TOKEN
)

In [None]:
!unzip /content/drive/MyDrive/DDD/mail.zip

# Data notes
- some bodies are None
- Very long inputs (RAM allows for up to around 5000 tokens)

In [None]:
def load_sqlite_data(labeled_only=True):
    if labeled_only:
        conn = sqlite3.connect('/content/mails.sqlite')
        query = """
        SELECT mails.*, gpt.is_spam as is_spam_gpt, gt.is_spam as is_spam_gt
        FROM mails
        JOIN gpt ON gpt.id = mails.id
        JOIN gt ON gt.id = mails.id
        """
    else:
        conn = sqlite3.connect('/content/mailinator.sqlite')
        query = "SELECT * FROM mails"

        mails = pd.read_sql_query(query, conn)
        conn.close()

    mails.body.fillna("", inplace=True)
    mails.subject.fillna("", inplace=True)
    return mails


In [None]:
df_mails = load_sqlite_data(labeled_only=True)

In [None]:
tmp = '''
You will be given the sender's name, sender's email address, email subject and email body. Hyperlinks that were a part of the email body are replaced with <LINK>.\
If the email body is too long, it will be cut-off after 2500 characters, followed by <TRUNCATED>.
You will label the email after <<<>>>. If the email is spam, respond only with "true", otherwise respond only with "false". Do not provide any explanations or notes.'''


In [None]:
system_prompt = '''
You are an email classification assistant. Your task is to label emails as either "relevant" or "irrelevant" based on the following criteria:

Irrelevant Emails:
- Spam
- Advertisements
- News articles
- Newsletters
- Unsolicited workshop invitations
- Mass mailings

Relevant Emails:
- Account-related notifications
- System notifications intended for website administrators
- Invoices
- Emails containing personal information of the recipient
- Emails with security-sensitive data, such as usernames, passwords, secret tokens, or online meeting access codes


Additionally, follow these rules:
1) all emails with direct indication of existing accounts, such as "update your profile", "see your account", or "log in" are relevant
2) all emails that mention recipients previously visiting, using, purchasing from, or registering on the website are relevant
3) emails that merely invite recipients to create account instead of refering to their existing accounts are irrelevant
4) emails that include only indications of newsletter subscriptions, such as "unsubscribe" or "manage email preferences" but do not indicate recipients
have existing accounts are irrelevant.

Classification Rules:
1. Emails with direct references to existing accounts, such as "update your profile," "see your account," or "log in," are classified as relevant.
2. Emails that reference the recipient's previous visits, usage, purchases, or registrations on a website are classified as relevant.
3. Emails that only invite recipients to create an account, without referring to an existing account, are classified as irrelevant.
4. Emails mentioning newsletter subscriptions with phrases like "unsubscribe" or "manage email preferences," but without indicating the recipient has an existing account, are classified as irrelevant.

Examples:

Input:
{
    "sender": "notifications-noreply@linkedin.com",
    "sender_name": "LinkedIn",
    "subject": "You appeared in 1 search. Review who searched you.",
    "body": "diana, this is what you’ve missed on LinkedIn. See all searches: <LINK>. See your profile: <LINK> Unsubscribe: <LINK>."
}

Output:
{"relevant": True}


Input:
{
    "sender": "wordpress@roskus.es",
    "sender_name": "WordPress",
    "subject": "[Wordfence Alert] Problems found on roskus.es",
    "body": "Some of your plug-ins are out of date <LINK>!"
}

Output:
{"relevant": True}


Input:
{
    "sender": "invoicing@squareupsandbox.com",
    "sender_name": "Owebest",
    "subject": "You paid an invoice! (#001508)",
    "body": "Hello John wick, You have paid invoice #001508 from Owebest Test for $49.00. Give us your feedback on your recent purchase!"
}

Output:
{"relevant": True}


Input:
{
    "sender": "noreply@sugarbabes.com",
    "sender_name": "Sugar Babes",
    "subject": "See who messaged you!",
    "body": "These sugar babes searched for your profile <LINK>. Log in <LINK>. Unsubscribe <LINK>."
}

Output:
{"relevant": True}


Input:
{
    "sender": "enews@choicehomewarranty.com",
    "sender_name": "CHOICE Warranty",
    "subject": "LUCKY YOU! Early Access Is Yours, Dudely",
    "body": "Hi Dudely, Click now to score this exclusive St. Patty's home warranty deal <LINK>. Sign-up <LINK>.
}

Output:
{"relevant": False}


Input:
{
    "sender": "news@bbcnews.com",
    "sender_name": "BBC",
    "subject": "News for This Week",
    "body": "New scientific discoveries <LINK>. Crime is going down <LINK>. Unsubscribe <LINK>"
}

Output:
{"relevant": False}


Provide your response in the JSON format `{"relevant": true}` or `{"relevant": false}`. Reply only with the JSON label. Do not provide any explanation.
'''.strip()

In [None]:
SENDER_LIMIT = 512
SUBJECT_LIMIT = 1024
BODY_LIMIT = 4096
BODY_LIMIT_HALF = 2048

def mail_to_message(mail):
    sender = mail['sender'] if mail['sender'] else ''
    sender_name = mail['sender_name'] if mail['sender_name'] else ''
    subject = mail['subject'] if mail['subject'] else ''
    body = mail['body'] if mail['body'] else ''

    sender = sender[:SENDER_LIMIT]
    sender_name = sender_name[:SENDER_LIMIT]
    subject = subject[:SUBJECT_LIMIT]
    body = body if len(body) <= BODY_LIMIT else f'{body[:BODY_LIMIT_HALF]} <TRIMMED> {body[-BODY_LIMIT_HALF:]}'

    return (
        f'sender: {sender}\n'
        f'sender_name: {sender_name}\n'
        f'subject: {subject}\n'
        f'body: {body}'
    )

def label_email(system_prompt, user_prompt, history=None):
    messages = [{"role": "system", "content": system_prompt}]
    if history:
        for h in history:
            messages.append({"role": "user", "content": h['user']})
            messages.append({"role": "assistant", "content": h['assistant']})
    messages.append({"role": "user", "content": user_prompt})

    # outputs = pipe(
    #   messages,
    #   max_new_tokens=50,
    # )
    # return outputs[0]["generated_text"][-1]['content']

    tokenized_prompt = tokenizer.apply_chat_template(messages,
                                                   tokenize=True,
                                                   add_generation_prompt=True,
                                                   return_tensors="pt").to("cuda")
    with torch.no_grad():
    outputs = model.generate(tokenized_prompt, max_new_tokens=50, pad_token_id=tokenizer.eos_token_id )
    return tokenizer.decode(outputs[0][len(tokenized_prompt[0]):], skip_special_tokens=True)


In [None]:
res_filepath = "/content/drive/MyDrive/DDD/mail_labels.csv"
res = []
file_batch = []

for i in range(len(df_mails)):
    mail = df_mails.loc[i]
    user_prompt = mail_to_message(mail)
    pred_label = label_email(system_prompt, user_prompt)

    res.append((mail.id, pred_label))
    file_batch.append((mail.id, pred_label))
    print(str(i) + ":" + pred_label + "\t" + mail.is_spam_gt)

    if (i > 0 and i % 40 == 0):
        with open(res_filepath, 'a') as file:
        for val in file_batch:
            file.write(f'{val[0]},{val[1]}\n')
        file_batch = []