Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make sure HPO default lookup is based on class for losses #111

Merged
merged 3 commits into from
Oct 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions docs/source/reference/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,3 @@ Loss Functions
:no-heading:
:headings: --
:skip: Loss

HPO Defaults
------------
.. autodata:: losses_hpo_defaults
11 changes: 9 additions & 2 deletions src/pykeen/hpo/hpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .samplers import get_sampler_cls
from ..datasets.base import DataSet
from ..evaluation import Evaluator, get_evaluator_cls
from ..losses import Loss, _LOSS_SUFFIX, get_loss_cls, losses_hpo_defaults
from ..losses import Loss, _LOSS_SUFFIX, get_loss_cls
from ..models import get_model_cls
from ..models.base import Model
from ..optimizers import Optimizer, get_optimizer_cls, optimizers_hpo_defaults
Expand Down Expand Up @@ -132,11 +132,18 @@ def __call__(self, trial: Trial) -> Optional[float]:
kwargs=self.model_kwargs,
kwargs_ranges=self.model_kwargs_ranges,
)

try:
loss_default_kwargs_ranges = self.loss.hpo_default
except AttributeError:
logger.warning('using a loss function with no hpo_default field: %s', self.loss)
loss_default_kwargs_ranges = {}

# 3. Loss
_loss_kwargs = _get_kwargs(
trial=trial,
prefix='loss',
default_kwargs_ranges=losses_hpo_defaults[self.loss],
default_kwargs_ranges=loss_default_kwargs_ranges,
kwargs=self.loss_kwargs,
kwargs_ranges=self.loss_kwargs_ranges,
)
Expand Down
33 changes: 14 additions & 19 deletions src/pykeen/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

"""Loss functions integrated in PyKEEN."""

from typing import Any, Mapping, Optional, Set, Type, Union
from typing import Any, ClassVar, Mapping, Optional, Set, Type, Union

import torch
from torch import nn
Expand All @@ -19,7 +19,6 @@
'MarginRankingLoss',
'MSELoss',
'BCEWithLogitsLoss',
'losses_hpo_defaults',
'get_loss_cls',
]

Expand All @@ -32,7 +31,10 @@
class Loss(nn.Module):
"""A loss function."""

synonyms: Optional[Set[str]] = None
synonyms: ClassVar[Optional[Set[str]]] = None

#: The default strategy for optimizing the model's hyper-parameters
hpo_default: ClassVar[Mapping[str, Any]] = {}


class PointwiseLoss(Loss):
Expand Down Expand Up @@ -85,6 +87,10 @@ class MarginRankingLoss(PairwiseLoss, nn.MarginRankingLoss):

synonyms = {"Pairwise Hinge Loss"}

hpo_default = dict(
margin=dict(type=int, low=0, high=3, q=1),
)


class SoftplusLoss(PointwiseLoss):
"""A loss function for the softplus."""
Expand Down Expand Up @@ -154,6 +160,11 @@ class NSSALoss(SetwiseLoss):

synonyms = {'Self-Adversarial Negative Sampling Loss', 'Negative Sampling Self-Adversarial Loss'}

hpo_default = dict(
margin=dict(type=int, low=3, high=30, q=3),
adversarial_temperature=dict(type=float, low=0.5, high=1.0),
)

def __init__(self, margin: float = 9.0, adversarial_temperature: float = 1.0, reduction: str = 'mean') -> None:
"""Initialize the NSSA loss.

Expand Down Expand Up @@ -222,22 +233,6 @@ def forward(
}


#: HPO Defaults for losses
losses_hpo_defaults: Mapping[Type[Loss], Mapping[str, Any]] = {
MarginRankingLoss: dict(
margin=dict(type=int, low=0, high=3, q=1),
),
NSSALoss: dict(
margin=dict(type=int, low=3, high=30, q=3),
adversarial_temperature=dict(type=float, low=0.5, high=1.0),
),
}
# Add empty dictionaries as defaults for all remaining losses
for cls in _LOSSES:
if cls not in losses_hpo_defaults:
losses_hpo_defaults[cls] = {}


def get_loss_cls(query: Union[None, str, Type[Loss]]) -> Type[Loss]:
"""Get the loss class."""
return get_cls(
Expand Down