-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
Callbacks API #16925
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Callbacks API #16925
Conversation
this is super awesome. Are you thinking about working further on this? Working on logging and callbacks could be a cool topic for the MLH fellowship thing. As a first pass we could do 'only' logging though? |
Yes, I would still like to work on this but I don't have much availability at the moment. I think working on callbacks as part of MLH would be nice, and in particular logging would indeed be a good start. Using callbacks with logging handler would IMO be better than having conditional print or I marked this as WIP, but basically it's a minimal working implementation, where the API should be sufficient for logging. Then it would need to be applied to all estimators and replace our current logging approach. If I had to change anything here it might be make callback method names a bit closer to keras callbacks. In the end I'm not sure SLEP is ideal for this, maybe introducing it as private feature, and incrementally improving it, as we did for estimator tags, might be better as it's hard to plan all of it ahead. None of the additions should have an impact on users at present. Even a superficial review would be much appreciated, cc @NicolasHug @thomasjpfan @glemaitre |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice stuff!
Made a few comments but I mostly have questions at this point
[tool.black] | ||
line-length = 79 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sneaky :p
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I don't see the point in manually formatting code anymore for new files. It shouldn't hurt even if we are not using everywhere..
} | ||
|
||
|
||
def _check_callback_params(**kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we let each callback independently validate its data?
My question might not make sense but I don't see this being used anywhere except in the tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we let each callback independently validate its data?
My question might not make sense but I don't see this being used anywhere except in the tests
Yes, absolutely each callback validates its data. But we also need to enforce that callbacks do follow the documented API in tests. For instance, that no undocumented parameters are passed etc which requires this function.
Third party callbacks could also use this validations function, similarly to how we expose check_array
.
sklearn/base.py
Outdated
In the case of meta-estmators, callbacks are also set recursively | ||
for all child estimators. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thoughts on doing this vs letting users set callbacks on sub-estimator instances?
what about e.g. early stopping when we ultimately support this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a deep=True
option to allow disabling recursion for meta-estimators, which can certainly be useful in some cases.
In most cases though, I don't see users manually setting callbacks for each individual estimator in a complex pipeline..
if callbacks is not None: | ||
with gil: | ||
_eval_callbacks(callbacks, n_iter=n_iter, coef=weights_array, | ||
intercept=intercept_array) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remember a strange behavior during the paris sprint last year (I think it was with @pierreglaser and @tomMoral ?) where the GIL was acquired in a condition like this, and even when the condition was always False the code was significantly slower.
Might be something to keep in mind
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's clearly an issues with parallel code cython/cython#3554 (and I'm not sure how to handle parallel code with callbacks so far).
However acquiring GIL in long running loops occasionally is beneficial as users can't interrupt calculation with Ctrl+C otherwise. So acquiring GIL at the end of each epoch would actually solve a bug here #9136 (comment)
Will switch to acquire the GIL at the end of each epoch even if callback is None.
sklearn/linear_model/_ridge.py
Outdated
@@ -103,6 +104,8 @@ def _mv(x): | |||
# old scipy | |||
coefs[i], info = sp_linalg.cg(C, y_column, maxiter=max_iter, | |||
tol=tol) | |||
if callbacks is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe you can remove this since the is not None
check is also done in _eval_callbacks
|
||
def _eval_callbacks(self, **kwargs): | ||
"""Call callbacks, e.g. in each iteration of an iterative solver""" | ||
from ._callbacks import _eval_callbacks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why lazy import?
if callbacks is None: | ||
return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we rely on callbacks
being an empty list instead of subcasing with None?
Or maybe you are anticipating a future where callbacks=None
would be a default argument to estimators?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we switched to the case when callbacks=None
when missing.
assert callback.n_calls == 0 | ||
estimator.fit(X, y) | ||
if callback.n_fit_calls == 0: | ||
pytest.skip("callbacks not implemented") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
otherwise assert it's equal to 1?
sklearn/tests/test_callbacks.py
Outdated
|
||
|
||
def check_has_callback(est, callback): | ||
assert getattr(est, "_callbacks", None) is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this different from hasattr
? Or can the attribute exist and be None for some reason?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, reworded more clearly as,
assert hasattr(est, "_callbacks") and est._callbacks is not None
since None
could be equivalent to [], just to be sure that this is not happening.
sklearn/decomposition/_nmf.py
Outdated
@@ -1335,7 +1345,7 @@ def transform(self, X): | |||
beta_loss=self.beta_loss, tol=self.tol, max_iter=self.max_iter, | |||
alpha=self.alpha, l1_ratio=self.l1_ratio, regularization='both', | |||
random_state=self.random_state, verbose=self.verbose, | |||
shuffle=self.shuffle) | |||
shuffle=self.shuffle, callbacks=getattr(self, '_callbacks', [])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like we should either decide that
- no callbacks means an empty list
- no callabacks means None and having callbacks means a non-empty list
but it seems that the code is mixing both right now?
If we ultimately plan on having callbacks=None
as a default param then the latter would be more appropriate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you are right let's go with callbacks=None
everywhere, and just let the eval function handle it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, here we would still need to do provide the default option to getattrs
: getattr(self, '_callbacks', None)
and that's not really much more readable than getattr(self, '_callbacks', [])
so both would work..
Thanks @NicolasHug ! I think I have addressed your comments, please let me know if you have other questions. I have updated the required callbacks method names, inspired by Keras, class MyCallback(BaseCallback):
def on_fit_begin(self, estimator, X, y):
...
def on_iter_end(self, **kwargs):
... which I think is more explicit than earlier names.
The plan for early stopping so far is,
For a complex pipeline, with callbacks set recursively, clearly we need to do this only for estimator where it makes sense, but I think it should be doable. In any case this can probably done in a follow up PR. For now I put in the documentation that the return value of callback methods is ignored. |
BTW, regarding logging and #17439 (cc @thomasjpfan, @adrinjalali ), there is an example of logging callback in sklearn-callbacks repo. For instance, this example from sklearn.compose import make_column_transformer
from sklearn.datasets import make_classification
from sklearn.impute import SimpleImputer
from sklearn.linear_model import SGDClassifier
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn_callbacks import DebugCallback
X, y = make_classification(n_samples=10000, n_features=100, random_state=0)
pipe = make_pipeline(
SimpleImputer(),
make_column_transformer(
(StandardScaler(), slice(0, 80)), (MinMaxScaler(), slice(80, 90)),
),
SGDClassifier(max_iter=20),
)
pbar = DebugCallback()
pipe._set_callbacks(pbar)
pipe.fit(X, y) would produce,
This can clearly be improved, e.g. linked with our |
This can get a tricky when we have logging in our cython or c code. Having designed callbacks in skorch and reviewed callbacks in fastai, I think it is hard to come up with a complete list of items to pass to the callback. Fastai can pass the caller object directly to the callback and the callback can directly change the state of the model. For our case, passing metrics should be good enough. For logging, we also commonly log the elapsed time and may have interesting formatting (
It also logs the number of leaves etc:
I would not think |
Thanks for the feedback @thomasjpfan !
In the current PR we can pass the estimator object to the callback, but we can't indeed do that from C or Cython code since it's not available there. We could take
For elapsed time we would also need to add the
yes, it's one of the challenge that models are really heterogeneous. I guess we would also add a The thing is that if we don't use callbacks for logging, it means that we would have to do very similar things twice, once for logging and once for callbacks. |
We could pass a specified set of objects/data (targeted to logging), as well as mutable |
How would this interact with multiprocessing? I'm not sure I understand how you would implement grid-search logging with multiprocessing as callbacks. |
How would this interact with multiprocessing? I'm not sure I understand
how you would implement grid-search logging with multiprocessing as
callbacks.
I don't see what the problem is if the callback is picklable and sends its
output to a shared resource.
|
@jnothman I guess it isn't clear to me how you would write to a shared resource with joblib. |
@amueller Yes, it's not that straightforward. One could send callback parameters to a shared queue from different processes. But then one needs an additional thread to process these callback events, and start/stop this thread at some point. For instance, See example below, from multiprocessing import Manager
from threading import Thread, Lock
import queue
from joblib import delayed, Parallel
from time import sleep
class Callback():
def __init__(self, m):
self.q = m.Queue()
def on_iter_end(self, x):
res = x**2
self.q.put(f'callback value {x}')
return res
@staticmethod
def process_callback_events(q, should_exit):
"""Process all events in the queue"""
while True:
try:
print(q.get_nowait())
except queue.Empty:
print('queue empty')
if should_exit.locked():
return
else:
sleep(1)
def start_processing_callback(self):
thread_should_exit = Lock()
thread = Thread(target=self.process_callback_events,
args=(self.q, thread_should_exit))
thread.start()
return thread_should_exit
m = Manager()
cbk = Callback(m)
callback_processing_stop = cbk.start_processing_callback()
res = Parallel(n_jobs=2)(
delayed(cbk.on_iter_end)(x) for x in range(10)
)
callback_processing_stop.acquire(blocking=False)
print(res) which would produce,
But then I'm sure there are plenty of edge cases with different joblib backends that would need handling (e.g. nested parallel blocks). And this would require re-designing the callbacks API somewhat. |
Do you think this is worth the complexity? It seems like a can of worms, but it also would open a lot of doors (does that qualify as mixed metaphors?). Is it plausible / feasible to make the actual usage mostly hidden from the user? I assume that could work if joblib maybe get some special hook? I'd rather not litter the sklearn code with callback locking code. And is there an easy way for the callback to determine the backend? We could also go to logging first if it's "easy enough" and then try to go to callbacks later once we figure out how to do it and whether it's worth it? |
As far as I can tell the approach of logging in the case of multiprocessing would work very similarly with a queue (
In terms of complexity, as long as it's isolated in the optional Callback object it might be not so bad. The above code can certainly be improved as well. The issue is more indeed the need to start/stop the callback/log monitoring thread around each parallel section which is not ideal. We should talk with joblib devs to see if this could be made more user friendly (both here and for logging). |
You are in closer physical proximity to the joblib people ;) It looks like there's no way around modifying each call to Thinking about it a bit more, does that happen very often? You could do a bunch in this PR without it. It's not only GridSearchCV, right? I guess OVO and OVR and VotingClassifier are candidates? In other words, does it make sense to do a first solution where we don't handle the parallel case? I guess users will probably be miffed if we introduce and interface but it's not supported in GridSearchCV. |
With @thomasjpfan we discussed that maybe it could be simpler to start with implementing just In general I'm all for small incremental PRs. Having this in the private API even without parallel support could already be quite useful. It's more of a question of how confident are we that we will be able to add support for joblib later if needed without massively changing the API, and that the chosen API is generally reasonable. |
Hi folks, just chiming in here to note that a callbacks API would be very useful for projects that collect / aggregate metadata for model training routines. In particular, the MLflow project would benefit greatly from a callbacks API in the context of its upcoming scikit-learn "autologging feature" (discussed here: mlflow/mlflow#2050): callbacks would enable MLflow to patch in custom hooks that store per-epoch metrics for a given scikit-learn training session in a centralized tracking service. We would also be very excited to see this capability introduced as a private API in an upcoming scikit-learn release, even if it does not apply to all classifiers or to parallel execution environments (joblib, Dask, etc), as alluded to in #16925 (comment). |
Hi folks, I also briefly wanted to propose another use of a future callbacks API that would be very useful for projects that train ML models under a time budget, which is stopping the training if the time is up. This would be very handy in the Auto-sklearn project. For each evaluation of the ML model we give a time limit and end the process if the time limit is hit. With such a callback we could safely shut down the process ourselves and still use the partially trained model. I'm not sure if this is use case will be considered (it's some form of early stopping I guess), but I just wanted to bring up another potential use case that wasn't discussed yet. |
Hello folks, I love and strongly support the idea of callbacks in sklearn, thanks for the progress so far, is this still a WIP for other algorithms, such as GaussianProcessClassifier? |
+1 for moving forward with this PR with a private-only callback registration API for now, possibly without specific support for the parallel case (as long as the callbacks themselves are picklable). We already override the delayed function in scikit-learn, so @rth feel free to prototype something with it to make it possible to implement additional logic on the loky workers for callbacks that need access to a shared resources. I think we would need to prototype a specific use case to see if we really need to change things in joblib or not. |
Superseded by #22000 |
This is a first iteration of an API for callbacks which could be used e.g. for monitoring progress of calculations, and convergence as well as, potentially, early stopping. The goal of this PR is to experiment with callbacks, which would likely serve as a basis for a SLEP.
As proposed in the PR documentation, to implement a callback API, in their current iteration, an estimator should,
fit
either explicitly callself._fit_callbacks(X, y)
or useself._validate_data(X, y)
whichmakes a
self._fit_callbacks
call internally.self._eval_callbacks(n_iter=.., **kwargs)
ateach iteration, where
kwargs
keys must be part of supported callbackarguments (cf. list below). The questions is whether we can meaningfully standardize parameters passed as kwargs, just passing
locals()
won't do.User defined callbacks must extend the
sklearn._callbacks.BaseCallback
abstract base class.
For instance some callbacks based on this PR are implemented in the
sklearn-callbacks package (see readme for detailed examples),
Progress bars #7574 #78 #10973
Determining which callback originates from which estimator is actually non trivial (and I haven't even started dealing with parallel computing). Currently I'm re-building a separate approximate computational graph for pipelines etc. Anyway once it's done (in a separate package), this could be used to animate model training on a graph (similar to what dask.diagnostics does) or say an HTML repr of pipelines by @thomasjpfan via some jupyter widget.
Monitoring convergence #14338 #8994 (comment)
Having callbacks is also quite useful to monitor model convergence,
e.g.
for now this only works for a small subset of linear models. One reason is that in the iteration loop, the solver must provide enough information to be able to reconstruct the model, which is not always the case. For instance for linear models it would be params + coef + intercept, but even then the definition of coef varies significantly across linear model (and whether we fit_intercept or not etc).
Early stopping #10973
The idea is that callbacks should be able to interrupt training: e.g. because some evaluation metric does no longer decrease on a validation set, cf figure above or due to other user defined reasons. This part is not yet included in this PR.
TODO