-
Notifications
You must be signed in to change notification settings - Fork 58
/
model.py
164 lines (137 loc) · 7.23 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import math, time
from itertools import chain
import torch
import torch.nn.functional as F
from torch import nn
from base import BaseModel
from utils.helpers import set_trainable
from utils.losses import *
from models.decoders import *
from models.encoder import Encoder
from utils.losses import CE_loss
class CCT(BaseModel):
def __init__(self, num_classes, conf, sup_loss=None, cons_w_unsup=None, ignore_index=None, testing=False,
pretrained=True, use_weak_lables=False, weakly_loss_w=0.4):
if not testing:
assert (ignore_index is not None) and (sup_loss is not None) and (cons_w_unsup is not None)
super(CCT, self).__init__()
assert int(conf['supervised']) + int(conf['semi']) == 1, 'one mode only'
if conf['supervised']:
self.mode = 'supervised'
else:
self.mode = 'semi'
# Supervised and unsupervised losses
self.ignore_index = ignore_index
if conf['un_loss'] == "KL":
self.unsuper_loss = softmax_kl_loss
elif conf['un_loss'] == "MSE":
self.unsuper_loss = softmax_mse_loss
elif conf['un_loss'] == "JS":
self.unsuper_loss = softmax_js_loss
else:
raise ValueError(f"Invalid supervised loss {conf['un_loss']}")
self.unsup_loss_w = cons_w_unsup
self.sup_loss_w = conf['supervised_w']
self.softmax_temp = conf['softmax_temp']
self.sup_loss = sup_loss
self.sup_type = conf['sup_loss']
# Use weak labels
self.use_weak_lables = use_weak_lables
self.weakly_loss_w = weakly_loss_w
# pair wise loss (sup mat)
self.aux_constraint = conf['aux_constraint']
self.aux_constraint_w = conf['aux_constraint_w']
# confidence masking (sup mat)
self.confidence_th = conf['confidence_th']
self.confidence_masking = conf['confidence_masking']
# Create the model
self.encoder = Encoder(pretrained=pretrained)
# The main encoder
upscale = 8
num_out_ch = 2048
decoder_in_ch = num_out_ch // 4
self.main_decoder = MainDecoder(upscale, decoder_in_ch, num_classes=num_classes)
# The auxilary decoders
if self.mode == 'semi' or self.mode == 'weakly_semi':
vat_decoder = [VATDecoder(upscale, decoder_in_ch, num_classes, xi=conf['xi'],
eps=conf['eps']) for _ in range(conf['vat'])]
drop_decoder = [DropOutDecoder(upscale, decoder_in_ch, num_classes,
drop_rate=conf['drop_rate'], spatial_dropout=conf['spatial'])
for _ in range(conf['drop'])]
cut_decoder = [CutOutDecoder(upscale, decoder_in_ch, num_classes, erase=conf['erase'])
for _ in range(conf['cutout'])]
context_m_decoder = [ContextMaskingDecoder(upscale, decoder_in_ch, num_classes)
for _ in range(conf['context_masking'])]
object_masking = [ObjectMaskingDecoder(upscale, decoder_in_ch, num_classes)
for _ in range(conf['object_masking'])]
feature_drop = [FeatureDropDecoder(upscale, decoder_in_ch, num_classes)
for _ in range(conf['feature_drop'])]
feature_noise = [FeatureNoiseDecoder(upscale, decoder_in_ch, num_classes,
uniform_range=conf['uniform_range'])
for _ in range(conf['feature_noise'])]
self.aux_decoders = nn.ModuleList([*vat_decoder, *drop_decoder, *cut_decoder,
*context_m_decoder, *object_masking, *feature_drop, *feature_noise])
def forward(self, x_l=None, target_l=None, x_ul=None, target_ul=None, curr_iter=None, epoch=None):
if not self.training:
return self.main_decoder(self.encoder(x_l))
# We compute the losses in the forward pass to avoid problems encountered in muti-gpu
# Forward pass the labels example
input_size = (x_l.size(2), x_l.size(3))
output_l = self.main_decoder(self.encoder(x_l))
if output_l.shape != x_l.shape:
output_l = F.interpolate(output_l, size=input_size, mode='bilinear', align_corners=True)
# Supervised loss
if self.sup_type == 'CE':
loss_sup = self.sup_loss(output_l, target_l, ignore_index=self.ignore_index, temperature=self.softmax_temp) * self.sup_loss_w
elif self.sup_type == 'FL':
loss_sup = self.sup_loss(output_l,target_l) * self.sup_loss_w
else:
loss_sup = self.sup_loss(output_l, target_l, curr_iter=curr_iter, epoch=epoch, ignore_index=self.ignore_index) * self.sup_loss_w
# If supervised mode only, return
if self.mode == 'supervised':
curr_losses = {'loss_sup': loss_sup}
outputs = {'sup_pred': output_l}
total_loss = loss_sup
return total_loss, curr_losses, outputs
# If semi supervised mode
elif self.mode == 'semi':
# Get main prediction
x_ul = self.encoder(x_ul)
output_ul = self.main_decoder(x_ul)
# Get auxiliary predictions
outputs_ul = [aux_decoder(x_ul, output_ul.detach()) for aux_decoder in self.aux_decoders]
targets = F.softmax(output_ul.detach(), dim=1)
# Compute unsupervised loss
loss_unsup = sum([self.unsuper_loss(inputs=u, targets=targets, \
conf_mask=self.confidence_masking, threshold=self.confidence_th, use_softmax=False)
for u in outputs_ul])
loss_unsup = (loss_unsup / len(outputs_ul))
curr_losses = {'loss_sup': loss_sup}
if output_ul.shape != x_l.shape:
output_ul = F.interpolate(output_ul, size=input_size, mode='bilinear', align_corners=True)
outputs = {'sup_pred': output_l, 'unsup_pred': output_ul}
# Compute the unsupervised loss
weight_u = self.unsup_loss_w(epoch=epoch, curr_iter=curr_iter)
loss_unsup = loss_unsup * weight_u
curr_losses['loss_unsup'] = loss_unsup
total_loss = loss_unsup + loss_sup
# If case we're using weak lables, add the weak loss term with a weight (self.weakly_loss_w)
if self.use_weak_lables:
weight_w = (weight_u / self.unsup_loss_w.final_w) * self.weakly_loss_w
loss_weakly = sum([CE_loss(outp, target_ul, ignore_index=self.ignore_index) for outp in outputs_ul]) / len(outputs_ul)
loss_weakly = loss_weakly * weight_w
curr_losses['loss_weakly'] = loss_weakly
total_loss += loss_weakly
# Pair-wise loss
if self.aux_constraint:
pair_wise = pair_wise_loss(outputs_ul) * self.aux_constraint_w
curr_losses['pair_wise'] = pair_wise
loss_unsup += pair_wise
return total_loss, curr_losses, outputs
def get_backbone_params(self):
return self.encoder.get_backbone_params()
def get_other_params(self):
if self.mode == 'semi':
return chain(self.encoder.get_module_params(), self.main_decoder.parameters(),
self.aux_decoders.parameters())
return chain(self.encoder.get_module_params(), self.main_decoder.parameters())