Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

📉☎️ Validation Loss Training Callback #1169

Merged
merged 53 commits into from
Sep 24, 2023
Merged

📉☎️ Validation Loss Training Callback #1169

merged 53 commits into from
Sep 24, 2023

Conversation

mberr
Copy link
Member

@mberr mberr commented Nov 19, 2022

This extracts parts of the training loop related to calculating the epoch loss into a function to re-use it for calculating validation losses.

Example:

from pykeen.datasets import get_dataset
from pykeen.pipeline import pipeline

dataset = get_dataset(dataset="nations")
pipeline(
    dataset=dataset,
    model="mure",
    training_kwargs=dict(
        callbacks="validation-loss",
        callback_kwargs=dict(triples_factory=dataset.validation),
    ),
    result_tracker="console",
)

mberr and others added 2 commits November 19, 2022 14:50
Also move optimizer & scheduler functionality into a callback
Comment on lines 461 to 463
# TODO: where to get these from?
label_smoothing = 0.0
training_data_loader_kwargs = dict(sampler=None)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are not attributes of the training loop but only present as variables in the _train method

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cf. 814d5c2

Comment on lines 469 to 470
# TODO: this should be num_instances rather than num_triples; also for cpu, we may want to reduce this
batch_size=self.triples_factory.num_triples,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may cause OOM kills on cpu for large datasets. It would be better to derive something from the training batch size (e.g., 4*batch_size or similar) as upper bound.

Performance-wise, a too large initial value will only effect runtime of the first call, since for later calls the previous values will be re-used.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cf. 7d95654

@@ -38,3 +38,15 @@ def test_batch_size(self):
),
)
assert {c.kwargs.get("batch_size", None) for c in mock_evaluate.call_args_list} != {None}


# TODO: more tests
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

most callbacks seem to lack tests 😕

@mberr mberr changed the title Validation Loss 📉☎️ Validation Loss Training Callback Sep 22, 2023
@cthoyt cthoyt enabled auto-merge (squash) September 23, 2023 00:19
@mberr mberr disabled auto-merge September 23, 2023 07:52


class ValidationLossTrainingCallback(TrainingCallback):
"""Calculate loss on a development set."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we get an end-to-end example usage in the docstring please

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


class ValidationLossTrainingCallback(TrainingCallback):
"""
Calculate loss on a development set.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validation set?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renamed to "evaluation"; I want to highlight that we do not need to have a single validation set, but could also use multiple callbacks with different evaluation sets (and potentially also differet frequencies, e.g., to have a small validation each step, and a bigger one every n-th step)

callback_kwargs=dict(triples_factory=dataset.validation),
),
result_tracker="console",
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you please explain how to get the validation losses after the fact as a list

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess someone might 1 use this in combination with a result tracker or 2 want to make their own charts or something

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about adding that to notebooks? (using the # %% style for better VCS): 49ae07e

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

myplot

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great, but I think it's ideal to keep the code examples in with the code itself so when people are looking through the docs they see it. It's also okay if we double it

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, the notebook is now 45 lines (tbf with formatting), and it is really messy to get the code well-formatted into rst...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The short version without the post-processing / plotting logic is already inside the docstring

@@ -490,6 +491,8 @@ class ValidationLossTrainingCallback(TrainingCallback):
def __init__(
self,
triples_factory: CoreTriplesFactory,
callbacks: TrainingCallbackHint = None,
callback_kwargs: TrainingCallbackKwargsHint = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might as well fix the name of this while we're here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or not, if it's gonna be a big diff

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mberr mberr enabled auto-merge (squash) September 23, 2023 21:02
@cthoyt cthoyt enabled auto-merge (squash) September 24, 2023 07:02
@cthoyt cthoyt enabled auto-merge (squash) September 24, 2023 07:29
@cthoyt cthoyt merged commit e184f97 into master Sep 24, 2023
11 checks passed
@cthoyt cthoyt deleted the validation-loss branch September 24, 2023 07:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants