In [69]:
import numpy as np
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn import svm
from sklearn import tree
from sklearn.model_selection import StratifiedKFold, train_test_split, GridSearchCV, cross_val_score
from sklearn.preprocessing import StandardScaler, LabelEncoder, OneHotEncoder, MinMaxScaler
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
from sklearn import linear_model
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.neighbors import KNeighborsClassifier
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from sklearn import tree

In [52]:
#dataset 1 - 3 algorithms - 5 trials
adult_data = pd.read_csv('adult.data', header=None)
adult_data.columns = ['age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'capital-gain', 'capital-loss', 'hours-per-week', 'native-country', 'salary']
adult_data = adult_data.dropna()
adult_data[:10]

Unnamed: 0,age,workclass,fnlwgt,education,education-num,marital-status,occupation,relationship,race,sex,capital-gain,capital-loss,hours-per-week,native-country,salary
0,39,State-gov,77516,Bachelors,13,Never-married,Adm-clerical,Not-in-family,White,Male,2174,0,40,United-States,<=50K
1,50,Self-emp-not-inc,83311,Bachelors,13,Married-civ-spouse,Exec-managerial,Husband,White,Male,0,0,13,United-States,<=50K
2,38,Private,215646,HS-grad,9,Divorced,Handlers-cleaners,Not-in-family,White,Male,0,0,40,United-States,<=50K
3,53,Private,234721,11th,7,Married-civ-spouse,Handlers-cleaners,Husband,Black,Male,0,0,40,United-States,<=50K
4,28,Private,338409,Bachelors,13,Married-civ-spouse,Prof-specialty,Wife,Black,Female,0,0,40,Cuba,<=50K
5,37,Private,284582,Masters,14,Married-civ-spouse,Exec-managerial,Wife,White,Female,0,0,40,United-States,<=50K
6,49,Private,160187,9th,5,Married-spouse-absent,Other-service,Not-in-family,Black,Female,0,0,16,Jamaica,<=50K
7,52,Self-emp-not-inc,209642,HS-grad,9,Married-civ-spouse,Exec-managerial,Husband,White,Male,0,0,45,United-States,>50K
8,31,Private,45781,Masters,14,Never-married,Prof-specialty,Not-in-family,White,Female,14084,0,50,United-States,>50K
9,42,Private,159449,Bachelors,13,Married-civ-spouse,Exec-managerial,Husband,White,Male,5178,0,40,United-States,>50K


In [3]:
def summerize_data(df):
    for column in df.columns:
        print(column)
        if df.dtypes[column] == np.object:
            print(df[column].value_counts())
        else:
            print(df[column].describe())
        print('\n')

In [53]:
summerize_data(adult_data)

age
count    32561.000000
mean        38.581647
std         13.640433
min         17.000000
25%         28.000000
50%         37.000000
75%         48.000000
max         90.000000
Name: age, dtype: float64


workclass
 Private             22696
 Self-emp-not-inc     2541
 Local-gov            2093
 ?                    1836
 State-gov            1298
 Self-emp-inc         1116
 Federal-gov           960
 Without-pay            14
 Never-worked            7
Name: workclass, dtype: int64


fnlwgt
count    3.256100e+04
mean     1.897784e+05
std      1.055500e+05
min      1.228500e+04
25%      1.178270e+05
50%      1.783560e+05
75%      2.370510e+05
max      1.484705e+06
Name: fnlwgt, dtype: float64


education
 HS-grad         10501
 Some-college     7291
 Bachelors        5355
 Masters          1723
 Assoc-voc        1382
 11th             1175
 Assoc-acdm       1067
 10th              933
 7th-8th           646
 Prof-school       576
 9th               514
 12th              433
 Doctor

In [5]:
def encode(df):
    result = df.copy()
    encoders = {}
    for column in result.columns:
        if result.dtypes[column] == np.object:
            encoders[column] = LabelEncoder()
            result[column] = encoders[column].fit_transform(result[column])
    return result, encoders

In [54]:
adult_encoded, encoders = encode(adult_data)
adult_salary = adult_encoded['salary']

In [64]:
numeric_subset = adult_data.select_dtypes('number')
categorical_subset = adult_data.select_dtypes('object')
categorical_subset = pd.get_dummies(categorical_subset[categorical_subset.columns.drop('salary')])
categorical_subset[:10]

Unnamed: 0,workclass_ ?,workclass_ Federal-gov,workclass_ Local-gov,workclass_ Never-worked,workclass_ Private,workclass_ Self-emp-inc,workclass_ Self-emp-not-inc,workclass_ State-gov,workclass_ Without-pay,education_ 10th,...,native-country_ Portugal,native-country_ Puerto-Rico,native-country_ Scotland,native-country_ South,native-country_ Taiwan,native-country_ Thailand,native-country_ Trinadad&Tobago,native-country_ United-States,native-country_ Vietnam,native-country_ Yugoslavia
0,0,0,0,0,0,0,0,1,0,0,...,0,0,0,0,0,0,0,1,0,0
1,0,0,0,0,0,0,1,0,0,0,...,0,0,0,0,0,0,0,1,0,0
2,0,0,0,0,1,0,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0
3,0,0,0,0,1,0,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0
4,0,0,0,0,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
5,0,0,0,0,1,0,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0
6,0,0,0,0,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
7,0,0,0,0,0,0,1,0,0,0,...,0,0,0,0,0,0,0,1,0,0
8,0,0,0,0,1,0,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0
9,0,0,0,0,1,0,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0


In [60]:
adult_onehot = pd.concat([numeric_subset, categorical_subset], axis = 1)
adult_onehot[:10]

Unnamed: 0,age,fnlwgt,education-num,capital-gain,capital-loss,hours-per-week,workclass_ ?,workclass_ Federal-gov,workclass_ Local-gov,workclass_ Never-worked,...,native-country_ Portugal,native-country_ Puerto-Rico,native-country_ Scotland,native-country_ South,native-country_ Taiwan,native-country_ Thailand,native-country_ Trinadad&Tobago,native-country_ United-States,native-country_ Vietnam,native-country_ Yugoslavia
0,39,77516,13,2174,0,40,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0
1,50,83311,13,0,0,13,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0
2,38,215646,9,0,0,40,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0
3,53,234721,7,0,0,40,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0
4,28,338409,13,0,0,40,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
5,37,284582,14,0,0,40,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0
6,49,160187,5,0,0,16,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
7,52,209642,9,0,0,45,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0
8,31,45781,14,14084,0,50,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0
9,42,159449,13,5178,0,40,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0


In [56]:
adult_onehot = adult_onehot.replace({np.inf: np.nan, -np.inf: np.nan})
adult_onehot = adult_onehot.dropna()
adult_onehot[:10]

Unnamed: 0,age,fnlwgt,education-num,capital-gain,capital-loss,hours-per-week,workclass_ ?,workclass_ Federal-gov,workclass_ Local-gov,workclass_ Never-worked,...,native-country_ Portugal,native-country_ Puerto-Rico,native-country_ Scotland,native-country_ South,native-country_ Taiwan,native-country_ Thailand,native-country_ Trinadad&Tobago,native-country_ United-States,native-country_ Vietnam,native-country_ Yugoslavia
0,39,77516,13,2174,0,40,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0
1,50,83311,13,0,0,13,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0
2,38,215646,9,0,0,40,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0
3,53,234721,7,0,0,40,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0
4,28,338409,13,0,0,40,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
5,37,284582,14,0,0,40,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0
6,49,160187,5,0,0,16,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
7,52,209642,9,0,0,45,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0
8,31,45781,14,14084,0,50,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0
9,42,159449,13,5178,0,40,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0


In [101]:
#Random Forest
for i in range(5):
    print('trial', i+1)
    adult_rf_list = []
    
    for j in range(5):
        X_train, X_test, y_train, y_test = train_test_split(adult_onehot, adult_salary, train_size=5000)
        
        for k in range(5):
            adult_rf = RandomForestClassifier(criterion='entropy')
            param_grid = [{'n_estimators': [1024], 'max_features': [1, 2, 4, 6, 8, 12, 16, 20]}]
            #n_list = [1.0, 2.0 , 4.0 , 6.0, 8.0, 12.0, 16.0, 20.0]
            rf_grid = GridSearchCV(adult_rf, param_grid, cv=StratifiedKFold(n_splits=5), scoring=['accuracy', 'roc_auc_ovr', 'f1_micro'], refit=False,
                   verbose=0, return_train_score=True)
            rf_grid.fit(X_train, y_train)
            print('mean test accuracy:', rf_grid.cv_results_['mean_test_accuracy'])
            print('mean test roc auo ovr:', rf_grid.cv_results_['mean_test_roc_auc_ovr'])
            print('mean test f1 micro:', rf_grid.cv_results_['mean_test_f1_micro'])
            rf_results = pd.DataFrame(rf_grid.cv_results_['params'])
            rf_results['score_acc'] = rf_grid.cv_results_['mean_test_accuracy']
            adult_rf_list.append(rf_grid.cv_results_['mean_test_accuracy'])
            rf_cols = rf_results.columns.to_series().str.split('__').apply(lambda x: x[-1])
            rf_results.columns = rf_cols
            rf_results = rf_results.sort_values(by=['score_acc'], ascending=False, ignore_index=True)
            print(rf_results)

trial 1
mean test accuracy: [0.8292 0.8298 0.8356 0.8386 0.8404 0.8438 0.8444 0.8444]
mean test roc auo ovr: [0.87844948 0.88138782 0.8859192  0.8888955  0.89152628 0.89475086
 0.89713573 0.89834653]
mean test f1 micro: [0.8292 0.8298 0.8356 0.8386 0.8404 0.8438 0.8444 0.8444]
   max_features  n_estimators  score_acc
0            16          1024     0.8444
1            20          1024     0.8444
2            12          1024     0.8438
3             8          1024     0.8404
4             6          1024     0.8386
5             4          1024     0.8356
6             2          1024     0.8298
7             1          1024     0.8292
mean test accuracy: [0.8296 0.8304 0.8366 0.8384 0.842  0.845  0.8438 0.8456]
mean test roc auo ovr: [0.87783077 0.88122949 0.88555515 0.88890934 0.89154485 0.89518608
 0.89675142 0.89748326]
mean test f1 micro: [0.8296 0.8304 0.8366 0.8384 0.842  0.845  0.8438 0.8456]
   max_features  n_estimators  score_acc
0            20          1024     0.8456
1

mean test accuracy: [0.8426 0.8444 0.8496 0.8536 0.8564 0.8574 0.8586 0.8596]
mean test roc auo ovr: [0.89062263 0.89264212 0.89616435 0.89906635 0.9002688  0.90263308
 0.90381236 0.90381685]
mean test f1 micro: [0.8426 0.8444 0.8496 0.8536 0.8564 0.8574 0.8586 0.8596]
   max_features  n_estimators  score_acc
0            20          1024     0.8596
1            16          1024     0.8586
2            12          1024     0.8574
3             8          1024     0.8564
4             6          1024     0.8536
5             4          1024     0.8496
6             2          1024     0.8444
7             1          1024     0.8426
mean test accuracy: [0.8408 0.8432 0.8504 0.8544 0.8562 0.8584 0.8588 0.8594]
mean test roc auo ovr: [0.89070743 0.8929298  0.89650928 0.8990533  0.90070837 0.90249322
 0.90307294 0.90390584]
mean test f1 micro: [0.8408 0.8432 0.8504 0.8544 0.8562 0.8584 0.8588 0.8594]
   max_features  n_estimators  score_acc
0            20          1024     0.8594
1        

mean test accuracy: [0.8378 0.8426 0.8472 0.8498 0.8518 0.8536 0.8544 0.8564]
mean test roc auo ovr: [0.89055948 0.89306252 0.89744666 0.90017091 0.90240933 0.90554396
 0.9064235  0.90774043]
mean test f1 micro: [0.8378 0.8426 0.8472 0.8498 0.8518 0.8536 0.8544 0.8564]
   max_features  n_estimators  score_acc
0            20          1024     0.8564
1            16          1024     0.8544
2            12          1024     0.8536
3             8          1024     0.8518
4             6          1024     0.8498
5             4          1024     0.8472
6             2          1024     0.8426
7             1          1024     0.8378
mean test accuracy: [0.838  0.8402 0.8468 0.8478 0.8522 0.854  0.855  0.854 ]
mean test roc auo ovr: [0.8904484  0.89326567 0.8972167  0.90025113 0.90237322 0.9048899
 0.90638352 0.90763355]
mean test f1 micro: [0.838  0.8402 0.8468 0.8478 0.8522 0.854  0.855  0.854 ]
   max_features  n_estimators  score_acc
0            16          1024     0.8550
1         

mean test accuracy: [0.8344 0.8354 0.8388 0.8412 0.846  0.848  0.849  0.8494]
mean test roc auo ovr: [0.88378265 0.88653992 0.89048836 0.89330093 0.89526796 0.89858044
 0.89992756 0.900746  ]
mean test f1 micro: [0.8344 0.8354 0.8388 0.8412 0.846  0.848  0.849  0.8494]
   max_features  n_estimators  score_acc
0            20          1024     0.8494
1            16          1024     0.8490
2            12          1024     0.8480
3             8          1024     0.8460
4             6          1024     0.8412
5             4          1024     0.8388
6             2          1024     0.8354
7             1          1024     0.8344
mean test accuracy: [0.8388 0.8398 0.8426 0.8464 0.8504 0.8506 0.8548 0.8564]
mean test roc auo ovr: [0.89002357 0.89226206 0.89691667 0.90013158 0.90239638 0.90477577
 0.90658717 0.90760088]
mean test f1 micro: [0.8388 0.8398 0.8426 0.8464 0.8504 0.8506 0.8548 0.8564]
   max_features  n_estimators  score_acc
0            20          1024     0.8564
1        

mean test accuracy: [0.8398 0.84   0.8426 0.846  0.8464 0.8498 0.8516 0.8518]
mean test roc auo ovr: [0.88307883 0.88554506 0.89035937 0.89318122 0.8959085  0.89825555
 0.89984021 0.90020033]
mean test f1 micro: [0.8398 0.84   0.8426 0.846  0.8464 0.8498 0.8516 0.8518]
   max_features  n_estimators  score_acc
0            20          1024     0.8518
1            16          1024     0.8516
2            12          1024     0.8498
3             8          1024     0.8464
4             6          1024     0.8460
5             4          1024     0.8426
6             2          1024     0.8400
7             1          1024     0.8398
mean test accuracy: [0.8382 0.8416 0.8452 0.845  0.8488 0.8492 0.852  0.8516]
mean test roc auo ovr: [0.88288573 0.88648155 0.88982391 0.8935032  0.89548476 0.89849843
 0.89935048 0.90064232]
mean test f1 micro: [0.8382 0.8416 0.8452 0.845  0.8488 0.8492 0.852  0.8516]
   max_features  n_estimators  score_acc
0            16          1024     0.8520
1        

mean test accuracy: [0.8396 0.8426 0.845  0.8466 0.8482 0.8478 0.848  0.847 ]
mean test roc auo ovr: [0.89460408 0.89688793 0.90008302 0.90237152 0.90396199 0.90601191
 0.90661422 0.90760939]
mean test f1 micro: [0.8396 0.8426 0.845  0.8466 0.8482 0.8478 0.848  0.847 ]
   max_features  n_estimators  score_acc
0             8          1024     0.8482
1            16          1024     0.8480
2            12          1024     0.8478
3            20          1024     0.8470
4             6          1024     0.8466
5             4          1024     0.8450
6             2          1024     0.8426
7             1          1024     0.8396
mean test accuracy: [0.8418 0.8414 0.8402 0.8474 0.8468 0.8472 0.8464 0.846 ]
mean test roc auo ovr: [0.89430313 0.89638702 0.89974399 0.9020487  0.90369763 0.90553374
 0.9068133  0.90705607]
mean test f1 micro: [0.8418 0.8414 0.8402 0.8474 0.8468 0.8472 0.8464 0.846 ]
   max_features  n_estimators  score_acc
0             6          1024     0.8474
1        

mean test accuracy: [0.845  0.8456 0.8496 0.851  0.8524 0.8544 0.8538 0.8546]
mean test roc auo ovr: [0.89127829 0.89240863 0.89573512 0.89754232 0.89871588 0.90006454
 0.9002969  0.90100238]
mean test f1 micro: [0.845  0.8456 0.8496 0.851  0.8524 0.8544 0.8538 0.8546]
   max_features  n_estimators  score_acc
0            20          1024     0.8546
1            12          1024     0.8544
2            16          1024     0.8538
3             8          1024     0.8524
4             6          1024     0.8510
5             4          1024     0.8496
6             2          1024     0.8456
7             1          1024     0.8450
mean test accuracy: [0.844  0.847  0.8506 0.8516 0.8522 0.8538 0.8546 0.8542]
mean test roc auo ovr: [0.89099486 0.89270777 0.89543878 0.89802274 0.89892185 0.90052645
 0.90108208 0.90080033]
mean test f1 micro: [0.844  0.847  0.8506 0.8516 0.8522 0.8538 0.8546 0.8542]
   max_features  n_estimators  score_acc
0            16          1024     0.8546
1        

mean test accuracy: [0.8344 0.8376 0.842  0.8438 0.8474 0.8498 0.8512 0.8538]
mean test roc auo ovr: [0.89105298 0.89322992 0.89809679 0.90084286 0.90194709 0.90423334
 0.90528675 0.90583803]
mean test f1 micro: [0.8344 0.8376 0.842  0.8438 0.8474 0.8498 0.8512 0.8538]
   max_features  n_estimators  score_acc
0            20          1024     0.8538
1            16          1024     0.8512
2            12          1024     0.8498
3             8          1024     0.8474
4             6          1024     0.8438
5             4          1024     0.8420
6             2          1024     0.8376
7             1          1024     0.8344
mean test accuracy: [0.8336 0.8374 0.8408 0.8436 0.846  0.851  0.8518 0.8536]
mean test roc auo ovr: [0.89089717 0.89414603 0.89726103 0.90010545 0.90220312 0.90421468
 0.90502475 0.90598688]
mean test f1 micro: [0.8336 0.8374 0.8408 0.8436 0.846  0.851  0.8518 0.8536]
   max_features  n_estimators  score_acc
0            20          1024     0.8536
1        

mean test accuracy: [0.8332 0.8362 0.8402 0.8448 0.846  0.8498 0.8522 0.8524]
mean test roc auo ovr: [0.88718217 0.88991652 0.89454959 0.89800833 0.8998545  0.90319898
 0.90465694 0.90570763]
mean test f1 micro: [0.8332 0.8362 0.8402 0.8448 0.846  0.8498 0.8522 0.8524]
   max_features  n_estimators  score_acc
0            20          1024     0.8524
1            16          1024     0.8522
2            12          1024     0.8498
3             8          1024     0.8460
4             6          1024     0.8448
5             4          1024     0.8402
6             2          1024     0.8362
7             1          1024     0.8332
mean test accuracy: [0.844  0.846  0.8492 0.8512 0.8528 0.8564 0.8578 0.8584]
mean test roc auo ovr: [0.90011726 0.90211317 0.90531347 0.90737955 0.90936276 0.91140566
 0.91276741 0.91274795]
mean test f1 micro: [0.844  0.846  0.8492 0.8512 0.8528 0.8564 0.8578 0.8584]
   max_features  n_estimators  score_acc
0            20          1024     0.8584
1        

mean test accuracy: [0.8456 0.848  0.8534 0.8574 0.8588 0.8614 0.8596 0.8612]
mean test roc auo ovr: [0.89967529 0.90238865 0.90576962 0.90787937 0.91048905 0.91261544
 0.91335615 0.91377156]
mean test f1 micro: [0.8456 0.848  0.8534 0.8574 0.8588 0.8614 0.8596 0.8612]
   max_features  n_estimators  score_acc
0            12          1024     0.8614
1            20          1024     0.8612
2            16          1024     0.8596
3             8          1024     0.8588
4             6          1024     0.8574
5             4          1024     0.8534
6             2          1024     0.8480
7             1          1024     0.8456
mean test accuracy: [0.8468 0.8496 0.8546 0.8564 0.86   0.8618 0.86   0.8586]
mean test roc auo ovr: [0.90034383 0.90257566 0.90575108 0.90871929 0.91024809 0.91258058
 0.9132379  0.9139604 ]
mean test f1 micro: [0.8468 0.8496 0.8546 0.8564 0.86   0.8618 0.86   0.8586]
   max_features  n_estimators  score_acc
0            12          1024     0.8618
1        

In [102]:
adult_rf_list

[array([0.8362, 0.8364, 0.8412, 0.8434, 0.8474, 0.8492, 0.851 , 0.8508]),
 array([0.8348, 0.8358, 0.8406, 0.8442, 0.846 , 0.8484, 0.8512, 0.8532]),
 array([0.8334, 0.8364, 0.8412, 0.8434, 0.846 , 0.8494, 0.8518, 0.8516]),
 array([0.8328, 0.8362, 0.839 , 0.843 , 0.845 , 0.849 , 0.851 , 0.852 ]),
 array([0.8332, 0.8362, 0.8402, 0.8448, 0.846 , 0.8498, 0.8522, 0.8524]),
 array([0.844 , 0.846 , 0.8492, 0.8512, 0.8528, 0.8564, 0.8578, 0.8584]),
 array([0.8428, 0.8462, 0.848 , 0.8526, 0.854 , 0.8548, 0.8586, 0.857 ]),
 array([0.8444, 0.8452, 0.8476, 0.8514, 0.8524, 0.8566, 0.8584, 0.8574]),
 array([0.8426, 0.847 , 0.8478, 0.8528, 0.854 , 0.8568, 0.8574, 0.8606]),
 array([0.8426, 0.8466, 0.8504, 0.85  , 0.8528, 0.8554, 0.8558, 0.8598]),
 array([0.8386, 0.842 , 0.8432, 0.848 , 0.8502, 0.851 , 0.8538, 0.8546]),
 array([0.8412, 0.8394, 0.8452, 0.8484, 0.8498, 0.853 , 0.8528, 0.8548]),
 array([0.8412, 0.84  , 0.8448, 0.8474, 0.8498, 0.851 , 0.8516, 0.8534]),
 array([0.8392, 0.8432, 0.8452, 0.8466

In [103]:
rf_grid.cv_results_

{'mean_fit_time': array([3.97007027, 3.99928036, 4.29126406, 4.64345846, 4.95584526,
        5.19524164, 5.75541806, 6.37459497]),
 'std_fit_time': array([0.10157881, 0.03121001, 0.01540944, 0.02960937, 0.10398074,
        0.08410731, 0.03145586, 0.04873347]),
 'mean_score_time': array([0.60018702, 0.57977624, 0.5595901 , 0.53158212, 0.5159801 ,
        0.4949842 , 0.49018989, 0.47978258]),
 'std_score_time': array([0.0022238 , 0.00248428, 0.01937412, 0.00195784, 0.0055873 ,
        0.00244898, 0.0224591 , 0.01914582]),
 'param_max_features': masked_array(data=[1, 2, 4, 6, 8, 12, 16, 20],
              mask=[False, False, False, False, False, False, False, False],
        fill_value='?',
             dtype=object),
 'param_n_estimators': masked_array(data=[1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024],
              mask=[False, False, False, False, False, False, False, False],
        fill_value='?',
             dtype=object),
 'params': [{'max_features': 1, 'n_estimators': 1024},


In [104]:
rf_grid.cv_results_['rank_test_accuracy']

array([8, 7, 6, 5, 4, 3, 2, 1])

In [105]:
rf_grid.cv_results_['params'][ np.argmin(rf_grid.cv_results_['rank_test_accuracy'])]

{'max_features': 20, 'n_estimators': 1024}

In [106]:
rf_grid.cv_results_['params'][ np.argmin(rf_grid.cv_results_['rank_test_roc_auc_ovr'])]

{'max_features': 20, 'n_estimators': 1024}

In [107]:
rf_results = pd.DataFrame(rf_grid.cv_results_['params'])
rf_results['score_acc'] = rf_grid.cv_results_['mean_test_accuracy']
rf_cols = rf_results.columns.to_series().str.split('__').apply(lambda x: x[-1])
rf_results.columns = rf_cols
rf_results = rf_results.sort_values(by=['score_acc'], ascending=False, ignore_index=True)
rf_results

Unnamed: 0,max_features,n_estimators,score_acc
0,20,1024,0.8488
1,16,1024,0.8466
2,12,1024,0.8446
3,8,1024,0.8438
4,6,1024,0.8396
5,4,1024,0.836
6,2,1024,0.8318
7,1,1024,0.8294


In [109]:
X_train, X_test, y_train, y_test = train_test_split(adult_onehot, adult_salary, train_size=5000)
X_train = pd.DataFrame(X_train, columns=X_train.columns)
rf_best = RandomForestClassifier(n_estimators=1024, criterion='entropy', max_features=20)
rf_best.fit(X_train, y_train)
pred_rf = rf_best.predict(X_test)
print('Random Forest:', classification_report(y_test, pred_rf))

Random Forest:               precision    recall  f1-score   support

           0       0.88      0.93      0.91     20918
           1       0.74      0.61      0.67      6643

    accuracy                           0.85     27561
   macro avg       0.81      0.77      0.79     27561
weighted avg       0.85      0.85      0.85     27561



In [111]:
#Logistic Regression
for i in range(5):
    print('trial', i+1)
    adult_logreg_list = []
    
    for j in range(5):
        X_train, X_test, y_train, y_test = train_test_split(adult_onehot, adult_salary, train_size=5000)
        sc = StandardScaler()
        X_train = pd.DataFrame(sc.fit_transform(X_train), columns=X_train.columns)
        X_test = sc.transform(X_test)
        
        for k in range(5):
            pipe = Pipeline([('std', StandardScaler()),
                 ('classifier', linear_model.LogisticRegression())])
            search_space = [{'classifier': [linear_model.LogisticRegression(max_iter=5000)],
                 'classifier__solver': ['saga'],
                 'classifier__penalty': ['l1', 'l2'],
                 'classifier__C': np.logspace(-4, 4, 9)},
                {'classifier': [linear_model.LogisticRegression(max_iter=5000)],
                 'classifier__solver': ['lbfgs'],
                 'classifier__penalty': ['l2'],
                 'classifier__C': np.logspace(-4, 4, 9)},
                {'classifier': [linear_model.LogisticRegression(max_iter=5000)],
                 'classifier__solver': ['lbfgs','saga'],
                 'classifier__penalty': ['none']}]
            logreg_grid = GridSearchCV(pipe, search_space, cv=StratifiedKFold(n_splits=5), scoring=['accuracy', 'roc_auc_ovr', 'f1_micro'], refit=False,
                   verbose=0, return_train_score=True)
            logreg_grid.fit(X_train, y_train)
            print('mean test accuracy:', logreg_grid.cv_results_['mean_test_accuracy'])
            print('mean test roc auo ovr:', logreg_grid.cv_results_['mean_test_roc_auc_ovr'])
            print('mean test f1 micro:', logreg_grid.cv_results_['mean_test_f1_micro'])
            logreg_results = pd.DataFrame(logreg_grid.cv_results_['params'])
            logreg_results['score_acc'] = logreg_grid.cv_results_['mean_test_accuracy']
            adult_logreg_list.append(logreg_grid.cv_results_['mean_test_accuracy'])
            logreg_cols = logreg_results.columns.to_series().str.split('__').apply(lambda x: x[-1])
            logreg_results.columns = logreg_cols
            logreg_results = logreg_results.sort_values(by=['score_acc'], ascending=False, ignore_index=True)
            print(logreg_results)

trial 1
mean test accuracy: [0.7682 0.7682 0.7682 0.824  0.831  0.8402 0.8482 0.8448 0.8444 0.8448
 0.844  0.844  0.844  0.844  0.844  0.844  0.844  0.844  0.7682 0.8238
 0.8398 0.8448 0.8446 0.8438 0.844  0.844  0.844  0.844  0.844 ]
mean test roc auo ovr: [0.5        0.8792428  0.5        0.88566947 0.88138146 0.89022848
 0.89792798 0.89102023 0.89236737 0.88931565 0.88921864 0.88852004
 0.88851887 0.88842453 0.88844028 0.8884122  0.88834907 0.88841669
 0.87924055 0.88514482 0.8897173  0.8910651  0.88871878 0.88612102
 0.88561888 0.88485572 0.88445473 0.88446034 0.88842454]
mean test f1 micro: [0.7682 0.7682 0.7682 0.824  0.831  0.8402 0.8482 0.8448 0.8444 0.8448
 0.844  0.844  0.844  0.844  0.844  0.844  0.844  0.844  0.7682 0.8238
 0.8398 0.8448 0.8446 0.8438 0.844  0.844  0.844  0.844  0.844 ]
                           classifier           C penalty solver  score_acc
0   LogisticRegression(max_iter=5000)      0.1000      l1   saga     0.8482
1   LogisticRegression(max_iter=5000) 

mean test accuracy: [0.7682 0.7682 0.7682 0.824  0.831  0.8402 0.8482 0.8448 0.8444 0.8448
 0.844  0.844  0.844  0.844  0.844  0.844  0.8442 0.844  0.7682 0.8238
 0.8398 0.8448 0.8446 0.8438 0.844  0.844  0.844  0.844  0.844 ]
mean test roc auo ovr: [0.5        0.87924168 0.5        0.88566723 0.88138033 0.89022848
 0.89792574 0.89101798 0.89236737 0.88928984 0.88920403 0.88851894
 0.88851779 0.88844701 0.88843577 0.88840205 0.88840883 0.88842455
 0.87924055 0.88514482 0.8897173  0.8910651  0.88871878 0.88612102
 0.88561888 0.88485572 0.88445473 0.88446034 0.88841558]
mean test f1 micro: [0.7682 0.7682 0.7682 0.824  0.831  0.8402 0.8482 0.8448 0.8444 0.8448
 0.844  0.844  0.844  0.844  0.844  0.844  0.8442 0.844  0.7682 0.8238
 0.8398 0.8448 0.8446 0.8438 0.844  0.844  0.844  0.844  0.844 ]
                           classifier           C penalty solver  score_acc
0   LogisticRegression(max_iter=5000)      0.1000      l1   saga     0.8482
1   LogisticRegression(max_iter=5000)      0.1

mean test accuracy: [0.7622 0.762  0.7622 0.8298 0.8372 0.8436 0.8488 0.8434 0.846  0.8452
 0.8452 0.8452 0.8452 0.8452 0.8452 0.8452 0.8452 0.8452 0.762  0.8302
 0.8436 0.8438 0.8444 0.8458 0.8448 0.845  0.845  0.845  0.8452]
mean test roc auo ovr: [0.5        0.88047929 0.5        0.88893764 0.8873278  0.89643952
 0.90297331 0.89872278 0.89895715 0.89746365 0.89728714 0.8970213
 0.89699815 0.89695404 0.89695182 0.89695955 0.89697058 0.89695624
 0.88047708 0.88946147 0.89643622 0.89918594 0.89605855 0.8949224
 0.89307734 0.89240791 0.89202191 0.89195684 0.89685036]
mean test f1 micro: [0.7622 0.762  0.7622 0.8298 0.8372 0.8436 0.8488 0.8434 0.846  0.8452
 0.8452 0.8452 0.8452 0.8452 0.8452 0.8452 0.8452 0.8452 0.762  0.8302
 0.8436 0.8438 0.8444 0.8458 0.8448 0.845  0.845  0.845  0.8452]
                           classifier           C penalty solver  score_acc
0   LogisticRegression(max_iter=5000)      0.1000      l1   saga     0.8488
1   LogisticRegression(max_iter=5000)      1.000

mean test accuracy: [0.7622 0.762  0.7622 0.8298 0.8372 0.8436 0.8488 0.8434 0.846  0.8452
 0.8452 0.8452 0.8452 0.8452 0.8452 0.8452 0.8452 0.8452 0.762  0.8302
 0.8436 0.8438 0.8444 0.8458 0.8448 0.845  0.845  0.845  0.8452]
mean test roc auo ovr: [0.5        0.8804793  0.5        0.88893763 0.8873278  0.89643842
 0.90297331 0.89874595 0.89896046 0.89746585 0.89731471 0.89702133
 0.89699595 0.89696617 0.89695183 0.89694852 0.89694742 0.8969397
 0.88047708 0.88946147 0.89643622 0.89918594 0.89605855 0.8949224
 0.89307734 0.89240791 0.89202191 0.89195684 0.89694962]
mean test f1 micro: [0.7622 0.762  0.7622 0.8298 0.8372 0.8436 0.8488 0.8434 0.846  0.8452
 0.8452 0.8452 0.8452 0.8452 0.8452 0.8452 0.8452 0.8452 0.762  0.8302
 0.8436 0.8438 0.8444 0.8458 0.8448 0.845  0.845  0.845  0.8452]
                           classifier           C penalty solver  score_acc
0   LogisticRegression(max_iter=5000)      0.1000      l1   saga     0.8488
1   LogisticRegression(max_iter=5000)      1.000

mean test accuracy: [0.7548 0.7546 0.7548 0.828  0.83   0.84   0.8494 0.844  0.8452 0.845
 0.8454 0.846  0.8458 0.8458 0.8458 0.8458 0.8458 0.8458 0.7542 0.8276
 0.84   0.844  0.845  0.8454 0.8454 0.8454 0.8454 0.8454 0.8458]
mean test roc auo ovr: [0.5        0.88315689 0.5        0.89019808 0.8875869  0.89663114
 0.90483333 0.89964496 0.90049483 0.89903121 0.89904401 0.8987631
 0.89875766 0.89873821 0.89873929 0.89873929 0.89873929 0.89873281
 0.88262638 0.88966433 0.89662682 0.89963776 0.89824619 0.89797743
 0.8976521  0.89759593 0.8975711  0.89756356 0.89872205]
mean test f1 micro: [0.7548 0.7546 0.7548 0.828  0.83   0.84   0.8494 0.844  0.8452 0.845
 0.8454 0.846  0.8458 0.8458 0.8458 0.8458 0.8458 0.8458 0.7542 0.8276
 0.84   0.844  0.845  0.8454 0.8454 0.8454 0.8454 0.8454 0.8458]
                           classifier           C penalty solver  score_acc
0   LogisticRegression(max_iter=5000)      0.1000      l1   saga     0.8494
1   LogisticRegression(max_iter=5000)     10.0000

mean test accuracy: [0.7504 0.75   0.7504 0.8308 0.8348 0.8426 0.8496 0.8474 0.8472 0.8472
 0.847  0.8472 0.8472 0.8472 0.8472 0.8472 0.8472 0.8472 0.7502 0.8308
 0.843  0.8474 0.847  0.8474 0.847  0.847  0.847  0.847  0.8472]
mean test roc auo ovr: [0.5        0.88461589 0.5        0.89228168 0.89049533 0.89848168
 0.90501313 0.90059937 0.90094246 0.89993723 0.89999665 0.89966781
 0.8995567  0.89949807 0.89958126 0.89948209 0.89949914 0.89956312
 0.88487723 0.89228274 0.89954481 0.90058335 0.89964044 0.89935927
 0.89738569 0.89667088 0.89638599 0.89634226 0.89958125]
mean test f1 micro: [0.7504 0.75   0.7504 0.8308 0.8348 0.8426 0.8496 0.8474 0.8472 0.8472
 0.847  0.8472 0.8472 0.8472 0.8472 0.8472 0.8472 0.8472 0.7502 0.8308
 0.843  0.8474 0.847  0.8474 0.847  0.847  0.847  0.847  0.8472]
                           classifier           C penalty solver  score_acc
0   LogisticRegression(max_iter=5000)      0.1000      l1   saga     0.8496
1   LogisticRegression(max_iter=5000)     10.0

mean test accuracy: [0.7504 0.75   0.7504 0.8308 0.8348 0.8426 0.8496 0.8474 0.8472 0.8472
 0.847  0.8472 0.8472 0.8472 0.8472 0.8472 0.8472 0.8472 0.7502 0.8308
 0.843  0.8474 0.847  0.8474 0.847  0.847  0.847  0.847  0.8472]
mean test roc auo ovr: [0.5        0.88461269 0.5        0.89228167 0.89049479 0.89848381
 0.9050174  0.90059937 0.90093819 0.89995005 0.89999772 0.89957489
 0.89952555 0.89946394 0.89956954 0.89907191 0.89956314 0.89950861
 0.88487723 0.89228274 0.89954481 0.90058335 0.89964044 0.89935927
 0.89738569 0.89667088 0.89638599 0.89634226 0.89959942]
mean test f1 micro: [0.7504 0.75   0.7504 0.8308 0.8348 0.8426 0.8496 0.8474 0.8472 0.8472
 0.847  0.8472 0.8472 0.8472 0.8472 0.8472 0.8472 0.8472 0.7502 0.8308
 0.843  0.8474 0.847  0.8474 0.847  0.847  0.847  0.847  0.8472]
                           classifier           C penalty solver  score_acc
0   LogisticRegression(max_iter=5000)      0.1000      l1   saga     0.8496
1   LogisticRegression(max_iter=5000)     10.0

mean test accuracy: [0.7582 0.758  0.7582 0.8296 0.8328 0.842  0.8524 0.8486 0.8498 0.8488
 0.8488 0.8486 0.8486 0.8486 0.8486 0.8486 0.8486 0.8486 0.7582 0.8294
 0.8422 0.8488 0.8486 0.8482 0.8478 0.848  0.848  0.8478 0.8486]
mean test roc auo ovr: [0.5        0.87764815 0.5        0.8857365  0.8789529  0.89194573
 0.89910859 0.89572288 0.89665682 0.89545606 0.8952945  0.89505212
 0.89503356 0.89500842 0.89501389 0.89501061 0.89500952 0.89501497
 0.87790655 0.88573541 0.89193808 0.89599765 0.89521054 0.89426082
 0.89326292 0.89291288 0.89261301 0.89253005 0.89501062]
mean test f1 micro: [0.7582 0.758  0.7582 0.8296 0.8328 0.842  0.8524 0.8486 0.8498 0.8488
 0.8488 0.8486 0.8486 0.8486 0.8486 0.8486 0.8486 0.8486 0.7582 0.8294
 0.8422 0.8488 0.8486 0.8482 0.8478 0.848  0.848  0.8478 0.8486]
                           classifier           C penalty solver  score_acc
0   LogisticRegression(max_iter=5000)      0.1000      l1   saga     0.8524
1   LogisticRegression(max_iter=5000)      1.0

mean test accuracy: [0.7582 0.758  0.7582 0.8296 0.8328 0.8422 0.8524 0.8486 0.8498 0.8488
 0.8488 0.8486 0.8486 0.8486 0.8486 0.8486 0.8486 0.8486 0.7582 0.8294
 0.8422 0.8488 0.8486 0.8482 0.8478 0.848  0.848  0.8478 0.8486]
mean test roc auo ovr: [0.5        0.87764379 0.5        0.88573977 0.87895072 0.89194245
 0.89911077 0.8957207  0.89665901 0.89546042 0.89529887 0.8950543
 0.89503465 0.89500298 0.89500844 0.89500952 0.89501497 0.89500951
 0.87790655 0.88573541 0.89193808 0.89599765 0.89521054 0.89426082
 0.89326292 0.89291288 0.89261301 0.89253005 0.89500843]
mean test f1 micro: [0.7582 0.758  0.7582 0.8296 0.8328 0.8422 0.8524 0.8486 0.8498 0.8488
 0.8488 0.8486 0.8486 0.8486 0.8486 0.8486 0.8486 0.8486 0.7582 0.8294
 0.8422 0.8488 0.8486 0.8482 0.8478 0.848  0.848  0.8478 0.8486]
                           classifier           C penalty solver  score_acc
0   LogisticRegression(max_iter=5000)      0.1000      l1   saga     0.8524
1   LogisticRegression(max_iter=5000)      1.00

mean test accuracy: [0.7524 0.7524 0.7524 0.8186 0.8236 0.8302 0.8362 0.835  0.835  0.835
 0.8348 0.8348 0.8348 0.8348 0.8348 0.8348 0.8348 0.8348 0.7524 0.8184
 0.8302 0.835  0.8348 0.8348 0.8348 0.8348 0.8348 0.8348 0.8348]
mean test roc auo ovr: [0.5        0.87664359 0.5        0.88437992 0.88213432 0.89054379
 0.89598594 0.89161672 0.89162228 0.88922788 0.88888296 0.88831649
 0.88823485 0.88817155 0.88815117 0.88770716 0.88815436 0.88814148
 0.87609231 0.88357342 0.89053734 0.89160705 0.88766754 0.88487759
 0.88387447 0.88347857 0.88347643 0.88344427 0.88815652]
mean test f1 micro: [0.7524 0.7524 0.7524 0.8186 0.8236 0.8302 0.8362 0.835  0.835  0.835
 0.8348 0.8348 0.8348 0.8348 0.8348 0.8348 0.8348 0.8348 0.7524 0.8184
 0.8302 0.835  0.8348 0.8348 0.8348 0.8348 0.8348 0.8348 0.8348]
                           classifier           C penalty solver  score_acc
0   LogisticRegression(max_iter=5000)      0.1000      l1   saga     0.8362
1   LogisticRegression(max_iter=5000)      0.100

mean test accuracy: [0.7648 0.765  0.7648 0.827  0.8312 0.8428 0.8494 0.8476 0.8454 0.8462
 0.846  0.8458 0.846  0.846  0.846  0.846  0.846  0.846  0.7654 0.8268
 0.8432 0.8476 0.846  0.8458 0.8456 0.8456 0.8456 0.8456 0.846 ]
mean test roc auo ovr: [0.5        0.88116376 0.5        0.88798072 0.88133089 0.89326142
 0.89957394 0.89519911 0.89548065 0.89370287 0.89348727 0.89308961
 0.8930685  0.89304296 0.89303183 0.89302184 0.89301294 0.89302296
 0.88168664 0.8879785  0.89379541 0.89520696 0.89348756 0.89193004
 0.89121571 0.89085756 0.89053051 0.89043596 0.89301961]
mean test f1 micro: [0.7648 0.765  0.7648 0.827  0.8312 0.8428 0.8494 0.8476 0.8454 0.8462
 0.846  0.8458 0.846  0.846  0.846  0.846  0.846  0.846  0.7654 0.8268
 0.8432 0.8476 0.846  0.8458 0.8456 0.8456 0.8456 0.8456 0.846 ]
                           classifier           C penalty solver  score_acc
0   LogisticRegression(max_iter=5000)      0.1000      l1   saga     0.8494
1   LogisticRegression(max_iter=5000)      0.1

mean test accuracy: [0.7648 0.765  0.7648 0.827  0.8312 0.8428 0.8494 0.8476 0.8454 0.8462
 0.846  0.8458 0.846  0.846  0.846  0.846  0.846  0.846  0.7654 0.8268
 0.8432 0.8476 0.846  0.8458 0.8456 0.8456 0.8456 0.8456 0.846 ]
mean test roc auo ovr: [0.5        0.88116264 0.5        0.88798294 0.88133256 0.89326142
 0.89957505 0.89519577 0.89547954 0.89369286 0.89348282 0.89309629
 0.89306071 0.89302965 0.89300294 0.89302629 0.89302074 0.89302072
 0.88168664 0.8879785  0.89379541 0.89520696 0.89348756 0.89193004
 0.89121571 0.89085756 0.89053051 0.89043596 0.89300739]
mean test f1 micro: [0.7648 0.765  0.7648 0.827  0.8312 0.8428 0.8494 0.8476 0.8454 0.8462
 0.846  0.8458 0.846  0.846  0.846  0.846  0.846  0.846  0.7654 0.8268
 0.8432 0.8476 0.846  0.8458 0.8456 0.8456 0.8456 0.8456 0.846 ]
                           classifier           C penalty solver  score_acc
0   LogisticRegression(max_iter=5000)      0.1000      l1   saga     0.8494
1   LogisticRegression(max_iter=5000)      0.1

mean test accuracy: [0.7588 0.759  0.7588 0.8294 0.8314 0.8514 0.8548 0.8534 0.8534 0.8532
 0.8528 0.8524 0.8524 0.8524 0.8524 0.8524 0.8524 0.8524 0.759  0.8294
 0.8514 0.8532 0.8526 0.8522 0.8522 0.8522 0.8522 0.8522 0.8526]
mean test roc auo ovr: [0.5        0.88320036 0.5        0.89217275 0.88654722 0.89959521
 0.90571757 0.90203911 0.90275023 0.90111639 0.90101259 0.90064349
 0.90063585 0.90058016 0.90058344 0.90052111 0.90058561 0.90040302
 0.8831949  0.8921651  0.89959632 0.90206318 0.90051813 0.89891163
 0.89785674 0.89721763 0.89705048 0.8970352  0.90063374]
mean test f1 micro: [0.7588 0.759  0.7588 0.8294 0.8314 0.8514 0.8548 0.8534 0.8534 0.8532
 0.8528 0.8524 0.8524 0.8524 0.8524 0.8524 0.8524 0.8524 0.759  0.8294
 0.8514 0.8532 0.8526 0.8522 0.8522 0.8522 0.8522 0.8522 0.8526]
                           classifier           C penalty solver  score_acc
0   LogisticRegression(max_iter=5000)      0.1000      l1   saga     0.8548
1   LogisticRegression(max_iter=5000)      0.1

mean test accuracy: [0.7588 0.759  0.7588 0.8294 0.8314 0.8514 0.8548 0.8534 0.8534 0.8532
 0.8528 0.8524 0.8524 0.8524 0.8524 0.8524 0.8524 0.8524 0.759  0.8294
 0.8514 0.8532 0.8526 0.8522 0.8522 0.8522 0.8522 0.8522 0.8524]
mean test roc auo ovr: [0.5        0.88319599 0.5        0.89217494 0.88654777 0.89959849
 0.90571647 0.90205004 0.90274913 0.90110872 0.90102023 0.90061508
 0.90062711 0.90059108 0.90059326 0.90058562 0.90058671 0.9005222
 0.8831949  0.8921651  0.89959632 0.90206318 0.90051813 0.89891163
 0.89785674 0.89721763 0.89705048 0.8970352  0.90058889]
mean test f1 micro: [0.7588 0.759  0.7588 0.8294 0.8314 0.8514 0.8548 0.8534 0.8534 0.8532
 0.8528 0.8524 0.8524 0.8524 0.8524 0.8524 0.8524 0.8524 0.759  0.8294
 0.8514 0.8532 0.8526 0.8522 0.8522 0.8522 0.8522 0.8522 0.8524]
                           classifier           C penalty solver  score_acc
0   LogisticRegression(max_iter=5000)      0.1000      l1   saga     0.8548
1   LogisticRegression(max_iter=5000)      0.10

mean test accuracy: [0.7558 0.7572 0.7558 0.8198 0.829  0.8378 0.8408 0.8378 0.8384 0.838
 0.838  0.838  0.838  0.838  0.838  0.838  0.838  0.838  0.7572 0.82
 0.8378 0.8378 0.8382 0.838  0.838  0.838  0.838  0.838  0.838 ]
mean test roc auo ovr: [0.5        0.8806613  0.5        0.8888874  0.88658243 0.8951953
 0.89967264 0.89681335 0.89736918 0.89558152 0.89554901 0.89508203
 0.89506361 0.89501486 0.89500619 0.89500835 0.89500511 0.89500728
 0.88065805 0.88887765 0.89519638 0.8968231  0.89535075 0.89389143
 0.89298382 0.89264596 0.89261353 0.89261355 0.89500511]
mean test f1 micro: [0.7558 0.7572 0.7558 0.8198 0.829  0.8378 0.8408 0.8378 0.8384 0.838
 0.838  0.838  0.838  0.838  0.838  0.838  0.838  0.838  0.7572 0.82
 0.8378 0.8378 0.8382 0.838  0.838  0.838  0.838  0.838  0.838 ]
                           classifier           C penalty solver  score_acc
0   LogisticRegression(max_iter=5000)      0.1000      l1   saga     0.8408
1   LogisticRegression(max_iter=5000)      1.0000    

mean test accuracy: [0.763  0.763  0.763  0.8256 0.8356 0.842  0.8474 0.8452 0.8466 0.846
 0.8464 0.8462 0.8464 0.8464 0.8464 0.8464 0.8464 0.8464 0.7632 0.8256
 0.8422 0.8454 0.8458 0.8458 0.846  0.8458 0.8458 0.8458 0.8464]
mean test roc auo ovr: [0.5        0.88072731 0.5        0.88846492 0.8873744  0.89441523
 0.90256261 0.89688715 0.89769232 0.89655645 0.89651221 0.89625451
 0.89625783 0.89623682 0.89622465 0.89622244 0.89623239 0.89622797
 0.88214521 0.88846271 0.89467514 0.89855722 0.8961406  0.89509874
 0.89536418 0.89530224 0.89529671 0.89525358 0.89623129]
mean test f1 micro: [0.763  0.763  0.763  0.8256 0.8356 0.842  0.8474 0.8452 0.8466 0.846
 0.8464 0.8462 0.8464 0.8464 0.8464 0.8464 0.8464 0.8464 0.7632 0.8256
 0.8422 0.8454 0.8458 0.8458 0.846  0.8458 0.8458 0.8458 0.8464]
                           classifier           C penalty solver  score_acc
0   LogisticRegression(max_iter=5000)      0.1000      l1   saga     0.8474
1   LogisticRegression(max_iter=5000)      1.000

mean test accuracy: [0.763  0.763  0.763  0.8256 0.8354 0.842  0.8474 0.8452 0.8466 0.846
 0.8464 0.8462 0.8464 0.8464 0.8464 0.8464 0.8464 0.8464 0.7632 0.8256
 0.8422 0.8454 0.8458 0.8458 0.846  0.8458 0.8458 0.8458 0.8464]
mean test roc auo ovr: [0.5        0.8807262  0.5        0.88846492 0.88737551 0.89441633
 0.90256261 0.89688825 0.89769232 0.89655977 0.89651    0.89625562
 0.89625894 0.89624124 0.89623018 0.89623571 0.8962158  0.89623571
 0.88214521 0.88846271 0.89467514 0.89855722 0.8961406  0.89509874
 0.89536418 0.89530224 0.89529671 0.89525358 0.89623129]
mean test f1 micro: [0.763  0.763  0.763  0.8256 0.8354 0.842  0.8474 0.8452 0.8466 0.846
 0.8464 0.8462 0.8464 0.8464 0.8464 0.8464 0.8464 0.8464 0.7632 0.8256
 0.8422 0.8454 0.8458 0.8458 0.846  0.8458 0.8458 0.8458 0.8464]
                           classifier           C penalty solver  score_acc
0   LogisticRegression(max_iter=5000)      0.1000      l1   saga     0.8474
1   LogisticRegression(max_iter=5000)      1.000

mean test accuracy: [0.7484 0.7492 0.7484 0.8198 0.8238 0.8366 0.8456 0.8446 0.844  0.8442
 0.8442 0.844  0.8442 0.8442 0.8442 0.8442 0.8442 0.8442 0.7492 0.8198
 0.8366 0.8446 0.8444 0.8446 0.8448 0.8446 0.8446 0.8446 0.8442]
mean test roc auo ovr: [0.5        0.87903729 0.5        0.88791482 0.88631429 0.89571511
 0.9025177  0.89851792 0.89914053 0.89788074 0.89784128 0.89762992
 0.89765326 0.8976214  0.89763626 0.89761929 0.89762034 0.89762671
 0.87903517 0.88791376 0.89570873 0.89853386 0.8976184  0.89674221
 0.8962865  0.8960655  0.89590276 0.89558254 0.89762034]
mean test f1 micro: [0.7484 0.7492 0.7484 0.8198 0.8238 0.8366 0.8456 0.8446 0.844  0.8442
 0.8442 0.844  0.8442 0.8442 0.8442 0.8442 0.8442 0.8442 0.7492 0.8198
 0.8366 0.8446 0.8444 0.8446 0.8448 0.8446 0.8446 0.8446 0.8442]
                           classifier           C penalty solver  score_acc
0   LogisticRegression(max_iter=5000)      0.1000      l1   saga     0.8456
1   LogisticRegression(max_iter=5000)    100.0

mean test accuracy: [0.7484 0.7492 0.7484 0.82   0.8238 0.8366 0.8456 0.8446 0.844  0.8442
 0.8442 0.844  0.8442 0.8442 0.8442 0.8442 0.8442 0.8442 0.7492 0.8198
 0.8366 0.8446 0.8444 0.8446 0.8448 0.8446 0.8446 0.8446 0.8442]
mean test roc auo ovr: [0.5        0.87903517 0.5        0.88791588 0.88631323 0.89571511
 0.90251982 0.89851792 0.8991384  0.89788392 0.89783703 0.89764478
 0.89765007 0.89763201 0.89763096 0.89762034 0.89761716 0.8976214
 0.87903517 0.88791376 0.89570873 0.89853386 0.8976184  0.89674221
 0.8962865  0.8960655  0.89590276 0.89558254 0.89762034]
mean test f1 micro: [0.7484 0.7492 0.7484 0.82   0.8238 0.8366 0.8456 0.8446 0.844  0.8442
 0.8442 0.844  0.8442 0.8442 0.8442 0.8442 0.8442 0.8442 0.7492 0.8198
 0.8366 0.8446 0.8444 0.8446 0.8448 0.8446 0.8446 0.8446 0.8442]
                           classifier           C penalty solver  score_acc
0   LogisticRegression(max_iter=5000)      0.1000      l1   saga     0.8456
1   LogisticRegression(max_iter=5000)    100.00

mean test accuracy: [0.7604 0.7604 0.7604 0.8316 0.837  0.8478 0.8528 0.851  0.8516 0.8492
 0.85   0.8492 0.8492 0.8492 0.8492 0.8492 0.8492 0.8492 0.7604 0.8316
 0.8478 0.8506 0.8492 0.8488 0.8486 0.8486 0.8486 0.8486 0.8492]
mean test roc auo ovr: [0.5        0.88513341 0.5        0.89211872 0.8907113  0.89886034
 0.90705871 0.90088774 0.90178015 0.89913511 0.89899485 0.89832624
 0.89830874 0.89824068 0.89824837 0.89823191 0.8982264  0.89821869
 0.8851356  0.89211981 0.89887352 0.90090644 0.89855004 0.89613754
 0.89458715 0.89376516 0.8935655  0.89348972 0.8979383 ]
mean test f1 micro: [0.7604 0.7604 0.7604 0.8316 0.837  0.8478 0.8528 0.851  0.8516 0.8492
 0.85   0.8492 0.8492 0.8492 0.8492 0.8492 0.8492 0.8492 0.7604 0.8316
 0.8478 0.8506 0.8492 0.8488 0.8486 0.8486 0.8486 0.8486 0.8492]
                           classifier           C penalty solver  score_acc
0   LogisticRegression(max_iter=5000)      0.1000      l1   saga     0.8528
1   LogisticRegression(max_iter=5000)      1.0

mean test accuracy: [0.7636 0.7638 0.7636 0.8326 0.8376 0.848  0.854  0.8548 0.8556 0.8554
 0.8552 0.8552 0.8552 0.8552 0.855  0.855  0.855  0.855  0.7638 0.8326
 0.848  0.855  0.8554 0.8548 0.8548 0.8548 0.8548 0.8548 0.8552]
mean test roc auo ovr: [0.5        0.88749434 0.5        0.89510487 0.89087241 0.90290958
 0.91034247 0.90644191 0.90756325 0.90545725 0.90551808 0.90494722
 0.90497381 0.90490177 0.9049062  0.90491174 0.90491729 0.90490509
 0.88749545 0.89510821 0.90292843 0.90643081 0.90519428 0.90353667
 0.90252436 0.9020132  0.90188796 0.90188905 0.90489513]
mean test f1 micro: [0.7636 0.7638 0.7636 0.8326 0.8376 0.848  0.854  0.8548 0.8556 0.8554
 0.8552 0.8552 0.8552 0.8552 0.855  0.855  0.855  0.855  0.7638 0.8326
 0.848  0.855  0.8554 0.8548 0.8548 0.8548 0.8548 0.8548 0.8552]
                           classifier           C penalty solver  score_acc
0   LogisticRegression(max_iter=5000)      1.0000      l1   saga     0.8556
1   LogisticRegression(max_iter=5000)      1.0

mean test accuracy: [0.7636 0.7638 0.7636 0.8326 0.8376 0.848  0.854  0.8548 0.8556 0.8554
 0.8552 0.8552 0.8552 0.8552 0.855  0.8552 0.855  0.8552 0.7638 0.8326
 0.848  0.855  0.8554 0.8548 0.8548 0.8548 0.8548 0.8548 0.855 ]
mean test roc auo ovr: [0.5        0.88749434 0.5        0.89510266 0.89087518 0.90291069
 0.91034136 0.90643859 0.90756213 0.9054994  0.90551586 0.90497162
 0.90496494 0.90491065 0.9049062  0.90488291 0.90491841 0.90491619
 0.88749545 0.89510821 0.90292843 0.90643081 0.90519428 0.90353667
 0.90252436 0.9020132  0.90188796 0.90188905 0.90490841]
mean test f1 micro: [0.7636 0.7638 0.7636 0.8326 0.8376 0.848  0.854  0.8548 0.8556 0.8554
 0.8552 0.8552 0.8552 0.8552 0.855  0.8552 0.855  0.8552 0.7638 0.8326
 0.848  0.855  0.8554 0.8548 0.8548 0.8548 0.8548 0.8548 0.855 ]
                           classifier           C penalty solver  score_acc
0   LogisticRegression(max_iter=5000)      1.0000      l1   saga     0.8556
1   LogisticRegression(max_iter=5000)      1.0

mean test accuracy: [0.7646 0.7646 0.7646 0.832  0.8368 0.8464 0.8528 0.8506 0.85   0.8502
 0.85   0.8494 0.8496 0.8496 0.8496 0.8496 0.8496 0.8496 0.7646 0.832
 0.8466 0.851  0.85   0.8494 0.8496 0.8496 0.8496 0.8496 0.8496]
mean test roc auo ovr: [0.5        0.88598174 0.5        0.89355934 0.88875595 0.89972335
 0.90605849 0.90251293 0.90311326 0.90139818 0.90131709 0.90101598
 0.90100043 0.90095825 0.90096265 0.900956   0.90096599 0.900956
 0.88597951 0.89354933 0.89972003 0.90253632 0.90110504 0.89969064
 0.89890813 0.89855019 0.89845016 0.89844906 0.90096044]
mean test f1 micro: [0.7646 0.7646 0.7646 0.832  0.8368 0.8464 0.8528 0.8506 0.85   0.8502
 0.85   0.8494 0.8496 0.8496 0.8496 0.8496 0.8496 0.8496 0.7646 0.832
 0.8466 0.851  0.85   0.8494 0.8496 0.8496 0.8496 0.8496 0.8496]
                           classifier           C penalty solver  score_acc
0   LogisticRegression(max_iter=5000)      0.1000      l1   saga     0.8528
1   LogisticRegression(max_iter=5000)      0.1000 

mean test accuracy: [0.7646 0.7646 0.7646 0.832  0.8368 0.8464 0.8528 0.8506 0.85   0.8502
 0.85   0.8494 0.8496 0.8496 0.8496 0.8496 0.8496 0.8496 0.7646 0.832
 0.8466 0.851  0.85   0.8494 0.8496 0.8496 0.8496 0.8496 0.8496]
mean test roc auo ovr: [0.5        0.88597729 0.5        0.89355823 0.88875428 0.89972001
 0.9060585  0.90251182 0.90311215 0.90140484 0.90131933 0.90101932
 0.90099821 0.90097488 0.90096932 0.90096043 0.90096153 0.90095931
 0.88597951 0.89354933 0.89972003 0.90253632 0.90110504 0.89969064
 0.89890813 0.89855019 0.89845016 0.89844906 0.90095709]
mean test f1 micro: [0.7646 0.7646 0.7646 0.832  0.8368 0.8464 0.8528 0.8506 0.85   0.8502
 0.85   0.8494 0.8496 0.8496 0.8496 0.8496 0.8496 0.8496 0.7646 0.832
 0.8466 0.851  0.85   0.8494 0.8496 0.8496 0.8496 0.8496 0.8496]
                           classifier           C penalty solver  score_acc
0   LogisticRegression(max_iter=5000)      0.1000      l1   saga     0.8528
1   LogisticRegression(max_iter=5000)      0.100

mean test accuracy: [0.7658 0.766  0.7658 0.8236 0.8342 0.8468 0.853  0.8508 0.8514 0.8506
 0.8512 0.851  0.851  0.851  0.851  0.851  0.851  0.851  0.766  0.8234
 0.8466 0.8508 0.8508 0.851  0.8512 0.8512 0.8512 0.8512 0.851 ]
mean test roc auo ovr: [0.5        0.87737672 0.5        0.88582877 0.88306133 0.8941787
 0.90077167 0.89600571 0.89790354 0.8945157  0.89443102 0.89385979
 0.89383189 0.89372704 0.89375828 0.89374378 0.8936534  0.89376163
 0.87737114 0.88615695 0.89510607 0.89603247 0.89430487 0.89339522
 0.89215616 0.89165962 0.89150676 0.89144204 0.89372146]
mean test f1 micro: [0.7658 0.766  0.7658 0.8236 0.8342 0.8468 0.853  0.8508 0.8514 0.8506
 0.8512 0.851  0.851  0.851  0.851  0.851  0.851  0.851  0.766  0.8234
 0.8466 0.8508 0.8508 0.851  0.8512 0.8512 0.8512 0.8512 0.851 ]
                           classifier           C penalty solver  score_acc
0   LogisticRegression(max_iter=5000)      0.1000      l1   saga     0.8530
1   LogisticRegression(max_iter=5000)      1.00

mean test accuracy: [0.75   0.75   0.75   0.8272 0.8354 0.8508 0.8552 0.8538 0.8552 0.854
 0.8546 0.8548 0.8548 0.8548 0.8548 0.8548 0.8548 0.8548 0.75   0.827
 0.8508 0.8536 0.8544 0.8546 0.8546 0.8546 0.8544 0.8544 0.8548]
mean test roc auo ovr: [0.5        0.8829696  0.5        0.891232   0.88998293 0.89959147
 0.90737013 0.9033856  0.90414933 0.90256747 0.90244907 0.9022208
 0.90221013 0.90219093 0.9021824  0.90217707 0.90218027 0.90217707
 0.88296747 0.89122773 0.89958613 0.90339413 0.90199573 0.90026133
 0.89942187 0.89914347 0.89880107 0.8988384  0.90218133]
mean test f1 micro: [0.75   0.75   0.75   0.8272 0.8354 0.8508 0.8552 0.8538 0.8552 0.854
 0.8546 0.8548 0.8548 0.8548 0.8548 0.8548 0.8548 0.8548 0.75   0.827
 0.8508 0.8536 0.8544 0.8546 0.8546 0.8546 0.8544 0.8544 0.8548]
                           classifier           C penalty solver  score_acc
0   LogisticRegression(max_iter=5000)      0.1000      l1   saga     0.8552
1   LogisticRegression(max_iter=5000)      1.0000  

mean test accuracy: [0.75   0.75   0.75   0.8272 0.8354 0.8508 0.8552 0.8538 0.8552 0.854
 0.8546 0.8548 0.8548 0.8548 0.8548 0.8548 0.8548 0.8548 0.75   0.827
 0.8508 0.8536 0.8544 0.8546 0.8546 0.8546 0.8544 0.8544 0.8548]
mean test roc auo ovr: [0.5        0.88297173 0.5        0.891232   0.889984   0.89958933
 0.90737013 0.9033856  0.90414933 0.9025664  0.90244907 0.9022208
 0.90220907 0.90218773 0.90218133 0.9021792  0.90218347 0.90218027
 0.88296747 0.89122773 0.89958613 0.90339413 0.90199573 0.90026133
 0.89942187 0.89914347 0.89880107 0.8988384  0.902176  ]
mean test f1 micro: [0.75   0.75   0.75   0.8272 0.8354 0.8508 0.8552 0.8538 0.8552 0.854
 0.8546 0.8548 0.8548 0.8548 0.8548 0.8548 0.8548 0.8548 0.75   0.827
 0.8508 0.8536 0.8544 0.8546 0.8546 0.8546 0.8544 0.8544 0.8548]
                           classifier           C penalty solver  score_acc
0   LogisticRegression(max_iter=5000)      0.1000      l1   saga     0.8552
1   LogisticRegression(max_iter=5000)      1.0000  

mean test accuracy: [0.7572 0.7574 0.7572 0.8218 0.827  0.8386 0.8458 0.844  0.845  0.8446
 0.8448 0.8448 0.8448 0.8448 0.8448 0.8448 0.8448 0.8448 0.7574 0.822
 0.8386 0.844  0.8446 0.8448 0.845  0.8452 0.8452 0.8452 0.8448]
mean test roc auo ovr: [0.5        0.87787724 0.5        0.88459208 0.88363187 0.89191926
 0.89933731 0.89466469 0.89513632 0.8929701  0.89273652 0.89236207
 0.89231861 0.8922707  0.89228813 0.89226745 0.89226963 0.89227725
 0.87787289 0.88459425 0.89191816 0.89464072 0.89242906 0.89033232
 0.88912333 0.88863806 0.8885195  0.88848582 0.89225433]
mean test f1 micro: [0.7572 0.7574 0.7572 0.8218 0.827  0.8386 0.8458 0.844  0.845  0.8446
 0.8448 0.8448 0.8448 0.8448 0.8448 0.8448 0.8448 0.8448 0.7574 0.822
 0.8386 0.844  0.8446 0.8448 0.845  0.8452 0.8452 0.8452 0.8448]
                           classifier           C penalty solver  score_acc
0   LogisticRegression(max_iter=5000)      0.1000      l1   saga     0.8458
1   LogisticRegression(max_iter=5000)         Na

mean test accuracy: [0.7572 0.7574 0.7572 0.8218 0.827  0.8386 0.8458 0.844  0.845  0.8446
 0.8448 0.8448 0.8448 0.8448 0.8448 0.8448 0.8448 0.8448 0.7574 0.822
 0.8386 0.844  0.8446 0.8448 0.845  0.8452 0.8452 0.8452 0.8448]
mean test roc auo ovr: [0.5        0.87787833 0.5        0.88459425 0.88363241 0.89191926
 0.8993384  0.89467013 0.89513741 0.8929701  0.89272454 0.89236862
 0.89232514 0.89227941 0.89228596 0.89228377 0.89227832 0.89227835
 0.87787289 0.88459425 0.89191816 0.89464072 0.89242906 0.89033232
 0.88912333 0.88863806 0.8885195  0.88848582 0.89227071]
mean test f1 micro: [0.7572 0.7574 0.7572 0.8218 0.827  0.8386 0.8458 0.844  0.845  0.8446
 0.8448 0.8448 0.8448 0.8448 0.8448 0.8448 0.8448 0.8448 0.7574 0.822
 0.8386 0.844  0.8446 0.8448 0.845  0.8452 0.8452 0.8452 0.8448]
                           classifier           C penalty solver  score_acc
0   LogisticRegression(max_iter=5000)      0.1000      l1   saga     0.8458
1   LogisticRegression(max_iter=5000)         Na

mean test accuracy: [0.7598 0.7596 0.7598 0.819  0.8214 0.835  0.8412 0.841  0.8416 0.8418
 0.8414 0.8414 0.8416 0.8418 0.8416 0.8416 0.8418 0.8418 0.7594 0.8188
 0.8346 0.841  0.8418 0.8416 0.8414 0.8414 0.8414 0.8414 0.8418]
mean test roc auo ovr: [0.5        0.87424334 0.5        0.88353302 0.8790017  0.89235364
 0.89990757 0.89721661 0.89788112 0.89686777 0.8967555  0.89652973
 0.89650234 0.89651106 0.89650016 0.89649795 0.89652757 0.89650564
 0.87454637 0.88300162 0.8918343  0.89723306 0.89640186 0.89576776
 0.89460777 0.89403174 0.8936093  0.89354379 0.89649139]
mean test f1 micro: [0.7598 0.7596 0.7598 0.819  0.8214 0.835  0.8412 0.841  0.8416 0.8418
 0.8414 0.8414 0.8416 0.8418 0.8416 0.8416 0.8418 0.8418 0.7594 0.8188
 0.8346 0.841  0.8418 0.8416 0.8414 0.8414 0.8414 0.8414 0.8418]
                           classifier           C penalty solver  score_acc
0   LogisticRegression(max_iter=5000)         NaN    none   saga     0.8418
1   LogisticRegression(max_iter=5000)      1.0

mean test accuracy: [0.7522 0.753  0.7522 0.8208 0.8278 0.8368 0.8442 0.8454 0.8448 0.8456
 0.8444 0.8452 0.845  0.8452 0.8452 0.8452 0.8452 0.8452 0.7528 0.8208
 0.8368 0.8452 0.8454 0.845  0.8456 0.8456 0.8456 0.8456 0.8452]
mean test roc auo ovr: [0.5        0.88269855 0.5        0.89023228 0.88601576 0.89811628
 0.90344752 0.90147777 0.90228255 0.90122147 0.90114307 0.90099302
 0.90095327 0.90093181 0.90093074 0.90091786 0.90093613 0.90092324
 0.88188776 0.88941827 0.89812271 0.90069809 0.90102725 0.89933127
 0.8996259  0.89928433 0.89912953 0.8990962  0.90092539]
mean test f1 micro: [0.7522 0.753  0.7522 0.8208 0.8278 0.8368 0.8442 0.8454 0.8448 0.8456
 0.8444 0.8452 0.845  0.8452 0.8452 0.8452 0.8452 0.8452 0.7528 0.8208
 0.8368 0.8452 0.8454 0.845  0.8456 0.8456 0.8456 0.8456 0.8452]
                           classifier           C penalty solver  score_acc
0   LogisticRegression(max_iter=5000)         NaN    none  lbfgs     0.8456
1   LogisticRegression(max_iter=5000)  10000.0

mean test accuracy: [0.7522 0.753  0.7522 0.8208 0.8278 0.8368 0.8442 0.8454 0.8448 0.8456
 0.8444 0.8452 0.8452 0.8452 0.8452 0.8452 0.8452 0.8452 0.7528 0.8208
 0.8368 0.8452 0.8454 0.845  0.8456 0.8456 0.8456 0.8456 0.8452]
mean test roc auo ovr: [0.5        0.88269855 0.5        0.89023549 0.88601254 0.89811842
 0.90344645 0.90147777 0.90228363 0.90122148 0.90113771 0.90098872
 0.90095649 0.90092324 0.90092322 0.90092753 0.90097699 0.90093615
 0.88188776 0.88941827 0.89812271 0.90069809 0.90102725 0.89933127
 0.8996259  0.89928433 0.89912953 0.8990962  0.9009125 ]
mean test f1 micro: [0.7522 0.753  0.7522 0.8208 0.8278 0.8368 0.8442 0.8454 0.8448 0.8456
 0.8444 0.8452 0.8452 0.8452 0.8452 0.8452 0.8452 0.8452 0.7528 0.8208
 0.8368 0.8452 0.8454 0.845  0.8456 0.8456 0.8456 0.8456 0.8452]
                           classifier           C penalty solver  score_acc
0   LogisticRegression(max_iter=5000)         NaN    none  lbfgs     0.8456
1   LogisticRegression(max_iter=5000)  10000.0

mean test accuracy: [0.7556 0.7556 0.7556 0.826  0.8342 0.8436 0.8502 0.848  0.8484 0.8482
 0.8482 0.8484 0.8482 0.8482 0.8482 0.8482 0.8482 0.8482 0.7558 0.8262
 0.8436 0.848  0.8484 0.848  0.848  0.848  0.848  0.848  0.8482]
mean test roc auo ovr: [0.5        0.88670184 0.5        0.89357438 0.88961651 0.8981934
 0.90558155 0.89938444 0.90097561 0.89811316 0.89812069 0.8973854
 0.89735832 0.89727819 0.89726303 0.8972511  0.89725978 0.89724464
 0.88696458 0.89383278 0.89819991 0.89941047 0.89819291 0.89598842
 0.89496629 0.89449138 0.89436808 0.89436374 0.89726085]
mean test f1 micro: [0.7556 0.7556 0.7556 0.826  0.8342 0.8436 0.8502 0.848  0.8484 0.8482
 0.8482 0.8484 0.8482 0.8482 0.8482 0.8482 0.8482 0.8482 0.7558 0.8262
 0.8436 0.848  0.8484 0.848  0.848  0.848  0.848  0.848  0.8482]
                           classifier           C penalty solver  score_acc
0   LogisticRegression(max_iter=5000)      0.1000      l1   saga     0.8502
1   LogisticRegression(max_iter=5000)     10.000

mean test accuracy: [0.7556 0.7556 0.7556 0.826  0.8342 0.8436 0.8502 0.848  0.8484 0.8482
 0.8482 0.8484 0.8482 0.848  0.848  0.8482 0.8482 0.8482 0.7558 0.8262
 0.8436 0.848  0.8484 0.848  0.848  0.848  0.848  0.848  0.8482]
mean test roc auo ovr: [0.5        0.88669859 0.5        0.89357438 0.88961597 0.8981934
 0.90558047 0.89938444 0.90097561 0.89813369 0.89812394 0.89738866
 0.89736375 0.8972576  0.89723919 0.89725437 0.89725762 0.89725978
 0.88696458 0.89383278 0.89819991 0.89941047 0.89819291 0.89598842
 0.89496629 0.89449138 0.89436808 0.89436374 0.89726087]
mean test f1 micro: [0.7556 0.7556 0.7556 0.826  0.8342 0.8436 0.8502 0.848  0.8484 0.8482
 0.8482 0.8484 0.8482 0.848  0.848  0.8482 0.8482 0.8482 0.7558 0.8262
 0.8436 0.848  0.8484 0.848  0.848  0.848  0.848  0.848  0.8482]
                           classifier           C penalty solver  score_acc
0   LogisticRegression(max_iter=5000)      0.1000      l1   saga     0.8502
1   LogisticRegression(max_iter=5000)      1.00

mean test accuracy: [0.7502 0.7506 0.7502 0.8208 0.8254 0.8372 0.8426 0.8432 0.844  0.845
 0.8446 0.8444 0.8444 0.8444 0.8444 0.8444 0.8444 0.8444 0.7506 0.821
 0.8372 0.8432 0.8448 0.8444 0.8446 0.8446 0.8446 0.8446 0.8444]
mean test roc auo ovr: [0.5        0.87947549 0.5        0.88780461 0.88375423 0.89512345
 0.90190092 0.89930148 0.90041523 0.89963467 0.89971678 0.89956726
 0.89959395 0.89957687 0.89958007 0.89957047 0.89957367 0.89956297
 0.87948082 0.88779927 0.89513519 0.89937295 0.89966344 0.89927261
 0.89895996 0.89876046 0.8986762  0.89867513 0.8995694 ]
mean test f1 micro: [0.7502 0.7506 0.7502 0.8208 0.8254 0.8372 0.8426 0.8432 0.844  0.845
 0.8446 0.8444 0.8444 0.8444 0.8444 0.8444 0.8444 0.8444 0.7506 0.821
 0.8372 0.8432 0.8448 0.8444 0.8446 0.8446 0.8446 0.8446 0.8444]
                           classifier           C penalty solver  score_acc
0   LogisticRegression(max_iter=5000)      1.0000      l2   saga     0.8450
1   LogisticRegression(max_iter=5000)      1.0000 

mean test accuracy: [0.7616 0.7616 0.7616 0.8292 0.8358 0.8414 0.8474 0.8444 0.8454 0.844
 0.8442 0.8438 0.8438 0.8438 0.8438 0.8438 0.8438 0.8438 0.7616 0.8288
 0.8412 0.8444 0.8438 0.8434 0.844  0.844  0.844  0.8442 0.8438]
mean test roc auo ovr: [0.5        0.87647948 0.5        0.88492295 0.88340904 0.89220364
 0.89877535 0.89418377 0.89522548 0.89311799 0.89317491 0.89266298
 0.89259913 0.89255606 0.89262222 0.8926255  0.8926167  0.8926145
 0.87647507 0.88492294 0.89221686 0.89417822 0.89264797 0.89103482
 0.89023583 0.88993225 0.88988277 0.88986408 0.89232969]
mean test f1 micro: [0.7616 0.7616 0.7616 0.8292 0.8358 0.8414 0.8474 0.8444 0.8454 0.844
 0.8442 0.8438 0.8438 0.8438 0.8438 0.8438 0.8438 0.8438 0.7616 0.8288
 0.8412 0.8444 0.8438 0.8434 0.844  0.844  0.844  0.8442 0.8438]
                           classifier           C penalty solver  score_acc
0   LogisticRegression(max_iter=5000)      0.1000      l1   saga     0.8474
1   LogisticRegression(max_iter=5000)      1.0000

mean test accuracy: [0.7616 0.7616 0.7616 0.8292 0.8358 0.8414 0.8474 0.8444 0.8454 0.844
 0.8442 0.8438 0.8438 0.8438 0.8438 0.8438 0.8438 0.8438 0.7616 0.8288
 0.8412 0.8444 0.8438 0.8434 0.844  0.844  0.844  0.8442 0.8438]
mean test roc auo ovr: [0.5        0.87648169 0.5        0.88492516 0.88341069 0.89220805
 0.89878196 0.89418486 0.89522439 0.89314116 0.89317161 0.89254389
 0.89267397 0.89260788 0.89247816 0.8926178  0.8926134  0.89251758
 0.87647507 0.88492294 0.89221686 0.89417822 0.89264797 0.89103482
 0.89023583 0.88993225 0.88988277 0.88986408 0.89258039]
mean test f1 micro: [0.7616 0.7616 0.7616 0.8292 0.8358 0.8414 0.8474 0.8444 0.8454 0.844
 0.8442 0.8438 0.8438 0.8438 0.8438 0.8438 0.8438 0.8438 0.7616 0.8288
 0.8412 0.8444 0.8438 0.8434 0.844  0.844  0.844  0.8442 0.8438]
                           classifier           C penalty solver  score_acc
0   LogisticRegression(max_iter=5000)      0.1000      l1   saga     0.8474
1   LogisticRegression(max_iter=5000)      1.000

mean test accuracy: [0.7508 0.7514 0.7508 0.8208 0.8246 0.838  0.8436 0.8404 0.8416 0.841
 0.8414 0.8416 0.8416 0.8416 0.8416 0.8416 0.8416 0.8416 0.7512 0.8206
 0.838  0.84   0.8408 0.841  0.841  0.841  0.841  0.841  0.8416]
mean test roc auo ovr: [0.5        0.87580521 0.5        0.88371752 0.87961451 0.88965493
 0.8975849  0.89276942 0.89403053 0.89259468 0.89265876 0.8924718
 0.89248143 0.89239817 0.89245471 0.89245791 0.89246434 0.89244511
 0.875002   0.88291859 0.889656   0.89197158 0.89146983 0.89049536
 0.89082399 0.8906636  0.89063258 0.89063258 0.89245365]
mean test f1 micro: [0.7508 0.7514 0.7508 0.8208 0.8246 0.838  0.8436 0.8404 0.8416 0.841
 0.8414 0.8416 0.8416 0.8416 0.8416 0.8416 0.8416 0.8416 0.7512 0.8206
 0.838  0.84   0.8408 0.841  0.841  0.841  0.841  0.841  0.8416]
                           classifier           C penalty solver  score_acc
0   LogisticRegression(max_iter=5000)      0.1000      l1   saga     0.8436
1   LogisticRegression(max_iter=5000)   1000.0000

mean test accuracy: [0.7508 0.7514 0.7508 0.8208 0.8246 0.838  0.8436 0.8404 0.8416 0.841
 0.8414 0.8416 0.8416 0.8416 0.8416 0.8416 0.8416 0.8416 0.7512 0.8206
 0.838  0.84   0.8408 0.841  0.841  0.841  0.841  0.841  0.8416]
mean test roc auo ovr: [0.5        0.87580414 0.5        0.88371539 0.87961451 0.88965599
 0.8975849  0.89277049 0.8940316  0.89260857 0.89265662 0.89247928
 0.8924825  0.89245792 0.89245793 0.89244938 0.89245579 0.89246003
 0.875002   0.88291859 0.889656   0.89197158 0.89146983 0.89049536
 0.89082399 0.8906636  0.89063258 0.89063258 0.89245578]
mean test f1 micro: [0.7508 0.7514 0.7508 0.8208 0.8246 0.838  0.8436 0.8404 0.8416 0.841
 0.8414 0.8416 0.8416 0.8416 0.8416 0.8416 0.8416 0.8416 0.7512 0.8206
 0.838  0.84   0.8408 0.841  0.841  0.841  0.841  0.841  0.8416]
                           classifier           C penalty solver  score_acc
0   LogisticRegression(max_iter=5000)      0.1000      l1   saga     0.8436
1   LogisticRegression(max_iter=5000)   1000.000

mean test accuracy: [0.7654 0.7656 0.7654 0.8302 0.8342 0.8478 0.8556 0.8556 0.853  0.8534
 0.8536 0.8534 0.8536 0.8536 0.8536 0.8536 0.8536 0.8536 0.7656 0.8298
 0.8474 0.8556 0.8532 0.853  0.8536 0.8536 0.8536 0.8536 0.8536]
mean test roc auo ovr: [0.5        0.88586367 0.5        0.89516487 0.88878604 0.9010906
 0.90966069 0.90588708 0.90771489 0.90587797 0.90598706 0.90547221
 0.90545549 0.90538084 0.90537639 0.90536192 0.90536971 0.90537192
 0.88586367 0.89516376 0.90056883 0.90530237 0.90580989 0.90267552
 0.90412044 0.90395097 0.90383938 0.90378023 0.9053697 ]
mean test f1 micro: [0.7654 0.7656 0.7654 0.8302 0.8342 0.8478 0.8556 0.8556 0.853  0.8534
 0.8536 0.8534 0.8536 0.8536 0.8536 0.8536 0.8536 0.8536 0.7656 0.8298
 0.8474 0.8556 0.8532 0.853  0.8536 0.8536 0.8536 0.8536 0.8536]
                           classifier           C penalty solver  score_acc
0   LogisticRegression(max_iter=5000)      0.1000      l1   saga     0.8556
1   LogisticRegression(max_iter=5000)      0.10

mean test accuracy: [0.7602 0.7602 0.7602 0.825  0.832  0.835  0.8444 0.8404 0.8408 0.8406
 0.8406 0.8404 0.8404 0.8404 0.8404 0.8404 0.8404 0.8404 0.76   0.825
 0.8346 0.8402 0.8406 0.8406 0.8404 0.8406 0.8406 0.8406 0.8404]
mean test roc auo ovr: [0.5        0.87551548 0.5        0.88254979 0.88182367 0.88874054
 0.89583849 0.89088088 0.89207231 0.88967627 0.88967742 0.88922851
 0.88922631 0.88917911 0.88917691 0.88916813 0.88917803 0.88916703
 0.87525197 0.88254212 0.88821752 0.89062504 0.88914426 0.88727971
 0.88628653 0.88589696 0.8857237  0.88571162 0.88916814]
mean test f1 micro: [0.7602 0.7602 0.7602 0.825  0.832  0.835  0.8444 0.8404 0.8408 0.8406
 0.8406 0.8404 0.8404 0.8404 0.8404 0.8404 0.8404 0.8404 0.76   0.825
 0.8346 0.8402 0.8406 0.8406 0.8404 0.8406 0.8406 0.8406 0.8404]
                           classifier           C penalty solver  score_acc
0   LogisticRegression(max_iter=5000)      0.1000      l1   saga     0.8444
1   LogisticRegression(max_iter=5000)      1.000

mean test accuracy: [0.7602 0.7602 0.7602 0.825  0.832  0.835  0.8444 0.8404 0.8408 0.8406
 0.8406 0.8404 0.8404 0.8404 0.8404 0.8404 0.8404 0.8404 0.76   0.825
 0.8346 0.8402 0.8406 0.8406 0.8404 0.8406 0.8406 0.8406 0.8404]
mean test roc auo ovr: [0.5        0.87551329 0.5        0.88255089 0.88182367 0.88874164
 0.8958374  0.89087978 0.8920734  0.88967407 0.88966973 0.8892428
 0.88922084 0.88917801 0.88917144 0.88916594 0.88917363 0.88916704
 0.87525197 0.88254212 0.88821752 0.89062504 0.88914426 0.88727971
 0.88628653 0.88589696 0.8857237  0.88571162 0.88917692]
mean test f1 micro: [0.7602 0.7602 0.7602 0.825  0.832  0.835  0.8444 0.8404 0.8408 0.8406
 0.8406 0.8404 0.8404 0.8404 0.8404 0.8404 0.8404 0.8404 0.76   0.825
 0.8346 0.8402 0.8406 0.8406 0.8404 0.8406 0.8406 0.8406 0.8404]
                           classifier           C penalty solver  score_acc
0   LogisticRegression(max_iter=5000)      0.1000      l1   saga     0.8444
1   LogisticRegression(max_iter=5000)      1.0000

In [119]:
adult_logreg_list

[array([0.7502, 0.7506, 0.7502, 0.8208, 0.8254, 0.8372, 0.8426, 0.8432,
        0.844 , 0.845 , 0.8446, 0.8444, 0.8444, 0.8444, 0.8444, 0.8444,
        0.8444, 0.8444, 0.7506, 0.821 , 0.8372, 0.8432, 0.8448, 0.8444,
        0.8446, 0.8446, 0.8446, 0.8446, 0.8444]),
 array([0.7502, 0.7506, 0.7502, 0.8208, 0.8254, 0.8372, 0.8426, 0.8432,
        0.844 , 0.845 , 0.8446, 0.8444, 0.8444, 0.8444, 0.8444, 0.8444,
        0.8444, 0.8444, 0.7506, 0.821 , 0.8372, 0.8432, 0.8448, 0.8444,
        0.8446, 0.8446, 0.8446, 0.8446, 0.8444]),
 array([0.7502, 0.7506, 0.7502, 0.8208, 0.8254, 0.8372, 0.8426, 0.8432,
        0.844 , 0.845 , 0.8446, 0.8444, 0.8444, 0.8444, 0.8444, 0.8444,
        0.8444, 0.8444, 0.7506, 0.821 , 0.8372, 0.8432, 0.8448, 0.8444,
        0.8446, 0.8446, 0.8446, 0.8446, 0.8444]),
 array([0.7502, 0.7506, 0.7502, 0.8208, 0.8254, 0.8372, 0.8426, 0.8432,
        0.844 , 0.8448, 0.8446, 0.8444, 0.8444, 0.8444, 0.8444, 0.8444,
        0.8444, 0.8444, 0.7506, 0.821 , 0.8372, 0.8432, 0.

In [112]:
logreg_grid.cv_results_

{'mean_fit_time': array([ 0.0272016 ,  0.19839382,  0.05079908,  0.34379282,  0.87417903,
         0.94077625,  2.93411336,  3.49889336, 14.82353082,  2.88011231,
         3.29389906,  2.89690971,  3.36209679,  2.94850655,  3.43409433,
         2.9655036 ,  3.43489194,  2.94571066,  0.02860184,  0.02907195,
         0.03499618,  0.05094495,  0.07038355,  0.10626822,  0.08104458,
         0.08522997,  0.08938751,  0.09069343,  2.60451736]),
 'std_fit_time': array([2.39840167e-03, 2.86239196e-02, 2.40563392e-03, 1.16208133e-01,
        1.10575993e-01, 5.86170024e-01, 8.97229732e-01, 2.64280574e+00,
        1.13411823e+01, 1.17979523e+00, 7.20912251e-01, 6.27423425e-01,
        6.87207564e-01, 6.09401858e-01, 6.97364705e-01, 5.92709278e-01,
        7.05964889e-01, 6.13816826e-01, 4.86643319e-04, 9.82944323e-04,
        2.00602119e-03, 3.58633356e-03, 1.27282459e-02, 3.79086079e-02,
        1.09741849e-02, 1.51864563e-02, 1.24687836e-02, 1.91512260e-02,
        5.38126118e-01]),
 'mean_sco

In [113]:
logreg_grid.cv_results_['rank_test_accuracy']

array([26, 26, 26, 24, 23, 21,  1, 10,  2,  6,  6, 10, 10, 10, 10, 10, 10,
       10, 29, 24, 22, 20,  6,  6, 10,  3,  3,  3, 10])

In [114]:
logreg_grid.cv_results_['params'][ np.argmin(logreg_grid.cv_results_['rank_test_accuracy'])]

{'classifier': LogisticRegression(max_iter=5000),
 'classifier__C': 0.1,
 'classifier__penalty': 'l1',
 'classifier__solver': 'saga'}

In [115]:
logreg_grid.cv_results_['params'][ np.argmin(logreg_grid.cv_results_['rank_test_roc_auc_ovr'])]

{'classifier': LogisticRegression(max_iter=5000),
 'classifier__C': 0.1,
 'classifier__penalty': 'l1',
 'classifier__solver': 'saga'}

In [116]:
logreg_results = pd.DataFrame(logreg_grid.cv_results_['params'])
logreg_results['score_acc'] = logreg_grid.cv_results_['mean_test_accuracy']
logreg_cols = logreg_results.columns.to_series().str.split('__').apply(lambda x: x[-1])
logreg_results.columns = logreg_cols
logreg_results = logreg_results.sort_values(by=['score_acc'], ascending=False, ignore_index=True)
logreg_results

Unnamed: 0,classifier,C,penalty,solver,score_acc
0,LogisticRegression(max_iter=5000),0.1,l1,saga,0.8444
1,LogisticRegression(max_iter=5000),1.0,l1,saga,0.8408
2,LogisticRegression(max_iter=5000),,none,lbfgs,0.8406
3,LogisticRegression(max_iter=5000),10000.0,l2,lbfgs,0.8406
4,LogisticRegression(max_iter=5000),1000.0,l2,lbfgs,0.8406
5,LogisticRegression(max_iter=5000),10.0,l2,lbfgs,0.8406
6,LogisticRegression(max_iter=5000),1.0,l2,saga,0.8406
7,LogisticRegression(max_iter=5000),10.0,l1,saga,0.8406
8,LogisticRegression(max_iter=5000),1.0,l2,lbfgs,0.8406
9,LogisticRegression(max_iter=5000),1000.0,l1,saga,0.8404


In [117]:
X_train, X_test, y_train, y_test = train_test_split(adult_onehot, adult_salary, train_size=5000)
sc = StandardScaler()
X_train = pd.DataFrame(sc.fit_transform(X_train), columns=X_train.columns)
X_test = sc.transform(X_test)
logreg_best = linear_model.LogisticRegression(penalty='l1', C=0.1, solver='saga', max_iter=5000)
logreg_best.fit(X_train, y_train)
pred_logreg = logreg_best.predict(X_test)
print('Logistic Regression:', classification_report(y_test, pred_logreg))

Logistic Regression:               precision    recall  f1-score   support

           0       0.88      0.94      0.91     20941
           1       0.74      0.58      0.65      6620

    accuracy                           0.85     27561
   macro avg       0.81      0.76      0.78     27561
weighted avg       0.84      0.85      0.85     27561



In [120]:
# KNN
for i in range(5):
    print('trial', i+1)
    adult_knn_list = []
    
    for j in range(5):
        X_train, X_test, y_train, y_test = train_test_split(adult_onehot, adult_salary, train_size=5000)
        sc = StandardScaler()
        X_train = pd.DataFrame(sc.fit_transform(X_train), columns=X_train.columns)
        X_test = sc.transform(X_test)
        
        for k in range(5):
            adult_knn = KNeighborsClassifier(weights='distance')
            k_list = [1, 5 , 9 , 13, 17, 21, 25, 29, 33, 37, 41, 45, 49, 53, 57, 61, 65, 69, 73, 77, 81, 85, 89, 93, 97, 101]
            knn_grid = GridSearchCV(adult_knn, {'n_neighbors':k_list}, cv=StratifiedKFold(n_splits=5), scoring=['accuracy', 'roc_auc_ovr', 'f1_micro'], refit=False,
                   verbose=0, return_train_score=True)
            knn_grid.fit(X_train, y_train)
            print('mean test accuracy:', knn_grid.cv_results_['mean_test_accuracy'])
            print('mean test roc auo ovr:', knn_grid.cv_results_['mean_test_roc_auc_ovr'])
            print('mean test f1 micro:', knn_grid.cv_results_['mean_test_f1_micro'])
            knn_results = pd.DataFrame(knn_grid.cv_results_['params'])
            knn_results['score_acc'] = knn_grid.cv_results_['mean_test_accuracy']
            adult_knn_list.append(knn_grid.cv_results_['mean_test_accuracy'])
            knn_cols = knn_results.columns.to_series().str.split('__').apply(lambda x: x[-1])
            knn_results.columns = knn_cols
            knn_results = knn_results.sort_values(by=['score_acc'], ascending=False, ignore_index=True)
            print(knn_results)

trial 1
mean test accuracy: [0.773  0.7988 0.8066 0.8088 0.8102 0.8124 0.8162 0.8136 0.817  0.8184
 0.8198 0.8218 0.8242 0.8238 0.8248 0.827  0.8258 0.8262 0.8272 0.8282
 0.828  0.8278 0.8284 0.8286 0.8278 0.8254]
mean test roc auo ovr: [0.7023692  0.82761777 0.84735191 0.85315679 0.85808026 0.86177443
 0.86504225 0.86783221 0.86972259 0.87166564 0.87327082 0.8747299
 0.87588092 0.87659347 0.8772851  0.87734147 0.87792907 0.87835702
 0.87911517 0.87936473 0.87996349 0.88029356 0.8810508  0.88093409
 0.88149764 0.88163411]
mean test f1 micro: [0.773  0.7988 0.8066 0.8088 0.8102 0.8124 0.8162 0.8136 0.817  0.8184
 0.8198 0.8218 0.8242 0.8238 0.8248 0.827  0.8258 0.8262 0.8272 0.8282
 0.828  0.8278 0.8284 0.8286 0.8278 0.8254]
    n_neighbors  score_acc
0            93     0.8286
1            89     0.8284
2            77     0.8282
3            81     0.8280
4            97     0.8278
5            85     0.8278
6            73     0.8272
7            61     0.8270
8            69     0.8

mean test accuracy: [0.777  0.8096 0.8198 0.8246 0.8226 0.8232 0.8262 0.83   0.8326 0.8318
 0.832  0.8302 0.8332 0.8322 0.833  0.834  0.8338 0.8338 0.8324 0.8348
 0.834  0.8338 0.8326 0.8358 0.8336 0.8336]
mean test roc auo ovr: [0.69175682 0.8184085  0.84177552 0.85243441 0.85747804 0.86044795
 0.86252118 0.8647565  0.8662624  0.86799233 0.86994807 0.8700049
 0.87067517 0.87135659 0.87146685 0.87140444 0.87151511 0.87179892
 0.87204646 0.87224868 0.87191282 0.8722515  0.87241119 0.87259781
 0.87288726 0.8732764 ]
mean test f1 micro: [0.777  0.8096 0.8198 0.8246 0.8226 0.8232 0.8262 0.83   0.8326 0.8318
 0.832  0.8302 0.8332 0.8322 0.833  0.834  0.8338 0.8338 0.8324 0.8348
 0.834  0.8338 0.8326 0.8358 0.8336 0.8336]
    n_neighbors  score_acc
0            93     0.8358
1            77     0.8348
2            81     0.8340
3            61     0.8340
4            85     0.8338
5            69     0.8338
6            65     0.8338
7           101     0.8336
8            97     0.8336
9   

mean test accuracy: [0.776  0.8072 0.814  0.8182 0.8222 0.8236 0.8196 0.8212 0.8254 0.8258
 0.8274 0.8296 0.827  0.8254 0.8262 0.8246 0.8248 0.8266 0.826  0.8244
 0.8244 0.8232 0.8236 0.8248 0.8264 0.8274]
mean test roc auo ovr: [0.69031858 0.81571523 0.83684435 0.84816376 0.85452826 0.8581178
 0.86144411 0.8635057  0.86491863 0.8665815  0.86824604 0.86913969
 0.86984754 0.8699139  0.8702186  0.87076386 0.87101437 0.87103539
 0.87110894 0.87111557 0.87159834 0.8721801  0.87263356 0.87289292
 0.87337514 0.87366878]
mean test f1 micro: [0.776  0.8072 0.814  0.8182 0.8222 0.8236 0.8196 0.8212 0.8254 0.8258
 0.8274 0.8296 0.827  0.8254 0.8262 0.8246 0.8248 0.8266 0.826  0.8244
 0.8244 0.8232 0.8236 0.8248 0.8264 0.8274]
    n_neighbors  score_acc
0            45     0.8296
1            41     0.8274
2           101     0.8274
3            49     0.8270
4            69     0.8266
5            97     0.8264
6            57     0.8262
7            73     0.8260
8            37     0.8258
9   

mean test accuracy: [0.7762 0.805  0.8088 0.8146 0.8194 0.8174 0.8204 0.825  0.8262 0.8246
 0.8238 0.8234 0.822  0.8224 0.8224 0.823  0.8228 0.8234 0.8248 0.8272
 0.8258 0.826  0.8266 0.8264 0.8262 0.8268]
mean test roc auo ovr: [0.69517656 0.81864949 0.84218819 0.85291003 0.85894255 0.86305991
 0.86661692 0.86907082 0.87059432 0.87199693 0.87332919 0.8738435
 0.87424616 0.87478203 0.8748417  0.87560368 0.875429   0.87579511
 0.87639415 0.8766337  0.87699888 0.87756152 0.87795892 0.87826565
 0.878552   0.87900847]
mean test f1 micro: [0.7762 0.805  0.8088 0.8146 0.8194 0.8174 0.8204 0.825  0.8262 0.8246
 0.8238 0.8234 0.822  0.8224 0.8224 0.823  0.8228 0.8234 0.8248 0.8272
 0.8258 0.826  0.8266 0.8264 0.8262 0.8268]
    n_neighbors  score_acc
0            77     0.8272
1           101     0.8268
2            89     0.8266
3            93     0.8264
4            97     0.8262
5            33     0.8262
6            85     0.8260
7            81     0.8258
8            29     0.8250
9   

mean test accuracy: [0.7748 0.805  0.8164 0.8158 0.8174 0.8192 0.8208 0.8216 0.8244 0.8252
 0.8254 0.8256 0.8248 0.8246 0.8234 0.823  0.8226 0.822  0.8228 0.823
 0.8232 0.8246 0.8244 0.8228 0.824  0.8242]
mean test roc auo ovr: [0.69147946 0.81747344 0.84420757 0.85420711 0.86113721 0.86447915
 0.86660428 0.86972265 0.8714252  0.87249008 0.87341013 0.87438471
 0.87491045 0.87582471 0.87613394 0.876576   0.87683641 0.87680566
 0.87691067 0.8769784  0.87651241 0.87679089 0.87722913 0.8770763
 0.87753044 0.87783031]
mean test f1 micro: [0.7748 0.805  0.8164 0.8158 0.8174 0.8192 0.8208 0.8216 0.8244 0.8252
 0.8254 0.8256 0.8248 0.8246 0.8234 0.823  0.8226 0.822  0.8228 0.823
 0.8232 0.8246 0.8244 0.8228 0.824  0.8242]
    n_neighbors  score_acc
0            45     0.8256
1            41     0.8254
2            37     0.8252
3            49     0.8248
4            53     0.8246
5            85     0.8246
6            33     0.8244
7            89     0.8244
8           101     0.8242
9     

mean test accuracy: [0.7818 0.8024 0.8146 0.817  0.8156 0.8158 0.8162 0.82   0.8226 0.8234
 0.8234 0.8244 0.826  0.8272 0.8276 0.8276 0.8302 0.8286 0.8308 0.8312
 0.829  0.831  0.8304 0.8298 0.8324 0.8308]
mean test roc auo ovr: [0.70496029 0.8269343  0.84715877 0.85286245 0.85904556 0.86298747
 0.86495803 0.86686067 0.86821198 0.86948043 0.87028442 0.8709009
 0.87078794 0.87155509 0.87197935 0.87254911 0.87371843 0.87426624
 0.87453578 0.87499587 0.87474721 0.87522683 0.87554687 0.87564238
 0.87580144 0.87587709]
mean test f1 micro: [0.7818 0.8024 0.8146 0.817  0.8156 0.8158 0.8162 0.82   0.8226 0.8234
 0.8234 0.8244 0.826  0.8272 0.8276 0.8276 0.8302 0.8286 0.8308 0.8312
 0.829  0.831  0.8304 0.8298 0.8324 0.8308]
    n_neighbors  score_acc
0            97     0.8324
1            77     0.8312
2            85     0.8310
3            73     0.8308
4           101     0.8308
5            89     0.8304
6            65     0.8302
7            93     0.8298
8            81     0.8290
9   

mean test accuracy: [0.7958 0.8162 0.8258 0.8318 0.829  0.8314 0.832  0.8354 0.8374 0.8382
 0.8396 0.8402 0.841  0.8416 0.842  0.8396 0.8406 0.841  0.8382 0.8378
 0.8396 0.8404 0.8396 0.8402 0.8396 0.8396]
mean test roc auo ovr: [0.72157348 0.83338158 0.85510213 0.86396344 0.8681904  0.87103332
 0.87332233 0.87725401 0.87947914 0.88093879 0.88182923 0.88185035
 0.88241592 0.88280557 0.883052   0.88314794 0.88378735 0.8839244
 0.88423388 0.88467679 0.88562029 0.8863535  0.8861068  0.88723287
 0.88775466 0.88799922]
mean test f1 micro: [0.7958 0.8162 0.8258 0.8318 0.829  0.8314 0.832  0.8354 0.8374 0.8382
 0.8396 0.8402 0.841  0.8416 0.842  0.8396 0.8406 0.841  0.8382 0.8378
 0.8396 0.8404 0.8396 0.8402 0.8396 0.8396]
    n_neighbors  score_acc
0            57     0.8420
1            53     0.8416
2            49     0.8410
3            69     0.8410
4            65     0.8406
5            85     0.8404
6            93     0.8402
7            45     0.8402
8            89     0.8396
9   

mean test accuracy: [0.759  0.788  0.8002 0.8042 0.8084 0.809  0.8138 0.815  0.8144 0.8154
 0.8152 0.8158 0.8166 0.8184 0.8182 0.8162 0.8166 0.8164 0.815  0.8136
 0.8146 0.815  0.8154 0.8168 0.8166 0.817 ]
mean test roc auo ovr: [0.67450119 0.81068493 0.83056379 0.84053457 0.84677435 0.85123531
 0.85611494 0.85885322 0.86013596 0.86228766 0.8642671  0.86469286
 0.86677206 0.86792996 0.86792887 0.86780021 0.8679812  0.86832683
 0.86846639 0.86874659 0.86866809 0.86828213 0.86850673 0.86875641
 0.86919471 0.8692852 ]
mean test f1 micro: [0.759  0.788  0.8002 0.8042 0.8084 0.809  0.8138 0.815  0.8144 0.8154
 0.8152 0.8158 0.8166 0.8184 0.8182 0.8162 0.8166 0.8164 0.815  0.8136
 0.8146 0.815  0.8154 0.8168 0.8166 0.817 ]
    n_neighbors  score_acc
0            53     0.8184
1            57     0.8182
2           101     0.8170
3            93     0.8168
4            97     0.8166
5            65     0.8166
6            49     0.8166
7            69     0.8164
8            61     0.8162
9  

mean test accuracy: [0.768  0.8034 0.8142 0.813  0.8178 0.8208 0.8208 0.8222 0.8246 0.8256
 0.8234 0.8222 0.824  0.824  0.825  0.8254 0.8254 0.8256 0.8244 0.8262
 0.8236 0.8216 0.821  0.82   0.8202 0.8208]
mean test roc auo ovr: [0.69188279 0.81231269 0.83421727 0.84386437 0.85041763 0.85416787
 0.85806282 0.86064928 0.86325152 0.86477337 0.86643972 0.86623287
 0.86704646 0.86790336 0.86908916 0.8692964  0.8697256  0.87000179
 0.87026658 0.87114612 0.87105827 0.87107035 0.87157337 0.87179954
 0.87251587 0.87277031]
mean test f1 micro: [0.768  0.8034 0.8142 0.813  0.8178 0.8208 0.8208 0.8222 0.8246 0.8256
 0.8234 0.8222 0.824  0.824  0.825  0.8254 0.8254 0.8256 0.8244 0.8262
 0.8236 0.8216 0.821  0.82   0.8202 0.8208]
    n_neighbors  score_acc
0            77     0.8262
1            37     0.8256
2            69     0.8256
3            61     0.8254
4            65     0.8254
5            57     0.8250
6            33     0.8246
7            73     0.8244
8            53     0.8240
9  

mean test accuracy: [0.759  0.7814 0.795  0.802  0.8078 0.8074 0.807  0.8076 0.808  0.8084
 0.8118 0.8102 0.8096 0.8096 0.812  0.812  0.8126 0.8142 0.815  0.816
 0.8158 0.8162 0.816  0.8148 0.815  0.8156]
mean test roc auo ovr: [0.68974371 0.80669499 0.83135839 0.84227966 0.84990392 0.85438615
 0.85628956 0.85976738 0.860479   0.8615181  0.86294387 0.86319328
 0.8639939  0.86539417 0.86623118 0.86610739 0.86656452 0.86669534
 0.86721496 0.86724349 0.86715321 0.86739415 0.86778023 0.86818298
 0.86842117 0.8690283 ]
mean test f1 micro: [0.759  0.7814 0.795  0.802  0.8078 0.8074 0.807  0.8076 0.808  0.8084
 0.8118 0.8102 0.8096 0.8096 0.812  0.812  0.8126 0.8142 0.815  0.816
 0.8158 0.8162 0.816  0.8148 0.815  0.8156]
    n_neighbors  score_acc
0            85     0.8162
1            89     0.8160
2            77     0.8160
3            81     0.8158
4           101     0.8156
5            97     0.8150
6            73     0.8150
7            93     0.8148
8            69     0.8142
9    

mean test accuracy: [0.773  0.8002 0.8096 0.813  0.817  0.8184 0.8192 0.8232 0.8238 0.8244
 0.8266 0.8278 0.828  0.829  0.8288 0.828  0.8284 0.8292 0.8298 0.8284
 0.8266 0.8252 0.8274 0.827  0.8274 0.828 ]
mean test roc auo ovr: [0.68224419 0.80685557 0.83057164 0.84123021 0.84786237 0.85059365
 0.85360364 0.85697364 0.85896279 0.86165038 0.86286035 0.86447346
 0.86501374 0.86618334 0.866258   0.86667275 0.86706538 0.8678241
 0.86818466 0.8683572  0.86823277 0.86804917 0.86823941 0.86839314
 0.86868513 0.86924034]
mean test f1 micro: [0.773  0.8002 0.8096 0.813  0.817  0.8184 0.8192 0.8232 0.8238 0.8244
 0.8266 0.8278 0.828  0.829  0.8288 0.828  0.8284 0.8292 0.8298 0.8284
 0.8266 0.8252 0.8274 0.827  0.8274 0.828 ]
    n_neighbors  score_acc
0            73     0.8298
1            69     0.8292
2            53     0.8290
3            57     0.8288
4            77     0.8284
5            65     0.8284
6            61     0.8280
7            49     0.8280
8           101     0.8280
9   

mean test accuracy: [0.7774 0.799  0.8076 0.8186 0.8196 0.8176 0.8198 0.8202 0.8218 0.8194
 0.8258 0.825  0.824  0.8244 0.8242 0.8258 0.828  0.8294 0.8294 0.8302
 0.8328 0.8316 0.8312 0.8318 0.832  0.8336]
mean test roc auo ovr: [0.69934738 0.82094194 0.83975098 0.8528159  0.85689246 0.8592698
 0.8613361  0.86273213 0.86374516 0.86502038 0.86615053 0.86714318
 0.86790286 0.86850575 0.86928715 0.8699418  0.86977957 0.87054025
 0.87117691 0.87173533 0.87219011 0.87220497 0.8724287  0.87285588
 0.87323038 0.87371159]
mean test f1 micro: [0.7774 0.799  0.8076 0.8186 0.8196 0.8176 0.8198 0.8202 0.8218 0.8194
 0.8258 0.825  0.824  0.8244 0.8242 0.8258 0.828  0.8294 0.8294 0.8302
 0.8328 0.8316 0.8312 0.8318 0.832  0.8336]
    n_neighbors  score_acc
0           101     0.8336
1            81     0.8328
2            97     0.8320
3            93     0.8318
4            85     0.8316
5            89     0.8312
6            77     0.8302
7            73     0.8294
8            69     0.8294
9   

mean test accuracy: [0.7884 0.8168 0.8238 0.8272 0.8268 0.8294 0.8318 0.8302 0.8302 0.8328
 0.8344 0.8332 0.835  0.8348 0.8338 0.8326 0.833  0.833  0.834  0.8352
 0.8352 0.8348 0.835  0.8352 0.8318 0.8304]
mean test roc auo ovr: [0.71183537 0.83737152 0.85452121 0.86210263 0.86804038 0.87119926
 0.87439121 0.87553832 0.87702858 0.87810738 0.87993863 0.88132644
 0.88144841 0.88175145 0.88198673 0.88199107 0.88269581 0.88295602
 0.88307204 0.88279664 0.88301782 0.88306011 0.88293542 0.88276195
 0.88291699 0.8827858 ]
mean test f1 micro: [0.7884 0.8168 0.8238 0.8272 0.8268 0.8294 0.8318 0.8302 0.8302 0.8328
 0.8344 0.8332 0.835  0.8348 0.8338 0.8326 0.833  0.833  0.834  0.8352
 0.8352 0.8348 0.835  0.8352 0.8318 0.8304]
    n_neighbors  score_acc
0            77     0.8352
1            81     0.8352
2            93     0.8352
3            49     0.8350
4            89     0.8350
5            53     0.8348
6            85     0.8348
7            41     0.8344
8            73     0.8340
9  

mean test accuracy: [0.7766 0.8038 0.8116 0.812  0.816  0.8146 0.8186 0.8198 0.8214 0.822
 0.8246 0.8226 0.8242 0.825  0.8266 0.8256 0.8256 0.8268 0.8266 0.8268
 0.8266 0.8264 0.8254 0.826  0.8262 0.8266]
mean test roc auo ovr: [0.70469837 0.82344734 0.84628006 0.85711591 0.86357235 0.86762166
 0.87024875 0.87191908 0.87415405 0.8754248  0.8760786  0.87773774
 0.87817075 0.87901541 0.88004602 0.88022825 0.88035251 0.88113748
 0.88128    0.88145751 0.88167926 0.88186175 0.88211491 0.88263762
 0.88270336 0.88288521]
mean test f1 micro: [0.7766 0.8038 0.8116 0.812  0.816  0.8146 0.8186 0.8198 0.8214 0.822
 0.8246 0.8226 0.8242 0.825  0.8266 0.8256 0.8256 0.8268 0.8266 0.8268
 0.8266 0.8264 0.8254 0.826  0.8262 0.8266]
    n_neighbors  score_acc
0            69     0.8268
1            77     0.8268
2           101     0.8266
3            57     0.8266
4            81     0.8266
5            73     0.8266
6            85     0.8264
7            97     0.8262
8            93     0.8260
9    

mean test accuracy: [0.775  0.807  0.817  0.8208 0.824  0.8264 0.8296 0.8304 0.8314 0.8292
 0.8328 0.8324 0.8332 0.8336 0.8328 0.8336 0.834  0.834  0.8344 0.833
 0.8332 0.8302 0.8288 0.83   0.8296 0.83  ]
mean test roc auo ovr: [0.69144596 0.81335613 0.84205433 0.85257506 0.85835583 0.86122422
 0.86436422 0.86610164 0.86764772 0.86907339 0.87095926 0.87211058
 0.87297049 0.87386394 0.87449183 0.87469746 0.87446324 0.87472221
 0.87536824 0.87530336 0.87541442 0.87590761 0.87606541 0.87591421
 0.87674003 0.87660697]
mean test f1 micro: [0.775  0.807  0.817  0.8208 0.824  0.8264 0.8296 0.8304 0.8314 0.8292
 0.8328 0.8324 0.8332 0.8336 0.8328 0.8336 0.834  0.834  0.8344 0.833
 0.8332 0.8302 0.8288 0.83   0.8296 0.83  ]
    n_neighbors  score_acc
0            73     0.8344
1            69     0.8340
2            65     0.8340
3            53     0.8336
4            61     0.8336
5            49     0.8332
6            81     0.8332
7            77     0.8330
8            41     0.8328
9    

mean test accuracy: [0.783  0.8122 0.8152 0.8208 0.8228 0.8296 0.829  0.8288 0.8318 0.8316
 0.8326 0.8328 0.8322 0.8308 0.831  0.8306 0.8316 0.8306 0.831  0.8322
 0.8334 0.8314 0.8326 0.8324 0.8332 0.8332]
mean test roc auo ovr: [0.6963265  0.8273406  0.84746682 0.85466458 0.86110398 0.86437449
 0.86740329 0.86997743 0.8722099  0.87350514 0.87402493 0.87428265
 0.87539375 0.87560824 0.87658504 0.87706962 0.87741028 0.87723427
 0.87733566 0.87782629 0.87859915 0.87897591 0.8796255  0.88020116
 0.88080289 0.88115638]
mean test f1 micro: [0.783  0.8122 0.8152 0.8208 0.8228 0.8296 0.829  0.8288 0.8318 0.8316
 0.8326 0.8328 0.8322 0.8308 0.831  0.8306 0.8316 0.8306 0.831  0.8322
 0.8334 0.8314 0.8326 0.8324 0.8332 0.8332]
    n_neighbors  score_acc
0            81     0.8334
1           101     0.8332
2            97     0.8332
3            45     0.8328
4            41     0.8326
5            89     0.8326
6            93     0.8324
7            77     0.8322
8            49     0.8322
9  

mean test accuracy: [0.7874 0.815  0.822  0.8298 0.8324 0.8328 0.8314 0.8338 0.8316 0.833
 0.832  0.8328 0.8326 0.8324 0.8328 0.8328 0.8336 0.8322 0.8328 0.8332
 0.8322 0.8328 0.8326 0.8304 0.8298 0.8314]
mean test roc auo ovr: [0.71024    0.83164434 0.84958781 0.8599451  0.86472895 0.86881425
 0.87223894 0.87421981 0.87674451 0.87770689 0.87795915 0.87876348
 0.87982382 0.88043834 0.88068293 0.88051602 0.88052506 0.88041962
 0.88073799 0.88148349 0.88162042 0.88188602 0.88170223 0.88179918
 0.88172866 0.88156013]
mean test f1 micro: [0.7874 0.815  0.822  0.8298 0.8324 0.8328 0.8314 0.8338 0.8316 0.833
 0.832  0.8328 0.8326 0.8324 0.8328 0.8328 0.8336 0.8322 0.8328 0.8332
 0.8322 0.8328 0.8326 0.8304 0.8298 0.8314]
    n_neighbors  score_acc
0            29     0.8338
1            65     0.8336
2            77     0.8332
3            37     0.8330
4            85     0.8328
5            21     0.8328
6            61     0.8328
7            57     0.8328
8            73     0.8328
9    

mean test accuracy: [0.7946 0.817  0.8216 0.8252 0.8284 0.828  0.829  0.8328 0.8328 0.833
 0.8344 0.8354 0.8372 0.8374 0.8384 0.8386 0.8388 0.8388 0.8418 0.8404
 0.84   0.8374 0.8364 0.8368 0.8358 0.8348]
mean test roc auo ovr: [0.70696136 0.81594193 0.84003989 0.8489802  0.85574066 0.860041
 0.86188425 0.86352467 0.86520748 0.86703932 0.86734456 0.86818161
 0.86903975 0.86976358 0.87011901 0.87044835 0.87123331 0.87184975
 0.87157793 0.87156603 0.87165971 0.87203149 0.8721533  0.87244609
 0.87198669 0.87269629]
mean test f1 micro: [0.7946 0.817  0.8216 0.8252 0.8284 0.828  0.829  0.8328 0.8328 0.833
 0.8344 0.8354 0.8372 0.8374 0.8384 0.8386 0.8388 0.8388 0.8418 0.8404
 0.84   0.8374 0.8364 0.8368 0.8358 0.8348]
    n_neighbors  score_acc
0            73     0.8418
1            77     0.8404
2            81     0.8400
3            65     0.8388
4            69     0.8388
5            61     0.8386
6            57     0.8384
7            53     0.8374
8            85     0.8374
9      

mean test accuracy: [0.78   0.8114 0.8194 0.8226 0.8282 0.8282 0.8286 0.8314 0.8322 0.8282
 0.8284 0.8272 0.8286 0.828  0.8278 0.83   0.8284 0.828  0.8296 0.8288
 0.8288 0.8284 0.8282 0.828  0.8266 0.8282]
mean test roc auo ovr: [0.70070565 0.82532979 0.8413709  0.84975636 0.85573397 0.85831656
 0.86120827 0.86327233 0.86468029 0.86640044 0.86806542 0.86952473
 0.87033887 0.87117363 0.87211939 0.87252017 0.87278536 0.87300564
 0.87341937 0.87354769 0.87344422 0.87390611 0.87410134 0.87437551
 0.87481177 0.87516159]
mean test f1 micro: [0.78   0.8114 0.8194 0.8226 0.8282 0.8282 0.8286 0.8314 0.8322 0.8282
 0.8284 0.8272 0.8286 0.828  0.8278 0.83   0.8284 0.828  0.8296 0.8288
 0.8288 0.8284 0.8282 0.828  0.8266 0.8282]
    n_neighbors  score_acc
0            33     0.8322
1            29     0.8314
2            61     0.8300
3            73     0.8296
4            77     0.8288
5            81     0.8288
6            25     0.8286
7            49     0.8286
8            41     0.8284
9  

mean test accuracy: [0.7848 0.8104 0.8186 0.8232 0.8224 0.8274 0.825  0.8292 0.8324 0.8326
 0.833  0.8314 0.8318 0.8304 0.8306 0.8302 0.8322 0.8336 0.8316 0.8312
 0.8312 0.8298 0.8286 0.83   0.8314 0.8302]
mean test roc auo ovr: [0.69957585 0.82002975 0.84204644 0.85278852 0.85875541 0.86348027
 0.86814761 0.86919831 0.87099944 0.87223098 0.87357588 0.87543618
 0.87633868 0.87698569 0.87747234 0.87798331 0.87835493 0.87805962
 0.87793575 0.878074   0.87876304 0.8791225  0.87931383 0.87959697
 0.88028269 0.8804309 ]
mean test f1 micro: [0.7848 0.8104 0.8186 0.8232 0.8224 0.8274 0.825  0.8292 0.8324 0.8326
 0.833  0.8314 0.8318 0.8304 0.8306 0.8302 0.8322 0.8336 0.8316 0.8312
 0.8312 0.8298 0.8286 0.83   0.8314 0.8302]
    n_neighbors  score_acc
0            69     0.8336
1            41     0.8330
2            37     0.8326
3            33     0.8324
4            65     0.8322
5            49     0.8318
6            73     0.8316
7            97     0.8314
8            45     0.8314
9  

mean test accuracy: [0.7834 0.8084 0.814  0.8248 0.826  0.8282 0.8258 0.8296 0.8318 0.8302
 0.829  0.8306 0.8306 0.831  0.8318 0.8328 0.8342 0.8318 0.8328 0.8328
 0.8322 0.831  0.8308 0.83   0.829  0.8306]
mean test roc auo ovr: [0.70953947 0.82382018 0.84770724 0.85794298 0.86282237 0.86579441
 0.86805702 0.87020998 0.87218037 0.8737489  0.87531963 0.87623849
 0.87713487 0.878625   0.87969298 0.8802193  0.88095833 0.88130154
 0.88135746 0.88136952 0.88159978 0.8819364  0.88199013 0.88177412
 0.88206469 0.88229605]
mean test f1 micro: [0.7834 0.8084 0.814  0.8248 0.826  0.8282 0.8258 0.8296 0.8318 0.8302
 0.829  0.8306 0.8306 0.831  0.8318 0.8328 0.8342 0.8318 0.8328 0.8328
 0.8322 0.831  0.8308 0.83   0.829  0.8306]
    n_neighbors  score_acc
0            65     0.8342
1            77     0.8328
2            73     0.8328
3            61     0.8328
4            81     0.8322
5            57     0.8318
6            69     0.8318
7            33     0.8318
8            53     0.8310
9  

In [121]:
adult_knn_list

[array([0.7946, 0.817 , 0.8216, 0.8252, 0.8284, 0.828 , 0.829 , 0.8328,
        0.8328, 0.833 , 0.8344, 0.8354, 0.8372, 0.8374, 0.8384, 0.8386,
        0.8388, 0.8388, 0.8418, 0.8404, 0.84  , 0.8374, 0.8364, 0.8368,
        0.8358, 0.8348]),
 array([0.7946, 0.817 , 0.8216, 0.8252, 0.8284, 0.828 , 0.829 , 0.8328,
        0.8328, 0.833 , 0.8344, 0.8354, 0.8372, 0.8374, 0.8384, 0.8386,
        0.8388, 0.8388, 0.8418, 0.8404, 0.84  , 0.8374, 0.8364, 0.8368,
        0.8358, 0.8348]),
 array([0.7946, 0.817 , 0.8216, 0.8252, 0.8284, 0.828 , 0.829 , 0.8328,
        0.8328, 0.833 , 0.8344, 0.8354, 0.8372, 0.8374, 0.8384, 0.8386,
        0.8388, 0.8388, 0.8418, 0.8404, 0.84  , 0.8374, 0.8364, 0.8368,
        0.8358, 0.8348]),
 array([0.7946, 0.817 , 0.8216, 0.8252, 0.8284, 0.828 , 0.829 , 0.8328,
        0.8328, 0.833 , 0.8344, 0.8354, 0.8372, 0.8374, 0.8384, 0.8386,
        0.8388, 0.8388, 0.8418, 0.8404, 0.84  , 0.8374, 0.8364, 0.8368,
        0.8358, 0.8348]),
 array([0.7946, 0.817 , 0.8216, 

In [122]:
knn_grid.cv_results_

{'mean_fit_time': array([0.00580626, 0.00540075, 0.00539799, 0.00538783, 0.00559816,
        0.00559702, 0.00499182, 0.00540023, 0.005199  , 0.00539684,
        0.0049983 , 0.00519714, 0.00540328, 0.0053937 , 0.00560389,
        0.00539632, 0.00579691, 0.00520358, 0.00559864, 0.00519695,
        0.00540686, 0.00560017, 0.0053947 , 0.00579896, 0.00580039,
        0.00519996]),
 'std_fit_time': array([4.03121230e-04, 4.89999789e-04, 4.91683888e-04, 4.93860348e-04,
        4.88030674e-04, 4.87595259e-04, 7.38620604e-06, 4.85389461e-04,
        4.00523469e-04, 4.92442710e-04, 1.05354038e-05, 4.01844669e-04,
        4.93709512e-04, 4.88611604e-04, 4.91799628e-04, 4.87573733e-04,
        1.16570387e-03, 3.98027478e-04, 8.00540795e-04, 3.94392086e-04,
        4.90400787e-04, 4.89940803e-04, 4.95166725e-04, 3.99518109e-04,
        4.00257111e-04, 3.99685162e-04]),
 'mean_score_time': array([0.13479357, 0.19199357, 0.19859672, 0.19359665, 0.19899616,
        0.19539394, 0.19539647, 0.19578385, 

In [123]:
knn_grid.cv_results_['rank_test_accuracy']

array([26, 25, 24, 23, 21, 20, 22, 17,  6, 15, 18, 13, 13,  9,  6,  2,  1,
        6,  2,  2,  5,  9, 11, 16, 18, 12])

In [124]:
knn_grid.cv_results_['params'][ np.argmin(knn_grid.cv_results_['rank_test_accuracy'])]

{'n_neighbors': 65}

In [125]:
knn_grid.cv_results_['params'][ np.argmin(knn_grid.cv_results_['rank_test_roc_auc_ovr'])]

{'n_neighbors': 101}

In [126]:
knn_results = pd.DataFrame(knn_grid.cv_results_['params'])
knn_results['score_acc'] = knn_grid.cv_results_['mean_test_accuracy']
knn_cols = knn_results.columns.to_series().str.split('__').apply(lambda x: x[-1])
knn_results.columns = knn_cols
knn_results = knn_results.sort_values(by=['score_acc'], ascending=False, ignore_index=True)
knn_results

Unnamed: 0,n_neighbors,score_acc
0,65,0.8342
1,77,0.8328
2,73,0.8328
3,61,0.8328
4,81,0.8322
5,57,0.8318
6,69,0.8318
7,33,0.8318
8,53,0.831
9,85,0.831


In [127]:
X_train, X_test, y_train, y_test = train_test_split(adult_onehot, adult_salary, train_size=5000)
X_train = pd.DataFrame(X_train, columns=X_train.columns)
knn_best = KNeighborsClassifier(weights='distance', n_neighbors=101)
knn_best.fit(X_train, y_train)
pred_knn = knn_best.predict(X_test)
print('KNN:', classification_report(y_test, pred_knn))

KNN:               precision    recall  f1-score   support

           0       0.77      0.97      0.86     20942
           1       0.54      0.10      0.16      6619

    accuracy                           0.76     27561
   macro avg       0.66      0.54      0.51     27561
weighted avg       0.72      0.76      0.69     27561

