Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 28 additions & 10 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1267,6 +1267,16 @@ def binary_cross_entropy_with_logits(input, target, weight=None, size_average=Tr
return loss.sum()


def _pointwise_loss(lambd, lambd_optimized, input, target, size_average=True, reduce=True):
if target.requires_grad:
d = lambd(input, target)
if not reduce:
return d
return torch.mean(d) if size_average else torch.sum(d)
else:
return lambd_optimized(input, target, size_average, reduce)


smooth_l1_loss = _add_docstr(torch._C._nn.smooth_l1_loss, r"""
smooth_l1_loss(input, target, size_average=True) -> Variable

Expand All @@ -1276,21 +1286,29 @@ def binary_cross_entropy_with_logits(input, target, weight=None, size_average=Tr
See :class:`~torch.nn.SmoothL1Loss` for details.
""")

l1_loss = _add_docstr(torch._C._nn.l1_loss, r"""
l1_loss(input, target, size_average=True, reduce=True) -> Variable

Function that takes the mean element-wise absolute value difference.
def l1_loss(input, target, size_average=True, reduce=True):
"""
l1_loss(input, target, size_average=True, reduce=True) -> Variable

See :class:`~torch.nn.L1Loss` for details.
""")
Function that takes the mean element-wise absolute value difference.

mse_loss = _add_docstr(torch._C._nn.mse_loss, r"""
mse_loss(input, target, size_average=True, reduce=True) -> Variable
See :class:`~torch.nn.L1Loss` for details.
"""
return _pointwise_loss(lambda a, b: torch.abs(a - b), torch._C._nn.l1_loss,
input, target, size_average, reduce)

Measures the element-wise mean squared error.

See :class:`~torch.nn.MSELoss` for details.
""")
def mse_loss(input, target, size_average=True, reduce=True):
"""
mse_loss(input, target, size_average=True, reduce=True) -> Variable

Measures the element-wise mean squared error.

See :class:`~torch.nn.MSELoss` for details.
"""
return _pointwise_loss(lambda a, b: (a - b) ** 2, torch._C._nn.mse_loss,
input, target, size_average, reduce)


def margin_ranking_loss(input1, input2, target, margin=0, size_average=True):
Expand Down