#### SMS Spam Collection - Mistral Embeddings with Prediction

In [20]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from lime.lime_text import LimeTextExplainer
import concurrent.futures
import ollama

os.makedirs("../data", exist_ok=True)
os.makedirs("../plots", exist_ok=True)

In [21]:
# Load data
df = pd.read_csv('../data/spam.csv', encoding='latin-1')
df['label'] = df['v1'].map({'ham': 0, 'spam': 1})

In [22]:
# Split data and generate sample
df_sample = df.groupby('label').apply(lambda x: x.sample(n=min(1000, len(x)), random_state=42)).reset_index(drop=True)

texts = df_sample['v2'].values
labels = df_sample['label'].values
texts_train, texts_test, y_train, y_test = train_test_split(
    texts, labels, test_size=0.3, stratify=labels, random_state=42
)

  df_sample = df.groupby('label').apply(lambda x: x.sample(n=min(1000, len(x)), random_state=42)).reset_index(drop=True)


In [23]:
# Embedding function
def get_embedding(text):
    try:
        return ollama.embeddings(model='mistral', prompt=text[:512])['embedding']
    except:
        return [0.0] * 4096

In [24]:
# Parallel embedding
print("Generating embeddings with Mistral (sampled)...")
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
    X_train = np.array(list(executor.map(get_embedding, texts_train)))
    X_test = np.array(list(executor.map(get_embedding, texts_test)))

Generating embeddings with Mistral (sampled)...


In [25]:
# Train model
clf = RandomForestClassifier(n_estimators=100, class_weight='balanced')
clf.fit(X_train, y_train)

y_pred = clf.predict(X_test)
report = classification_report(y_test, y_pred, output_dict=True)
print(classification_report(y_test, y_pred))

              precision    recall  f1-score   support

           0       0.00      0.00      0.00       301
           1       0.43      1.00      0.60       224

    accuracy                           0.43       525
   macro avg       0.21      0.50      0.30       525
weighted avg       0.18      0.43      0.26       525



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [26]:
# Save metrics for later report use
metrics_df = pd.DataFrame([{
    'accuracy': report['accuracy'],
    'precision': report['1']['precision'],
    'recall': report['1']['recall'],
    'f1': report['1']['f1-score']
}])
metrics_df.to_csv('../data/model_metrics.csv', index=False)

In [27]:
# Confusion Matrix
plt.figure(figsize=(6, 5))
cm = confusion_matrix(y_test, y_pred)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=['Legitimate', 'Fraud'],
            yticklabels=['Legitimate', 'Fraud'])
plt.title('Fraud Detection Confusion Matrix')
plt.tight_layout()
plt.savefig('../plots/confusion_matrix.png', dpi=300)
plt.close()

In [28]:
# LIME Explanations
explainer = LimeTextExplainer(class_names=['ham', 'spam'])

def predict_proba(texts):
    emb = np.array([get_embedding(t) for t in texts])
    return clf.predict_proba(emb)

fraud_indices = np.where(y_test == 1)[0]
if len(fraud_indices) > 0:
    idx = fraud_indices[0]
    exp = explainer.explain_instance(
        text_instance=texts_test[idx],
        classifier_fn=predict_proba,
        num_features=10,
        num_samples=300
    )
    exp.save_to_file('../plots/lime_explanation.html')
    fig = exp.as_pyplot_figure()
    fig.set_size_inches(10, 6)
    plt.title('LIME Explanation for Fraud Prediction')
    plt.tight_layout()
    plt.savefig('../plots/lime_visualization.png', dpi=300)
    plt.close()
else:
    print("No fraud samples in test set to explain.")