In [1]:
#imports
import joblib
import pandas as pd
import numpy as np
import torch
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, accuracy_score, precision_score, recall_score, f1_score
from sklearn.base import BaseEstimator, TransformerMixin
from sentence_transformers import SentenceTransformer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
df = pd.read_csv('../../data/analysis/emails_augmented.csv')
assert 'body_no_stopwords' in df.columns and 'label' in df.columns, "Missing required columns."
X = df['body_no_stopwords']
y = df['label']

In [3]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [4]:
# Custom transformer, accepting model_name for flexibility
class SBERTTransformer(BaseEstimator, TransformerMixin):
    def __init__(self, model_name='all-MiniLM-L6-v2', batch_size=32):
        self.model_name = model_name
        self.batch_size = batch_size
        self.model = None
        
    def fit(self, X, y=None):
        from sentence_transformers import SentenceTransformer
        self.model = SentenceTransformer(self.model_name)
        return self

    def transform(self, X):
        return self.model.encode(list(X), batch_size=self.batch_size, show_progress_bar=False)

In [5]:
import torch    # Get number of CPU threads used
torch.set_num_threads(16)
print(torch.get_num_threads())   

16


In [7]:
model = SentenceTransformer('all-MiniLM-L6-v2') 
X_train_emb = model.encode(X_train.tolist(), batch_size=64, show_progress_bar=True)
X_test_emb = model.encode(X_test.tolist(), batch_size=64, show_progress_bar=True)


  return forward_call(*args, **kwargs)
Batches: 100%|██████████| 1026/1026 [15:57<00:00,  1.07it/s]
Batches: 100%|██████████| 257/257 [03:52<00:00,  1.11it/s]


In [8]:
import os
from joblib import dump 
output_dir = '../../output/embeddings'
os.makedirs(output_dir, exist_ok=True)
dump(X_train_emb, os.path.join(output_dir, 'X_train_emb.joblib'))
dump(X_test_emb, os.path.join(output_dir, 'X_test_emb.joblib'))
dump(y_train, os.path.join(output_dir, 'y_train.joblib'))
dump(y_test, os.path.join(output_dir, 'y_test.joblib'))

print("Embeddings saved successfully!")


Embeddings saved successfully!
