In [23]:
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV, train_test_split, RandomizedSearchCV, ParameterGrid
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import classification_report
from sklearn.svm import SVC
from sklearn.model_selection import cross_validate, cross_val_score, StratifiedKFold

import pickle
import numpy as np

import sys
sys.path.append('..')

from process_dataset import speech_features

In [24]:
def print_scores(scores):
    print('Accuracy: ', np.mean(scores['test_accuracy']))
    print('F1 Macro: ', np.mean(scores['test_f1_macro']))
    print('Precision Macro: ', np.mean(scores['test_precision_macro']))
    print('Recall Macro: ', np.mean(scores['test_recall_macro']))

def get_data():
    with open('../data/speech_features.pkl', 'rb') as f:
        data = pickle.load(f)

    x = np.array(data[0])
    y = np.array(data[1])
    
    x = MinMaxScaler().fit_transform(x)

    return x, y

x, y = get_data()

def cross_validate_model(model):
    x, y = get_data()
    scoring = {'accuracy': 'accuracy',
           'f1_macro': 'f1_macro',
           'precision_macro': 'precision_macro',
           'recall_macro' : 'recall_macro'}

    scores = cross_validate(model, x, y, cv=5, scoring=scoring, n_jobs=-1)
    print_scores(scores)

def check_accuracy(model):
    x, y = get_data()
    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)
    model.fit(x_train, y_train)
    results = model.predict(x_test)

    print(classification_report(y_test, results))

## Logistic Regression

In [25]:
# lr = LogisticRegression(penalty='l1', solver='liblinear')
# lr = LogisticRegression(C=1, class_weight='balanced', random_state=42, solver='sag')
lr = LogisticRegression(C=0.75, class_weight='balanced', random_state=42, solver='liblinear')
# cross_validate_model(lr)
check_accuracy(lr)

              precision    recall  f1-score   support

         ang       0.63      0.71      0.67       208
         hap       0.64      0.58      0.61       317
         neu       0.72      0.59      0.65       369
         sad       0.59      0.79      0.68       213

    accuracy                           0.65      1107
   macro avg       0.65      0.67      0.65      1107
weighted avg       0.66      0.65      0.65      1107



### Random Search

In [8]:
params = {
    'solver': ['liblinear', 'saga', 'sag', 'newton-cg'],
    'penalty' : ['l1', 'l2', 'elasticnet', 'none'],
    'C' : [0.001, 0.01, 0.1, 1, 10, 30, 50],
    'fit_intercept': [True, False],
    'class_weight': ['balanced', None],
    'multi_class': ['auto', 'ovr', 'multinomial']
}

In [None]:
lr_g = RandomizedSearchCV(LogisticRegression(random_state=42), param_distributions=params, n_iter=50, n_jobs=-1, cv=5, random_state=42, verbose=5)

lr_g.fit(x, y)

In [11]:
print(lr_g.best_params_)
print(lr_g.best_score_)
print(lr_g.best_estimator_)

{'solver': 'sag', 'penalty': 'l2', 'multi_class': 'auto', 'fit_intercept': True, 'class_weight': 'balanced', 'C': 1}
0.627367516592586
LogisticRegression(C=1, class_weight='balanced', random_state=42, solver='sag')


### Grid Search

In [20]:
params = {
    'penalty': ['l2', 'none', 'l1'],
    'solver' : ['sag', 'liblinear'],
    'C': [0.75, 1, 3, 5]
}
lr_g = GridSearchCV(LogisticRegression(random_state=42, multi_class='auto', fit_intercept=True, class_weight='balanced'), param_grid=params, cv=5, return_train_score=False, verbose=5, n_jobs=-1)

pg = ParameterGrid(params)
print(len(pg), 'combinations per fold')


24 combinations per fold


In [None]:
lr_g.fit(x, y)


In [22]:
print(lr_g.best_score_)
print(lr_g.best_params_)
print(lr_g.best_estimator_)

0.6351429584217482
{'C': 0.75, 'penalty': 'l2', 'solver': 'liblinear'}
LogisticRegression(C=0.75, class_weight='balanced', random_state=42,
                   solver='liblinear')


## SVM

In [None]:
svm = SVC(kernel='linear', probability=True, random_state=42)
test_accuracy(svm)

## Random Forest

In [None]:
rf = RandomForestClassifier(random_state=42)
test_accuracy(rf)