-
Notifications
You must be signed in to change notification settings - Fork 153
/
loss.py
49 lines (38 loc) · 1.64 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from __future__ import division
import numpy as np
import torch
from torch.nn import functional as F
def class_balanced_cross_entropy_loss(output, label, size_average=True, batch_average=True, void_pixels=None):
"""Define the class balanced cross entropy loss to train the network
Args:
output: Output of the network
label: Ground truth label
size_average: return per-element (pixel) average loss
batch_average: return per-batch average loss
void_pixels: pixels to ignore from the loss
Returns:
Tensor that evaluates the loss
"""
assert(output.size() == label.size())
labels = torch.ge(label, 0.5).float()
num_labels_pos = torch.sum(labels)
num_labels_neg = torch.sum(1.0 - labels)
num_total = num_labels_pos + num_labels_neg
output_gt_zero = torch.ge(output, 0).float()
loss_val = torch.mul(output, (labels - output_gt_zero)) - torch.log(
1 + torch.exp(output - 2 * torch.mul(output, output_gt_zero)))
loss_pos_pix = -torch.mul(labels, loss_val)
loss_neg_pix = -torch.mul(1.0 - labels, loss_val)
if void_pixels is not None:
w_void = torch.le(void_pixels, 0.5).float()
loss_pos_pix = torch.mul(w_void, loss_pos_pix)
loss_neg_pix = torch.mul(w_void, loss_neg_pix)
num_total = num_total - torch.ge(void_pixels, 0.5).float().sum()
loss_pos = torch.sum(loss_pos_pix)
loss_neg = torch.sum(loss_neg_pix)
final_loss = num_labels_neg / num_total * loss_pos + num_labels_pos / num_total * loss_neg
if size_average:
final_loss /= np.prod(label.size())
elif batch_average:
final_loss /= label.size()[0]
return final_loss