### Imports & Function Definitions

In [2]:
import os
import pickle

# Change directory to project root
os.chdir("../../")

# Function to load model and vectorizer
def load_model_and_vectorizer(model_path, vectorizer_path):
    with open(model_path, "rb") as model_file:
        model = pickle.load(model_file)
    with open(vectorizer_path, "rb") as vec_file:
        vectorizer = pickle.load(vec_file)
    return model, vectorizer


### Model Loading

In [3]:
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torch

# Paths (replace these with your config.yaml paths if applicable)
vectorizer_path = "machine_learning/models/tfidf_vectorizer.pkl"

logistic_model_path = "machine_learning/models/logistic_regression.pkl"
naive_bayes_model_path = "machine_learning/models/naive_bayes.pkl"
svm_model_path = "machine_learning/models/svm.pkl"

distilbert_model_path = "machine_learning/models/distilbert/checkpoints/checkpoint-300"
distilbert_tokenizer_path = "machine_learning/models/distilbert/tokenizer"


# Load all models and vectorizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

logistic_model, vectorizer = load_model_and_vectorizer(logistic_model_path, vectorizer_path)
naive_bayes_model, _ = load_model_and_vectorizer(naive_bayes_model_path, vectorizer_path)
svm_model, _ = load_model_and_vectorizer(svm_model_path, vectorizer_path)
distilbert_model = DistilBertForSequenceClassification.from_pretrained(distilbert_model_path).to(device)
distilbert_tokenizer = DistilBertTokenizer.from_pretrained(distilbert_tokenizer_path)

print("All models and vectorizer loaded successfully.")

All models and vectorizer loaded successfully.


### Define Prediction Function

In [5]:
def predict_category(email_content, model, vectorizer):
    """
    Predict the category of an email using the given model and vectorizer.
    
    Args:
        email_content (str): The email content to predict.
        model: Trained classification model.
        vectorizer: Pre-trained TF-IDF vectorizer.
    
    Returns:
        str: Predicted category.
    """
    email_tfidf = vectorizer.transform([email_content])
    return model.predict(email_tfidf)[0]

def predict_category_distilbert(email_content, model, tokenizer):
    """
    Predict the category of an email using DistilBERT.
    
    Args:
        email_content (str): The email content to predict.
        model: Fine-tuned DistilBERT model.
        tokenizer: Tokenizer for DistilBERT.
    
    Returns:
        str: Predicted category.
    """
    model.eval()
    inputs = tokenizer(
        email_content,
        truncation=True,
        padding=True,
        max_length=512,
        return_tensors="pt"
    ).to(device)

    with torch.no_grad():
        outputs = model(**inputs)
        predicted_class = torch.argmax(outputs.logits, dim=1).cpu().item()

    # Map predicted class to category (adjust mapping to your dataset)
    label_mapping = {0: "Work", 1: "Personal", 2: "Promotional", 3: "Urgent"}
    return label_mapping[predicted_class]


### Test the Prediction

In [6]:
# Example email
example_email = '''
    Your Roblox Assesment will expire in 24 hrs! You must complete it as soon as possible!
    '''

# Predict using Logistic Regression
logistic_prediction = predict_category(example_email, logistic_model, vectorizer)
print(f"Logistic Regression Prediction: {logistic_prediction}")

# Predict using Naive Bayes
naive_bayes_prediction = predict_category(example_email, naive_bayes_model, vectorizer)
print(f"Naive Bayes Prediction: {naive_bayes_prediction}")

# Predict using SVM
svm_prediction = predict_category(example_email, svm_model, vectorizer)
print(f"SVM Prediction: {svm_prediction}")

# Predict using DistilBERT
distilbert_prediction = predict_category_distilbert(example_email, distilbert_model, distilbert_tokenizer)
print(f"DistilBERT Prediction: {distilbert_prediction}")

Logistic Regression Prediction: Urgent
Naive Bayes Prediction: Urgent
SVM Prediction: Urgent
DistilBERT Prediction: Urgent


### Interactive Predictions

In [5]:
# Interactive loop for predictions
while True:
    email_content = input("Enter an email to predict its category (or type 'exit' to quit): ")
    if email_content.lower() == "exit":
        print("Exiting...")
        break
    print("Predictions:")
    print(f"  Logistic Regression: {predict_category(email_content, logistic_model, vectorizer)}")
    print(f"  Naive Bayes: {predict_category(email_content, naive_bayes_model, vectorizer)}")
    print(f"  SVM: {predict_category(email_content, svm_model, vectorizer)}")
    print(f"  DistilBERT: {predict_category_distilbert(email_content, distilbert_model, distilbert_tokenizer)}")
    print()


Exiting...
