-
Notifications
You must be signed in to change notification settings - Fork 4
/
t2t_trainer.py
300 lines (265 loc) · 16.2 KB
/
t2t_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
''''
this is a customize trainer for T5-like mode training,
in this class, the training loop is customized for more flexibility and control over
'''
import math
import os
import sys
import warnings
import tensorflow as tf
from tqdm import tqdm
from sklearn.metrics import accuracy_score, classification_report
import numpy as np
from keras import backend as K
from ttt.utils import add_filehandler_for_logger, get_existing_cks
from tensorboardX import SummaryWriter
# for translation evaluation from: https://github.com/mjpost/sacrebleu
# which is also used in the original T5 paper
import sacrebleu
from .utils import write_args_enhance, save_ck, dictionize_t2t_dataset, set_seed
class T2TTrainer():
def __init__(self, args, logger):
self.eval_on = args.eval_on
assert self.eval_on in ["acc",
"bleu"], "now t2t training only supports --eval_on acc, bleu, only works when --do_eval=True"
# self.best = -np.Inf
self.patience = args.patience
self.wait = 0
self.logger = logger
self.args = args
self.use_tb = self.args.__dict__.get('use_tb', False)
self._tb_writer = None
if self.use_tb:
self._tb_writer = SummaryWriter(log_dir=self.args.__dict__.get('output_folder', "runs"))
self.scheduler = args.scheduler
if "learning_rate" in self.args.__dict__:
self.lr_to_reach = args.learning_rate
else:
self.lr_to_reach = args.lr
self.args.best = np.Inf if self.args.eval_on == "loss" or self.args.eval_on == "perplexity" else - np.Inf
self.best = self.args.best
def train(self, model, strategy, tokenizer, inputs=None, train_dataset=None, eval_dataset=None, evaluate_fn=None, verbose=False):
if inputs is None:
assert train_dataset is not None, "you have to pass either inputs or train_dataset"
else:
warnings.warn(
"Passing `inputs` as a keyword argument is deprecated. Use train_dataset and eval_dataset instead.",
FutureWarning,
)
if isinstance(inputs, tuple):
inputs = dictionize_t2t_dataset(*inputs)
if inputs is not None:
x_train, y_train = inputs["x_train"], inputs["y_train"]
num_train_examples = len(inputs["y_train"]["target_input_ids"])
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
else:
if hasattr(train_dataset, "num_examples"):
num_train_examples = train_dataset.num_examples
else:
num_train_examples = tf.data.experimental.cardinality(train_dataset).numpy()
self.logger.info(f"set random seed for everything with {self.args.seed}")
set_seed(self.args.seed)
global_batch_size = self.args.per_device_train_batch_size * strategy.num_replicas_in_sync
train_dataset = train_dataset.shuffle(buffer_size=self.args.seed).batch(global_batch_size)
train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
# THERE WILL BE exceptions when switching to distributed_dataset when running on tpus if
# val_dist_dataset = strategy.experimental_distribute_dataset(eval_dataset)
train_length = math.ceil(num_train_examples / global_batch_size)
self.steps_per_epoch = train_length
if inputs is not None:
if self.args.do_eval:
assert "x_eval" in inputs and "y_eval" in inputs, "do_eval=True, and no validation data is found"
x_val, y_val = inputs["x_eval"], inputs["y_eval"]
eval_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
eval_dataset = eval_dataset.batch(self.args.eval_batch_size)
eval_steps = math.ceil(
len(inputs["y_eval"]["target_input_ids"]) / (self.args.eval_batch_size))
else:
if self.args.do_eval:
if hasattr(eval_dataset, "num_examples"):
eval_num_examples = eval_dataset.num_examples
else:
eval_num_examples = tf.data.experimental.cardinality(eval_dataset).numpy()
eval_steps = math.ceil(eval_num_examples / (self.args.eval_batch_size))
eval_dataset = eval_dataset.batch(self.args.eval_batch_size)
if verbose:
self.logger.info(model.summary())
# these are used for non-constant lr scheduler
if "num_train_epochs" in self.args.__dict__:
self.args.num_epochs_train = self.args.num_train_epochs
if "log_and_save_steps" in self.args.__dict__:
self.args.log_steps = self.args.log_and_save_steps
self.total_steps = self.steps_per_epoch * self.args.num_epochs_train
if "warmup_steps_or_ratio" in self.args.__dict__:
if self.args.warmup_steps_or_ratio <= 1 and self.args.warmup_steps_or_ratio > 0:
self.args.warmup_steps = int(self.total_steps * self.args.warmup_steps_or_ratio)
else:
self.args.warmup_steps = self.args.warmup_steps_or_ratio
else:
self.args.warmup_steps = int(self.total_steps * self.args.warmup_ratio)
self.warmup_steps = self.args.warmup_steps
write_args_enhance(self.args, logger=self.logger)
with strategy.scope():
optimizer = tf.keras.optimizers.Adam(lr=self.args.lr if self.scheduler.startswith("constant") else 0.0)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
)
def compute_loss(labels, predictions):
per_example_loss = loss_fn(labels, predictions)
return tf.nn.compute_average_loss(per_example_loss, global_batch_size=global_batch_size)
def train_step(x_train, y_train):
with tf.GradientTape() as tape:
# here some changes has been made (compared to before commit `a07c58e` ) to fix a bug reported here: https://github.com/wangcongcong123/ttt/issues/2
# The following describes how this bug is fixed
# the compute_loss function in transformers:TFT5ForConditionalGeneration has already taken care of the loss computation (already averaged!!!!) that failed
# when switching to TPU, hence we re-compute it here using the returned logits from the model ready for backprop instead of using the internally calculated loss
outputs = model(inputs=x_train["source_input_ids"], attention_mask=x_train["source_attention_mask"],
decoder_attention_mask=x_train["target_attention_mask"],
labels=y_train["target_input_ids"], training=True, return_dict=True)
logits = outputs.logits
loss = compute_loss(tf.reshape(y_train["target_input_ids"], (-1, y_train["target_input_ids"].shape[-1])),
tf.reshape(logits, (-1, logits.shape[-1])))
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
return loss
@tf.function
def distributed_train_step(x_train, y_train):
per_replica_losses = strategy.experimental_run_v2(train_step, args=(x_train, y_train,))
return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
# evaluate
def evaluate(steps, tag="epoch"):
assert tag in ["epoch", "global_step"]
gts = []
preds = []
for x_eval, y_eval in tqdm(eval_dataset, total=eval_steps, desc="evaluating..."):
predictions = model.generate(input_ids=x_eval["source_input_ids"],
attention_mask=x_eval["source_attention_mask"],
max_length=self.args.max_tgt_length)
pred = [tokenizer.decode(ids) for ids in predictions]
gt = [tokenizer.decode(ids) for ids in y_eval["target_input_ids"]]
# labels (not -100 replaced since it is not used to calculate loss here)
preds.extend(pred)
gts.extend(gt)
if self.eval_on == "bleu":
# bleu = 0
bleu = sacrebleu.corpus_bleu(preds, [gts])
eval_score = bleu.score
else:
eval_score = accuracy_score(gts, preds)
self.logger.info(f"val_cls_report: {classification_report(gts, preds, digits=4)}")
if self.use_tb:
self._tb_writer.add_scalar(f"val_{self.eval_on}_{tag}", eval_score, steps)
self.logger.info("\n")
self.logger.info(f"*******eval at {tag} = {steps} on validation dataset*********")
self.logger.info(f"val_{self.eval_on}: {eval_score}")
if self.eval_on == "acc" or self.eval_on == "bleu":
if eval_score >= self.best:
self.wait = 0
self.best = eval_score
self.logger.info(
f"so far the best check point at {tag}={steps} based on eval_on {self.eval_on}")
# self.save_ck(model, steps, tag, best_ck=True)
save_ck(self.args, self.logger, model, tokenizer=tokenizer, steps=steps,
tag=tag, best_ck=False, from_tf=True)
else:
self.wait += 1
else:
raise ValueError("not support yet")
self.logger.info(f"best so far({self.eval_on}): {self.best}")
self.logger.info(f"early stop count: {self.wait}/{self.patience}")
# self.save_ck(model, steps, tag)
save_ck(self.args, self.logger, model, tokenizer=tokenizer, steps=steps,
tag=tag, best_ck=False, from_tf=True)
if self.wait >= self.patience:
self.logger.info("run out of patience, early stop")
if self.use_tb:
self._tb_writer.close()
sys.exit(0)
def update_lr(global_step):
# already tested on tpu, works fine
# global_step is dynamically passed here
if global_step <= self.warmup_steps:
if self.scheduler == "warmuplinear" or self.scheduler == "warmupcostant":
inc = self.lr_to_reach / self.warmup_steps
K.set_value(optimizer.learning_rate, K.eval(optimizer.lr) + inc)
else:
if self.scheduler == "warmuplinear" or self.scheduler == "constantlinear":
dec = self.lr_to_reach / (self.total_steps - self.warmup_steps)
K.set_value(optimizer.learning_rate, K.eval(optimizer.lr) - dec)
# for "constant" scheduler, nothing to do here
global_step = 0
early_exit = False
interval_loss = 0.0
interval_count = 0
for epoch in tqdm(range(self.args.num_epochs_train), desc="epochs"):
self.logger.info(f"start training at epoch = {epoch}")
self.logger.info(f"global train batch size = {global_batch_size}")
self.logger.info(f"using learning rate scheduler: {self.scheduler}")
self.logger.info(
f"num_train_examples: {num_train_examples}, total_steps: {self.total_steps}, steps_per_epoch: {self.steps_per_epoch}")
if self.scheduler != "constant":
self.logger.info(f"warmup_steps:{self.warmup_steps}")
pbar = tqdm(enumerate(train_dist_dataset), total=train_length)
for step, (x_train, y_train) in pbar:
# learning rate scheduler
update_lr(global_step)
loss = distributed_train_step(x_train, y_train)
interval_loss += loss.numpy()
interval_count += 1
global_step += 1
pbar.set_description(f"training - epoch {epoch + 1}/{self.args.num_epochs_train} iter {step}: train loss {loss.numpy():.5f}. lr {optimizer.lr.numpy():e}")
if self.args.log_steps != -1 and global_step % self.args.log_steps == 0:
if self.use_tb:
self._tb_writer.add_scalar("train_loss_global_step", interval_loss / interval_count,
global_step)
self._tb_writer.add_scalar("train_lr_global_step", optimizer.lr.numpy(), global_step)
if self.args.do_eval:
if evaluate_fn is not None and eval_dataset is not None:
eval_dict = evaluate_fn(self.args, self.logger, model, tokenizer, eval_dataset, steps=global_step, tag="global_step", eval_length=eval_steps)
if self._tb_writer:
if "eval_scores" in eval_dict:
for key, value in eval_dict["eval_scores"].items():
self._tb_writer.add_scalar(f"eval_{key}_global_step", value, global_step)
if "is_early_stop" in eval_dict and eval_dict["is_early_stop"]:
self.logger.info(f"run out of patience at global step = {global_step}, early stop")
if self._tb_writer:
self._tb_writer.close()
early_exit = True
break
else:
evaluate(global_step, tag="global_step")
self.logger.info(f"train loss at global_step {global_step}: {interval_loss / interval_count}")
interval_loss = 0.0
interval_count = 0
if early_exit:
break
train_loss = interval_loss / interval_count
interval_loss = 0.0
interval_count = 0
if self.args.log_steps == -1:
if self.args.do_eval:
if evaluate_fn is not None and eval_dataset is not None:
eval_dict = evaluate_fn(self.args, self.logger, model, tokenizer, eval_dataset, steps=epoch + 1, tag="epoch", eval_length=eval_steps)
if self._tb_writer:
if "eval_scores" in eval_dict:
for key, value in eval_dict["eval_scores"].items():
self._tb_writer.add_scalar(f"eval_{key}_epoch", value, epoch + 1)
if "is_early_stop" in eval_dict and eval_dict["is_early_stop"]:
self.logger.info(f"run out of patience at epoch = {epoch + 1}, early stop")
if self._tb_writer:
self._tb_writer.close()
break
else:
evaluate(epoch + 1, tag="epoch")
if self.use_tb:
self._tb_writer.add_scalar("train_loss_epoch", train_loss,
global_step)
self._tb_writer.add_scalar("train_lr_epoch", optimizer.lr.numpy(), global_step)
self.logger.info(f"train loss at end of epoch {epoch + 1}: {train_loss}")
if not self.args.do_eval:
# if do not do evaluate, the checkpoint at the end of epoch needs to be saved
# self.save_ck(model, epoch + 1, tag="epoch")
save_ck(self.args, self.logger, model, tokenizer=tokenizer, steps=epoch + 1,
tag="epoch", best_ck=False, from_tf=True)
if self.use_tb:
self._tb_writer.close()