This repository has been archived by the owner on Jul 7, 2023. It is now read-only.
/
trainer_utils.py
1288 lines (1116 loc) · 50.8 KB
/
trainer_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for trainer binary."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import operator
import os
import sys
# Dependency imports
import numpy as np
import six
# pylint: disable=redefined-builtin
from six.moves import input
from six.moves import xrange
from six.moves import zip
# pylint: enable=redefined-builtin
from tensor2tensor.data_generators import problem_hparams
from tensor2tensor.models import models # pylint: disable=unused-import
from tensor2tensor.utils import data_reader
from tensor2tensor.utils import expert_utils as eu
from tensor2tensor.utils import metrics
from tensor2tensor.utils import registry
import tensorflow as tf
from tensorflow.contrib.learn.python.learn import learn_runner
from tensorflow.python.ops import init_ops
# Number of samples to draw for an image input (in such cases as captioning)
IMAGE_DECODE_LENGTH = 100
flags = tf.flags
FLAGS = flags.FLAGS
flags.DEFINE_bool("registry_help", False,
"If True, logs the contents of the registry and exits.")
flags.DEFINE_string("output_dir", "", "Base output directory for run.")
flags.DEFINE_string("model", "", "Which model to use.")
flags.DEFINE_string("hparams_set", "", "Which parameters to use.")
flags.DEFINE_string("hparams_range", "", "Parameters range.")
flags.DEFINE_string(
"hparams", "",
"""A comma-separated list of `name=value` hyperparameter values. This flag
is used to override hyperparameter settings either when manually selecting
hyperparameters or when using Vizier. If a hyperparameter setting is
specified by this flag then it must be a valid hyperparameter name for the
model.""")
flags.DEFINE_string("problems", "", "Dash separated list of problems to "
"solve.")
flags.DEFINE_string("data_dir", "/tmp/data", "Directory with training data.")
flags.DEFINE_integer("train_steps", 250000,
"The number of steps to run training for.")
flags.DEFINE_integer("eval_steps", 10, "Number of steps in evaluation.")
flags.DEFINE_integer("keep_checkpoint_max", 20,
"How many recent checkpoints to keep.")
flags.DEFINE_bool("experimental_optimize_placement", False,
"Optimize ops placement with experimental session options.")
# Distributed training flags
flags.DEFINE_string("master", "", "Address of TensorFlow master.")
flags.DEFINE_string("schedule", "local_run",
"Method of tf.contrib.learn.Experiment to run.")
flags.DEFINE_bool("daisy_chain_variables", True,
"copy variables around in a daisy chain")
flags.DEFINE_bool("sync", False, "Sync compute on PS.")
flags.DEFINE_string("worker_job", "/job:worker", "name of worker job")
flags.DEFINE_integer("worker_gpu", 1, "How many GPUs to use.")
flags.DEFINE_integer("worker_replicas", 1, "How many workers to use.")
flags.DEFINE_integer("worker_id", 0, "Which worker task are we.")
flags.DEFINE_integer("ps_gpu", 0, "How many GPUs to use per ps.")
flags.DEFINE_string("gpu_order", "", "Optional order for daisy-chaining gpus."
" e.g. \"1 3 2 4\"")
flags.DEFINE_string("ps_job", "/job:ps", "name of ps job")
flags.DEFINE_integer("ps_replicas", 0, "How many ps replicas.")
# Decode flags
flags.DEFINE_bool("decode_use_last_position_only", False,
"In inference, use last position only for speedup.")
flags.DEFINE_bool("decode_interactive", False,
"Interactive local inference mode.")
flags.DEFINE_bool("decode_endless", False, "Run decoding endlessly. Temporary.")
flags.DEFINE_bool("decode_save_images", False, "Save inference input images.")
flags.DEFINE_string("decode_from_file", None, "Path to decode file")
flags.DEFINE_string("decode_to_file", None, "Path to inference output file")
flags.DEFINE_integer("decode_shards", 1, "How many shards to decode.")
flags.DEFINE_integer("decode_problem_id", 0, "Which problem to decode.")
flags.DEFINE_integer("decode_extra_length", 50, "Added decode length.")
flags.DEFINE_integer("decode_batch_size", 32, "Batch size for decoding. "
"The decodes will be written to <filename>.decodes in"
"format result\tinput")
flags.DEFINE_integer("decode_beam_size", 4, "The beam size for beam decoding")
flags.DEFINE_float("decode_alpha", 0.6, "Alpha for length penalty")
flags.DEFINE_bool("decode_return_beams", False,
"Whether to return 1 (False) or all (True) beams. The \n "
"output file will have the format "
"<beam1>\t<beam2>..\t<input>")
def make_experiment_fn(data_dir, model_name, train_steps, eval_steps):
"""Returns experiment_fn for learn_runner. Wraps create_experiment."""
def experiment_fn(output_dir):
return create_experiment(
output_dir=output_dir,
data_dir=data_dir,
model_name=model_name,
train_steps=train_steps,
eval_steps=eval_steps)
return experiment_fn
def create_experiment(output_dir, data_dir, model_name, train_steps,
eval_steps):
hparams = create_hparams(FLAGS.hparams_set, data_dir)
estimator, input_fns = create_experiment_components(
hparams=hparams,
output_dir=output_dir,
data_dir=data_dir,
model_name=model_name)
return tf.contrib.learn.Experiment(
estimator=estimator,
train_input_fn=input_fns["train"],
eval_input_fn=input_fns["eval"],
eval_metrics=metrics.create_evaluation_metrics(FLAGS.problems.split("-")),
train_steps=train_steps,
eval_steps=eval_steps,
train_monitors=[])
def create_experiment_components(hparams, output_dir, data_dir, model_name):
"""Constructs and returns Estimator and train/eval input functions."""
tf.logging.info("Creating experiment, storing model files in %s", output_dir)
num_datashards = data_parallelism().n
train_input_fn = get_input_fn(
mode=tf.contrib.learn.ModeKeys.TRAIN,
hparams=hparams,
data_file_patterns=get_datasets_for_mode(data_dir,
tf.contrib.learn.ModeKeys.TRAIN),
num_datashards=num_datashards)
eval_input_fn = get_input_fn(
mode=tf.contrib.learn.ModeKeys.EVAL,
hparams=hparams,
data_file_patterns=get_datasets_for_mode(data_dir,
tf.contrib.learn.ModeKeys.EVAL),
num_datashards=num_datashards)
estimator = tf.contrib.learn.Estimator(
model_fn=model_builder(model_name, hparams=hparams),
model_dir=output_dir,
config=tf.contrib.learn.RunConfig(
master=FLAGS.master,
model_dir=output_dir,
session_config=session_config(),
keep_checkpoint_max=FLAGS.keep_checkpoint_max))
# Store the hparams in the estimator as well
estimator.hparams = hparams
return estimator, {"train": train_input_fn, "eval": eval_input_fn}
def log_registry():
if FLAGS.registry_help:
tf.logging.info(registry.help_string())
sys.exit(0)
def create_hparams(params_id, data_dir):
"""Returns hyperparameters, including any flag value overrides.
If the hparams FLAG is set, then it will use any values specified in
hparams to override any individually-set hyperparameter. This logic
allows tuners to override hyperparameter settings to find optimal values.
Args:
params_id: which set of parameters to choose (must be in _PARAMS above).
data_dir: the directory containing the training data.
Returns:
The hyperparameters as a tf.contrib.training.HParams object.
"""
hparams = registry.hparams(params_id)()
hparams.add_hparam("data_dir", data_dir)
# Command line flags override any of the preceding hyperparameter values.
if FLAGS.hparams:
hparams = hparams.parse(FLAGS.hparams)
# Add hparams for the problems
hparams.problems = [
problem_hparams.problem_hparams(problem, hparams)
for problem in FLAGS.problems.split("-")
]
return hparams
def run(data_dir, model, output_dir, train_steps, eval_steps, schedule):
"""Runs an Estimator locally or distributed.
This function chooses one of two paths to execute:
1. Running locally if schedule=="local_run".
3. Distributed training/evaluation otherwise.
Args:
data_dir: The directory the data can be found in.
model: The name of the model to use.
output_dir: The directory to store outputs in.
train_steps: The number of steps to run training for.
eval_steps: The number of steps to run evaluation for.
schedule: (str) The schedule to run. The value here must
be the name of one of Experiment's methods.
"""
exp_fn = make_experiment_fn(
data_dir=data_dir,
model_name=model,
train_steps=train_steps,
eval_steps=eval_steps)
if schedule == "local_run":
# Run the local demo.
run_locally(exp_fn(output_dir))
else:
# Perform distributed training/evaluation.
learn_runner.run(
experiment_fn=exp_fn, schedule=schedule, output_dir=output_dir)
def validate_flags():
if not FLAGS.model:
raise ValueError("Must specify a model with --model.")
if not FLAGS.problems:
raise ValueError("Must specify a set of problems with --problems.")
if not (FLAGS.hparams_set or FLAGS.hparams_range):
raise ValueError("Must specify either --hparams_set or --hparams_range.")
if not FLAGS.schedule:
raise ValueError("Must specify --schedule.")
if not FLAGS.output_dir:
FLAGS.output_dir = "/tmp/tensor2tensor"
tf.logging.warning("It is strongly recommended to specify --output_dir. "
"Using default output_dir=%s.", FLAGS.output_dir)
def session_config():
"""The TensorFlow Session config to use."""
graph_options = tf.GraphOptions(optimizer_options=tf.OptimizerOptions(
opt_level=tf.OptimizerOptions.L1, do_function_inlining=False))
if FLAGS.experimental_optimize_placement:
rewrite_options = tf.RewriterConfig(optimize_tensor_layout=True)
rewrite_options.optimizers.append("pruning")
rewrite_options.optimizers.append("constfold")
rewrite_options.optimizers.append("layout")
graph_options = tf.GraphOptions(
rewrite_options=rewrite_options, infer_shapes=True)
config = tf.ConfigProto(
allow_soft_placement=True, graph_options=graph_options)
return config
def model_builder(model, hparams):
"""Returns a function to build the model.
Args:
model: The name of the model to use.
hparams: The hyperparameters.
Returns:
A function to build the model's graph. This function is called by
the Estimator object to construct the graph.
"""
def initializer():
if hparams.initializer == "orthogonal":
return tf.orthogonal_initializer(gain=hparams.initializer_gain)
elif hparams.initializer == "uniform":
max_val = 0.1 * hparams.initializer_gain
return tf.random_uniform_initializer(-max_val, max_val)
elif hparams.initializer == "normal_unit_scaling":
return init_ops.variance_scaling_initializer(
hparams.initializer_gain, mode="fan_avg", distribution="normal")
elif hparams.initializer == "uniform_unit_scaling":
return init_ops.variance_scaling_initializer(
hparams.initializer_gain, mode="fan_avg", distribution="uniform")
else:
raise ValueError("Unrecognized initializer: %s" % hparams.initializer)
def learning_rate_decay():
"""Inverse-decay learning rate until warmup_steps, then decay."""
warmup_steps = tf.to_float(
hparams.learning_rate_warmup_steps * FLAGS.worker_replicas)
step = tf.to_float(tf.contrib.framework.get_global_step())
if hparams.learning_rate_decay_scheme == "noam":
return 5000.0 * hparams.hidden_size**-0.5 * tf.minimum(
(step + 1) * warmup_steps**-1.5, (step + 1)**-0.5)
elif hparams.learning_rate_decay_scheme == "exp100k":
return 0.94**(step // 100000)
inv_base = tf.exp(tf.log(0.01) / warmup_steps)
inv_decay = inv_base**(warmup_steps - step)
if hparams.learning_rate_decay_scheme == "sqrt":
decay = _sqrt_decay(step - warmup_steps)
elif hparams.learning_rate_decay_scheme == "exp10k":
decay = _exp_decay_after(step - warmup_steps, 0.9995,
FLAGS.train_steps - warmup_steps - 10000)
elif hparams.learning_rate_decay_scheme == "exp50k":
decay = _exp_decay_after(step - warmup_steps, 0.99995,
FLAGS.train_steps - warmup_steps - 50000)
elif hparams.learning_rate_decay_scheme == "exp500k":
decay = _exp_decay_after(step - warmup_steps, 0.9999955,
FLAGS.train_steps - warmup_steps - 500000)
elif hparams.learning_rate_decay_scheme == "none":
decay = tf.constant(1.0)
else:
raise ValueError("Unrecognized learning rate decay scheme: %s" %
hparams.learning_rate_decay_scheme)
return tf.cond(
step < warmup_steps,
lambda: inv_decay,
lambda: decay,
name="learning_rate_decay_warump_cond")
def model_fn(features, targets, mode):
"""Creates the prediction, loss, and train ops.
Args:
features: A dictionary of tensors keyed by the feature name.
targets: A tensor representing the labels (targets).
mode: The execution mode, as defined in tf.contrib.learn.ModeKeys.
Returns:
A tuple consisting of the prediction, loss, and train_op.
"""
if mode == tf.contrib.learn.ModeKeys.INFER and FLAGS.decode_interactive:
features = _interactive_input_tensor_to_features_dict(features, hparams)
if mode == tf.contrib.learn.ModeKeys.INFER and FLAGS.decode_from_file:
features = _decode_input_tensor_to_features_dict(features, hparams)
# A dictionary containing:
# - problem_choice: A Tensor containing an integer indicating which problem
# was selected for this run.
# - predictions: A Tensor containing the model's output predictions.
run_info = dict()
run_info["problem_choice"] = features["problem_choice"]
if targets is not None:
features["targets"] = targets
dp = data_parallelism()
# Add input statistics for incoming features.
with tf.name_scope("input_stats"):
for (k, v) in six.iteritems(features):
if isinstance(v, tf.Tensor) and v.get_shape().ndims > 1:
tf.summary.scalar("%s_batch" % k, tf.shape(v)[0] // dp.n)
tf.summary.scalar("%s_length" % k, tf.shape(v)[1])
nonpadding = tf.to_float(tf.not_equal(v, 0))
tf.summary.scalar("%s_nonpadding_tokens" % k,
tf.reduce_sum(nonpadding))
tf.summary.scalar("%s_nonpadding_fraction" % k,
tf.reduce_mean(nonpadding))
tf.get_variable_scope().set_initializer(initializer())
train = mode == tf.contrib.learn.ModeKeys.TRAIN
# Get multi-problem logits and loss based on features["problem_choice"].
def nth_model(n):
"""Build the model for the n-th problem, plus some added variables."""
model_class = registry.model(model)(
hparams, hparams.problems[n], n, dp, _ps_devices(all_workers=True))
if mode == tf.contrib.learn.ModeKeys.INFER:
return model_class.infer(
features,
beam_size=FLAGS.decode_beam_size,
top_beams=(FLAGS.decode_beam_size
if FLAGS.decode_return_beams else 1),
last_position_only=FLAGS.decode_use_last_position_only,
alpha=FLAGS.decode_alpha,
decode_length=FLAGS.decode_extra_length)
# In distributed mode, we build graph for problem=0 and problem=worker_id.
skipping_is_on = hparams.problem_choice == "distributed" and train
problem_worker_id = FLAGS.worker_id % len(hparams.problems)
skip_this_one = n != 0 and n % FLAGS.worker_replicas != problem_worker_id
# On worker 0 also build graph for problems <= 1.
# TODO(lukaszkaiser): why is this hack needed for variables init? Repair.
skip_this_one = skip_this_one and (FLAGS.worker_id != 0 or n > 1)
sharded_logits, training_loss, extra_loss = model_class.model_fn(
features, train, skip=(skipping_is_on and skip_this_one))
with tf.variable_scope("losses_avg", reuse=True):
loss_moving_avg = tf.get_variable("problem_%d/training_loss" % n)
o1 = loss_moving_avg.assign(loss_moving_avg * 0.9 + training_loss * 0.1)
loss_moving_avg = tf.get_variable("problem_%d/extra_loss" % n)
o2 = loss_moving_avg.assign(loss_moving_avg * 0.9 + extra_loss * 0.1)
loss_moving_avg = tf.get_variable("problem_%d/total_loss" % n)
total_loss = training_loss + extra_loss
o3 = loss_moving_avg.assign(loss_moving_avg * 0.9 + total_loss * 0.1)
with tf.variable_scope("train_stats"): # Count steps for this problem.
problem_steps = tf.get_variable(
"problem_%d_steps" % n, initializer=0, trainable=False)
o4 = problem_steps.assign_add(1)
with tf.control_dependencies([o1, o2, o3, o4]): # Make sure the ops run.
total_loss = tf.identity(total_loss)
return [total_loss] + sharded_logits # Need to flatten for cond later.
result_list = _cond_on_index(nth_model, features["problem_choice"], 0,
len(hparams.problems) - 1)
if mode == tf.contrib.learn.ModeKeys.INFER:
# Beam search in sequence model returns both decodes withe key "outputs"
# and scores with they key "scores". If return list is a dict, we expect
# that it will have keys "outputs", a tensor of int32 and scores, a
# tensor of floats. This is useful if we want to return scores from
# estimator.predict
if not isinstance(result_list, dict):
ret = {"outputs": result_list}, None, None
else:
ret = {
"outputs": result_list["outputs"],
"scores": result_list["scores"]
}, None, None
if "inputs" in features:
ret[0]["inputs"] = features["inputs"]
if "infer_targets" in features:
ret[0]["targets"] = features["infer_targets"]
return ret
sharded_logits, total_loss = result_list[1:], result_list[0]
if mode == tf.contrib.learn.ModeKeys.EVAL:
logits = tf.concat(sharded_logits, 0)
# For evaluation, return the logits layer as our predictions.
run_info["predictions"] = logits
train_op = None
return run_info, total_loss, None
assert mode == tf.contrib.learn.ModeKeys.TRAIN
# Some training statistics.
with tf.name_scope("training_stats"):
learning_rate = hparams.learning_rate * learning_rate_decay()
learning_rate /= math.sqrt(float(FLAGS.worker_replicas))
tf.summary.scalar("learning_rate", learning_rate)
global_step = tf.to_float(tf.contrib.framework.get_global_step())
for n in xrange(len(hparams.problems)):
with tf.variable_scope("losses_avg", reuse=True):
total_loss_var = tf.get_variable("problem_%d/total_loss" % n)
training_loss_var = tf.get_variable("problem_%d/training_loss" % n)
extra_loss_var = tf.get_variable("problem_%d/extra_loss" % n)
tf.summary.scalar("loss_avg_%d/total_loss" % n, total_loss_var)
tf.summary.scalar("loss_avg_%d/training_loss" % n, training_loss_var)
tf.summary.scalar("loss_avg_%d/extra_loss" % n, extra_loss_var)
with tf.variable_scope("train_stats", reuse=True):
nth_steps = tf.get_variable("problem_%d_steps" % n, dtype=tf.int32)
tf.summary.scalar("problem_%d_frequency" % n,
tf.to_float(nth_steps) / (global_step + 1.0))
# Log trainable weights and add decay.
total_size, total_embedding, weight_decay_loss = 0, 0, 0.0
all_weights = {v.name: v for v in tf.trainable_variables()}
for v_name in sorted(list(all_weights)):
v = all_weights[v_name]
v_size = int(np.prod(np.array(v.shape.as_list())))
tf.logging.info("Weight %s\tshape %s\tsize %d",
v.name[:-2].ljust(80), str(v.shape).ljust(20), v_size)
if "embedding" in v_name:
total_embedding += v_size
total_size += v_size
if hparams.weight_decay > 0.0 and len(v.shape.as_list()) > 1:
# Add weight regularization if set and the weight is not a bias (dim>1).
with tf.device(v._ref().device): # pylint: disable=protected-access
v_loss = tf.nn.l2_loss(v) / v_size
weight_decay_loss += v_loss
is_body = len(v_name) > 5 and v_name[:5] == "body/"
if hparams.weight_noise > 0.0 and is_body:
# Add weight noise if set in hparams.
with tf.device(v._ref().device): # pylint: disable=protected-access
scale = learning_rate * 0.001
noise = tf.truncated_normal(v.shape) * hparams.weight_noise * scale
noise_op = v.assign_add(noise)
with tf.control_dependencies([noise_op]):
total_loss = tf.identity(total_loss)
tf.logging.info("Total trainable variables size: %d", total_size)
tf.logging.info("Total embedding variables size: %d", total_embedding)
tf.logging.info("Total non-embedding variables size: %d",
total_size - total_embedding)
total_loss += weight_decay_loss * hparams.weight_decay
# Define the train_op for the TRAIN mode.
opt = _ConditionalOptimizer(hparams.optimizer, learning_rate, hparams)
tf.logging.info("Computing gradients for global model_fn.")
train_op = tf.contrib.layers.optimize_loss(
name="training",
loss=total_loss,
global_step=tf.contrib.framework.get_global_step(),
learning_rate=learning_rate,
clip_gradients=hparams.clip_grad_norm or None,
optimizer=opt,
colocate_gradients_with_ops=True)
tf.logging.info("Global model_fn finished.")
return run_info, total_loss, train_op
return model_fn
def run_locally(exp):
"""Runs an Experiment locally - trains, evaluates, and decodes.
Args:
exp: Experiment.
"""
if exp.train_steps > 0:
# Train
tf.logging.info("Performing local training.")
exp.train()
if exp.eval_steps > 0:
# Evaluate
tf.logging.info("Performing local evaluation.")
unused_metrics = exp.evaluate(delay_secs=0)
# Predict
estimator = exp.estimator
if FLAGS.decode_interactive:
decode_interactively(estimator)
elif FLAGS.decode_from_file is not None:
decode_from_file(estimator, FLAGS.decode_from_file)
else:
decode_from_dataset(estimator)
def decode_from_dataset(estimator):
hparams = estimator.hparams
for i, problem in enumerate(FLAGS.problems.split("-")):
inputs_vocab = hparams.problems[i].vocabulary.get("inputs", None)
targets_vocab = hparams.problems[i].vocabulary["targets"]
tf.logging.info("Performing local inference.")
infer_problems_data = get_datasets_for_mode(hparams.data_dir,
tf.contrib.learn.ModeKeys.INFER)
infer_input_fn = get_input_fn(
mode=tf.contrib.learn.ModeKeys.INFER,
hparams=hparams,
data_file_patterns=infer_problems_data,
num_datashards=data_parallelism().n,
fixed_problem=i)
result_iter = estimator.predict(
input_fn=infer_input_fn, as_iterable=FLAGS.decode_endless)
def log_fn(inputs,
targets,
outputs,
problem,
j,
inputs_vocab=inputs_vocab,
targets_vocab=targets_vocab):
"""Log inference results."""
if "image" in problem and FLAGS.decode_save_images:
save_path = os.path.join(estimator.model_dir,
"%s_prediction_%d.jpg" % (problem, j))
show_and_save_image(inputs / 255., save_path)
elif inputs_vocab:
decoded_inputs = inputs_vocab.decode(inputs.flatten())
tf.logging.info("Inference results INPUT: %s" % decoded_inputs)
decoded_outputs = targets_vocab.decode(outputs.flatten())
decoded_targets = targets_vocab.decode(targets.flatten())
tf.logging.info("Inference results OUTPUT: %s" % decoded_outputs)
if FLAGS.decode_to_file:
output_filepath = FLAGS.decode_to_file + ".outputs." + problem
output_file = tf.gfile.Open(output_filepath, "a")
output_file.write(decoded_outputs + "\n")
target_filepath = FLAGS.decode_to_file + ".targets." + problem
target_file = tf.gfile.Open(target_filepath, "a")
target_file.write(decoded_targets + "\n")
# The function predict() returns an iterable over the network's
# predictions from the test input. if FLAGS.decode_endless is set, it will
# decode over the dev set endlessly, looping over it. We use the returned
# iterator to log inputs and decodes.
if FLAGS.decode_endless:
tf.logging.info("Warning: Decoding endlessly")
for j, result in enumerate(result_iter):
inputs, targets, outputs = (result["inputs"], result["targets"],
result["outputs"])
if FLAGS.decode_return_beams:
output_beams = np.split(outputs, FLAGS.decode_beam_size, axis=0)
for k, beam in enumerate(output_beams):
tf.logging.info("BEAM %d:" % k)
log_fn(inputs, targets, beam, problem, j)
else:
log_fn(inputs, targets, outputs, problem, j)
else:
for j, (inputs, targets, outputs) in enumerate(
zip(result_iter["inputs"], result_iter["targets"], result_iter[
"outputs"])):
if FLAGS.decode_return_beams:
output_beams = np.split(outputs, FLAGS.decode_beam_size, axis=0)
for k, beam in enumerate(output_beams):
tf.logging.info("BEAM %d:" % k)
log_fn(inputs, targets, beam, problem, j)
else:
log_fn(inputs, targets, outputs, problem, j)
def decode_from_file(estimator, filename):
"""Compute predictions on entries in filename and write them out."""
hparams = estimator.hparams
problem_id = FLAGS.decode_problem_id
inputs_vocab = hparams.problems[problem_id].vocabulary["inputs"]
targets_vocab = hparams.problems[problem_id].vocabulary["targets"]
tf.logging.info("Performing Decoding from a file.")
sorted_inputs, sorted_keys = _get_sorted_inputs(filename)
num_decode_batches = (len(sorted_inputs) - 1) // FLAGS.decode_batch_size + 1
input_fn = _decode_batch_input_fn(problem_id, num_decode_batches,
sorted_inputs, inputs_vocab)
# strips everything after the first <EOS> id, which is assumed to be 1
def _save_until_eos(hyp): # pylint: disable=missing-docstring
ret = []
index = 0
# until you reach <EOS> id
while index < len(hyp) and hyp[index] != 1:
ret.append(hyp[index])
index += 1
return np.array(ret)
decodes = []
for _ in range(num_decode_batches):
result_iter = estimator.predict(
input_fn=input_fn.next if six.PY2 else input_fn.__next__,
as_iterable=True)
for result in result_iter:
def log_fn(inputs, outputs):
decoded_inputs = inputs_vocab.decode(_save_until_eos(inputs.flatten()))
tf.logging.info("Inference results INPUT: %s" % decoded_inputs)
decoded_outputs = targets_vocab.decode(
_save_until_eos(outputs.flatten()))
tf.logging.info("Inference results OUTPUT: %s" % decoded_outputs)
return decoded_outputs
if FLAGS.decode_return_beams:
beam_decodes = []
output_beams = np.split(
result["outputs"], FLAGS.decode_beam_size, axis=0)
for k, beam in enumerate(output_beams):
tf.logging.info("BEAM %d:" % k)
beam_decodes.append(log_fn(result["inputs"], beam))
decodes.append(str.join("\t", beam_decodes))
else:
decodes.append(log_fn(result["inputs"], result["outputs"]))
# Reversing the decoded inputs and outputs because they were reversed in
# _decode_batch_input_fn
sorted_inputs.reverse()
decodes.reverse()
# Dumping inputs and outputs to file filename.decodes in
# format result\tinput in the same order as original inputs
if FLAGS.decode_shards > 1:
base_filename = filename + ("%.2d" % FLAGS.worker_id)
else:
base_filename = filename
decode_filename = (base_filename + "." + FLAGS.model + "." + FLAGS.hparams_set
+ ".beam" + str(FLAGS.decode_beam_size) + ".alpha" +
str(FLAGS.decode_alpha) + ".decodes")
tf.logging.info("Writing decodes into %s" % decode_filename)
outfile = tf.gfile.Open(decode_filename, "w")
for index in range(len(sorted_inputs)):
outfile.write("%s\t%s\n" % (decodes[sorted_keys[index]],
sorted_inputs[sorted_keys[index]]))
def decode_interactively(estimator):
hparams = estimator.hparams
infer_input_fn = _interactive_input_fn(hparams)
for problem_idx, example in infer_input_fn:
targets_vocab = hparams.problems[problem_idx].vocabulary["targets"]
result_iter = estimator.predict(input_fn=lambda e=example: e)
for result in result_iter:
if FLAGS.decode_return_beams:
beams = np.split(result["outputs"], FLAGS.decode_beam_size, axis=0)
scores = None
if "scores" in result:
scores = np.split(result["scores"], FLAGS.decode_beam_size, axis=0)
for k, beam in enumerate(beams):
tf.logging.info("BEAM %d:" % k)
if scores is not None:
tf.logging.info("%s\tScore:%f" %
(targets_vocab.decode(beam.flatten()), scores[k]))
else:
tf.logging.info(targets_vocab.decode(beam.flatten()))
else:
tf.logging.info(targets_vocab.decode(result["outputs"].flatten()))
def _decode_batch_input_fn(problem_id, num_decode_batches, sorted_inputs,
vocabulary):
tf.logging.info(" batch %d" % num_decode_batches)
# First reverse all the input sentences so that if you're going to get OOMs,
# you'll see it in the first batch
sorted_inputs.reverse()
for b in range(num_decode_batches):
tf.logging.info("Deocding batch %d" % b)
batch_length = 0
batch_inputs = []
for inputs in sorted_inputs[b * FLAGS.decode_batch_size:(
b + 1) * FLAGS.decode_batch_size]:
input_ids = vocabulary.encode(inputs)
input_ids.append(1) # Assuming EOS=1.
batch_inputs.append(input_ids)
if len(input_ids) > batch_length:
batch_length = len(input_ids)
final_batch_inputs = []
for input_ids in batch_inputs:
assert len(input_ids) <= batch_length
x = input_ids + [0] * (batch_length - len(input_ids))
final_batch_inputs.append(x)
yield {
"inputs": np.array(final_batch_inputs),
"problem_choice": np.array(problem_id)
}
def get_datasets_for_mode(data_dir, mode):
return data_reader.get_datasets(FLAGS.problems, data_dir, mode)
def _cond_on_index(fn, index_tensor, cur_idx, max_idx):
"""Call fn(index_tensor) using tf.cond in [cur_id, max_idx]."""
if cur_idx == max_idx:
return fn(cur_idx)
return tf.cond(
tf.equal(index_tensor, cur_idx), lambda: fn(cur_idx),
lambda: _cond_on_index(fn, index_tensor, cur_idx + 1, max_idx))
def _interactive_input_fn(hparams):
"""Generator that reads from the terminal and yields "interactive inputs".
Due to temporary limitations in tf.learn, if we don't want to reload the
whole graph, then we are stuck encoding all of the input as one fixed-size
numpy array.
We yield int64 arrays with shape [const_array_size]. The format is:
[num_samples, decode_length, len(input ids), <input ids>, <padding>]
Args:
hparams: model hparams
Yields:
numpy arrays
Raises:
Exception: when `input_type` is invalid.
"""
num_samples = 3
decode_length = 100
input_type = "text"
problem_id = 0
p_hparams = hparams.problems[problem_id]
has_input = "inputs" in p_hparams.input_modality
vocabulary = p_hparams.vocabulary["inputs" if has_input else "targets"]
# This should be longer than the longest input.
const_array_size = 10000
while True:
prompt = ("INTERACTIVE MODE num_samples=%d decode_length=%d \n"
" it=<input_type> ('text' or 'image')\n"
" pr=<problem_num> (set the problem number)\n"
" in=<input_problem> (set the input problem number)\n"
" ou=<output_problem> (set the output problem number)\n"
" ns=<num_samples> (changes number of samples)\n"
" dl=<decode_length> (changes decode legnth)\n"
" <%s> (decode)\n"
" q (quit)\n"
">" % (num_samples, decode_length, "source_string"
if has_input else "target_prefix"))
input_string = input(prompt)
if input_string == "q":
return
elif input_string[:3] == "pr=":
problem_id = int(input_string[3:])
p_hparams = hparams.problems[problem_id]
has_input = "inputs" in p_hparams.input_modality
vocabulary = p_hparams.vocabulary["inputs" if has_input else "targets"]
elif input_string[:3] == "in=":
problem = int(input_string[3:])
p_hparams.input_modality = hparams.problems[problem].input_modality
p_hparams.input_space_id = hparams.problems[problem].input_space_id
elif input_string[:3] == "ou=":
problem = int(input_string[3:])
p_hparams.target_modality = hparams.problems[problem].target_modality
p_hparams.target_space_id = hparams.problems[problem].target_space_id
elif input_string[:3] == "ns=":
num_samples = int(input_string[3:])
elif input_string[:3] == "dl=":
decode_length = int(input_string[3:])
elif input_string[:3] == "it=":
input_type = input_string[3:]
else:
if input_type == "text":
input_ids = vocabulary.encode(input_string)
if has_input:
input_ids.append(1) # assume 1 means end-of-source
x = [num_samples, decode_length, len(input_ids)] + input_ids
assert len(x) < const_array_size
x += [0] * (const_array_size - len(x))
yield problem_id, {
"inputs": np.array(x),
"problem_choice": np.array(problem_id)
}
elif input_type == "image":
input_path = input_string
img = read_image(input_path)
yield problem_id, {
"inputs": img,
"problem_choice": np.array(problem_id)
}
else:
raise Exception("Unsupported input type.")
def read_image(path):
try:
import matplotlib.image as im # pylint: disable=g-import-not-at-top
except ImportError as e:
tf.logging.warning(
"Reading an image requires matplotlib to be installed: %s", e)
raise NotImplementedError("Image reading not implemented.")
return im.imread(path)
def show_and_save_image(img, save_path):
try:
import matplotlib.pyplot as plt # pylint: disable=g-import-not-at-top
except ImportError as e:
tf.logging.warning("Showing and saving an image requires matplotlib to be "
"installed: %s", e)
raise NotImplementedError("Image display and save not implemented.")
plt.imshow(img)
plt.savefig(save_path)
def _get_sorted_inputs(filename):
"""Returning inputs sorted according to length.
Args:
filename: path to file with inputs, 1 per line.
Returns:
a sorted list of inputs
"""
tf.logging.info("Getting sorted inputs")
# read file and sort inputs according them according to input length.
if FLAGS.decode_shards > 1:
decode_filename = filename + ("%.2d" % FLAGS.worker_id)
else:
decode_filename = filename
inputs = [line.strip() for line in tf.gfile.Open(decode_filename)]
input_lens = [(i, len(line.strip().split())) for i, line in enumerate(inputs)]
sorted_input_lens = sorted(input_lens, key=operator.itemgetter(1))
# We'll need the keys to rearrange the inputs back into their original order
sorted_keys = {}
sorted_inputs = []
for i, (index, _) in enumerate(sorted_input_lens):
sorted_inputs.append(inputs[index])
sorted_keys[index] = i
return sorted_inputs, sorted_keys
def _interactive_input_tensor_to_features_dict(feature_map, hparams):
"""Convert the interactive input format (see above) to a dictionary.
Args:
feature_map: a dictionary with keys `problem_choice` and `input` containing
Tensors.
hparams: model hyperparameters
Returns:
a features dictionary, as expected by the decoder.
"""
inputs = tf.constant(feature_map["inputs"])
input_is_image = False if len(inputs.shape) < 3 else True
def input_fn(problem_choice, x=inputs): # pylint: disable=missing-docstring
p_hparams = hparams.problems[problem_choice]
if not input_is_image:
# Remove the batch dimension.
num_samples = x[0]
length = x[2]
x = tf.slice(x, [3], tf.to_int32([length]))
x = tf.reshape(x, [1, -1, 1, 1])
# Transform into a batch of size num_samples to get that many random
# decodes.
x = tf.tile(x, tf.to_int32([num_samples, 1, 1, 1]))
else:
x = tf.image.resize_images(x, [299, 299])
x = tf.reshape(x, [1, 299, 299, -1])
x = tf.to_int32(x)
return (tf.constant(p_hparams.input_space_id),
tf.constant(p_hparams.target_space_id), x)
input_space_id, target_space_id, x = _cond_on_index(
input_fn, feature_map["problem_choice"], 0, len(hparams.problems) - 1)
features = {}
features["problem_choice"] = tf.constant(feature_map["problem_choice"])
features["input_space_id"] = input_space_id
features["target_space_id"] = target_space_id
features["decode_length"] = (IMAGE_DECODE_LENGTH
if input_is_image else inputs[1])
features["inputs"] = x
return features
def _decode_input_tensor_to_features_dict(feature_map, hparams):
"""Convert the interactive input format (see above) to a dictionary.
Args:
feature_map: a dictionary with keys `problem_choice` and `input` containing
Tensors.
hparams: model hyperparameters
Returns:
a features dictionary, as expected by the decoder.
"""
inputs = tf.constant(feature_map["inputs"])
input_is_image = False
def input_fn(problem_choice, x=inputs): # pylint: disable=missing-docstring
p_hparams = hparams.problems[problem_choice]
# Add a third empty dimension dimension
x = tf.expand_dims(x, axis=[2])
x = tf.to_int32(x)
return (tf.constant(p_hparams.input_space_id),
tf.constant(p_hparams.target_space_id), x)
input_space_id, target_space_id, x = _cond_on_index(
input_fn, feature_map["problem_choice"], 0, len(hparams.problems) - 1)
features = {}
features["problem_choice"] = feature_map["problem_choice"]
features["input_space_id"] = input_space_id
features["target_space_id"] = target_space_id
features["decode_length"] = (IMAGE_DECODE_LENGTH
if input_is_image else tf.shape(x)[1] + 50)
features["inputs"] = x
return features
def get_input_fn(mode,
hparams,
data_file_patterns=None,
num_datashards=None,
fixed_problem=None):
"""Provides input to the graph, either from disk or via a placeholder.
This function produces an input function that will feed data into
the network. There are two modes of operation:
1. If data_file_pattern and all subsequent arguments are None, then
it creates a placeholder for a serialized tf.Example proto.
2. If data_file_pattern is defined, it will read the data from the
files at the given location. Use this mode for training,
evaluation, and testing prediction.
Args:
mode: The execution mode, as defined in tf.contrib.learn.ModeKeys.
hparams: HParams object.
data_file_patterns: The list of file patterns to use to read in data. Set to
`None` if you want to create a placeholder for the input data. The
`problems` flag is a list of problem names joined by the `-` character.
The flag's string is then split along the `-` and each problem gets its
own example queue.
num_datashards: An integer.
fixed_problem: An integer indicating the problem to fetch data for, or None
if the input is to be randomly selected.
Returns:
A function that returns a dictionary of features and the target labels.
"""