In [1]:
import xgboost as xgb

In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from sklearn.linear_model import LinearRegression
from sklearn.model_selection import cross_val_score, train_test_split, GridSearchCV
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.metrics import log_loss, make_scorer, confusion_matrix, mean_squared_error
from pandas.plotting import scatter_matrix
import pickle



In [3]:
nlp_df = pd.read_csv('../data/df_with_nlp.csv', index_col=0)
X = nlp_df
y = pd.read_csv("../data/work-balance-stars.csv", header=None, index_col=0).values

In [4]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=0)

### Linear Regression

In [47]:
lr = LinearRegression()
lr.fit(X_train, y_train)

LinearRegression(copy_X=True, fit_intercept=True, n_jobs=1, normalize=False)

In [48]:
y_pred = lr.predict(X_test)

### Random Forest

In [24]:
random_forest_grid = {'max_depth': [3, None],
                      'max_features': ['sqrt', 'log2', None],
                      'min_samples_split': [2, 4],
                      'min_samples_leaf': [1, 2, 4],
                      'bootstrap': [True, False],
                      'n_estimators': [100, 200, 500, 1000],
                      'random_state': [1]}

rf_gridsearch = GridSearchCV(RandomForestRegressor(),
                             random_forest_grid,
                             n_jobs=-1,
                             verbose=True,
                             scoring='neg_mean_squared_error')
rf_gridsearch.fit(X_train, y_train)
print( "best parameters:", rf_gridsearch.best_params_ )

best_rf_model = rf_gridsearch.best_estimator_


Fitting 3 folds for each of 288 candidates, totalling 864 fits


[Parallel(n_jobs=-1)]: Done  42 tasks      | elapsed:   28.6s
[Parallel(n_jobs=-1)]: Done 192 tasks      | elapsed:  2.8min
[Parallel(n_jobs=-1)]: Done 442 tasks      | elapsed: 19.5min
[Parallel(n_jobs=-1)]: Done 792 tasks      | elapsed: 32.3min
[Parallel(n_jobs=-1)]: Done 864 out of 864 | elapsed: 39.9min finished


best parameters: {'bootstrap': True, 'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 4, 'min_samples_split': 2, 'n_estimators': 1000, 'random_state': 1}


In [27]:
best_params = {'bootstrap': True, 'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 4, 'min_samples_split': 2, 'n_estimators': 1000, 'random_state': 1}
rf = RandomForestRegressor(**best_params)
rf.fit(X_train, y_train)

RandomForestRegressor(bootstrap=True, criterion='mse', max_depth=None,
           max_features='sqrt', max_leaf_nodes=None,
           min_impurity_decrease=0.0, min_impurity_split=None,
           min_samples_leaf=4, min_samples_split=2,
           min_weight_fraction_leaf=0.0, n_estimators=1000, n_jobs=1,
           oob_score=False, random_state=1, verbose=0, warm_start=False)

In [28]:
y_pred = rf.predict(X_test)

In [33]:
mse = mean_squared_error(y_test, y_pred)
rmse = mse ** 0.5
print(rmse)

0.93692045248394


### Gradient Boosted Regressor

In [29]:
xgb_model = xgb.XGBRegressor(learning_rate=0.001, n_estimators=500)
xgb_model.fit(X_train, y_train, verbose=1)

XGBRegressor(base_score=0.5, booster='gbtree', colsample_bylevel=1,
       colsample_bytree=1, gamma=0, learning_rate=0.001, max_delta_step=0,
       max_depth=3, min_child_weight=1, missing=None, n_estimators=500,
       n_jobs=1, nthread=None, objective='reg:linear', random_state=0,
       reg_alpha=0, reg_lambda=1, scale_pos_weight=1, seed=None,
       silent=True, subsample=1)

In [30]:
y_pred = xgb_model.predict(X_test)

In [22]:
mse = mean_squared_error(y_test, y_pred)
rmse = mse ** 0.5
print(rmse)
print(mse)

0.9114942727682114
0.8308218092892504


In [31]:
mse = mean_squared_error(y_test, y_pred)
rmse = mse ** 0.5
print(rmse)
print(mse)

1.8751634744708616
3.516238055989634


In [55]:
with open('gradient_boosting_regressor.pkl', 'rb') as f:
    gbr = pickle.load(f)

In [6]:
param_grid = {'max_depth': [3],
              'learning_rate': [0.001],
              'n_estimators': [100, 200, 500, 1000],
              'subsample': [0.5]
              }

gbr_gridsearch = GridSearchCV(xgb.XGBRegressor(),
                             param_grid,
                             n_jobs=-1,
                             verbose=1,
                             scoring='neg_mean_squared_error',
                             cv=3)
gbr_gridsearch.fit(X_train, y_train)
print( "best parameters:", gbr_gridsearch.best_params_ )

best_gbr_model = gbr_gridsearch.best_estimator_


Fitting 3 folds for each of 4 candidates, totalling 12 fits


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=-1)]: Done  12 out of  12 | elapsed: 22.3min finished


best parameters: {'learning_rate': 0.001, 'max_depth': 3, 'n_estimators': 1000, 'subsample': 0.5}


In [32]:
best_y_preds = best_gbr_model.predict(X_test)

In [13]:
for i in range(len(y_test)):
    print(y_test[i], best_y_preds[i])

[5.] 2.9516249
[1.] 1.1488051
[2.] 1.4640856
[3.] 2.5122452
[3.] 2.0222082
[5.] 2.7473016
[2.] 1.9760691
[4.] 2.1103764
[5.] 2.7684991
[4.] 2.7671347
[4.] 1.5686439
[1.] 1.508537
[4.] 2.2262857
[2.] 1.9613713
[2.] 2.367868
[4.] 2.2679985
[5.] 2.528296
[3.01421128] 2.1060553
[3.] 2.7370663
[2.] 1.6304716
[1.] 2.0101292
[2.] 1.5508903
[4.] 2.7684991
[4.] 2.7411284
[5.] 2.0987933
[3.] 2.0041642
[5.] 2.7682157
[5.] 2.7629485
[3.] 1.7002053
[2.] 2.377018
[4.] 2.3601887
[1.] 1.9771448
[2.] 1.3641977
[3.] 2.0987933
[1.] 1.3457191
[5.] 2.9510858
[3.] 1.9552867
[3.] 2.9507926
[3.] 2.4809613
[3.] 2.1104875
[3.] 1.0696102
[3.] 2.3810003
[4.] 2.4591398
[4.] 2.3412294
[2.] 1.9360484
[5.] 2.951091
[2.] 1.7036954
[2.] 1.5029162
[3.] 2.1129847
[2.] 1.8932636
[3.] 2.5418596
[3.] 2.4708507
[2.] 1.6043792
[5.] 2.9287715
[5.] 2.9519181
[1.] 2.3348825
[1.] 1.5432932
[5.] 2.9516249
[3.] 1.6569396
[1.] 1.128564
[2.] 1.8783702
[4.] 2.3586557
[5.] 2.5303862
[3.] 2.3558822
[4.] 2.4567962
[3.] 1.1322036
[4.] 2.3

[3.] 2.3725295
[1.] 1.2977821
[5.] 2.9507926
[4.] 2.3027673
[1.] 1.5931545
[3.] 1.1100941
[2.] 2.7640436
[3.] 2.0128236
[5.] 2.9516249
[3.01421128] 2.1091537
[4.] 2.3699298
[3.] 2.1214886
[2.] 2.0257454
[1.] 1.1353703
[1.] 1.7454473
[5.] 2.7671444
[3.01421128] 2.098068
[4.] 2.7676973
[3.] 2.3849344
[3.01421128] 2.0997982
[4.] 2.5306795
[2.] 2.6636
[5.] 2.9516249
[2.] 2.4591343
[4.] 1.5506501
[4.] 2.1228898
[3.] 2.4561484
[1.] 1.6280279
[3.] 1.7111233
[3.01421128] 2.0997982
[1.] 1.1159943
[1.] 1.823864
[5.] 2.9516249
[4.] 1.977917
[4.] 2.4511828
[3.] 1.4640856
[2.] 1.7063247
[3.01421128] 2.0997982
[1.] 1.3939397
[4.] 2.9516249
[3.] 1.9093343
[2.] 2.7637503
[3.] 2.3280473
[4.] 2.3679385
[3.01421128] 2.0992794
[1.] 1.7944177
[5.] 2.9510858
[3.] 1.9782305
[4.] 2.9287715
[2.] 1.9719878
[4.] 2.367868
[4.] 2.3699298
[1.] 1.1436868
[1.] 1.6057159
[1.] 1.4190682
[3.01421128] 2.042943
[3.] 1.6102256
[5.] 2.7623858
[4.] 2.073103
[3.01421128] 2.098068
[5.] 2.9519181
[3.] 1.9413984
[4.] 2.3679385
[

[3.] 2.480668
[3.01421128] 2.0997982
[5.] 2.7462835
[3.01421128] 2.0751023
[1.] 1.1132526
[3.] 1.9489028
[3.01421128] 2.0997982
[1.] 1.1181104
[3.] 1.7121722
[3.] 1.9800699
[4.] 2.7632418
[3.01421128] 2.042943
[1.] 1.0632181
[1.] 1.1077735
[5.] 1.6761
[3.] 1.4061131
[3.01421128] 2.0997982
[3.] 2.1071138
[2.] 1.9555712
[3.01421128] 1.9788082
[2.] 1.3647522
[2.] 1.9033846
[3.01421128] 2.0992794
[2.] 2.0131073
[5.] 2.660505
[3.] 2.3801131
[2.] 2.1090696
[3.] 1.6280279
[2.] 1.7352898
[3.01421128] 2.0997982
[3.01421128] 2.0997982
[3.] 1.956946
[3.] 1.6445534
[1.5] 1.7960374
[3.] 2.256295
[4.] 2.3021784
[4.] 2.298873
[5.] 2.9516249
[3.] 1.8825969
[3.] 2.0741465
[3.] 1.9801428
[4.] 2.5284326
[4.] 1.1835201
[1.] 1.079705
[3.] 2.1060076
[1.] 1.1211705
[3.01421128] 2.0997982
[3.] 1.7793727
[4.] 1.9689809
[4.] 2.3578732
[5.] 2.2936358
[5.] 2.2568288
[3.01421128] 2.0997982
[3.] 2.9519181
[1.] 2.3849344
[5.] 2.6593199
[3.01421128] 2.0671148
[4.] 2.3542786
[3.] 2.473333
[2.] 1.9507498
[1.] 1.566615


[5.] 2.1194794
[4.] 2.3699298
[3.] 2.3067598
[3.] 2.3704677
[3.5] 2.2291598
[3.01421128] 2.0997982
[3.01421128] 2.0997982
[1.] 1.1377627
[5.] 2.541038
[3.01421128] 2.0992794
[2.] 1.4623784
[3.] 2.1184764
[4.] 2.377018
[3.] 1.7041085
[1.] 1.9860191
[2.] 1.7111233
[5.] 2.3128552
[3.01421128] 2.0939445
[3.01421128] 2.098068
[3.01421128] 2.0997982
[1.] 2.0957065
[2.] 2.1052997
[4.] 1.7095419
[4.] 2.7640436
[2.] 1.1142837
[3.01421128] 2.098068
[4.] 2.9519181
[2.] 1.1478214
[4.] 2.5418596
[3.] 1.9698272
[3.5] 2.362015
[5.] 1.7111233
[4.] 1.9757075
[1.] 1.8639255
[2.] 1.4983352
[3.01421128] 2.1060553
[5.] 2.5303862
[2.] 2.9282377
[2.] 1.712293
[1.] 1.1128212
[4.] 2.030282
[2.] 2.480668
[2.] 1.5718349
[4.] 2.385222
[1.] 1.0755303
[1.] 1.690418
[3.01421128] 2.0671148
[3.] 1.9618387
[4.] 1.8849208
[3.] 1.9598689
[5.] 1.9733659
[4.] 2.7684991
[1.] 1.1809568
[4.] 1.6645951
[3.] 2.36209
[4.] 2.3849344
[4.] 1.9636302
[4.] 2.00642
[3.01421128] 2.098068
[2.] 2.0737686
[3.] 1.1356559
[4.] 2.2405958
[4.

[1.] 1.1347493
[1.] 1.1809568
[3.01421128] 2.098068
[3.] 1.8229734
[2.] 1.1099834
[1.] 1.4171294
[5.] 2.951091
[4.] 2.9281492
[3.] 2.0112965
[1.] 1.9869887
[4.] 2.9507926
[2.] 2.0164762
[2.] 2.121127
[4.] 2.3423905
[2.] 1.1456122
[1.] 1.1425266
[3.] 1.8108174
[2.] 1.8839802
[4.] 2.7687924
[5.] 2.7621024
[5.] 2.9308772
[5.] 2.951091
[1.] 1.1380949
[2.] 1.5934913
[3.] 1.5500333
[4.] 2.1074228
[5.] 2.9270236
[2.] 1.5022591
[3.] 1.1258118
[1.] 1.4599226
[1.] 1.3437004
[3.01421128] 2.0997982
[4.] 2.377018
[4.] 2.7676973
[3.] 2.480668
[4.] 1.9439365
[3.] 2.4679036
[5.] 2.4722104
[4.5] 2.5188375
[4.] 1.6905456
[4.] 2.742847
[5.] 1.5522958
[4.] 2.25987
[5.] 2.2926798
[5.] 2.071135
[2.] 2.1104875
[3.] 1.7076302
[2.] 1.1214368
[4.] 2.369646
[1.] 1.1089122
[1.] 1.8032719
[1.] 1.1271416
[1.] 1.079705
[4.] 1.6814404
[1.] 1.6848838
[4.] 2.7681139
[3.] 1.1825411
[2.] 1.5064985
[4.] 2.3334978
[4.5] 2.619613
[1.] 1.1797415
[2.] 2.3310044
[5.] 2.9516249
[3.] 2.381288
[3.] 1.7103335
[3.] 2.469963
[5.] 2.

[1.] 1.5938553
[5.] 2.5425978
[3.01421128] 2.0992794
[4.] 2.1796932
[2.] 2.0954196
[3.01421128] 2.0997982
[5.] 1.7373464
[4.] 1.082064
[4.] 2.4917192
[1.] 1.7382276
[4.] 2.2115788
[2.] 1.7104999
[4.] 2.0970252
[4.] 2.5418596
[1.5] 1.8386022
[2.] 2.4963107
[4.] 2.5290513
[3.01421128] 1.9782896
[3.01421128] 2.0997982
[3.] 1.5797672
[1.] 1.707741
[4.] 2.4810205
[3.01421128] 2.0997982
[3.01421128] 2.0997982
[2.] 1.7288626
[3.] 1.406406
[3.] 1.9594841
[4.] 2.122355
[3.] 2.7674139
[1.] 1.5480148
[2.] 2.3285007
[2.] 1.9093343
[2.] 2.0851364
[3.01421128] 2.0997982
[1.] 1.3106587
[1.] 1.1314814
[4.] 2.3682222
[2.] 1.1317091
[3.01421128] 2.0997982
[3.] 2.9241872
[4.] 2.7679906
[2.] 2.4579527
[4.] 2.5418596
[4.] 1.7076302
[5.] 1.1790195
[3.] 2.4781055
[4.] 2.1211262
[4.] 2.3679385
[2.] 2.7623858
[3.01421128] 2.1091537
[3.] 1.744097
[5.] 2.3725295
[1.] 1.7092853
[2.] 1.5500605
[2.] 1.9982069
[4.] 2.4981117
[3.] 1.976142
[4.] 2.474769
[1.] 2.1227164
[3.01421128] 2.0997982
[4.] 2.745593
[1.] 1.21615

[2.5] 1.7474055
[5.] 2.9519181
[3.5] 2.4302492
[4.] 1.9851637
[4.] 2.366279
[1.] 1.8290358
[1.] 1.1729102
[5.] 1.8269364
[4.] 2.4351556
[3.01421128] 2.098068
[4.] 1.9780624
[3.01421128] 1.9950148
[4.] 2.0317888
[2.] 2.0970252
[3.] 1.7056706
[5.] 2.3067598
[5.] 2.7676973
[4.] 2.4686046
[3.] 2.0369716
[3.] 1.8051054
[3.] 1.9561265
[5.] 2.1094317
[5.] 2.951091
[2.] 2.3844197
[3.] 2.3576088
[3.01421128] 2.0997982
[3.] 1.9471401
[3.] 1.5505878
[4.] 1.941622
[3.01421128] 2.0997982
[3.01421128] 2.0997982
[4.] 2.2519066
[5.] 2.122561
[4.] 2.3683968
[1.] 1.1832633
[3.] 1.9636309
[3.] 2.2404737
[3.] 1.698967
[2.] 1.1375059
[3.] 2.1177526
[5.] 2.9287715
[1.] 1.1385591
[5.] 2.7667806
[3.] 2.469963
[2.] 2.3553863
[2.] 1.184987
[4.] 2.7670739
[2.] 1.0632181
[3.] 1.6715095
[3.] 2.369646
[2.] 1.7076457
[3.01421128] 2.098068
[3.01421128] 2.0997982
[3.] 1.8272288
[1.] 1.6635933
[1.] 1.8423989
[3.01421128] 2.0997982
[3.01421128] 2.0997982
[1.] 1.1314814
[3.01421128] 2.0997982
[2.] 1.5717316
[1.] 1.683476

[3.] 1.9773617
[3.01421128] 2.0992794
[3.] 2.0162232
[3.] 2.3658216
[3.01421128] 2.0997982
[1.] 1.1368423
[2.] 1.6136495
[3.] 2.3449745
[5.] 2.9516249
[1.] 1.3196766
[5.] 1.9947596
[4.] 2.4693854
[4.] 2.5068886
[3.] 1.9723078
[3.] 2.0879664
[3.] 1.951234
[1.] 1.1132526
[3.] 1.9352955
[3.] 2.384047
[3.] 1.9719878
[4.] 1.9979542
[1.] 1.079705
[3.01421128] 2.0653849
[4.5] 2.5383396
[1.] 1.1226963
[3.] 2.011964
[4.] 2.1094317
[3.] 1.7095419
[1.] 1.1706048
[1.] 2.1254659
[3.] 2.198392
[3.] 2.0268621
[3.] 1.9720366
[2.] 1.5853121
[4.] 1.9755394
[1.] 1.103924
[3.] 2.0129147
[5.] 1.5543
[2.] 1.5476985
[2.] 2.070773
[3.] 2.9516249
[2.] 2.3452582
[4.] 2.353807
[4.] 1.9711901
[3.01421128] 2.0997982
[4.] 2.5418596
[1.] 1.5073467
[4.] 1.999678
[4.] 2.7677743
[4.] 2.385222
[5.] 1.138227
[1.] 1.1250468
[2.] 1.545647
[2.] 2.1024122
[4.] 1.7114712
[2.] 2.3810003
[5.] 2.9507926
[2.] 1.9686068
[3.] 1.9618387
[4.] 2.369646
[3.] 2.7408915
[3.] 2.5361755
[4.] 2.3716898
[3.01421128] 2.0997982
[5.] 2.9516249


[3.] 2.951091
[3.01421128] 2.0653849
[2.] 1.7111233
[1.] 1.173779
[4.] 2.3810003
[4.] 2.3682222
[2.] 2.367727
[5.] 2.951091
[4.] 2.4608161
[5.] 2.2227626
[2.] 1.7103156
[3.01421128] 2.0997982
[3.] 1.1941679
[4.] 2.1104875
[3.] 1.5479962
[2.] 1.0845368
[4.] 2.098068
[1.] 1.9993982
[1.] 1.5901967
[3.] 2.3348825
[1.] 1.4683099
[4.] 2.7671347
[4.] 1.7124288
[2.] 1.6679564
[1.] 1.3428581
[2.] 1.4065151
[5.] 2.9513843
[5.] 2.9516249
[4.] 2.9516249
[5.] 2.1103756
[4.] 2.298873
[2.] 1.8539927
[2.] 1.5954362
[4.] 2.6636
[2.] 1.980238
[3.] 2.3576088
[5.] 2.9510858
[1.] 1.5529616
[1.] 1.4005997
[3.01421128] 2.0997982
[1.] 1.1622405
[3.] 2.3564339
[1.] 1.0592678
[1.] 1.6260769
[4.] 1.9923223
[3.01421128] 2.1091537
[5.] 2.9519181
[4.] 2.9279392
[1.] 1.1270205
[3.] 2.3314943
[3.01421128] 1.9818767
[5.] 2.9516249
[3.] 2.2568288
[4.] 2.9516249
[4.] 2.9519181
[1.] 1.064894
[1.] 1.1432486
[2.] 1.3120711
[4.] 2.7622938
[3.] 2.9516249
[1.] 1.1232384
[5.] 2.1264706
[5.] 2.9519181
[4.] 2.1050708
[5.] 2.1250

[3.01421128] 2.0997982
[3.01421128] 2.0997982
[1.] 1.5428503
[4.] 1.6713753
[4.] 2.1164885
[3.] 1.9888935
[2.] 1.4626384
[4.] 1.9021738
[3.01421128] 2.7427452
[1.] 1.1311201
[3.] 2.369646
[3.] 1.1421944
[2.] 1.7043651
[3.] 2.2627525
[5.] 2.0223095
[4.] 2.3463833
[4.] 2.385222
[2.] 1.4145522
[2.] 1.5528111
[3.01421128] 2.0653849
[3.01421128] 2.0997982
[3.] 2.5398521
[3.01421128] 2.0997982
[4.] 2.9270236
[2.] 2.1254659
[5.] 1.9594841
[3.] 1.4054112
[3.] 2.1197205
[3.01421128] 2.042943
[1.] 1.1835201
[3.01421128] 2.0992794
[4.] 2.1227171
[2.] 1.4586798
[2.] 2.7469828
[2.] 1.9777001
[2.] 1.1322448
[4.] 1.9762633
[3.01421128] 2.0665965
[2.] 1.4071829
[4.] 1.6602932
[5.] 2.9510858
[1.] 1.133559
[5.] 2.7431304
[5.] 2.9519181
[3.] 1.6696784
[4.] 2.6599903
[3.] 1.9986931
[3.] 1.9099265
[3.] 2.5418596
[4.] 2.4795785
[5.] 1.4556262
[2.] 2.3844197
[4.5] 2.3514915
[3.] 1.7097012
[4.] 2.5295646
[3.01421128] 2.0671148
[1.] 1.5163143
[3.] 2.0917215
[3.] 1.4021603
[5.] 2.742837
[3.] 2.4795785
[5.] 2.95

[5.] 2.9516249
[4.] 2.3699298
[1.] 1.1411455
[3.] 2.0875273
[3.] 1.7111913
[3.] 2.4675813
[2.] 1.5521642
[5.] 2.9519181
[3.] 2.0721905
[3.] 2.7412455
[2.] 1.7129811
[1.] 2.0246148
[3.] 2.7684991
[3.01421128] 2.0997982
[2.5] 2.0295467
[1.] 1.6374618
[4.] 1.7104223
[1.5] 1.9867531
[5.] 2.9516249
[3.01421128] 2.098068
[2.] 1.9618387
[3.01421128] 2.0997982
[3.] 1.8932636
[3.01421128] 2.0997982
[3.] 2.3558822
[1.] 1.4968417
[3.01421128] 2.0997982
[2.] 2.3797865
[2.] 1.5480148
[5.] 1.9719878
[2.] 1.9557358
[3.] 2.1599398
[1.] 1.712994
[5.] 2.1843967
[3.] 2.7420352
[4.] 2.9507926
[3.01421128] 2.0991437
[3.] 2.095419
[4.] 2.5293193
[1.] 1.5923202
[1.] 1.8102037
[2.] 2.3797865
[3.] 2.3679013
[1.] 1.0540767
[3.01421128] 2.0997982
[3.01421128] 2.0997982
[3.] 1.9290763
[3.01421128] 2.0997982
[3.] 1.8169129
[3.] 1.9780624
[3.] 2.7467933
[2.] 1.9697213
[3.] 2.7682157
[5.] 2.9519181
[3.01421128] 1.993803
[5.] 2.5421531
[2.] 1.0632181
[1.] 1.1039906
[1.] 1.5517383
[5.] 2.9516249
[3.] 2.021655
[3.] 2.4

[1.] 1.4005997
[2.] 1.663966
[3.] 1.5631634
[5.] 2.3584814
[5.] 2.9516249
[4.] 2.107039
[5.] 2.6408494
[3.01421128] 2.0997982
[1.] 1.4818873
[5.] 2.9287715
[2.] 2.4523394
[2.] 1.3647522
[1.] 2.9270236
[3.] 1.5068322
[2.] 1.403749
[2.] 1.6317891
[4.] 2.9519181
[1.] 1.3315544
[4.] 1.9786959
[4.] 2.33114
[5.] 2.9287715
[3.] 1.9976516
[5.] 2.951091
[1.] 1.1036673
[3.01421128] 2.0992794
[3.] 2.5421531
[3.] 1.5506335
[2.] 1.6115429
[2.] 2.4675813
[3.] 2.742837
[4.] 2.5239465
[4.] 1.5646325
[5.] 1.7837678
[3.] 2.6636
[3.] 2.4586086
[2.] 1.5506335
[3.] 1.1325016
[5.] 2.4418437
[2.] 1.8049436
[5.] 2.9507926
[3.] 2.0540285
[5.] 2.951091
[1.] 1.1214368
[1.] 1.5923436
[4.] 2.5174496
[2.] 2.381288
[3.] 2.0131073
[4.] 2.3699918
[1.] 1.2738372
[3.01421128] 2.0997982
[3.] 2.3612888
[1.] 1.3990033
[4.] 2.028719
[5.] 2.9516249
[3.01421128] 2.0992794
[5.] 2.9516249
[3.] 1.628578
[1.] 1.1701396
[5.] 2.951091
[2.] 1.6986054
[3.] 2.7671444
[3.] 2.0967503
[3.] 1.9465848
[5.] 2.9519181
[4.] 2.3835325
[3.] 2.4

[3.] 2.1090696
[3.] 2.1211262
[2.] 1.969149
[1.] 1.88447
[5.] 2.9519181
[1.] 2.2510357
[4.] 2.9516249
[3.] 2.5232809
[1.] 1.4020104
[3.01421128] 2.0992794
[3.] 1.8871609
[3.] 1.9618387
[4.] 2.385222
[2.] 1.3659768
[5.] 2.6482863
[4.] 1.9708229
[5.] 2.2959795
[1.] 1.5764121
[5.] 2.9516249
[1.] 1.978271
[2.] 1.1162052
[5.] 2.9519181
[3.] 2.1214838
[4.] 2.767428
[2.] 2.3668346
[4.] 2.4597113
[4.] 2.9287715
[1.] 1.279996
[5.] 2.1151416
[1.] 2.122454
[4.] 2.3000557
[4.] 2.7666793
[2.5] 2.0003242
[4.] 2.7687924
[1.] 1.6039362
[4.] 2.5418596
[2.] 2.2713006
[4.] 2.5398521
[4.] 1.5797672
[5.] 2.951091
[5.] 2.9510858
[1.] 1.9708444
[2.] 2.1074762
[5.] 2.329267
[2.] 1.2535124
[2.] 1.4901788
[2.] 1.7084627
[1.] 1.5082034
[4.] 2.0616465
[1.] 1.2915823
[4.] 2.3521733
[5.] 2.9516249
[4.] 2.370822
[3.] 2.9507926
[4.] 2.1126132
[4.] 1.6381931
[5.] 2.1250026
[2.] 2.9519181
[4.] 2.9516249
[5.] 2.5418596
[1.] 1.4048567
[3.5] 1.9831296
[1.] 1.122953
[4.] 1.5797672
[5.] 2.9516249
[3.] 2.2726936
[3.01421128]

In [11]:
mse = mean_squared_error(y_test, best_y_preds)
rmse = mse ** 0.5
print(mse)
print(rmse)

1.8738464762673295
1.368885121647295


In [33]:
with open('models/best_gradient_boosting_regressor.pkl', 'wb') as f:
    # Write the model to a file.
    pickle.dump(best_gbr_model, f)

In [None]:
with open('model/gradient_boosting_regressor.pkl', 'wb') as f:
    # Write the model to a file.
    pickle.dump(gbr, f)

In [73]:
y_pred = gbr.predict(X_test)

In [75]:
mse = mean_squared_error(y_test, y_pred)
rmse = mse ** 0.5
print(mse)
print(rmse)

0.8921436054582954
0.9445335385566229


### Later

In [None]:
# with open('model.pkl', 'rb') as f:
#     model = pickle.load(f)

In [None]:
_ = scatter_matrix(amazon_df, alpha=0.2, figsize=(10, 10), diagonal='kde')

In [None]:


fig, ax = plt.subplots(figsize=(14, 3))

plot_univariate_smooth(
    ax, 
    non_zero_bal_df["limit"].values.reshape(-1, 1),
    non_zero_bal_df['balance'],
    bootstrap=100)

ax.set_title("Univariate Effect of Credit Limit on Non-zero Bank Balance")
ax.set_ylabel("Non-zero Balance")
ax.set_xlabel("Limit")

In [None]:
limit_fit = Pipeline([
    ('limit', ColumnSelector(name='limit')),
    ('limit_spline', LinearSpline(knots=[2500, 6000, 7000]))
])