Skip to content
This repository has been archived by the owner on Aug 31, 2021. It is now read-only.

Add support for validation sets #85

Closed
untom opened this issue Jan 19, 2016 · 6 comments
Closed

Add support for validation sets #85

untom opened this issue Jan 19, 2016 · 6 comments

Comments

@untom
Copy link

untom commented Jan 19, 2016

It would be nice if skflow had some support for validation sets to be used for early stopping and monitor validation set loss during training. This could be realized failry easily by adding a fraction_validationset to the TensorFlowEstimator. Within fit, the given training set could then be split into two parts.

@ilblackdragon
Copy link
Contributor

The validation set support is a good idea. One concern I have is to how to fit with sklearn interface when a user wants to pass his specific validation set. Because usually you split your dataset 3 ways - train, validation and test and use validation for hyperparameter search.

@makseq
Copy link
Contributor

makseq commented Feb 3, 2016

I made it as delegate validation function passing through fit(..., cross_valid_fn=my_func()).
my_func(): return calc_rmse(test_data)
It's very useful because user can adjust his my_func() as he wants.

@dansbecker
Copy link
Contributor

I was going to take a look at this (though I'm still learning my way around skflow).

I see an early_stopping_rounds argument in the estimators (TensorFlowEstimator and its derived classes), and the argument is passed to the TensorFlowTrainer, which appears to implement early stopping logic.

Is the current early stopping logic different from what's suggested in this issue? I'll pursue this issue further if I can understand how it differs from the current early stopping.

@ilblackdragon
Copy link
Contributor

@dansbecker Thanks for taking a look! Right now early stopping is done on training loss - e.g. if training converged, model stops before a number of required steps.

On the other hand, using validation set is another option. But so far, in examples we were implementing it this way: https://github.com/tensorflow/skflow/blob/master/examples/resnet.py#L148
So one way I was thinking this issue can be addressed is by making something like
skflow.train(estimator, X_train, y_train, X_valid, y_valid, metric?) that does this loop and also does stopping if validation metric stops improving.

@dansbecker
Copy link
Contributor

Thanks @ilblackdragon. That makes sense.

On the issue of using the same data for early stopping and for hyperparameter search: Three options come to mind.

  1. Let the user pass in two sets of data to the fit method. One set for training the network, the second (which @untom calls validation in this issue) for determining when to stop training. As I think you mentioned, this raises the question of whether the same validation data can also being used for hyperparameter search.

    It feels reasonable if I think of the number of steps as a hyperparameter, in which case we'd expect the same data to determine number of steps as is used to determine the other hyperparameters.

    However, I think it is least consistent with the sklearn interface.

    Incidentally, this is the approach that keras uses.

  2. Let the user specify an argument for what fraction of the training data is used only to determine early stopping (and not used to set weights). In this approach, they don't specify a separate data set to be used for early stopping.

    This option may confuse some users, who expect all data in the training set to always be used to determine network weights. However, it's consistent with the sklearn interface than the first option, and I think most users will find it the easiest to use.

  3. Let the user create a monitor object that tells the network when to stop. The user specifies the relevant data when creating that object, and that monitor is an optional argument to the fit method. This is how I interpret what @makseq described above, and this post describes its use with sklearn's GradientBoostingClassifier.

Thoughts?

@dansbecker
Copy link
Contributor

@ilblackdragon :Those three above are in addition to the one you mentioned
skflow.train(estimator, X_train, y_train, X_valid, y_valid, metric?)

My inclination is towards either your suggestion, or 2 or 3 in the note above.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

5 participants