In [51]:
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import classification_report
from sklearn.svm import SVC

import pickle
import numpy as np

import sys
sys.path.append('..')

In [52]:
def get_data():
    with open('../data/all_features.pkl', 'rb') as f:
        data = pickle.load(f)

    x = np.array(data[0])
    y = np.array(data[1])
    
    x = MinMaxScaler().fit_transform(x)

    return x, y

def test_accuracy(model):
    x, y = get_data()
    x_train, x_test, y_train, y_test = train_test_split(x, y, random_state=42, test_size=0.2)

    model.fit(x_train, y_train)
    results = model.predict(x_test)

    print(classification_report(y_test, results))

x, y = get_data()

## Logistic Regression

In [53]:
lr = LogisticRegression(solver='newton-cg', random_state=42, n_jobs=-1)
test_accuracy(lr)

              precision    recall  f1-score   support

         ang       0.55      0.47      0.51       205
         fru       0.47      0.50      0.48       380
         hap       0.49      0.47      0.48       314
         neu       0.53      0.50      0.51       362
         sad       0.61      0.73      0.67       215

    accuracy                           0.52      1476
   macro avg       0.53      0.53      0.53      1476
weighted avg       0.52      0.52      0.52      1476



In [None]:
lr_g = GridSearchCV(LogisticRegression(solver='newton-cg', random_state=42, n_jobs=-1), param_grid={
    'penalty': ['l2', 'none'],
    'multi_class': ['auto', 'ovr', 'multinomial'],
    'class_weight': ['balanced', None],
    'C': np.logspace(-4, 4, 7)
}, cv=3, return_train_score=False, verbose=10, n_jobs=-1)

lr_g.fit(x, y)


In [None]:
print(lr_g.best_score_)
print(lr_g.best_params_)

## SVM

In [None]:
svm = SVC(kernel='linear', probability=True, random_state=42)
test_accuracy(svm)

## Random Forest

In [None]:
rf = RandomForestClassifier(random_state=42, n_jobs=-1)
test_accuracy(rf)