Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reset Keras models in bias_variance_decomp after each bootstrap round #746

Closed
rasbt opened this issue Nov 5, 2020 · 6 comments · Fixed by #748
Closed

Reset Keras models in bias_variance_decomp after each bootstrap round #746

rasbt opened this issue Nov 5, 2020 · 6 comments · Fixed by #748

Comments

@rasbt
Copy link
Owner

rasbt commented Nov 5, 2020

Following up on #725, I think that the model needs to be reset after each bootstrap round.

Consider the following:

from mlxtend.evaluate import bias_variance_decomp
from mlxtend.data import boston_housing_data
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import keras
import tensorflow as tf


X, y = boston_housing_data()
X_train, X_test, y_train, y_test = train_test_split(X, y,
                                                    test_size=0.3,
                                                    random_state=123,
                                                    shuffle=True)



model = keras.Sequential([
    keras.layers.Dense(32, activation=tf.nn.relu),
    keras.layers.Dense(1)
  ])

optimizer = tf.keras.optimizers.Adam()
model.compile(loss='mean_squared_error', optimizer=optimizer)

model.fit(X_train, y_train, epochs=100)

mean_squared_error(model.predict(X_test), y_test)

results in

Epoch 1/100
12/12 [==============================] - 0s 563us/step - loss: 3791.4114
Epoch 2/100
12/12 [==============================] - 0s 557us/step - loss: 540.7360
Epoch 3/100
12/12 [==============================] - 0s 512us/step - loss: 563.6880
Epoch 4/100
12/12 [==============================] - 0s 629us/step - loss: 396.0307
...
Epoch 100/100
12/12 [==============================] - 0s 451us/step - loss: 39.4736
[2]:
41.76526825141232

Now, the running the bias_var_decomposition:

avg_expected_loss, avg_bias, avg_var = bias_variance_decomp(
        model, X_train, y_train, X_test, y_test, 
        loss='mse',
        random_seed=123)

yields

12/12 [==============================] - 0s 517us/step - loss: 41.6021
12/12 [==============================] - 0s 518us/step - loss: 40.0839
12/12 [==============================] - 0s 626us/step - loss: 38.7366
12/12 [==============================] - 0s 656us/step - loss: 36.4204
12/12 [==============================] - 0s 624us/step - loss: 38.7534
12/12 [==============================] - 0s 629us/step - loss: 44.3270
...
12/12 [==============================] - 0s 628us/step - loss: 27.2055
12/12 [==============================] - 0s 617us/step - loss: 27.8539

you can see that the loss starts at 41.6, which indicates that the fitted model is reused.

Btw. the results are

avg_expected_loss, avg_bias, avg_var
(32.459421052631576, 29.777634046052633, 2.6817870065789475)

which looks okay. However, the issue is that it seems like the Keras model is not reset after each round in

https://github.com/rasbt/mlxtend/blob/master/mlxtend/evaluate/bias_variance_decomp.py#L78

I think in the code above

    for i in range(num_rounds):
        X_boot, y_boot = _draw_bootstrap_sample(rng, X_train, y_train)
        if estimator.__class__.__name__ == 'Sequential':
            estimator.fit(X_boot, y_boot)
            pred = estimator.predict(X_test).reshape(1, -1)
        else:
            pred = estimator.fit(X_boot, y_boot).predict(X_test)
        all_pred[i] = pred

needs to be changed to something like

    for i in range(num_rounds):
        X_boot, y_boot = _draw_bootstrap_sample(rng, X_train, y_train)
        if estimator.__class__.__name__ == 'Sequential':
            # reset / reinitialize estimator!!! 
            estimator.fit(X_boot, y_boot)
            pred = estimator.predict(X_test).reshape(1, -1)
        else:
            pred = estimator.fit(X_boot, y_boot).predict(X_test)
        all_pred[i] = pred

What do you think @hanzigs ?

@hanzigs
Copy link
Contributor

hanzigs commented Nov 6, 2020

Hi @rasbt ,
Thanks for the Question, Its a valid one,
I had this doubt before, even for the sklearn estimator, we have the model fitted with the data, and we are refitting it with the same data,
Either model reset is required or we should try predict directly from model
also we can use

if estimator.__class__.__name__ in ['Sequential','Functional']:

to adapt functional keras as well
Thanks

@rasbt
Copy link
Owner Author

rasbt commented Nov 9, 2020

Thanks for the comment, @hanzigs. It's probably best to (1) reset the model (this would be more expected), (2) Make a note about that in the documentation, (3) allow fit parameters for keras to set the number of epochs for refitting the reset models, and (4) make the modification you describe.

I can do that in a separate PR.

@hanzigs
Copy link
Contributor

hanzigs commented Nov 10, 2020

Hi @rasbt
Just to mention this, I tried commenting out the fit part, it worked well, we don't have to refit or reinitiate the model, do you want to try of this way

    for i in range(num_rounds):
        #X_boot, y_boot = _draw_bootstrap_sample(rng, X_train, y_train)
        if estimator.__class__.__name__ in ['Sequential','Functional']:
            #estimator.fit(X_boot, y_boot, verbose = 1)
            pred = estimator.predict(X_test).reshape(1,-1)
        else:    
            pred = estimator.predict(X_test)
            #pred = estimator.fit(X_boot, y_boot).predict(X_test)
        all_pred[i] = pred

Thanks

@rasbt
Copy link
Owner Author

rasbt commented Nov 10, 2020

Technically, this would work, but the problem with that would be that all the predictions would then be the same in each round.

@hanzigs
Copy link
Contributor

hanzigs commented Nov 10, 2020

You are right, then we should reinitiate the model

@rasbt
Copy link
Owner Author

rasbt commented Nov 10, 2020

Yeah. I have implemented it in #748 along with a fit_params option to pass on the number of epochs for the keras fit method. It seems to work well

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants