-
Notifications
You must be signed in to change notification settings - Fork 88
/
loss_contrast.py
234 lines (167 loc) · 8.68 KB
/
loss_contrast.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from abc import ABC
import torch
import torch.nn as nn
import torch.nn.functional as F
from lib.loss.loss_helper import FSAuxCELoss, FSAuxRMILoss, FSCELoss
from lib.utils.tools.logger import Logger as Log
class PixelContrastLoss(nn.Module, ABC):
def __init__(self, configer):
super(PixelContrastLoss, self).__init__()
self.configer = configer
self.temperature = self.configer.get('contrast', 'temperature')
self.base_temperature = self.configer.get('contrast', 'base_temperature')
self.ignore_label = -1
if self.configer.exists('loss', 'params') and 'ce_ignore_index' in self.configer.get('loss', 'params'):
self.ignore_label = self.configer.get('loss', 'params')['ce_ignore_index']
self.max_samples = self.configer.get('contrast', 'max_samples')
self.max_views = self.configer.get('contrast', 'max_views')
def _hard_anchor_sampling(self, X, y_hat, y):
batch_size, feat_dim = X.shape[0], X.shape[-1]
classes = []
total_classes = 0
for ii in range(batch_size):
this_y = y_hat[ii]
this_classes = torch.unique(this_y)
this_classes = [x for x in this_classes if x != self.ignore_label]
this_classes = [x for x in this_classes if (this_y == x).nonzero().shape[0] > self.max_views]
classes.append(this_classes)
total_classes += len(this_classes)
if total_classes == 0:
return None, None
n_view = self.max_samples // total_classes
n_view = min(n_view, self.max_views)
X_ = torch.zeros((total_classes, n_view, feat_dim), dtype=torch.float).cuda()
y_ = torch.zeros(total_classes, dtype=torch.float).cuda()
X_ptr = 0
for ii in range(batch_size):
this_y_hat = y_hat[ii]
this_y = y[ii]
this_classes = classes[ii]
for cls_id in this_classes:
hard_indices = ((this_y_hat == cls_id) & (this_y != cls_id)).nonzero()
easy_indices = ((this_y_hat == cls_id) & (this_y == cls_id)).nonzero()
num_hard = hard_indices.shape[0]
num_easy = easy_indices.shape[0]
if num_hard >= n_view / 2 and num_easy >= n_view / 2:
num_hard_keep = n_view // 2
num_easy_keep = n_view - num_hard_keep
elif num_hard >= n_view / 2:
num_easy_keep = num_easy
num_hard_keep = n_view - num_easy_keep
elif num_easy >= n_view / 2:
num_hard_keep = num_hard
num_easy_keep = n_view - num_hard_keep
else:
Log.info('this shoud be never touched! {} {} {}'.format(num_hard, num_easy, n_view))
raise Exception
perm = torch.randperm(num_hard)
hard_indices = hard_indices[perm[:num_hard_keep]]
perm = torch.randperm(num_easy)
easy_indices = easy_indices[perm[:num_easy_keep]]
indices = torch.cat((hard_indices, easy_indices), dim=0)
X_[X_ptr, :, :] = X[ii, indices, :].squeeze(1)
y_[X_ptr] = cls_id
X_ptr += 1
return X_, y_
def _contrastive(self, feats_, labels_):
anchor_num, n_view = feats_.shape[0], feats_.shape[1]
labels_ = labels_.contiguous().view(-1, 1)
mask = torch.eq(labels_, torch.transpose(labels_, 0, 1)).float().cuda()
contrast_count = n_view
contrast_feature = torch.cat(torch.unbind(feats_, dim=1), dim=0)
anchor_feature = contrast_feature
anchor_count = contrast_count
anchor_dot_contrast = torch.div(torch.matmul(anchor_feature, torch.transpose(contrast_feature, 0, 1)),
self.temperature)
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()
mask = mask.repeat(anchor_count, contrast_count)
neg_mask = 1 - mask
logits_mask = torch.ones_like(mask).scatter_(1,
torch.arange(anchor_num * anchor_count).view(-1, 1).cuda(),
0)
mask = mask * logits_mask
neg_logits = torch.exp(logits) * neg_mask
neg_logits = neg_logits.sum(1, keepdim=True)
exp_logits = torch.exp(logits)
log_prob = logits - torch.log(exp_logits + neg_logits)
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
loss = loss.mean()
return loss
def forward(self, feats, labels=None, predict=None):
labels = labels.unsqueeze(1).float().clone()
labels = torch.nn.functional.interpolate(labels,
(feats.shape[2], feats.shape[3]), mode='nearest')
labels = labels.squeeze(1).long()
assert labels.shape[-1] == feats.shape[-1], '{} {}'.format(labels.shape, feats.shape)
batch_size = feats.shape[0]
labels = labels.contiguous().view(batch_size, -1)
predict = predict.contiguous().view(batch_size, -1)
feats = feats.permute(0, 2, 3, 1)
feats = feats.contiguous().view(feats.shape[0], -1, feats.shape[-1])
feats_, labels_ = self._hard_anchor_sampling(feats, labels, predict)
loss = self._contrastive(feats_, labels_)
return loss
class ContrastCELoss(nn.Module, ABC):
def __init__(self, configer=None):
super(ContrastCELoss, self).__init__()
self.configer = configer
ignore_index = -1
if self.configer.exists('loss', 'params') and 'ce_ignore_index' in self.configer.get('loss', 'params'):
ignore_index = self.configer.get('loss', 'params')['ce_ignore_index']
Log.info('ignore_index: {}'.format(ignore_index))
self.loss_weight = self.configer.get('contrast', 'loss_weight')
self.use_rmi = self.configer.get('contrast', 'use_rmi')
if self.use_rmi:
self.seg_criterion = FSAuxRMILoss(configer=configer)
else:
self.seg_criterion = FSCELoss(configer=configer)
self.contrast_criterion = PixelContrastLoss(configer=configer)
def forward(self, preds, target, with_embed=False):
h, w = target.size(1), target.size(2)
assert "seg" in preds
assert "embed" in preds
seg = preds['seg']
embedding = preds['embed']
pred = F.interpolate(input=seg, size=(h, w), mode='bilinear', align_corners=True)
loss = self.seg_criterion(pred, target)
_, predict = torch.max(seg, 1)
loss_contrast = self.contrast_criterion(embedding, target, predict)
if with_embed is True:
return loss + self.loss_weight * loss_contrast
return loss + 0 * loss_contrast # just a trick to avoid errors in distributed training
class ContrastAuxCELoss(nn.Module, ABC):
def __init__(self, configer=None):
super(ContrastAuxCELoss, self).__init__()
self.configer = configer
ignore_index = -1
if self.configer.exists('loss', 'params') and 'ce_ignore_index' in self.configer.get('loss', 'params'):
ignore_index = self.configer.get('loss', 'params')['ce_ignore_index']
Log.info('ignore_index: {}'.format(ignore_index))
self.loss_weight = self.configer.get('contrast', 'loss_weight')
self.use_rmi = self.configer.get('contrast', 'use_rmi')
if self.use_rmi:
self.seg_criterion = FSAuxRMILoss(configer=configer)
else:
self.seg_criterion = FSAuxCELoss(configer=configer)
self.contrast_criterion = PixelContrastLoss(configer=configer)
def forward(self, preds, target, with_embed=False):
h, w = target.size(1), target.size(2)
assert "seg" in preds
assert "seg_aux" in preds
assert "embed" in preds
seg = preds['seg']
seg_aux = preds['seg_aux']
embedding = preds['embed']
pred = F.interpolate(input=seg, size=(h, w), mode='bilinear', align_corners=True)
pred_aux = F.interpolate(input=seg_aux, size=(h, w), mode='bilinear', align_corners=True)
loss = self.seg_criterion([pred_aux, pred], target)
_, predict = torch.max(seg, 1)
loss_contrast = self.contrast_criterion(embedding, target, predict)
if with_embed is True:
return loss + self.loss_weight * loss_contrast
return loss + 0 * loss_contrast # just a trick to avoid errors in distributed training