diff --git a/torch/nn/modules/loss.py b/torch/nn/modules/loss.py index e96a38bc0691b..19873ea04833b 100644 --- a/torch/nn/modules/loss.py +++ b/torch/nn/modules/loss.py @@ -217,7 +217,15 @@ class HingeEmbeddingLoss(_Loss): The `margin` has a default value of `1`, or can be set in the constructor. """ - pass + + def __init__(self, margin=1.0, size_average=True): + super(HingeEmbeddingLoss, self).__init__() + self.margin = margin + self.size_average = size_average + + def forward(self, input, target): + return self._backend.HingeEmbeddingLoss(self.margin, + self.size_average)(input, target) class MultiLabelMarginLoss(_Loss):