-
Notifications
You must be signed in to change notification settings - Fork 839
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/more losses #845
Merged
Merged
Feat/more losses #845
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
21da3a6
added some loss functions
hrzn 9b57c22
Merge branch 'master' into feat/more-losses
hrzn 3126c2e
remove M3 loss which seems bogus
hrzn 070ecec
add unit tests for losses
hrzn 609d9fc
correct unit test
hrzn ed315d8
correct unit test
hrzn c8e00ba
correct unit test
hrzn d71fa1b
added NINF case
hrzn 34ed4d8
add optional denominator computation in MAPE loss
hrzn 33f01bd
better MAPE/MAE split
hrzn f56ce1a
simplify MAPE loss
hrzn d9af40c
simplify loss tests
hrzn f2982ea
remove a print statement
hrzn 50d4bb2
Update darts/utils/losses.py
hrzn f9dc3f1
Merge branch 'master' into feat/more-losses
hrzn File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import torch | ||
|
||
from darts.tests.base_test_class import DartsBaseTestClass | ||
from darts.utils.losses import MAELoss, MapeLoss, SmapeLoss | ||
|
||
|
||
class LossesTestCase(DartsBaseTestClass): | ||
x = torch.tensor([1.1, 2.2, 0.6345, -1.436]) | ||
y = torch.tensor([1.5, 0.5]) | ||
|
||
def helper_test_loss(self, exp_loss_val, exp_w_grad, loss_fn): | ||
W = torch.tensor([[0.1, -0.2, 0.3, -0.4], [-0.8, 0.7, -0.6, 0.5]]) | ||
W.requires_grad = True | ||
y_hat = W @ self.x | ||
lval = loss_fn(y_hat, self.y) | ||
lval.backward() | ||
|
||
self.assertTrue(torch.allclose(lval, exp_loss_val, atol=1e-3)) | ||
self.assertTrue(torch.allclose(W.grad, exp_w_grad, atol=1e-3)) | ||
|
||
def test_smape_loss(self): | ||
exp_val = torch.tensor(0.7753) | ||
exp_grad = torch.tensor( | ||
[[-0.2843, -0.5685, -0.1640, 0.3711], [-0.5859, -1.1718, -0.3380, 0.7649]] | ||
) | ||
self.helper_test_loss(exp_val, exp_grad, SmapeLoss()) | ||
|
||
def test_mape_loss(self): | ||
exp_val = torch.tensor(1.2937) | ||
exp_grad = torch.tensor( | ||
[[-0.3667, -0.7333, -0.2115, 0.4787], [-1.1000, -2.2000, -0.6345, 1.4360]] | ||
) | ||
self.helper_test_loss(exp_val, exp_grad, MapeLoss()) | ||
|
||
def test_mae_loss(self): | ||
exp_val = torch.tensor(1.0020) | ||
exp_grad = torch.tensor( | ||
[[-0.5500, -1.1000, -0.3173, 0.7180], [-0.5500, -1.1000, -0.3173, 0.7180]] | ||
) | ||
self.helper_test_loss(exp_val, exp_grad, MAELoss()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
""" | ||
PyTorch Loss Functions | ||
---------------------- | ||
""" | ||
# Inspiration: https://github.com/ElementAI/N-BEATS/blob/master/common/torch/losses.py | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
def _divide_no_nan(a, b): | ||
""" | ||
a/b where the resulted NaN or Inf are replaced by 0. | ||
""" | ||
result = a / b | ||
result[result != result] = 0.0 | ||
result[result == np.inf] = 0.0 | ||
hrzn marked this conversation as resolved.
Show resolved
Hide resolved
|
||
result[result == np.NINF] = 0.0 | ||
return result | ||
|
||
|
||
class SmapeLoss(nn.Module): | ||
def __init__(self, block_denom_grad: bool = True): | ||
""" | ||
sMAPE loss as defined in https://robjhyndman.com/hyndsight/smape/ (Chen and Yang 2004) | ||
|
||
Given a time series of actual values :math:`y_t` and a time series of predicted values :math:`\\hat{y}_t` | ||
both of length :math:`T`, it is computed as | ||
|
||
.. math:: | ||
\\frac{1}{T} | ||
\\sum_{t=1}^{T}{\\frac{\\left| y_t - \\hat{y}_t \\right|} | ||
{\\left| y_t \\right| + \\left| \\hat{y}_t \\right|} }. | ||
|
||
The results of divisions yielding NaN or Inf are replaced by 0. Note that we drop the coefficient of | ||
200 usually used for computing sMAPE values, as it impacts only the magnitude of the gradients | ||
and not their direction. | ||
|
||
Parameters | ||
---------- | ||
block_denom_grad | ||
Whether to stop the gradient in the denomitator | ||
""" | ||
super().__init__() | ||
self.block_denom_grad = block_denom_grad | ||
|
||
def forward(self, inpt, tgt): | ||
num = torch.abs(tgt - inpt) | ||
denom = torch.abs(tgt) + torch.abs(inpt) | ||
if self.block_denom_grad: | ||
denom = denom.detach() | ||
return torch.mean(_divide_no_nan(num, denom)) | ||
hrzn marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
class MapeLoss(nn.Module): | ||
def __init__(self): | ||
""" | ||
MAPE loss as defined in: https://en.wikipedia.org/wiki/Mean_absolute_percentage_error. | ||
|
||
Given a time series of actual values :math:`y_t` and a time series of predicted values :math:`\\hat{y}_t` | ||
both of length :math:`T`, it is computed as | ||
|
||
.. math:: | ||
\\frac{1}{T} | ||
\\sum_{t=1}^{T}{\\frac{\\left| y_t - \\hat{y}_t \\right|}{y_t}}. | ||
|
||
The results of divisions yielding NaN or Inf are replaced by 0. Note that we drop the coefficient of | ||
100 usually used for computing MAPE values, as it impacts only the magnitude of the gradients | ||
and not their direction. | ||
""" | ||
super().__init__() | ||
|
||
def forward(self, inpt, tgt): | ||
return torch.mean(torch.abs(_divide_no_nan(tgt - inpt, tgt))) | ||
|
||
|
||
class MAELoss(nn.Module): | ||
def __init__(self): | ||
""" | ||
MAE loss as defined in: https://en.wikipedia.org/wiki/Mean_absolute_error. | ||
|
||
Given a time series of actual values :math:`y_t` and a time series of predicted values :math:`\\hat{y}_t` | ||
both of length :math:`T`, it is computed as | ||
|
||
.. math:: | ||
\\frac{1}{T} | ||
\\sum_{t=1}^{T}{\\left| y_t - \\hat{y}_t \\right|}. | ||
|
||
Note that this is the same as torch.nn.L1Loss. | ||
""" | ||
super().__init__() | ||
|
||
def forward(self, inpt, tgt): | ||
return torch.mean(torch.abs(tgt - inpt)) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice tests +1