In [1]:
import xgboost as xgb

In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from sklearn.linear_model import LinearRegression
from sklearn.model_selection import cross_val_score, train_test_split, GridSearchCV
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.metrics import log_loss, make_scorer, confusion_matrix, mean_squared_error
from pandas.plotting import scatter_matrix
import pickle



In [33]:
nlp_df = pd.read_csv('../data/df_with_nlp.csv', index_col=0)
X = nlp_df
y = pd.read_csv("../data/work-balance-stars.csv", header=None, index_col=0).values

In [34]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=0)

### Linear Regression

In [47]:
lr = LinearRegression()
lr.fit(X_train, y_train)

LinearRegression(copy_X=True, fit_intercept=True, n_jobs=1, normalize=False)

In [48]:
y_pred = lr.predict(X_test)

### Random Forest

In [24]:
random_forest_grid = {'max_depth': [3, None],
                      'max_features': ['sqrt', 'log2', None],
                      'min_samples_split': [2, 4],
                      'min_samples_leaf': [1, 2, 4],
                      'bootstrap': [True, False],
                      'n_estimators': [100, 200, 500, 1000],
                      'random_state': [1]}

rf_gridsearch = GridSearchCV(RandomForestRegressor(),
                             random_forest_grid,
                             n_jobs=-1,
                             verbose=True,
                             scoring='neg_mean_squared_error')
rf_gridsearch.fit(X_train, y_train)
print( "best parameters:", rf_gridsearch.best_params_ )

best_rf_model = rf_gridsearch.best_estimator_


Fitting 3 folds for each of 288 candidates, totalling 864 fits


[Parallel(n_jobs=-1)]: Done  42 tasks      | elapsed:   28.6s
[Parallel(n_jobs=-1)]: Done 192 tasks      | elapsed:  2.8min
[Parallel(n_jobs=-1)]: Done 442 tasks      | elapsed: 19.5min
[Parallel(n_jobs=-1)]: Done 792 tasks      | elapsed: 32.3min
[Parallel(n_jobs=-1)]: Done 864 out of 864 | elapsed: 39.9min finished


best parameters: {'bootstrap': True, 'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 4, 'min_samples_split': 2, 'n_estimators': 1000, 'random_state': 1}


In [27]:
best_params = {'bootstrap': True, 'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 4, 'min_samples_split': 2, 'n_estimators': 1000, 'random_state': 1}
rf = RandomForestRegressor(**best_params)
rf.fit(X_train, y_train)

RandomForestRegressor(bootstrap=True, criterion='mse', max_depth=None,
           max_features='sqrt', max_leaf_nodes=None,
           min_impurity_decrease=0.0, min_impurity_split=None,
           min_samples_leaf=4, min_samples_split=2,
           min_weight_fraction_leaf=0.0, n_estimators=1000, n_jobs=1,
           oob_score=False, random_state=1, verbose=0, warm_start=False)

In [28]:
y_pred = rf.predict(X_test)

In [33]:
mse = mean_squared_error(y_test, y_pred)
rmse = mse ** 0.5
print(rmse)

0.93692045248394


### Gradient Boosted Regressor

In [33]:
xgb_model = xgb.XGBRegressor(learning_rate=0.01, reg_lambda=1, n_estimators=1000)
xgb_model.fit(X_train, y_train, verbose=1)

XGBRegressor(base_score=0.5, booster='gbtree', colsample_bylevel=1,
       colsample_bytree=1, gamma=0, learning_rate=0.01, max_delta_step=0,
       max_depth=3, min_child_weight=1, missing=None, n_estimators=1000,
       n_jobs=1, nthread=None, objective='reg:linear', random_state=0,
       reg_alpha=0, reg_lambda=1, scale_pos_weight=1, seed=None,
       silent=True, subsample=1)

In [34]:
y_pred = xgb_model.predict(X_test)

In [27]:
mse = mean_squared_error(y_test, y_pred)
rmse = mse ** 0.5
print(rmse)
print(mse)

0.9115343258296804
0.8308948271657699


In [35]:
mse = mean_squared_error(y_test, y_pred)
rmse = mse ** 0.5
print(rmse)
print(mse)

0.8723116190562487
0.7609275607405339


In [55]:
with open('gradient_boosting_regressor.pkl', 'rb') as f:
    gbr = pickle.load(f)

In [6]:
param_grid = {'max_depth': [3],
              'learning_rate': [0.01],
              'n_estimators': [100, 200, 500, 1000],
              'subsample': [0.5]
              }

gbr_gridsearch = GridSearchCV(xgb.XGBRegressor(),
                             param_grid,
                             n_jobs=-1,
                             verbose=1,
                             scoring='neg_mean_squared_error',
                             cv=3)
gbr_gridsearch.fit(X_train, y_train)
print( "best parameters:", gbr_gridsearch.best_params_ )

best_gbr_model = gbr_gridsearch.best_estimator_


Fitting 3 folds for each of 4 candidates, totalling 12 fits


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=-1)]: Done  12 out of  12 | elapsed: 22.3min finished


best parameters: {'learning_rate': 0.001, 'max_depth': 3, 'n_estimators': 1000, 'subsample': 0.5}


In [32]:
best_y_preds = best_gbr_model.predict(X_test)

In [23]:
for i in range(len(y_test)):
    print(y_test[i], best_y_preds[i])

NameError: name 'best_y_preds' is not defined

In [11]:
mse = mean_squared_error(y_test, best_y_preds)
rmse = mse ** 0.5
print(mse)
print(rmse)

1.8738464762673295
1.368885121647295


In [33]:
with open('models/best_gradient_boosting_regressor.pkl', 'wb') as f:
    # Write the model to a file.
    pickle.dump(best_gbr_model, f)

In [30]:
with open('models/gradient_boosting_regressor.pkl', 'wb') as f:
    # Write the model to a file.
    pickle.dump(xgb_model, f)

In [73]:
y_pred = gbr.predict(X_test)

In [75]:
mse = mean_squared_error(y_test, y_pred)
rmse = mse ** 0.5
print(mse)
print(rmse)

0.8921436054582954
0.9445335385566229


### RNN

In [1]:
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import LSTM

Using Theano backend.


In [7]:
timesteps=1

In [40]:
X_train = X_train[['culture-values-stars', 'career-opportunities-stars', 'comp-benefit-stars', 'senior-management-stars', 
                   'helpful-count', 'is_current_employee', 'year', 'quarter', 'amazon_earnings_this_quarter','timesteps']]
X_test = X_test[['culture-values-stars', 'career-opportunities-stars', 'comp-benefit-stars', 'senior-management-stars', 
                   'helpful-count', 'is_current_employee', 'year', 'quarter', 'amazon_earnings_this_quarter','timesteps']]

In [41]:
X_train = X_train.values.reshape(X_train.values.shape[0], timesteps, X_train.values.shape[1])
X_test = X_test.values.reshape(X_test.values.shape[0], timesteps, X_test.values.shape[1])

In [50]:
neurons=50
model = Sequential()
model.add(LSTM(neurons, input_shape=(X_train.shape[1], X_train.shape[2])))
model.add(layers.Dense(1, activation='relu'))

model.add(layers.Dense(1)
model.compile(loss='mean_squared_error', optimizer='adam')

In [54]:
model.fit(X_train, y_train, epochs=5, batch_size=72, validation_data=(X_test, y_test), verbose=2, shuffle=False)

Train on 17706 samples, validate on 8722 samples
Epoch 1/5
 - 417s - loss: 1.6047 - val_loss: 1.6099
Epoch 2/5
 - 405s - loss: 1.6047 - val_loss: 1.6099
Epoch 3/5
 - 408s - loss: 1.6047 - val_loss: 1.6099
Epoch 4/5
 - 468s - loss: 1.6047 - val_loss: 1.6099
Epoch 5/5
 - 530s - loss: 1.6047 - val_loss: 1.6099


<keras.callbacks.History at 0x1819a7d0668>

In [55]:
y_pred_rnn = model.predict(X_test)

In [56]:
mse_rnn = mean_squared_error(y_test, y_pred_rnn)
rmse_rnn = mse_rnn ** 0.5
print(mse_rnn)
print(rmse_rnn)

1.609875584791323
1.2688087266374404


In [58]:
for i in range(len(y_test)):
    print(y_test[i], y_pred_rnn[i])

[5.] [3.0139868]
[1.] [3.0139868]
[2.] [3.0139868]
[3.] [3.0139868]
[3.] [3.0139868]
[5.] [3.0139868]
[2.] [3.0139868]
[4.] [3.0139868]
[5.] [3.0139868]
[4.] [3.0139868]
[4.] [3.0139868]
[1.] [3.0139868]
[4.] [3.0139868]
[2.] [3.0139868]
[2.] [3.0139868]
[4.] [3.0139868]
[5.] [3.0139868]
[3.01421128] [3.0139868]
[3.] [3.0139868]
[2.] [3.0139868]
[1.] [3.0139868]
[2.] [3.0139868]
[4.] [3.0139868]
[4.] [3.0139868]
[5.] [3.0139868]
[3.] [3.0139868]
[5.] [3.0139868]
[5.] [3.0139868]
[3.] [3.0139868]
[2.] [3.0139868]
[4.] [3.0139868]
[1.] [3.0139868]
[2.] [3.0139868]
[3.] [3.0139868]
[1.] [3.0139868]
[5.] [3.0139868]
[3.] [3.0139868]
[3.] [3.0139868]
[3.] [3.0139868]
[3.] [3.0139868]
[3.] [3.0139868]
[3.] [3.0139868]
[4.] [3.0139868]
[4.] [3.0139868]
[2.] [3.0139868]
[5.] [3.0139868]
[2.] [3.0139868]
[2.] [3.0139868]
[3.] [3.0139868]
[2.] [3.0139868]
[3.] [3.0139868]
[3.] [3.0139868]
[2.] [3.0139868]
[5.] [3.0139868]
[5.] [3.0139868]
[1.] [3.0139868]
[1.] [3.0139868]
[5.] [3.0139868]
[3.] [

[4.] [3.0139868]
[3.] [3.0139868]
[5.] [3.0139868]
[4.] [3.0139868]
[4.] [3.0139868]
[5.] [3.0139868]
[3.01421128] [3.0139868]
[1.] [3.0139868]
[1.] [3.0139868]
[5.] [3.0139868]
[4.] [3.0139868]
[5.] [3.0139868]
[1.] [3.0139868]
[3.01421128] [3.0139868]
[3.] [3.0139868]
[1.] [3.0139868]
[1.] [3.0139868]
[2.] [3.0139868]
[1.] [3.0139868]
[2.] [3.0139868]
[3.01421128] [3.0139868]
[2.] [3.0139868]
[4.] [3.0139868]
[5.] [3.0139868]
[3.] [3.0139868]
[3.] [3.0139868]
[4.] [3.0139868]
[1.] [3.0139868]
[4.] [3.0139868]
[2.] [3.0139868]
[3.] [3.0139868]
[5.] [3.0139868]
[3.] [3.0139868]
[2.] [3.0139868]
[4.] [3.0139868]
[4.] [3.0139868]
[1.] [3.0139868]
[4.] [3.0139868]
[5.] [3.0139868]
[2.] [3.0139868]
[3.] [3.0139868]
[4.] [3.0139868]
[2.] [3.0139868]
[5.] [3.0139868]
[3.] [3.0139868]
[3.01421128] [3.0139868]
[4.] [3.0139868]
[3.01421128] [3.0139868]
[5.] [3.0139868]
[5.] [3.0139868]
[3.01421128] [3.0139868]
[3.01421128] [3.0139868]
[1.] [3.0139868]
[3.] [3.0139868]
[4.] [3.0139868]
[3.5] [3.

[4.] [3.0139868]
[3.01421128] [3.0139868]
[4.] [3.0139868]
[3.01421128] [3.0139868]
[2.] [3.0139868]
[5.] [3.0139868]
[2.] [3.0139868]
[3.01421128] [3.0139868]
[3.] [3.0139868]
[3.] [3.0139868]
[5.] [3.0139868]
[3.01421128] [3.0139868]
[3.5] [3.0139868]
[3.01421128] [3.0139868]
[2.] [3.0139868]
[5.] [3.0139868]
[3.] [3.0139868]
[4.] [3.0139868]
[3.] [3.0139868]
[3.] [3.0139868]
[1.] [3.0139868]
[4.] [3.0139868]
[5.] [3.0139868]
[1.] [3.0139868]
[2.] [3.0139868]
[5.] [3.0139868]
[3.01421128] [3.0139868]
[3.] [3.0139868]
[5.] [3.0139868]
[3.] [3.0139868]
[3.] [3.0139868]
[1.] [3.0139868]
[3.] [3.0139868]
[4.] [3.0139868]
[2.] [3.0139868]
[4.] [3.0139868]
[4.] [3.0139868]
[1.] [3.0139868]
[1.] [3.0139868]
[3.] [3.0139868]
[5.] [3.0139868]
[1.] [3.0139868]
[3.] [3.0139868]
[5.] [3.0139868]
[3.] [3.0139868]
[2.] [3.0139868]
[2.] [3.0139868]
[5.] [3.0139868]
[4.] [3.0139868]
[4.] [3.0139868]
[2.] [3.0139868]
[4.] [3.0139868]
[1.] [3.0139868]
[4.] [3.0139868]
[2.] [3.0139868]
[2.] [3.0139868]

[1.] [3.0139868]
[3.] [3.0139868]
[3.] [3.0139868]
[4.] [3.0139868]
[1.] [3.0139868]
[4.] [3.0139868]
[5.] [3.0139868]
[2.] [3.0139868]
[1.] [3.0139868]
[4.] [3.0139868]
[1.] [3.0139868]
[1.] [3.0139868]
[3.] [3.0139868]
[3.] [3.0139868]
[3.] [3.0139868]
[3.] [3.0139868]
[5.] [3.0139868]
[3.01421128] [3.0139868]
[1.] [3.0139868]
[5.] [3.0139868]
[2.] [3.0139868]
[3.] [3.0139868]
[5.] [3.0139868]
[3.01421128] [3.0139868]
[3.01421128] [3.0139868]
[4.] [3.0139868]
[2.] [3.0139868]
[3.01421128] [3.0139868]
[3.] [3.0139868]
[1.] [3.0139868]
[3.] [3.0139868]
[3.] [3.0139868]
[4.] [3.0139868]
[5.] [3.0139868]
[3.5] [3.0139868]
[2.] [3.0139868]
[2.] [3.0139868]
[3.] [3.0139868]
[5.] [3.0139868]
[4.] [3.0139868]
[1.] [3.0139868]
[2.] [3.0139868]
[2.] [3.0139868]
[2.] [3.0139868]
[3.] [3.0139868]
[5.] [3.0139868]
[3.] [3.0139868]
[3.] [3.0139868]
[2.] [3.0139868]
[1.] [3.0139868]
[4.] [3.0139868]
[5.] [3.0139868]
[3.] [3.0139868]
[4.] [3.0139868]
[3.] [3.0139868]
[2.5] [3.0139868]
[1.] [3.013986

[1.] [3.0139868]
[4.] [3.0139868]
[3.] [3.0139868]
[5.] [3.0139868]
[3.] [3.0139868]
[2.] [3.0139868]
[3.01421128] [3.0139868]
[3.] [3.0139868]
[4.] [3.0139868]
[4.] [3.0139868]
[4.] [3.0139868]
[3.] [3.0139868]
[4.] [3.0139868]
[3.01421128] [3.0139868]
[3.01421128] [3.0139868]
[2.] [3.0139868]
[5.] [3.0139868]
[3.01421128] [3.0139868]
[2.] [3.0139868]
[3.] [3.0139868]
[1.] [3.0139868]
[1.] [3.0139868]
[1.] [3.0139868]
[3.01421128] [3.0139868]
[3.] [3.0139868]
[2.] [3.0139868]
[1.] [3.0139868]
[5.] [3.0139868]
[4.] [3.0139868]
[3.] [3.0139868]
[1.] [3.0139868]
[4.] [3.0139868]
[2.] [3.0139868]
[2.] [3.0139868]
[4.] [3.0139868]
[2.] [3.0139868]
[1.] [3.0139868]
[3.] [3.0139868]
[2.] [3.0139868]
[4.] [3.0139868]
[5.] [3.0139868]
[5.] [3.0139868]
[5.] [3.0139868]
[1.] [3.0139868]
[2.] [3.0139868]
[3.] [3.0139868]
[4.] [3.0139868]
[5.] [3.0139868]
[2.] [3.0139868]
[3.] [3.0139868]
[1.] [3.0139868]
[1.] [3.0139868]
[3.01421128] [3.0139868]
[4.] [3.0139868]
[4.] [3.0139868]
[3.] [3.0139868]


[1.] [3.0139868]
[5.] [3.0139868]
[1.] [3.0139868]
[1.] [3.0139868]
[5.] [3.0139868]
[1.] [3.0139868]
[3.01421128] [3.0139868]
[2.] [3.0139868]
[3.] [3.0139868]
[3.] [3.0139868]
[4.] [3.0139868]
[3.] [3.0139868]
[3.01421128] [3.0139868]
[2.] [3.0139868]
[2.] [3.0139868]
[5.] [3.0139868]
[2.] [3.0139868]
[4.] [3.0139868]
[3.01421128] [3.0139868]
[4.] [3.0139868]
[2.] [3.0139868]
[3.] [3.0139868]
[3.01421128] [3.0139868]
[4.] [3.0139868]
[4.] [3.0139868]
[2.] [3.0139868]
[2.] [3.0139868]
[3.] [3.0139868]
[5.] [3.0139868]
[1.] [3.0139868]
[4.] [3.0139868]
[5.] [3.0139868]
[1.] [3.0139868]
[3.01421128] [3.0139868]
[3.] [3.0139868]
[3.] [3.0139868]
[4.] [3.0139868]
[5.] [3.0139868]
[4.] [3.0139868]
[1.] [3.0139868]
[3.] [3.0139868]
[5.] [3.0139868]
[3.] [3.0139868]
[3.01421128] [3.0139868]
[5.] [3.0139868]
[4.] [3.0139868]
[4.] [3.0139868]
[5.] [3.0139868]
[1.] [3.0139868]
[3.] [3.0139868]
[3.01421128] [3.0139868]
[5.] [3.0139868]
[3.01421128] [3.0139868]
[4.] [3.0139868]
[4.] [3.0139868]
[

[4.] [3.0139868]
[1.] [3.0139868]
[3.] [3.0139868]
[3.] [3.0139868]
[5.] [3.0139868]
[1.] [3.0139868]
[2.] [3.0139868]
[4.] [3.0139868]
[3.] [3.0139868]
[5.] [3.0139868]
[3.01421128] [3.0139868]
[4.] [3.0139868]
[3.01421128] [3.0139868]
[3.] [3.0139868]
[3.01421128] [3.0139868]
[3.] [3.0139868]
[1.] [3.0139868]
[4.] [3.0139868]
[4.] [3.0139868]
[3.] [3.0139868]
[3.] [3.0139868]
[3.] [3.0139868]
[4.] [3.0139868]
[2.] [3.0139868]
[3.] [3.0139868]
[1.] [3.0139868]
[2.] [3.0139868]
[3.] [3.0139868]
[4.] [3.0139868]
[2.] [3.0139868]
[3.] [3.0139868]
[2.] [3.0139868]
[4.] [3.0139868]
[3.01421128] [3.0139868]
[3.] [3.0139868]
[3.] [3.0139868]
[5.] [3.0139868]
[3.] [3.0139868]
[5.] [3.0139868]
[3.] [3.0139868]
[3.01421128] [3.0139868]
[2.] [3.0139868]
[5.] [3.0139868]
[2.] [3.0139868]
[5.] [3.0139868]
[4.] [3.0139868]
[2.] [3.0139868]
[5.] [3.0139868]
[3.] [3.0139868]
[1.5] [3.0139868]
[1.] [3.0139868]
[3.] [3.0139868]
[5.] [3.0139868]
[4.] [3.0139868]
[3.01421128] [3.0139868]
[2.] [3.0139868]

[1.] [3.0139868]
[4.] [3.0139868]
[1.] [3.0139868]
[4.] [3.0139868]
[2.] [3.0139868]
[3.] [3.0139868]
[1.] [3.0139868]
[4.] [3.0139868]
[5.] [3.0139868]
[1.] [3.0139868]
[3.01421128] [3.0139868]
[1.] [3.0139868]
[5.] [3.0139868]
[4.] [3.0139868]
[2.] [3.0139868]
[3.] [3.0139868]
[4.] [3.0139868]
[2.] [3.0139868]
[5.] [3.0139868]
[4.] [3.0139868]
[3.01421128] [3.0139868]
[4.] [3.0139868]
[3.] [3.0139868]
[1.] [3.0139868]
[4.] [3.0139868]
[5.] [3.0139868]
[3.] [3.0139868]
[4.] [3.0139868]
[5.] [3.0139868]
[3.] [3.0139868]
[3.] [3.0139868]
[1.] [3.0139868]
[2.] [3.0139868]
[3.01421128] [3.0139868]
[3.] [3.0139868]
[1.] [3.0139868]
[2.] [3.0139868]
[5.] [3.0139868]
[1.] [3.0139868]
[4.] [3.0139868]
[3.01421128] [3.0139868]
[1.] [3.0139868]
[5.] [3.0139868]
[4.] [3.0139868]
[4.] [3.0139868]
[3.01421128] [3.0139868]
[1.] [3.0139868]
[5.] [3.0139868]
[4.] [3.0139868]
[1.] [3.0139868]
[4.] [3.0139868]
[4.] [3.0139868]
[5.] [3.0139868]
[4.] [3.0139868]
[1.] [3.0139868]
[1.] [3.0139868]
[3.01421

[1.] [3.0139868]
[5.] [3.0139868]
[3.01421128] [3.0139868]
[3.] [3.0139868]
[4.] [3.0139868]
[3.] [3.0139868]
[1.] [3.0139868]
[4.] [3.0139868]
[3.] [3.0139868]
[3.] [3.0139868]
[3.] [3.0139868]
[3.] [3.0139868]
[3.] [3.0139868]
[1.] [3.0139868]
[5.] [3.0139868]
[4.] [3.0139868]
[3.] [3.0139868]
[3.] [3.0139868]
[1.] [3.0139868]
[3.] [3.0139868]
[1.] [3.0139868]
[3.01421128] [3.0139868]
[2.] [3.0139868]
[4.] [3.0139868]
[4.] [3.0139868]
[3.] [3.0139868]
[4.] [3.0139868]
[1.] [3.0139868]
[4.] [3.0139868]
[1.] [3.0139868]
[3.] [3.0139868]
[2.] [3.0139868]
[4.] [3.0139868]
[2.] [3.0139868]
[3.01421128] [3.0139868]
[3.01421128] [3.0139868]
[5.] [3.0139868]
[1.] [3.0139868]
[5.] [3.0139868]
[5.] [3.0139868]
[3.01421128] [3.0139868]
[4.] [3.0139868]
[3.] [3.0139868]
[4.] [3.0139868]
[3.01421128] [3.0139868]
[2.] [3.0139868]
[1.] [3.0139868]
[1.] [3.0139868]
[2.] [3.0139868]
[2.] [3.0139868]
[3.01421128] [3.0139868]
[4.] [3.0139868]
[3.] [3.0139868]
[3.01421128] [3.0139868]
[3.] [3.0139868]
[

[1.] [3.0139868]
[1.] [3.0139868]
[5.] [3.0139868]
[4.] [3.0139868]
[4.] [3.0139868]
[1.] [3.0139868]
[1.] [3.0139868]
[2.] [3.0139868]
[5.] [3.0139868]
[2.] [3.0139868]
[5.] [3.0139868]
[1.] [3.0139868]
[3.] [3.0139868]
[2.] [3.0139868]
[1.] [3.0139868]
[5.] [3.0139868]
[2.] [3.0139868]
[2.] [3.0139868]
[3.] [3.0139868]
[1.] [3.0139868]
[3.01421128] [3.0139868]
[5.] [3.0139868]
[3.] [3.0139868]
[5.] [3.0139868]
[1.] [3.0139868]
[3.01421128] [3.0139868]
[3.] [3.0139868]
[3.] [3.0139868]
[2.] [3.0139868]
[4.] [3.0139868]
[3.01421128] [3.0139868]
[3.01421128] [3.0139868]
[4.] [3.0139868]
[1.] [3.0139868]
[3.01421128] [3.0139868]
[1.] [3.0139868]
[3.] [3.0139868]
[4.] [3.0139868]
[5.] [3.0139868]
[3.01421128] [3.0139868]
[5.] [3.0139868]
[3.01421128] [3.0139868]
[3.01421128] [3.0139868]
[3.5] [3.0139868]
[1.] [3.0139868]
[4.] [3.0139868]
[5.] [3.0139868]
[3.] [3.0139868]
[4.] [3.0139868]
[3.] [3.0139868]
[3.01421128] [3.0139868]
[2.] [3.0139868]
[5.] [3.0139868]
[2.] [3.0139868]
[1.] [3.0

[2.] [3.0139868]
[5.] [3.0139868]
[5.] [3.0139868]
[4.] [3.0139868]
[1.] [3.0139868]
[3.] [3.0139868]
[3.] [3.0139868]
[3.] [3.0139868]
[2.] [3.0139868]
[5.] [3.0139868]
[3.] [3.0139868]
[3.] [3.0139868]
[2.] [3.0139868]
[1.] [3.0139868]
[3.] [3.0139868]
[3.01421128] [3.0139868]
[2.5] [3.0139868]
[1.] [3.0139868]
[4.] [3.0139868]
[1.5] [3.0139868]
[5.] [3.0139868]
[3.01421128] [3.0139868]
[2.] [3.0139868]
[3.01421128] [3.0139868]
[3.] [3.0139868]
[3.01421128] [3.0139868]
[3.] [3.0139868]
[1.] [3.0139868]
[3.01421128] [3.0139868]
[2.] [3.0139868]
[2.] [3.0139868]
[5.] [3.0139868]
[2.] [3.0139868]
[3.] [3.0139868]
[1.] [3.0139868]
[5.] [3.0139868]
[3.] [3.0139868]
[4.] [3.0139868]
[3.01421128] [3.0139868]
[3.] [3.0139868]
[4.] [3.0139868]
[1.] [3.0139868]
[1.] [3.0139868]
[2.] [3.0139868]
[3.] [3.0139868]
[1.] [3.0139868]
[3.01421128] [3.0139868]
[3.01421128] [3.0139868]
[3.] [3.0139868]
[3.01421128] [3.0139868]
[3.] [3.0139868]
[3.] [3.0139868]
[3.] [3.0139868]
[2.] [3.0139868]
[3.] [3.

[4.] [3.0139868]
[4.] [3.0139868]
[3.01421128] [3.0139868]
[3.01421128] [3.0139868]
[3.01421128] [3.0139868]
[1.] [3.0139868]
[5.] [3.0139868]
[4.] [3.0139868]
[3.] [3.0139868]
[3.01421128] [3.0139868]
[5.] [3.0139868]
[2.] [3.0139868]
[5.] [3.0139868]
[1.] [3.0139868]
[1.] [3.0139868]
[4.] [3.0139868]
[5.] [3.0139868]
[3.] [3.0139868]
[3.] [3.0139868]
[5.] [3.0139868]
[2.] [3.0139868]
[3.] [3.0139868]
[3.01421128] [3.0139868]
[3.] [3.0139868]
[1.] [3.0139868]
[5.] [3.0139868]
[4.] [3.0139868]
[3.01421128] [3.0139868]
[1.] [3.0139868]
[1.] [3.0139868]
[1.] [3.0139868]
[2.] [3.0139868]
[4.] [3.0139868]
[5.] [3.0139868]
[3.] [3.0139868]
[5.] [3.0139868]
[3.] [3.0139868]
[3.] [3.0139868]
[5.] [3.0139868]
[1.] [3.0139868]
[3.] [3.0139868]
[4.] [3.0139868]
[1.] [3.0139868]
[3.] [3.0139868]
[3.01421128] [3.0139868]
[4.] [3.0139868]
[1.] [3.0139868]
[4.] [3.0139868]
[4.5] [3.0139868]
[1.] [3.0139868]
[2.] [3.0139868]
[1.] [3.0139868]
[2.] [3.0139868]
[4.] [3.0139868]
[4.] [3.0139868]
[3.] [3.

[3.] [3.0139868]
[2.] [3.0139868]
[4.] [3.0139868]
[4.] [3.0139868]
[4.] [3.0139868]
[3.] [3.0139868]
[3.] [3.0139868]
[3.] [3.0139868]
[4.] [3.0139868]
[1.] [3.0139868]
[2.] [3.0139868]
[1.] [3.0139868]
[1.] [3.0139868]
[3.01421128] [3.0139868]
[4.] [3.0139868]
[3.] [3.0139868]
[2.] [3.0139868]
[4.] [3.0139868]
[1.] [3.0139868]
[1.] [3.0139868]
[3.] [3.0139868]
[3.] [3.0139868]
[3.] [3.0139868]
[2.] [3.0139868]
[1.] [3.0139868]
[5.] [3.0139868]
[1.] [3.0139868]
[4.] [3.0139868]
[3.] [3.0139868]
[1.] [3.0139868]
[3.01421128] [3.0139868]
[3.] [3.0139868]
[3.] [3.0139868]
[4.] [3.0139868]
[2.] [3.0139868]
[5.] [3.0139868]
[4.] [3.0139868]
[5.] [3.0139868]
[1.] [3.0139868]
[5.] [3.0139868]
[1.] [3.0139868]
[2.] [3.0139868]
[5.] [3.0139868]
[3.] [3.0139868]
[4.] [3.0139868]
[2.] [3.0139868]
[4.] [3.0139868]
[4.] [3.0139868]
[1.] [3.0139868]
[5.] [3.0139868]
[1.] [3.0139868]
[4.] [3.0139868]
[4.] [3.0139868]
[2.5] [3.0139868]
[4.] [3.0139868]
[1.] [3.0139868]
[4.] [3.0139868]
[2.] [3.013986

In [57]:
with open('models/rnn_model.pkl', 'wb') as f:
    # Write the model to a file.
    pickle.dump(model, f)

In [53]:
mse_rnn = mean_squared_error(y_test, y_pred_rnn)
rmse_rnn = mse_rnn ** 0.5
print(mse_rnn)
print(rmse_rnn)

1.6099188849161883
1.2688257898215138


### Later

In [None]:
# with open('model.pkl', 'rb') as f:
#     model = pickle.load(f)

In [None]:
_ = scatter_matrix(amazon_df, alpha=0.2, figsize=(10, 10), diagonal='kde')

In [None]:


fig, ax = plt.subplots(figsize=(14, 3))

plot_univariate_smooth(
    ax, 
    non_zero_bal_df["limit"].values.reshape(-1, 1),
    non_zero_bal_df['balance'],
    bootstrap=100)

ax.set_title("Univariate Effect of Credit Limit on Non-zero Bank Balance")
ax.set_ylabel("Non-zero Balance")
ax.set_xlabel("Limit")

In [None]:
limit_fit = Pipeline([
    ('limit', ColumnSelector(name='limit')),
    ('limit_spline', LinearSpline(knots=[2500, 6000, 7000]))
])