Skip to content

Commit

Permalink
Fix FocalLoss (#259)
Browse files Browse the repository at this point in the history
* Fix positional arguments for FocalLoss
* Set loss_0 (1 - alpha) coeff in binary_focal_loss
  • Loading branch information
qubvel committed Nov 18, 2019
1 parent ba5cfca commit d71189c
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion segmentation_models/base/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,6 @@ def binary_focal_loss(gt, pr, gamma=2.0, alpha=0.25, **kwargs):
pr = backend.clip(pr, backend.epsilon(), 1.0 - backend.epsilon())

loss_1 = - gt * (alpha * backend.pow((1 - pr), gamma) * backend.log(pr))
loss_0 = - (1 - gt) * (alpha * backend.pow((pr), gamma) * backend.log(1 - pr))
loss_0 = - (1 - gt) * ((1 - alpha) * backend.pow((pr), gamma) * backend.log(1 - pr))
loss = backend.mean(loss_0 + loss_1)
return loss
6 changes: 3 additions & 3 deletions segmentation_models/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ def __call__(self, gt, pr):
return F.categorical_focal_loss(
gt,
pr,
self.alpha,
self.gamma,
alpha=self.alpha,
gamma=self.gamma,
class_indexes=self.class_indexes,
**self.submodules
)
Expand Down Expand Up @@ -235,7 +235,7 @@ def __init__(self, alpha=0.25, gamma=2.):
self.gamma = gamma

def __call__(self, gt, pr):
return F.binary_focal_loss(gt, pr, self.alpha, self.gamma, **self.submodules)
return F.binary_focal_loss(gt, pr, alpha=self.alpha, gamma=self.gamma, **self.submodules)


# aliases
Expand Down

0 comments on commit d71189c

Please sign in to comment.