#### Importing libraries

In [1]:
import pandas as pd
import numpy as np
import re
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import shap

In [2]:
from transformers import (
    BertTokenizer,
    BertForSequenceClassification,
   
    BertModel
)

#### loading data

In [3]:
df1 = pd.read_csv("C:/Users/Dell/Desktop/XAI_Model/SpamAssasin.csv")
df2 = pd.read_csv("C:/Users/Dell/Desktop/XAI_Model/CEAS_08.csv")
df = pd.concat([df1, df2], ignore_index=True)

In [4]:
df.head()

Unnamed: 0,sender,receiver,date,subject,body,label,urls
0,Robert Elz <kre@munnari.OZ.AU>,Chris Garrigues <cwg-dated-1030377287.06fa6d@D...,"Thu, 22 Aug 2002 18:26:25 +0700",Re: New Sequences Window,"Date: Wed, 21 Aug 2002 10:54:46 -0500 ...",0,1
1,Steve Burt <Steve_Burt@cursor-system.com>,"""'zzzzteana@yahoogroups.com'"" <zzzzteana@yahoo...","Thu, 22 Aug 2002 12:46:18 +0100",[zzzzteana] RE: Alexander,"Martin A posted:\nTassos Papadopoulos, the Gre...",0,1
2,"""Tim Chapman"" <timc@2ubh.com>",zzzzteana <zzzzteana@yahoogroups.com>,"Thu, 22 Aug 2002 13:52:38 +0100",[zzzzteana] Moscow bomber,Man Threatens Explosion In Moscow \n\nThursday...,0,1
3,Monty Solomon <monty@roscom.com>,undisclosed-recipient: ;,"Thu, 22 Aug 2002 09:15:25 -0400",[IRR] Klez: The Virus That Won't Die,Klez: The Virus That Won't Die\n \nAlready the...,0,1
4,Stewart Smith <Stewart.Smith@ee.ed.ac.uk>,zzzzteana@yahoogroups.com,"Thu, 22 Aug 2002 14:38:22 +0100",Re: [zzzzteana] Nothing like mama used to make,"> in adding cream to spaghetti carbonara, whi...",0,1


In [5]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 44963 entries, 0 to 44962
Data columns (total 7 columns):
 #   Column    Non-Null Count  Dtype 
---  ------    --------------  ----- 
 0   sender    44963 non-null  object
 1   receiver  44291 non-null  object
 2   date      44963 non-null  object
 3   subject   44919 non-null  object
 4   body      44962 non-null  object
 5   label     44963 non-null  int64 
 6   urls      44963 non-null  int64 
dtypes: int64(2), object(5)
memory usage: 2.4+ MB


#### Data cleaning/preprocessing

In [6]:
df.dropna(subset=['receiver', 'subject', 'body', 'label'], inplace=True)

In [7]:
df.isnull().sum()

sender      0
receiver    0
date        0
subject     0
body        0
label       0
urls        0
dtype: int64

In [8]:
df['urls'].nunique()

2

In [9]:
def advanced_text_cleaning(text):
    if pd.isna(text):
        return ""
    text = str(text).lower()
    text = re.sub(r'(from:|to:|subject:|date:|reply-to:|message-id:).*?\n', '', text)
    text = re.sub(r'<[^>]+>', ' ', text)
    text = re.sub(r'http[s]?://\S+', ' URL_TOKEN ', text)
    text = re.sub(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', ' EMAIL_TOKEN ', text)
    text = re.sub(r'[\+]?[1-9]?[0-9]{7,15}', ' PHONE_TOKEN ', text)
    text = re.sub(r'[!]{3,}', ' MULTIPLE_EXCLAMATION ', text)
    text = re.sub(r'[\?]{3,}', ' MULTIPLE_QUESTION ', text)
    text = re.sub(r'[$]{2,}', ' MULTIPLE_DOLLAR ', text)
    text = re.sub(r'[^a-zA-Z0-9\s!?$.]', ' ', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

In [10]:
# text-cleaning
df['subject_clean'] = df['subject'].apply(advanced_text_cleaning)
df['body_clean'] = df['body'].apply(advanced_text_cleaning)
df['text'] = df['subject_clean'] + ' ' + df['body_clean']

In [11]:
#label-cleaning
df['label'] = df['label'].apply(lambda x: 1 if x == 1 else 0)

In [12]:
df = df[df['text'].str.len() > 10]

In [13]:
df = df[['text', 'label', 'urls']].dropna()

In [14]:
df.info()

<class 'pandas.core.frame.DataFrame'>
Index: 44251 entries, 0 to 44962
Data columns (total 3 columns):
 #   Column  Non-Null Count  Dtype 
---  ------  --------------  ----- 
 0   text    44251 non-null  object
 1   label   44251 non-null  int64 
 2   urls    44251 non-null  int64 
dtypes: int64(2), object(1)
memory usage: 1.4+ MB


#### Feature engineering & scaling

In [15]:
def extract_email_features_row(text, url_count=0):
    features = {
        'char_count': len(text),
        'word_count': len(text.split()),
        'sentence_count': text.count('.') + text.count('!') + text.count('?'),
        'avg_word_length': len(text) / (len(text.split()) + 1e-5),
        'exclamation_count': text.count('!'),
        'question_count': text.count('?'),
        'dollar_count': text.count('$'),
        'capital_count': sum(1 for c in text if c.isupper()),
        'capital_ratio': sum(1 for c in text if c.isupper()) / (len(text) + 1e-5),
        'url_count': url_count,
        'has_url': int(url_count > 0),
        'has_attachment': int('attachment' in text.lower())
    }
    
    # sus-words detection
    suspicious_words = ['urgent', 'immediate', 'act now', 'limited time', 'click here', 
                       'free', 'winner', 'prize', 'verify', 'confirm', 'suspended']
    
    for word in suspicious_words:
        features[f'contains_{word.replace(" ", "_")}'] = int(word in text)
    return features

In [16]:
feature_dicts = [extract_email_features_row(text, url_count) 
                for text, url_count in zip(df['text'], df['urls'])]

feature_df = pd.DataFrame(feature_dicts)

In [17]:
print(f"Number of features: {len(feature_df.columns)}")
print(f"Dataset shape: {df.shape}")
print(f"Feature matrix shape: {feature_df.shape}")

Number of features: 23
Dataset shape: (44251, 3)
Feature matrix shape: (44251, 23)


In [18]:
print("\nFeature columns:")
print(feature_df.columns.tolist())


Feature columns:
['char_count', 'word_count', 'sentence_count', 'avg_word_length', 'exclamation_count', 'question_count', 'dollar_count', 'capital_count', 'capital_ratio', 'url_count', 'has_url', 'has_attachment', 'contains_urgent', 'contains_immediate', 'contains_act_now', 'contains_limited_time', 'contains_click_here', 'contains_free', 'contains_winner', 'contains_prize', 'contains_verify', 'contains_confirm', 'contains_suspended']


In [19]:
print("\nFeature statistics:")
print(feature_df.describe())


Feature statistics:
          char_count    word_count  sentence_count  avg_word_length  \
count   44251.000000  44251.000000    44251.000000     44251.000000   
mean     1319.883415    226.736548       18.595376         5.903886   
std      3633.188196    543.760521       57.077018         2.085177   
min        21.000000      2.000000        0.000000         2.414971   
25%       252.000000     44.000000        3.000000         5.320312   
50%       561.000000    100.000000        8.000000         5.584158   
75%      1473.000000    262.000000       19.000000         5.999996   
max    230437.000000  21713.000000     3359.000000        53.559313   

       exclamation_count  question_count  dollar_count  capital_count  \
count       44251.000000    44251.000000  44251.000000   44251.000000   
mean            1.088540        1.024813      0.631624      43.780932   
std             5.313654        2.829572      8.222032     277.029665   
min             0.000000        0.000000      0

#### Data Scaling

In [20]:
scaler = StandardScaler()
scaled_features = scaler.fit_transform(feature_df)

#### Data splitting

In [21]:
X_text = df['text'].tolist()
X_tabular = scaled_features
y = df['label'].values

In [22]:
X_text_train, X_text_test, X_tab_train, X_tab_test, y_train, y_test = train_test_split(
    X_text, X_tabular, y, test_size=0.2, stratify=y, random_state=42
)

#### Tokenization

In [23]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
max_length = 256

In [24]:
def tokenize_texts(texts):
    return tokenizer(texts, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt")

#### PYTORCH DATASET 

In [25]:
class HybridDataset(Dataset):
    def __init__(self, texts, tabular_features, labels):
        self.encodings = tokenize_texts(texts)
        self.tabular = torch.tensor(tabular_features, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item['tabular'] = self.tabular[idx]
        item['labels'] = self.labels[idx]
        return item

    def __len__(self):
        return len(self.labels)

In [26]:
train_dataset = HybridDataset(X_text_train, X_tab_train, y_train)
test_dataset = HybridDataset(X_text_test, X_tab_test, y_test)

In [27]:
len(train_dataset)

35400

In [28]:
len(test_dataset)

8851

In [29]:
import warnings
warnings.filterwarnings('ignore')

#### Hybrid Model (BERT + Tabular Features)

In [30]:
class BERTWithTabular(nn.Module):
    def __init__(self, tabular_dim):
        super().__init__()
        self.bert = BertModel.from_pretrained("bert-base-uncased")
        self.dropout = nn.Dropout(0.3)
        self.tabular_fc = nn.Linear(tabular_dim, 64)
        self.classifier = nn.Linear(self.bert.config.hidden_size + 64, 2)

    def forward(self, input_ids, attention_mask, tabular, **kwargs):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        tab_feats = torch.relu(self.tabular_fc(tabular))
        fused = torch.cat((pooled_output, tab_feats), dim=1)
        fused = self.dropout(fused)
        logits = self.classifier(fused)
        return logits

#### training and model SETUP

In [31]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [32]:
device

device(type='cpu')

In [33]:
model = BERTWithTabular(tabular_dim=X_tabular.shape[1]).to(device)

In [34]:
model


BERTWithTabular(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwi

In [35]:
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
loss_fn = nn.CrossEntropyLoss()

In [36]:
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128)

#### TRAINING

In [37]:
train_losses, val_losses, val_accuracies = [], [], []
best_val_loss = float('inf')
patience, counter = 2, 0

In [38]:
def train_epoch(model, train_loader, optimizer, loss_fn, device):
    model.train()
    total_loss = 0

    for step, batch in enumerate(train_loader, start=1):
        print(f"[Train] Step {step}/{len(train_loader)}")

        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        tabular = batch['tabular'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask, tabular)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(train_loader)

In [39]:
def evaluate(model, test_loader, loss_fn, device):
    model.eval()
    total_loss = 0
    all_preds, all_labels = [], []

    with torch.no_grad():
        for step, batch in enumerate(test_loader, start=1):
            print(f"[Eval] Step {step}/{len(test_loader)}")

            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            tabular = batch['tabular'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids, attention_mask, tabular)
            loss = loss_fn(outputs, labels)

            total_loss += loss.item()

            preds = torch.argmax(outputs, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    avg_loss = total_loss / len(test_loader)
    return avg_loss, all_preds, all_labels

#### Training loop

In [40]:
# Training configuration
num_epochs = 3
best_val_loss = float('inf')
patience = 2
counter = 0

print("-" * 100)
print(f"Device: {device}")
print("-" * 100)

----------------------------------------------------------------------------------------------------
Device: cpu
----------------------------------------------------------------------------------------------------


In [None]:
# Training loop
for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")    
    train_loss = train_epoch(model, train_loader, optimizer, loss_fn, device)
    # eval on validation set
    val_loss, val_preds, val_labels = evaluate(model, test_loader, loss_fn, device)
    
    # validation accuracy
    val_acc = accuracy_score(val_labels, val_preds)
    
    # metrics
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)
    
    # epoch results
    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Val Loss: {val_loss:.4f}")
    print(f"  Val Accuracy: {val_acc:.4f}")
    
    # early-stopping check
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        counter = 0
        # Save best model
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'epoch': epoch,
            'val_loss': val_loss,
            'val_accuracy': val_acc
        }, "best_model.pt")
        print(f"  New best model saved! (Val Loss: {val_loss:.4f})")
    else:
        counter += 1
        print(f"  No improvement ({counter}/{patience})")
        
        if counter >= patience:
            print("  Early stopping triggered!")
            break
    
    print("-" * 40)
print("Training completed!")

Epoch 1/3
[Train] Step 1/277


#### Eval

In [None]:
# loading best model for final evaluation
checkpoint = torch.load("best_model.pt")
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Best model from epoch {checkpoint['epoch'] + 1} loaded (Val Loss: {checkpoint['val_loss']:.4f})")

In [None]:
model.eval()
all_preds, all_labels = [], []
all_probas = []

In [None]:
with torch.no_grad():
    for batch in test_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        tabular = batch['tabular'].to(device)
        labels = batch['labels'].to(device)
        
        logits = model(input_ids, attention_mask, tabular)
        
        # Get predictions and probabilities
        preds = torch.argmax(logits, dim=1)
        probas = torch.softmax(logits, dim=1)[:, 1]  # Probability of positive class
        
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        all_probas.extend(probas.cpu().numpy())

In [None]:
from sklearn.metrics import accuracy_score, classification_report, f1_score, roc_auc_score, confusion_matrix, ConfusionMatrixDisplay

# final metrics
final_accuracy = accuracy_score(all_labels, all_preds)
final_f1 = f1_score(all_labels, all_preds)
final_auc = roc_auc_score(all_labels, all_probas)

print(f"final Accuracy = {final_accuracy:.4f}")
print(f"final F1 Score = {final_f1:.4f}")
print(f"final AUC-ROC = {final_auc:.4f}")

In [None]:
print("\nClassification Report:")
print(classification_report(all_labels, all_preds, target_names=['Ham', 'Spam']))

In [None]:
# confusion matrix
cm = confusion_matrix(all_labels, all_preds)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['Ham', 'Spam'])
disp.plot(cmap=plt.cm.Blues)
plt.title("Confusion Matrix")
plt.show()

#### Training/Validation Loss, Accuracy plots

In [None]:
# training curves
plt.figure(figsize=(15, 5))

#1: train&val Loss
plt.subplot(1, 3, 1)
plt.plot(range(1, len(train_losses) + 1), train_losses, 'b-', label='Training Loss', linewidth=2)
plt.plot(range(1, len(val_losses) + 1), val_losses, 'r-', label='Validation Loss', linewidth=2)
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True, alpha=0.3)

#2: val Accuracy
plt.subplot(1, 3, 2)
plt.plot(range(1, len(val_accuracies) + 1), val_accuracies, 'g-', label='Validation Accuracy', linewidth=2)
plt.title('Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True, alpha=0.3)

#3: train-val Loss 
plt.subplot(1, 3, 3)
epochs = range(1, len(train_losses) + 1)
plt.plot(epochs, train_losses, 'b-', label='Training Loss', linewidth=2)
plt.plot(epochs, val_losses, 'r-', label='Validation Loss', linewidth=2)
plt.fill_between(epochs, train_losses, val_losses, alpha=0.2, color='gray')
plt.title('Loss Comparison')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

#### TRAINING SUMMARY

In [None]:
print(f"epochs completed = {len(train_losses)}")
print(f"best validation loss = {best_val_loss:.4f}")
print(f"best validation accuracy = {max(val_accuracies):.4f}")
print(f"final training loss = {train_losses[-1]:.4f}")
print(f"final validation loss = {val_losses[-1]:.4f}")
print(f"final validation accuracy = {val_accuracies[-1]:.4f}")

#### SHAP INFERENCE

In [None]:
import pandas as pd
import numpy as np
import re
import torch
import shap
from transformers import BertTokenizer, BertModel
from shap.maskers import Text
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler

In [None]:
# ========== cleaning ==========
def advanced_text_cleaning(text):
    if pd.isna(text): return ""
    text = str(text).lower()
    text = re.sub(r'(from:|to:|subject:|date:|reply-to:|message-id:).*?\n', '', text)
    text = re.sub(r'<[^>]+>', ' ', text)
    text = re.sub(r'http[s]?://\S+', ' URL_TOKEN ', text)
    text = re.sub(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', ' EMAIL_TOKEN ', text)
    text = re.sub(r'[\+]?[1-9]?[0-9]{7,15}', ' PHONE_TOKEN ', text)
    text = re.sub(r'[!]{3,}', ' MULTIPLE_EXCLAMATION ', text)
    text = re.sub(r'[\?]{3,}', ' MULTIPLE_QUESTION ', text)
    text = re.sub(r'[$]{2,}', ' MULTIPLE_DOLLAR ', text)
    text = re.sub(r'[^a-zA-Z0-9\s!?$.]', ' ', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

# ========== Feature Engineering ==========
def extract_email_features_row(text, url_count=0):
    features = {
        'char_count': len(text),
        'word_count': len(text.split()),
        'sentence_count': text.count('.') + text.count('!') + text.count('?'),
        'avg_word_length': len(text) / (len(text.split()) + 1e-5),
        'exclamation_count': text.count('!'),
        'question_count': text.count('?'),
        'dollar_count': text.count('$'),
        'capital_count': sum(1 for c in text if c.isupper()),
        'capital_ratio': sum(1 for c in text if c.isupper()) / (len(text) + 1e-5),
        'url_count': url_count,
        'has_url': int(url_count > 0),
        'has_attachment': int('attachment' in text.lower())
    }
    suspicious_words = ['urgent', 'immediate', 'act now', 'limited time', 'click here',
                        'free', 'winner', 'prize', 'verify', 'confirm', 'suspended']
    for word in suspicious_words:
        features[f'contains_{word.replace(" ", "_")}'] = int(word in text)
    return pd.DataFrame([features])

In [None]:
# ========== Model (Only forward compatible for SHAP) ==========
class BERTWithTabular(torch.nn.Module):
    def __init__(self, tabular_dim):
        super().__init__()
        self.bert = BertModel.from_pretrained("bert-base-uncased")
        self.dropout = torch.nn.Dropout(0.3)
        self.tabular_fc = torch.nn.Linear(tabular_dim, 64)
        self.classifier = torch.nn.Linear(self.bert.config.hidden_size + 64, 2)

    def forward(self, input_ids, attention_mask, tabular):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        tab_feats = torch.relu(self.tabular_fc(tabular))
        fused = torch.cat((pooled_output, tab_feats), dim=1)
        fused = self.dropout(fused)
        logits = self.classifier(fused)
        return logits

In [None]:
# ========== loding Trained Model ==========
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BERTWithTabular(tabular_dim=23).to(device)

In [None]:
checkpoint = torch.load("best_model.pt", map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

In [None]:
# ========== SHAP Prediction Wrapper ==========
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
max_len = 256

def predict_shap_raw(texts):
    if isinstance(texts, np.ndarray):
        texts = texts.tolist()
    if isinstance(texts, str):
        texts = [texts]

    enc = tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=max_len).to(device)
    with torch.no_grad():
        output = model.bert(**enc).pooler_output
    return output.cpu().numpy()

In [None]:
# ========== Inference Pipeline ==========
def run_shap_inference(row, scaler, return_shap=True):
    subject_clean = advanced_text_cleaning(row['subject'])
    body_clean = advanced_text_cleaning(row['body'])
    full_text = subject_clean + " " + body_clean
    print(f"[DEBUG] Text length: {len(full_text)}")
    print(f"[DEBUG] Text preview: {full_text[:300]}")
    
    # Tabular features
    features_df = extract_email_features_row(full_text, row.get("urls", 0))
    tabular_tensor = torch.tensor(scaler.transform(features_df), dtype=torch.float32).to(device)

    # Tokenization
    enc = tokenizer([full_text], padding=True, truncation=True, return_tensors="pt", max_length=max_len).to(device)

    # Model prediction
    with torch.no_grad():
        logits = model(enc['input_ids'], enc['attention_mask'], tabular_tensor)
        probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
        predicted_class = int(np.argmax(probs))
        confidence = probs[predicted_class]

    result = {
        "text": full_text,
        "predicted_class": predicted_class,
        "confidence": confidence,
        "probabilities": probs.tolist()
    }

    # SHAP explanation
    if return_shap:
        text_masker = Text(tokenizer=tokenizer)
        explainer = shap.Explainer(predict_shap_raw, masker=text_masker)
        shap_values = explainer([full_text])  # wrap in list
        result["shap_values"] = shap_values[0]

    return result

In [None]:
df1 = pd.read_csv("C:/Users/LOQ/Desktop/SpamAssasin.csv")
df2 = pd.read_csv("C:/Users/LOQ/Desktop/CEAS_08.csv")
df = pd.concat([df1, df2], ignore_index=True)
df.dropna(subset=['receiver', 'subject', 'body', 'label'], inplace=True)

In [None]:
df['subject_clean'] = df['subject'].apply(advanced_text_cleaning)
df['body_clean'] = df['body'].apply(advanced_text_cleaning)
df['text'] = df['subject_clean'] + ' ' + df['body_clean']
df = df[df['text'].str.len() > 10]

In [None]:
# Fit scaler on all features
feature_df = pd.concat([extract_email_features_row(t) for t in df['text']], ignore_index=True)

In [None]:
scaler = StandardScaler()
scaler.fit(feature_df)

In [None]:
# ========= pick one sample for testing/inference =========
sample_row = df.iloc[20000]

In [None]:
sample_row['text']

In [None]:
# Run inference + SHAP
result = run_shap_inference(sample_row, scaler)

In [None]:
# ========= print & visualize =========
print(f"Prediction = {'Phishing' if result['predicted_class'] == 1 else 'Legitimate'}")
print(f"Confidence = {result['confidence']:.3f}")
print(f"Probabilities = {result['probabilities']}")

In [None]:
# SHAP-visualization
shap.plots.text(result['shap_values'])

#### LIME INFERENCE

In [None]:
from lime.lime_text import LimeTextExplainer

In [None]:
def run_lime_inference(row, scaler, num_features=10):
    subject_clean = advanced_text_cleaning(row['subject'])
    body_clean = advanced_text_cleaning(row['body'])
    full_text = subject_clean + " " + body_clean
    base_features_df = extract_email_features_row(full_text, row.get("urls", 0))
    base_features_scaled = scaler.transform(base_features_df)
    tabular_tensor_base = torch.tensor(base_features_scaled, dtype=torch.float32).to(device)

    # Define LIME prediction wrapper
    def lime_predict(texts):
        results = []
        for txt in texts:
            # Clean and extract tabular features for each LIME-perturbed sample
            features_df = extract_email_features_row(txt)
            features_scaled = scaler.transform(features_df)
            tabular_tensor = torch.tensor(features_scaled, dtype=torch.float32).to(device)

            # Tokenize
            enc = tokenizer([txt], padding=True, truncation=True, return_tensors="pt", max_length=max_len).to(device)

            with torch.no_grad():
                logits = model(enc['input_ids'], enc['attention_mask'], tabular_tensor)
                probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
                results.append(probs)

        return np.array(results)

    # explainer
    explainer = LimeTextExplainer(class_names=["Legitimate", "Phishing"])
    explanation = explainer.explain_instance(full_text, lime_predict, num_features=num_features)

    # prediction
    pred_probs = lime_predict([full_text])[0]
    predicted_class = int(np.argmax(pred_probs))
    confidence = pred_probs[predicted_class]

    result = {
        "text": full_text,
        "predicted_class": predicted_class,
        "confidence": confidence,
        "probabilities": pred_probs.tolist(),
        "lime_explanation": explanation
    }

    return result

In [None]:
sample_row = df.iloc[20000]
result_lime = run_lime_inference(sample_row, scaler)

In [None]:
result_lime['lime_explanation'].show_in_notebook()