-
Notifications
You must be signed in to change notification settings - Fork 23
/
trainers.py
316 lines (256 loc) · 13.8 KB
/
trainers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
import warnings
from functools import partial
import numpy as np
import torch
from scipy import stats
from tqdm import tqdm
from mlutils import measures
from mlutils.measures import *
from mlutils.training import early_stopping, MultipleObjectiveTracker, eval_state, cycle_datasets, Exhauster, LongCycler
from ..utility.nn_helpers import set_random_seed
from ..utility import metrics
from ..utility.metrics import corr_stop, poisson_stop
def early_stop_trainer(model, seed, stop_function='corr_stop',
loss_function='PoissonLoss', epoch=0, interval=1, patience=10, max_iter=75,
maximize=True, tolerance=0.001, device='cuda', restore_best=True,
lr_init=0.005, lr_decay_factor=0.3, min_lr=0.0001, optim_batch_step=True,
verbose=True, lr_decay_steps=3, dataloaders=None, **kwargs):
""""
Args:
model: PyTorch nn module
seed: random seed
trainer_config:
lr_schedule: list or ndarray that contains lr and lr decrements after early stopping kicks in
stop_function: stop condition in early stopping, has to be one string of the following:
'corr_stop'
'gamma stop'
'exp_stop'
'poisson_stop'
loss_function: has to be a string that gets evaluated with eval()
Loss functions that are built in at mlutils that can
be selected in the trainer config are:
'PoissonLoss'
'GammaLoss'
device: Device that the model resides on. Expects arguments such as torch.device('')
Examples: 'cpu', 'cuda:2' (0-indexed gpu)
Pytorch Dataloaders are expanded into dictionary of individual loaders
train: PyTorch DtaLoader -- training data
val: validation data loader
test: test data loader -- not used during training
Returns:
score: performance score of the model
output: user specified validation object based on the 'stop function'
model_state: the full state_dict() of the trained model
"""
train = dataloaders["train"] if dataloaders else kwargs["train"]
val = dataloaders["val"] if dataloaders else kwargs["val"]
test = dataloaders["test"] if dataloaders else kwargs["test"]
# --- begin of helper function definitions
def model_predictions(loader, model, data_key):
"""
computes model predictions for a given dataloader and a model
Returns:
target: ground truth, i.e. neuronal firing rates of the neurons
output: responses as predicted by the network
"""
target, output = torch.empty(0), torch.empty(0)
for images, responses in loader[data_key]:
output = torch.cat((output, (model(images.to(device), data_key=data_key).detach().cpu())), dim=0)
target = torch.cat((target, responses.detach().cpu()), dim=0)
return target.numpy(), output.numpy()
# all early stopping conditions
def corr_stop(model, loader=None, avg=True):
"""
Returns either the average correlation of all neurons or the the correlations per neuron.
Gets called by early stopping and the model performance evaluation
"""
loader = val if loader is None else loader
n_neurons, correlations_sum = 0, 0
if not avg:
all_correlations = np.array([])
for data_key in loader:
with eval_state(model):
target, output = model_predictions(loader, model, data_key)
ret = corr(target, output, axis=0)
if np.any(np.isnan(ret)):
warnings.warn('{}% NaNs '.format(np.isnan(ret).mean() * 100))
ret[np.isnan(ret)] = 0
if not avg:
all_correlations = np.append(all_correlations, ret)
else:
n_neurons += output.shape[1]
correlations_sum += ret.sum()
corr_ret = correlations_sum / n_neurons if avg else all_correlations
return corr_ret
def gamma_stop(model):
with eval_state(model):
target, output = model_predictions(val, model)
ret = -stats.gamma.logpdf(target + 1e-7, output + 0.5).mean(axis=1) / np.log(2)
if np.any(np.isnan(ret)):
warnings.warn(' {}% NaNs '.format(np.isnan(ret).mean() * 100))
ret[np.isnan(ret)] = 0
return ret.mean()
def exp_stop(model, bias=1e-12, target_bias=1e-7):
with eval_state(model):
target, output = model_predictions(val, model)
target = target + target_bias
output = output + bias
ret = (target / output + np.log(output)).mean(axis=1) / np.log(2)
if np.any(np.isnan(ret)):
warnings.warn(' {}% NaNs '.format(np.isnan(ret).mean() * 100))
ret[np.isnan(ret)] = 0
# -- average if requested
return ret.mean()
def poisson_stop(model, loader=None, avg=False):
poisson_losses = np.array([])
loader = val if loader is None else loader
n_neurons = 0
for data_key in loader:
with eval_state(model):
target, output = model_predictions(loader, model, data_key)
ret = output - target * np.log(output + 1e-12)
if np.any(np.isnan(ret)):
warnings.warn(' {}% NaNs '.format(np.isnan(ret).mean() * 100))
poisson_losses = np.append(poisson_losses, np.nansum(ret, 0))
n_neurons += output.shape[1]
return poisson_losses.sum()/n_neurons if avg else poisson_losses.sum()
def readout_regularizer_stop(model):
ret = 0
with eval_state(model):
for data_key in val:
ret += model.readout.regularizer(data_key).detach().cpu().numpy()
return ret
def core_regularizer_stop(model):
with eval_state(model):
if model.core.regularizer():
return model.core.regularizer().detach().cpu().numpy()
else:
return 0
def full_objective(model, data_key, inputs, targets, **kwargs):
"""
Computes the training loss for the model and prespecified criterion.
Default: PoissonLoss, summed over Neurons and Batches, scaled by dataset
size and batch size to account for batch noise.
Args:
inputs: i.e. images
targets: neuronal responses that the model should predict
Returns: training loss summed over all neurons. Summed over batches and Neurons
"""
m = len(train[data_key].dataset)
k = inputs.shape[0]
return np.sqrt(m / k) * criterion(model(inputs.to(device), data_key=data_key, **kwargs), targets.to(device)).sum() \
+ model.regularizer(data_key)
def run(model, full_objective, optimizer, scheduler, stop_closure, train_loader,
epoch, interval, patience, max_iter, maximize, tolerance,
restore_best, tracker, optim_step_count, lr_decay_steps):
for epoch, val_obj in early_stopping(model, stop_closure,
interval=interval, patience=patience,
start=epoch, max_iter=max_iter, maximize=maximize,
tolerance=tolerance, restore_best=restore_best,
tracker=tracker, scheduler=scheduler, lr_decay_steps=lr_decay_steps):
optimizer.zero_grad()
# reports the entry of the current epoch for all tracked objectives
if verbose:
for key in tracker.log.keys():
print(key, tracker.log[key][-1])
# Beginning of main training loop
for batch_no, (data_key, data) in tqdm(enumerate(LongCycler(train_loader)),
desc='Epoch {}'.format(epoch)):
loss = full_objective(model, data_key, *data)
if (batch_no+1) % optim_step_count == 0:
optimizer.step()
optimizer.zero_grad()
loss.backward()
# End of training
return model, epoch
# model setup
set_random_seed(seed)
model.to(device)
model.train()
# current criterium is supposed to be poisson loss. Only for that loss, the additional arguments are defined
criterion = eval(loss_function)(per_neuron=True, avg=False)
# get stopping criterion from helper functions based on keyword
stop_closure = eval(stop_function)
tracker = MultipleObjectiveTracker(correlation=partial(corr_stop, model),
poisson_loss=partial(poisson_stop, model),
poisson_loss_val=partial(poisson_stop, model, val),
readout_l1=partial(readout_regularizer_stop, model),
core_regularizer=partial(core_regularizer_stop, model))
trainable_params = [p for p in list(model.parameters()) if p.requires_grad]
optimizer = torch.optim.Adam(trainable_params, lr=lr_init)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
mode='max' if maximize else 'min',
factor=lr_decay_factor,
patience=patience,
threshold=tolerance,
min_lr=min_lr,
verbose=verbose,
threshold_mode='abs',
)
optim_step_count = len(train.keys()) if optim_batch_step else 1
model, epoch = run(model=model,
full_objective=full_objective,
optimizer=optimizer,
scheduler=scheduler,
stop_closure=stop_closure,
train_loader=train,
epoch=epoch,
interval=interval,
patience=patience,
max_iter=max_iter,
lr_decay_steps=lr_decay_steps,
maximize=maximize,
tolerance=tolerance,
restore_best=restore_best,
tracker=tracker,
optim_step_count=optim_step_count)
model.eval()
tracker.finalize()
# compute average test correlations as the score
avg_corr = corr_stop(model, test, avg=True)
# return the whole tracker output as a dict
output = {k: v for k, v in tracker.log.items()}
return avg_corr, output, model.state_dict()
def standard_early_stop_trainer(model, trainloaders, valloaders, testloaders,
loss_function='PoissonLoss', stop_function='corr_stop',
maximize=True, init_lr=0.005, device='cuda'):
def full_objective(model, data_key, inputs, targets):
m = len(trainloaders[data_key].dataset)
k = inputs.shape[0]
# return np.sqrt(m / k) * criterion(model(inputs, data_key), targets).sum() + model.regularizer(data_key)
return criterion(model(inputs, data_key), targets) + model.regularizer(data_key)
##### This is where everything happens ################################################################################
model.train()
criterion = getattr(measures, loss_function)(per_neuron=False, avg=True)
stop_closure = partial(getattr(metrics, stop_function), model, valloaders, device=device)
n_iterations = len(LongCycler(trainloaders))
print("Training with learning rate {}".format(init_lr))
optimizer = torch.optim.Adam(model.parameters(), lr=init_lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max' if maximize else 'min', factor=0.3,
patience=5, threshold=0.001, min_lr=0.0001, verbose=True, threshold_mode='abs')
# set the number of iterations over which you would like to accummulate gradients
optim_step_count = len(trainloaders.keys())
# define some trackers
tracker = MultipleObjectiveTracker(correlation=partial(corr_stop, model, valloaders, device=device),
poisson_loss=partial(poisson_stop, model, valloaders, device=device))
# train over epochs
for epoch, val_obj in early_stopping(model, stop_closure, interval=1, patience=5,
start=0, max_iter=100, maximize=True,
tolerance=1e-6, restore_best=True, tracker=tracker,
scheduler=scheduler, lr_decay_steps=3):
optimizer.zero_grad()
# train over batches
for batch_no, (data_key, data) in tqdm(enumerate(LongCycler(trainloaders)), total=n_iterations, desc="Epoch {}".format(epoch), disable=False):
loss = full_objective(model, data_key, *data)
loss.backward()
if (batch_no+1) % optim_step_count == 0:
optimizer.step()
optimizer.zero_grad()
print(loss.item())
########################################################################################################################
# Compute avg validation and test correlation
avg_val_corr = corr_stop(model, valloaders, device=device)
avg_test_corr = corr_stop(model, testloaders, device=device)
# return the whole tracker output as a dict
output = {k: v for k, v in tracker.log.items()}
return avg_test_corr, output, model.state_dict()