Skip to content

Commit 6d83f89

Browse files
committed
bug fix for retinanet with 2 classes (fg/bg)
(cherry picked from commit d1cf5e5)
1 parent b6561a1 commit 6d83f89

File tree

2 files changed

+3
-4
lines changed

2 files changed

+3
-4
lines changed

mmdet/core/anchor/anchor_target.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,7 @@ def anchor_target_single(flat_anchors,
158158

159159

160160
def expand_binary_labels(labels, label_weights, label_channels):
161-
bin_labels = labels.new_full(
162-
(labels.size(0), label_channels), 0, dtype=torch.float32)
161+
bin_labels = labels.new_full((labels.size(0), label_channels), 0)
163162
inds = torch.nonzero(labels >= 1).squeeze()
164163
if inds.numel() > 0:
165164
bin_labels[inds, labels[inds] - 1] = 1

mmdet/core/loss/losses.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@ def weighted_nll_loss(pred, label, weight, avg_factor=None):
1010
return torch.sum(raw * weight)[None] / avg_factor
1111

1212

13-
def weighted_cross_entropy(pred, label, weight, avg_factor=None,
14-
reduce=True):
13+
def weighted_cross_entropy(pred, label, weight, avg_factor=None, reduce=True):
1514
if avg_factor is None:
1615
avg_factor = max(torch.sum(weight > 0).float().item(), 1.)
1716
raw = F.cross_entropy(pred, label, reduction='none')
@@ -36,6 +35,7 @@ def sigmoid_focal_loss(pred,
3635
alpha=0.25,
3736
reduction='elementwise_mean'):
3837
pred_sigmoid = pred.sigmoid()
38+
target = target.type_as(pred)
3939
pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
4040
weight = (alpha * target + (1 - alpha) * (1 - target)) * weight
4141
weight = weight * pt.pow(gamma)

0 commit comments

Comments
 (0)