Skip to content

Commit

Permalink
Merge pull request #2 from Nzteb/mse
Browse files Browse the repository at this point in the history
Add se loss
  • Loading branch information
rgemulla committed May 14, 2020
2 parents 40383af + dfd0aac commit 455a5e9
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
4 changes: 4 additions & 0 deletions kge/config-default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ train:
# true label. Computed once for each positive and once for each negative
# triple in the batch. See loss_arg for parameters.
#
# Squared error (se, all training types): Calculate squared error between
# the score of a triple and its true value (0, 1). Computed once for each
# positive and once for each negative triple in the batch.
#
# Generally, the loss values are averaged over the batch elements (e.g.,
# positive triple for 1vsAll and negative_sampling, sp- or po-pair for
# KvsAll). If multiple loss values arise for each batch element (e.g., when
Expand Down
13 changes: 13 additions & 0 deletions kge/util/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def create(config: Config):
"ce",
"kl",
"soft_margin",
"se",
],
)
if config.get("train.loss") == "bce":
Expand Down Expand Up @@ -81,6 +82,8 @@ def create(config: Config):
return MarginRankingKgeLoss(config, margin=margin)
elif config.get("train.loss") == "soft_margin":
return SoftMarginKgeLoss(config)
elif config.get("train.loss") == "se":
return SEKgeLoss(config)
else:
raise ValueError(
"invalid value train.loss={}".format(config.get("train.loss"))
Expand Down Expand Up @@ -259,3 +262,13 @@ def __call__(self, scores, labels, **kwargs):
)
else:
raise ValueError("train.type for margin ranking.")


class SEKgeLoss(KgeLoss):
def __init__(self, config, reduction="sum", **kwargs):
super().__init__(config)
self._loss = torch.nn.MSELoss(reduction=reduction, **kwargs)

def __call__(self, scores, labels, **kwargs):
labels = self._labels_as_matrix(scores, labels)
return self._loss(scores, labels)

0 comments on commit 455a5e9

Please sign in to comment.