-
Notifications
You must be signed in to change notification settings - Fork 0
/
training_loop.py
301 lines (269 loc) · 13.6 KB
/
training_loop.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
import copy
import os
import torch
import torch.nn as nn
import random
import numpy as np
import json
import consts
import models
from data_loader import ImagenetteDataset, Rescale, RandomCrop, ToTensor
from torch.utils.data import DataLoader
from torchvision import transforms
from augment import augment
import argparse
import wandb
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
parser = argparse.ArgumentParser(
description='Process flags for unsupervised pre-training with MoCo.')
parser.add_argument('--pre_training_debug',
type=bool,
default=False,
required=False,
help="Whether or not to run pre-training in debug mode. In debug mode, the model learns over "
"a subset of the original dataset.")
parser.add_argument('--seed',
type=int,
default=1,
required=False,
help='The seed used for random sampling.')
parser.add_argument('--pretraining_epochs',
type=int,
default=2,
required=False,
help='The number of epochs used during pre-training.')
parser.add_argument('--pretraining_learning_rate',
type=float,
default=1e-2,
help='The initial learning rate used during pre-training.')
parser.add_argument('--pretraining_momentum',
type=float,
default=0.9,
help='The momentum value used to transfer weights between encoders during pre-training.')
parser.add_argument('--pretraining_batch_size',
type=int,
default=64,
help='The mini-batch size used during pre-training with MoCo. Keys and queries are generated '
'from the entries of a mini-batch.')
parser.add_argument('--mul_for_num_of_keys',
type=int,
default=2,
help="The number of keys is a multiple of batch size times this value.")
parser.add_argument('--encoder_output_dim',
type=int,
default=64,
help='The encoder\'s output dim')
parser.add_argument('--temperature',
type=float,
default=0.07,
help='The temperature used in the Contrastive loss')
parser.add_argument('--m',
type=float,
default=0.999,
help='The momentum used to update the key\'s encoder parameters')
parser.add_argument('--use_imagenet_pretrained_encoder',
type=bool,
default=True,
help="Whether or not the MoCo encoder should be previously pre-trained on Imagenet.")
# Sample run from server command line:
# srun python3 training_loop.py --pre_training_debug False --seed 2 --pretraining_epochs 100 \
# --pretraining_learning_rate 0.001 --number_of_keys 64 --pretraining_batch_size 256
# Train function
def pre_train(encoder,
m_encoder,
device,
train_loader,
epochs=3,
lr=0.001,
pretraining_momentum=0.9,
t=0.07,
m=0.999,
number_of_keys=3):
"""
:param encoder: An instance of `models.Encoder` representing the MoCo encoder. This model is composed of a base
layer such as ResNet, followed by one or two fully connected layers.
:param m_encoder: A copy of the encoder model utilized as part of MoCo pre-training.
:param device: A tf.device.Device instance on which MoCo pre-training occurs.
:param epochs: The number of iterations over the entire dataset during pre-training.
:param lr: The learning rate of the primary encoder model.
:param pretraining_momentum: The momentum for the optimizer used while training MoCo's encoder during pre-training.
:param t: The pre-training temperature used as part of contrastive loss in MoCo's pre-training.
:param m: The momentum for weight transfer from the `encoder` to the `m_encoder`. The lower `m`, the higher the
weight transfer from `encoder` to `m_encoder`.
:param number_of_keys: The number of keys used as part of the contrastive loss objective during pre-training. The
higher the number of keys, the more images the encoder must distinguish between when matching a query to its
corresponding key.
:return: A fully pre-trained MoCo encoder.
"""
wandb.watch(encoder)
queue_dict = [] # Will add in FIFO order keys of mini batches
for i in range(number_of_keys):
# 2048 is the output dimension of Resnet50
queue_dict.append(torch.rand(encoder.final_num_of_features))
loss_fn = nn.BCEWithLogitsLoss()
# The optimization is done only to the encoder weights and not to the momentum encoder
optimizer = torch.optim.SGD(encoder.parameters(), lr=lr, momentum=pretraining_momentum)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[120, 160], gamma=0.1)
for epoch in range(epochs):
print(f'start epoch {epoch}')
batch_index = 0
epoch_loss = []
epoch_acc = []
for minibatch_index, minibatch in enumerate(train_loader):
minibatch = minibatch.double()
optimizer.zero_grad()
x_q = augment(images=minibatch,
jitter_prob=0.8,
horizontal_flip_prob=0.5,
grayscale_conversion_prob=0.2,
gaussian_blur_prob=0.5)
x_k = augment(images=minibatch,
jitter_prob=0.8,
horizontal_flip_prob=0.5,
grayscale_conversion_prob=0.2,
gaussian_blur_prob=0.5)
q = encoder.forward(x_q, device=device) # Queries have shape [N, C] where C = 2048 * 1 * 128
with torch.no_grad():
k = m_encoder.forward(x_k, device=device) # Keys have shape [N, C]
k = k.detach() # No gradients to keys
# Positive logits have shape [N, 1]
q = torch.flatten(q, start_dim=1)
q = torch.unsqueeze(q, dim=1)
k = torch.flatten(k, start_dim=1)
k = torch.unsqueeze(k, dim=2)
l_pos = (q @ k).squeeze(dim=2)
queue_view = torch.concat([queue_dict[i].unsqueeze(dim=1).to(device) for i in range(len(queue_dict))], 1)
queue_view.detach()
q = torch.squeeze(q, dim=1).to(device)
l_neg = q @ queue_view.double().to(device) # Negative logits are a tensor of shape [N, K]
logits = torch.concat((l_pos, l_neg), dim=1) # Lots have shape [N, K + 1]
labels = torch.zeros(l_pos.shape[0]).to(device)
one_hot_labels = torch.nn.functional.one_hot(labels.to(torch.int64), num_classes=logits.shape[1])
loss = loss_fn(logits / t, one_hot_labels.double().to(device))
epoch_loss.append(loss)
preds = torch.argmax(input=logits, dim=1)
accuracy = (torch.sum(preds == labels) / logits.shape[0])
wandb.log({consts.MINI_BATCH_LOSS: loss,
consts.MINI_BATCH_ACCURACY: accuracy})
epoch_acc.append(accuracy)
if batch_index % 5 == 0:
print(f'\t{consts.MINI_BATCH_INDEX} = {batch_index},\t'
f'{consts.MINI_BATCH_LOSS} = {loss},'
f'\t{consts.MINI_BATCH_ACCURACY} = {accuracy}')
# SGD update query network
loss.backward()
optimizer.step() # Update only encoder parmas and not m_encoder params
scheduler.step()
with torch.no_grad(): # no gradient to keys
# Momentum update key network
m_encoder_state_dict = m_encoder.state_dict()
encoder_state_dict = encoder.state_dict()
for m_name, m_param in m_encoder_state_dict.items():
# Transform the parameter as required.
transformed_param = m * m_param + (1 - m) * encoder_state_dict[m_name]
# Update the parameter.
m_encoder_state_dict[m_name].copy_(transformed_param)
for i in range(k.shape[0]):
# Enqueue queue and queue dict
queue_dict.append(k[i].squeeze(dim=1))
# Dequeue the oldest mini batch
queue_dict.pop(0)
batch_index += 1
epoch_loss = sum(epoch_loss) / len(epoch_loss)
epoch_acc = sum(epoch_acc) / len(epoch_acc)
print(f'{consts.EPOCH_INDEX} #:{epoch},\t'
f'{consts.EPOCH_LOSS}: {epoch_loss},\t'
f'{consts.EPOCH_ACCURACY}: {epoch_acc}')
wandb.log({consts.EPOCH_LOSS: epoch_loss,
consts.EPOCH_ACCURACY: epoch_acc})
print('Finished pre-training!')
return encoder
def set_seed(seed=42):
"""
:param seed: The integer seed used for random number generation.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
if __name__ == '__main__':
args = parser.parse_args()
config_args = {consts.PRETRAINING_EPOCHS: args.pretraining_epochs,
consts.PRETRAINING_LEARNING_RATE: args.pretraining_learning_rate,
consts.PRETRAINING_MOMENTUM: args.pretraining_momentum,
consts.PRETRAINING_BATCH_SIZE: args.pretraining_batch_size,
consts.MUL_FOR_NUM_KEYS: args.mul_for_num_of_keys,
consts.ENCODER_OUTPUT_DIM: args.encoder_output_dim,
consts.TEMPERATURE: args.temperature,
consts.PRETRAINING_M: args.m,
consts.SEED: args.seed}
print(f'config_args: {config_args}')
wandb.init(project="semi_supervised_cv", entity="zbamberger", config=config_args)
# wandb.init(project="semi_supervised_cv", entity="noambenmoshe", config=config_args)
config = wandb.config
number_of_keys = config[consts.MUL_FOR_NUM_KEYS] * config[consts.PRETRAINING_BATCH_SIZE]
assert number_of_keys % config[consts.PRETRAINING_BATCH_SIZE] == 0,\
f'{number_of_keys} is not divisible by {config[consts.PRETRAINING_BATCH_SIZE]}.\n'
print(config)
set_seed(args.seed)
imagenette_dataset = ImagenetteDataset(csv_file=consts.csv_filename,
root_dir=consts.image_dir,
transform=transforms.Compose([
Rescale(256),
RandomCrop(224),
ToTensor(),
normalize
]),
debug=args.pre_training_debug)
train_loader = DataLoader(imagenette_dataset,
batch_size=config[consts.PRETRAINING_BATCH_SIZE],
shuffle=True)
encoder = models.Encoder(args.encoder_output_dim,
pretrained=args.use_imagenet_pretrained_encoder).double()
m_endcoder = models.Encoder(args.encoder_output_dim,
pretrained=args.use_imagenet_pretrained_encoder).double()
# Train model
device = torch.device(consts.CUDA if torch.cuda.is_available() else consts.CPU)
encoder.to(device)
m_endcoder.to(device)
# Initialize parameters in both encoders to be the same.
for param, m_param in zip(encoder.parameters(), m_endcoder.parameters()):
m_param.data.copy_(param.data)
encoder = pre_train(encoder,
m_endcoder,
device=device,
train_loader=train_loader,
epochs=config[consts.PRETRAINING_EPOCHS],
lr=config[consts.PRETRAINING_LEARNING_RATE],
pretraining_momentum=config[consts.PRETRAINING_MOMENTUM],
number_of_keys=number_of_keys,
t=config[consts.TEMPERATURE],
m=config[consts.PRETRAINING_M])
# Freeze the encoder
encoder.requires_grad_(False)
# Save model state
config_dict = {}
for k in config.keys():
config_dict[k] = config[k]
if not os.path.exists(consts.SAVED_ENCODERS_DIR):
os.mkdir(consts.SAVED_ENCODERS_DIR)
main_name = "_".join(["number_of_keys",
number_of_keys,
consts.RESNET_50,
str(config[consts.PRETRAINING_EPOCHS]),
consts.EPOCHS,
str(config[consts.PRETRAINING_LEARNING_RATE]).replace(".", "_"),
consts.PRETRAINING_LEARNING_RATE,
str(config[consts.PRETRAINING_BATCH_SIZE]),
consts.PRETRAINING_BATCH_SIZE,
str(config[consts.PRETRAINING_M]).replace(".", "_"),
consts.PRETRAINING_M])
file_name = main_name + consts.MODEL_FILE_ENCODING
torch.save(encoder.state_dict(), os.path.join(consts.SAVED_ENCODERS_DIR, file_name))
config_path = os.path.join(consts.SAVED_ENCODERS_DIR, main_name + consts.MODEL_CONFIGURATION_FILE_ENCODING)
with open(config_path, 'w') as fp:
json.dump(config_dict, fp, indent=4)
print(f'Saved pre-trained model to {config_path}')