In [1]:
import pandas as pd
import numpy as np
import os
import dill as pickle
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import cross_val_score
from sklearn.metrics import f1_score, accuracy_score, roc_auc_score

In [2]:
def load_sentence_embeddings(model='DistilBERT', features_path='features/', filename='sentence_embeddings'):
    if model == 'DistilBERT':
        tst_df = pd.concat([pickle.load(open(os.path.join(features_path, file), 'rb')) for file in os.listdir(features_path) if filename in file])
        return tst_df

In [3]:
tst_df = load_sentence_embeddings(features_path="features_250/")

In [4]:
tst_df.shape

(22332, 2)

In [5]:
tst_df.head()

Unnamed: 0,sentence_embeddings,label
0,"[-0.185443714261055, -0.11448108404874802, -0....",0
1,"[-0.3724660873413086, 0.04101637750864029, -0....",0
2,"[-0.41084980964660645, -0.1713167279958725, -0...",0
3,"[-0.14235153794288635, 0.19862940907478333, -0...",0
4,"[-0.47683459520339966, -0.040994927287101746, ...",0


In [6]:
tst_df.label.value_counts()

0    14888
1     7444
Name: label, dtype: int64

In [7]:
features = np.array(tst_df.sentence_embeddings.tolist())

In [8]:
labels = tst_df['label']

In [9]:
train_features, test_features, train_labels, test_labels = train_test_split(features, labels)

In [10]:
lr_clf = LogisticRegression(max_iter=10000)
lr_clf.fit(train_features, train_labels)

In [11]:
predictions = lr_clf.predict(test_features)
probs = lr_clf.predict_proba(test_features)[:, 1]

In [12]:
result_table = [["F1", "Accuracy", "AUC"]]

In [13]:
result_table.append([
    round(f1_score(test_labels, predictions), 3),
    round(accuracy_score(test_labels, predictions), 3),
    round(roc_auc_score(test_labels, probs), 3),
])

In [14]:
result_table

[['F1', 'Accuracy', 'AUC'], [0.955, 0.97, 0.995]]

In [15]:
from tabulate import tabulate
print(tabulate(result_table, headers="firstrow", tablefmt="grid"))

+-------+------------+-------+
|    F1 |   Accuracy |   AUC |
| 0.955 |       0.97 | 0.995 |
+-------+------------+-------+
