In [1]:
%load_ext autoreload

%autoreload 2
%matplotlib inline
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.cluster import KMeans
from sklearn.svm import SVC
from sklearn.decomposition import PCA
from sklearn import metrics

# from mlxtend.plotting import plot_decision_regions
from sklearn import preprocessing
from sklearn.linear_model import LogisticRegression
from ast import literal_eval

import warnings
import numpy as np
from collections import OrderedDict

from lob_data_utils import lob, db_result, model, roc_results
from lob_data_utils.svm_calculation import lob_svm
import os


sns.set_style('whitegrid')
warnings.filterwarnings('ignore')

In [2]:
data_length = 10000
r = 1.0
s = 1.0
stocks = list(roc_results.result_cv_10000.keys())

In [3]:
def get_mean_scores(scores: dict) -> dict:
    mean_scores = {}
    for k, v in scores.items():
        mean_scores[k] = np.mean(v)
    return mean_scores

def get_score_for_clf(clf, df_test):
    x_test = df_test[['queue_imbalance']]
    y_test = df_test['mid_price_indicator'].values
    return model.test_model(clf, x_test, y_test)

def get_logistic_regression(stock, data_length):
    df, df_test = lob.load_prepared_data(
        stock, data_dir='../gaussian_filter/data', cv=False, length=data_length)
    clf = LogisticRegression()
    train_x = df[['queue_imbalance']]

    scores = model.validate_model(clf, train_x, df['mid_price_indicator'])
    res = {
        **get_mean_scores(scores),
        'stock': stock,
        'kernel': 'logistic',
    }
    test_scores = get_score_for_clf(clf, df_test)
    return {**res, **test_scores}

In [4]:
df_res = pd.DataFrame()
for stock in stocks:
    #pd.read_csv('svm_features_{}_len{}_r{}_s{}.csv'.format(stock, data_length, r, s))
    filename = 'svm_pca_gdf_{}_len{}_r{}_s{}.csv'.format(stock, data_length, r, s)
    if os.path.exists(filename):
        df_res = df_res.append(pd.read_csv(filename))
#df_res.drop(columns=['Unnamed: 0'], inplace=True)
columns = ['C', 'f1', 'features', 'gamma', 'kappa',
           'matthews', 'roc_auc', 'stock',
       'test_f1', 'test_kappa', 'test_matthews', 'test_roc_auc']
df_res[columns].sort_values(by='matthews', ascending=False).groupby('stock').head(1)

Unnamed: 0,C,f1,features,gamma,kappa,matthews,roc_auc,stock,test_f1,test_kappa,test_matthews,test_roc_auc
35,100.0,0.604039,gdf_0-50_que,0.001,0.188148,0.189623,0.593941,11946,0.631235,0.205205,0.205644,0.602332
28,10.0,0.593088,pca_gdf_que4,0.001,0.180103,0.181585,0.590329,3879,0.566337,0.123816,0.123915,0.561932
38,100.0,0.600282,pca_gdf_que1,1.0,0.169434,0.171004,0.584663,3035,0.575083,0.100511,0.100688,0.550177
24,1.0,0.626286,gdf_24-26_que_prev,1.0,0.163578,0.165601,0.581485,4320,0.630831,0.160028,0.161926,0.579402
17,0.1,0.601108,pca_gdf_que4,1.0,0.157448,0.164031,0.578763,1956,0.601738,0.129149,0.131675,0.564618
35,100.0,0.569319,pca_gdf_que_prev9,0.001,0.162968,0.16401,0.581638,7858,0.532731,0.164913,0.166582,0.581953
30,10.0,0.587315,pca_gdf_que_prev5,0.1,0.157832,0.159166,0.579142,9761,0.603878,0.139908,0.141347,0.56983
28,10.0,0.570876,pca_gdf_que2,0.001,0.153234,0.158249,0.576935,12417,0.594021,0.1585,0.159298,0.579349
36,100.0,0.625132,gdf_20_30_que,0.01,0.149926,0.157989,0.575127,13061,0.642194,0.091467,0.098884,0.544805
23,1.0,0.584832,gdf_20_30_que,0.1,0.149201,0.155399,0.574786,12255,0.643533,0.094336,0.111492,0.547105


In [5]:
log_res = []
for stock in stocks:
    log_res.append(get_logistic_regression(stock, data_length))
df_log_res = pd.DataFrame(log_res)
df_log_res['stock'] = df_log_res['stock'].values.astype(np.int)
df_log_res.index = df_log_res['stock'].values.astype(np.int)

In [6]:
df_gdf_best = df_res[columns].sort_values(by='test_matthews', ascending=False).groupby('stock').head(1)
df_gdf_best['stock'] = df_gdf_best['stock'].values.astype(np.int)
df_gdf_best.index = df_gdf_best['stock'].values.astype(np.int)

In [7]:
df_all = pd.merge(df_gdf_best, df_log_res, on='stock', suffixes=['_svm', '_log'])

In [8]:
all_columns = [ 'features', 'matthews_svm', 'matthews_log',  'test_matthews_svm',  'test_matthews_log',
       'roc_auc_svm', 'roc_auc_log', 'test_roc_auc_svm',  'test_roc_auc_log', 'stock', 
               'f1_svm', 'f1_log', 'test_f1_svm', 'test_f1_log', ]
df_all[all_columns]

Unnamed: 0,features,matthews_svm,matthews_log,test_matthews_svm,test_matthews_log,roc_auc_svm,roc_auc_log,test_roc_auc_svm,test_roc_auc_log,stock,f1_svm,f1_log,test_f1_svm,test_f1_log
0,gdf_0-50_que,0.189623,0.186824,0.205644,0.203627,0.593941,0.592373,0.602332,0.601087,11946,0.604039,0.595737,0.631235,0.634056
1,gdf_24-26_que_prev,0.159977,0.156657,0.175396,0.163789,0.574495,0.577752,0.586201,0.581178,4320,0.57292,0.610483,0.634821,0.620721
2,gdf_23-27_que,0.132019,0.142499,0.174273,0.162155,0.562168,0.570582,0.584643,0.580693,10508,0.617642,0.582458,0.622403,0.595089
3,pca_gdf_que_prev9,0.134107,0.146384,0.173381,0.171235,0.567367,0.572852,0.583913,0.584512,7858,0.491723,0.551977,0.517343,0.539665
4,que,0.141919,0.137272,0.171835,0.168692,0.57081,0.567489,0.585412,0.583011,3161,0.582392,0.579932,0.610824,0.621412
5,pca_gdf_que1,0.113993,0.125634,0.166685,0.162023,0.556786,0.56234,0.583233,0.58089,2651,0.501255,0.553881,0.599229,0.59779
6,gdf_0-50_que_prev,0.12822,0.129177,0.164562,0.150086,0.562229,0.562399,0.582281,0.573388,3022,0.606264,0.609023,0.61467,0.635929
7,pca_gdf_que_prev7,0.125957,0.130301,0.162579,0.146515,0.560368,0.564461,0.5808,0.573201,1113,0.49068,0.534083,0.561001,0.591563
8,pca_gdf_que1,0.046888,0.138016,0.160484,0.138375,0.521126,0.568685,0.577134,0.568627,1431,0.656601,0.585431,0.627641,0.593882
9,pca_gdf_que2,0.155395,0.150172,0.160438,0.154446,0.575317,0.573432,0.579881,0.576707,12417,0.573145,0.591476,0.595479,0.596958


In [9]:
len(df_all[df_all['matthews_svm'] > df_all['matthews_log']][all_columns]), len(df_all)

(18, 53)

In [10]:
len(df_all[df_all['roc_auc_svm'] > df_all['roc_auc_log']][all_columns]), len(df_all)

(15, 53)

In [11]:
df_all[df_all['test_matthews_svm'] < df_all['test_matthews_log']][all_columns]

Unnamed: 0,features,matthews_svm,matthews_log,test_matthews_svm,test_matthews_log,roc_auc_svm,roc_auc_log,test_roc_auc_svm,test_roc_auc_log,stock,f1_svm,f1_log,test_f1_svm,test_f1_log
12,pca_gdf_que_prev10,0.148575,0.145601,0.150493,0.153016,0.572621,0.572127,0.575244,0.576301,11869,0.591002,0.582713,0.579495,0.561531
27,pca_gdf_que_prev2,0.123121,0.130253,0.127163,0.13486,0.559384,0.564948,0.563468,0.567422,9086,0.522086,0.554918,0.539195,0.567149
49,pca_gdf_que2,0.092872,0.132624,0.085228,0.086034,0.544369,0.563571,0.536509,0.53853,13003,0.325912,0.479875,0.377632,0.409548
50,gdf_24-26_que_prev,0.059728,0.111922,0.08167,0.086004,0.529633,0.554976,0.540834,0.543004,9063,0.503067,0.525106,0.54008,0.545274
51,gdf_20_30_que,0.093081,0.117283,0.072547,0.074043,0.545689,0.557431,0.536262,0.537,9058,0.475583,0.53941,0.530157,0.528993


In [12]:
df_all[df_all['test_roc_auc_svm'] < df_all['test_roc_auc_log']][all_columns]

Unnamed: 0,features,matthews_svm,matthews_log,test_matthews_svm,test_matthews_log,roc_auc_svm,roc_auc_log,test_roc_auc_svm,test_roc_auc_log,stock,f1_svm,f1_log,test_f1_svm,test_f1_log
3,pca_gdf_que_prev9,0.134107,0.146384,0.173381,0.171235,0.567367,0.572852,0.583913,0.584512,7858,0.491723,0.551977,0.517343,0.539665
12,pca_gdf_que_prev10,0.148575,0.145601,0.150493,0.153016,0.572621,0.572127,0.575244,0.576301,11869,0.591002,0.582713,0.579495,0.561531
15,pca_gdf_que_prev5,0.093566,0.12563,0.145583,0.143575,0.544434,0.561209,0.569429,0.570249,1907,0.534906,0.599786,0.646465,0.63052
27,pca_gdf_que_prev2,0.123121,0.130253,0.127163,0.13486,0.559384,0.564948,0.563468,0.567422,9086,0.522086,0.554918,0.539195,0.567149
40,pca_gdf_que10,0.104658,0.113968,0.106215,0.106202,0.551626,0.556876,0.553005,0.553061,9269,0.521679,0.546408,0.542199,0.567554
41,gdf_20_30_que_prev,0.021517,0.132153,0.105638,0.094105,0.510059,0.565317,0.539038,0.547029,12456,0.666336,0.596622,0.664442,0.561743
49,pca_gdf_que2,0.092872,0.132624,0.085228,0.086034,0.544369,0.563571,0.536509,0.53853,13003,0.325912,0.479875,0.377632,0.409548
50,gdf_24-26_que_prev,0.059728,0.111922,0.08167,0.086004,0.529633,0.554976,0.540834,0.543004,9063,0.503067,0.525106,0.54008,0.545274
51,gdf_20_30_que,0.093081,0.117283,0.072547,0.074043,0.545689,0.557431,0.536262,0.537,9058,0.475583,0.53941,0.530157,0.528993
