-
-
Notifications
You must be signed in to change notification settings - Fork 179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
📉☎️ Validation Loss Training Callback #1169
Conversation
Also move optimizer & scheduler functionality into a callback
src/pykeen/training/callbacks.py
Outdated
# TODO: where to get these from? | ||
label_smoothing = 0.0 | ||
training_data_loader_kwargs = dict(sampler=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these are not attributes of the training loop but only present as variables in the _train
method
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cf. 814d5c2
src/pykeen/training/callbacks.py
Outdated
# TODO: this should be num_instances rather than num_triples; also for cpu, we may want to reduce this | ||
batch_size=self.triples_factory.num_triples, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This may cause OOM kills on cpu for large datasets. It would be better to derive something from the training batch size (e.g., 4*batch_size
or similar) as upper bound.
Performance-wise, a too large initial value will only effect runtime of the first call, since for later calls the previous values will be re-used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cf. 7d95654
@@ -38,3 +38,15 @@ def test_batch_size(self): | |||
), | |||
) | |||
assert {c.kwargs.get("batch_size", None) for c in mock_evaluate.call_args_list} != {None} | |||
|
|||
|
|||
# TODO: more tests |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
most callbacks seem to lack tests 😕
src/pykeen/training/callbacks.py
Outdated
|
||
|
||
class ValidationLossTrainingCallback(TrainingCallback): | ||
"""Calculate loss on a development set.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we get an end-to-end example usage in the docstring please
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
src/pykeen/training/callbacks.py
Outdated
|
||
class ValidationLossTrainingCallback(TrainingCallback): | ||
""" | ||
Calculate loss on a development set. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
validation set?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
renamed to "evaluation"; I want to highlight that we do not need to have a single validation set, but could also use multiple callbacks with different evaluation sets (and potentially also differet frequencies, e.g., to have a small validation each step, and a bigger one every n
-th step)
callback_kwargs=dict(triples_factory=dataset.validation), | ||
), | ||
result_tracker="console", | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you please explain how to get the validation losses after the fact as a list
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess someone might 1 use this in combination with a result tracker or 2 want to make their own charts or something
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about adding that to notebooks
? (using the # %%
style for better VCS): 49ae07e
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great, but I think it's ideal to keep the code examples in with the code itself so when people are looking through the docs they see it. It's also okay if we double it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, the notebook is now 45 lines (tbf with formatting), and it is really messy to get the code well-formatted into rst...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The short version without the post-processing / plotting logic is already inside the docstring
src/pykeen/training/callbacks.py
Outdated
@@ -490,6 +491,8 @@ class ValidationLossTrainingCallback(TrainingCallback): | |||
def __init__( | |||
self, | |||
triples_factory: CoreTriplesFactory, | |||
callbacks: TrainingCallbackHint = None, | |||
callback_kwargs: TrainingCallbackKwargsHint = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might as well fix the name of this while we're here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or not, if it's gonna be a big diff
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This extracts parts of the training loop related to calculating the epoch loss into a function to re-use it for calculating validation losses.
Example: