/
training.py
2777 lines (2526 loc) · 119 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Training-related part of the Keras engine.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import numpy as np
from tensorflow.python import tf2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.distribute import distribute_coordinator as dc
from tensorflow.python.distribute import distribute_coordinator_context as dc_context
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import losses
from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.engine import distributed_training_utils
from tensorflow.python.keras.engine import training_arrays
from tensorflow.python.keras.engine import training_distributed
from tensorflow.python.keras.engine import training_eager
from tensorflow.python.keras.engine import training_generator
from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.engine.network import Network
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.keras.saving import saving_utils
from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils.generic_utils import slice_arrays
from tensorflow.python.keras.utils.losses_utils import squeeze_or_expand_dimensions
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import optimizer as tf_optimizer_module
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.mode_keys import ModeKeys
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import keras_export
@keras_export('keras.models.Model', 'keras.Model')
class Model(Network):
"""`Model` groups layers into an object with training and inference features.
There are two ways to instantiate a `Model`:
1 - With the "functional API", where you start from `Input`,
you chain layer calls to specify the model's forward pass,
and finally you create your model from inputs and outputs:
```python
import tensorflow as tf
inputs = tf.keras.Input(shape=(3,))
x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
```
2 - By subclassing the `Model` class: in that case, you should define your
layers in `__init__` and you should implement the model's forward pass
in `call`.
```python
import tensorflow as tf
class MyModel(tf.keras.Model):
def __init__(self):
super(MyModel, self).__init__()
self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
def call(self, inputs):
x = self.dense1(inputs)
return self.dense2(x)
model = MyModel()
```
If you subclass `Model`, you can optionally have
a `training` argument (boolean) in `call`, which you can use to specify
a different behavior in training and inference:
```python
import tensorflow as tf
class MyModel(tf.keras.Model):
def __init__(self):
super(MyModel, self).__init__()
self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
self.dropout = tf.keras.layers.Dropout(0.5)
def call(self, inputs, training=False):
x = self.dense1(inputs)
if training:
x = self.dropout(x, training=training)
return self.dense2(x)
model = MyModel()
```
"""
def __init__(self, *args, **kwargs):
super(Model, self).__init__(*args, **kwargs)
# initializing _distribution_strategy here since it is possible to call
# predict on a model without compiling it.
self._distribution_strategy = None
# This flag is used to track if the user is using the deprecated path of
# passing distribution strategy to compile rather than creating the model
# under distribution strategy scope.
self._compile_distribution = False
self._distributed_session_is_configured = False
self.run_eagerly = None
def get_weights(self):
"""Retrieves the weights of the model.
Returns:
A flat list of Numpy arrays.
"""
if self._distribution_strategy:
with self._distribution_strategy.scope():
return super(Model, self).get_weights()
return super(Model, self).get_weights()
@checkpointable.no_automatic_dependency_tracking
def compile(self,
optimizer,
loss=None,
metrics=None,
loss_weights=None,
sample_weight_mode=None,
weighted_metrics=None,
target_tensors=None,
distribute=None,
**kwargs):
"""Configures the model for training.
Arguments:
optimizer: String (name of optimizer) or optimizer instance.
See `tf.keras.optimizers`.
loss: String (name of objective function) or objective function.
See `tf.losses`. If the model has multiple outputs, you can use a
different loss on each output by passing a dictionary or a list of
losses. The loss value that will be minimized by the model
will then be the sum of all individual losses.
metrics: List of metrics to be evaluated by the model
during training and testing.
Typically you will use `metrics=['accuracy']`.
To specify different metrics for different outputs of a
multi-output model, you could also pass a dictionary,
such as `metrics={'output_a': 'accuracy'}`.
loss_weights: Optional list or dictionary specifying scalar
coefficients (Python floats) to weight the loss contributions
of different model outputs.
The loss value that will be minimized by the model
will then be the *weighted sum* of all individual losses,
weighted by the `loss_weights` coefficients.
If a list, it is expected to have a 1:1 mapping
to the model's outputs. If a tensor, it is expected to map
output names (strings) to scalar coefficients.
sample_weight_mode: If you need to do timestep-wise
sample weighting (2D weights), set this to `"temporal"`.
`None` defaults to sample-wise weights (1D).
If the model has multiple outputs, you can use a different
`sample_weight_mode` on each output by passing a
dictionary or a list of modes.
weighted_metrics: List of metrics to be evaluated and weighted
by sample_weight or class_weight during training and testing.
target_tensors: By default, Keras will create placeholders for the
model's target, which will be fed with the target data during
training. If instead you would like to use your own
target tensors (in turn, Keras will not expect external
Numpy data for these targets at training time), you
can specify them via the `target_tensors` argument. It can be
a single tensor (for a single-output model), a list of tensors,
or a dict mapping output names to target tensors.
distribute: NOT SUPPORTED IN TF 2.0, please create and compile the
model under distribution strategy scope instead of passing it to
compile.
**kwargs: Any additional arguments.
Raises:
ValueError: In case of invalid arguments for
`optimizer`, `loss`, `metrics` or `sample_weight_mode`.
"""
run_eagerly = kwargs.pop('run_eagerly', None)
self._run_eagerly = run_eagerly
optimizer = optimizers.get(optimizer)
if distribute is not None:
if tf2.enabled():
raise ValueError(
'Distribute argument in compile is not available in TF 2.0 please '
'create the model under the distribution strategy scope.')
logging.warning('Distribute argument in compile is deprecated please '
'create the model under the distribution strategy scope.')
self._distribution_strategy = distribute
self._compile_distribution = True
self._distributed_session_is_configured = False
else:
if distribution_strategy_context.has_strategy():
# When the user builds the model in the DS scope and cross replica
# context we want distribution strategy to be set but when building the
# replica copies of the models internally we should not be compiling
# with distribution strategy and use the default compilation path.
if distribution_strategy_context.in_cross_replica_context():
self._distribution_strategy = (
distribution_strategy_context.get_strategy())
# Validate that arguments passed by the user to `compile` are supported by
# DistributionStrategy.
if self._distribution_strategy:
if not isinstance(optimizer,
(tf_optimizer_module.Optimizer, optimizers.TFOptimizer,
optimizer_v2.OptimizerV2)):
raise NotImplementedError(
'optimizer must be an instance of '
'tf.train.Optimizer, not a %s' % type(optimizer))
if sample_weight_mode:
raise NotImplementedError('sample_weight_mode is not supported with '
'DistributionStrategy.')
if weighted_metrics:
raise NotImplementedError('weighted_metrics is not supported with '
'DistributionStrategy.')
if target_tensors:
raise ValueError('target_tensors is not supported with '
'DistributionStrategy.')
loss = loss or {}
if self.run_eagerly and not isinstance(
optimizer, (tf_optimizer_module.Optimizer, optimizers.TFOptimizer,
optimizer_v2.OptimizerV2)):
raise ValueError(
'When running a model in eager execution, the optimizer must be an '
'instance of tf.train.Optimizer. Received: '
'%s' % optimizer)
self.optimizer = optimizer
# We've disabled automatic dependency tracking for this method, but do want
# to add a checkpoint dependency on the optimizer if it's checkpointable.
if isinstance(self.optimizer, checkpointable.Checkpointable):
self._track_checkpointable(
self.optimizer, name='optimizer', overwrite=True)
self.loss = loss
self._compile_metrics = metrics or []
self.loss_weights = loss_weights
self.sample_weight_mode = sample_weight_mode
self._compile_weighted_metrics = weighted_metrics
if self.run_eagerly and target_tensors is not None:
raise ValueError(
'target_tensors argument is not supported when '
'running a model eagerly.')
self.target_tensors = target_tensors
# Set DistributionStrategy specific parameters.
self._distributed_model = None
# Initialize model metric attributes.
self._init_metric_attributes()
if not self.built or not self.inputs or not self.outputs:
# Model is not compilable because it does not know its number of inputs
# and outputs, nor their shapes and names. We will compile after the first
# time the model gets called on training data.
return
self._is_compiled = True
# Prepare loss functions.
if isinstance(loss, dict):
for name in loss:
if name not in self.output_names:
raise ValueError(
'Unknown entry in loss '
'dictionary: "' + name + '". '
'Only expected the following keys: ' + str(self.output_names))
loss_functions = []
for name in self.output_names:
if name not in loss:
logging.warning(
'Output "' + name +
'" missing from loss dictionary. We assume '
'this was done on purpose. The fit and evaluate APIs will not be '
'expecting any data to be passed to "' + name + '".')
loss_functions.append(training_utils.get_loss_function(loss.get(name)))
elif isinstance(loss, list):
if len(loss) != len(self.outputs):
raise ValueError('When passing a list as loss, '
'it should have one entry per model outputs. '
'The model has ' + str(len(self.outputs)) +
' outputs, but you passed loss=' + str(loss))
loss_functions = [training_utils.get_loss_function(l) for l in loss]
else:
loss_function = training_utils.get_loss_function(loss)
loss_functions = [loss_function for _ in range(len(self.outputs))]
self.loss_functions = loss_functions
skip_target_indices = []
skip_target_weighing_indices = []
self._feed_outputs = []
self._feed_output_names = []
self._feed_output_shapes = []
self._feed_loss_fns = []
for i in range(len(loss_functions)):
if loss_functions[i] is None:
skip_target_indices.append(i)
skip_target_weighing_indices.append(i)
# Prepare output masks.
if not self.run_eagerly:
masks = [getattr(x, '_keras_mask', None) for x in self.outputs]
if not isinstance(masks, list):
masks = [masks]
# Prepare loss weights.
if loss_weights is None:
loss_weights_list = [1. for _ in range(len(self.outputs))]
elif isinstance(loss_weights, dict):
for name in loss_weights:
if name not in self.output_names:
raise ValueError(
'Unknown entry in loss_weights '
'dictionary: "' + name + '". '
'Only expected the following keys: ' + str(self.output_names))
loss_weights_list = []
for name in self.output_names:
loss_weights_list.append(loss_weights.get(name, 1.))
elif isinstance(loss_weights, list):
if len(loss_weights) != len(self.outputs):
raise ValueError(
'When passing a list as loss_weights, '
'it should have one entry per model output. '
'The model has ' + str(len(self.outputs)) +
' outputs, but you passed loss_weights=' + str(loss_weights))
loss_weights_list = loss_weights
else:
raise TypeError('Could not interpret loss_weights argument: ' +
str(loss_weights) + ' - expected a list of dicts.')
self.loss_weights_list = loss_weights_list
# Initialization for Eager mode execution.
if self.run_eagerly:
# Prepare sample weights.
self._set_sample_weight_attributes(sample_weight_mode,
skip_target_weighing_indices)
# Save all metric attributes per output of the model.
self._cache_output_metric_attributes(metrics, weighted_metrics)
if target_tensors is not None:
raise ValueError('target_tensors are not currently supported in Eager '
'mode.')
self.total_loss = None
for i in range(len(self.outputs)):
if len(self.outputs) > 1:
self._compile_metrics_names.append(self.output_names[i] + '_loss')
# Set metric attributes on model.
self._set_metric_attributes(
self.outputs,
skip_target_indices=skip_target_indices,
)
self.targets = []
for i in range(len(self.outputs)):
self._feed_output_names.append(self.output_names[i])
self._collected_trainable_weights = self.trainable_weights
return
with K.get_graph().as_default():
# Prepare targets of model.
self.targets = []
self._feed_targets = []
if target_tensors not in (None, []):
if isinstance(target_tensors, list):
if len(target_tensors) != len(self.outputs):
raise ValueError(
'When passing a list as `target_tensors`, '
'it should have one entry per model output. '
'The model has %s outputs, but you passed target_tensors=%s' %
(len(self.outputs), target_tensors))
elif isinstance(target_tensors, dict):
for name in target_tensors:
if name not in self.output_names:
raise ValueError(
'Unknown entry in `target_tensors` '
'dictionary: "' + name + '". '
'Only expected the following keys: ' + str(self.output_names))
tmp_target_tensors = []
for name in self.output_names:
tmp_target_tensors.append(target_tensors.get(name, None))
target_tensors = tmp_target_tensors
elif tensor_util.is_tensor(target_tensors):
target_tensors = [target_tensors]
else:
raise TypeError('Expected `target_tensors` to be a list or tuple or '
'dict or a single tensor, but got:', target_tensors)
for i in range(len(self.outputs)):
if i in skip_target_indices:
self.targets.append(None)
else:
shape = K.int_shape(self.outputs[i])
name = self.output_names[i]
if target_tensors not in (None, []):
target = target_tensors[i]
else:
target = None
if target is None or K.is_placeholder(target):
if target is None:
target_dtype = losses.LABEL_DTYPES_FOR_LOSSES.get(
self.loss_functions[i],
K.dtype(self.outputs[i]))
target = K.placeholder(
ndim=len(shape),
name=name + '_target',
sparse=K.is_sparse(self.outputs[i]),
dtype=target_dtype)
self._feed_targets.append(target)
self._feed_outputs.append(self.outputs[i])
self._feed_output_names.append(name)
self._feed_output_shapes.append(shape)
self._feed_loss_fns.append(self.loss_functions[i])
else:
skip_target_weighing_indices.append(i)
self.targets.append(target)
# Prepare sample weights.
self._set_sample_weight_attributes(sample_weight_mode,
skip_target_weighing_indices)
# Save all metric attributes per output of the model.
self._cache_output_metric_attributes(metrics, weighted_metrics)
# Compute total loss.
total_loss = None
with K.name_scope('loss'):
for i in range(len(self.outputs)):
if i in skip_target_indices:
continue
y_true = self.targets[i]
y_pred = self.outputs[i]
loss_fn = loss_functions[i]
sample_weight = self.sample_weights[i]
mask = masks[i]
loss_weight = loss_weights_list[i]
with K.name_scope(self.output_names[i] + '_loss'):
if isinstance(loss_fn, losses.Loss):
if mask is not None:
mask = math_ops.cast(mask, y_pred.dtype)
# Update weights with mask.
if sample_weight is None:
sample_weight = mask
else:
# Update dimensions of weights to match with mask if possible.
mask, _, sample_weight = squeeze_or_expand_dimensions(
mask, None, sample_weight)
sample_weight *= mask
output_loss = loss_fn(y_true, y_pred, sample_weight=sample_weight)
else:
weighted_loss = training_utils.weighted_masked_objective(loss_fn)
output_loss = weighted_loss(y_true, y_pred, sample_weight, mask)
if len(self.outputs) > 1:
# Keep track of the un-aggregated loss result tensor.
self._compile_metrics_tensors[self.output_names[i] +
'_loss'] = output_loss
# Keep track of stateful result tensor and function for the loss.
loss_name = loss_fn.name if isinstance(
loss_fn, losses.Loss) else loss_fn.__name__
mean_wrapped_loss = metrics_module.MeanMetricWrapper(
loss_fn, name=loss_name)
result_tensor = self._call_metric_fn(mean_wrapped_loss, y_true,
y_pred, sample_weight, mask)
self._compile_stateful_metrics_tensors[self.output_names[i] +
'_loss'] = result_tensor
self._compile_stateful_metric_functions.append(mean_wrapped_loss)
self._compile_metrics_names.append(self.output_names[i] + '_loss')
if total_loss is None:
total_loss = loss_weight * output_loss
else:
total_loss += loss_weight * output_loss
if total_loss is None:
if not self.losses:
raise ValueError('The model cannot be compiled '
'because it has no loss to optimize.')
else:
total_loss = 0.
# Add regularization penalties
# and other layer-specific losses.
for loss_tensor in self.losses:
total_loss += loss_tensor
# Set metric attributes on model.
self._set_metric_attributes(
self.outputs,
skip_target_indices=skip_target_indices,
)
# Invoke metric functions for all the outputs.
self._handle_metrics(
self.outputs,
masks=masks,
targets=self.targets,
skip_target_indices=skip_target_indices,
sample_weights=self.sample_weights)
# Prepare gradient updates and state updates.
self.total_loss = total_loss
# Functions for train, test and predict will
# be compiled lazily when required.
# This saves time when the user is not using all functions.
self._function_kwargs = kwargs
self._fit_function = None
self._eval_function = None
self.train_function = None
self.test_function = None
self.predict_function = None
# Collected trainable weights, sorted in topological order.
trainable_weights = self.trainable_weights
self._collected_trainable_weights = trainable_weights
# Validate all variables were correctly created in distribution scope.
if self._distribution_strategy and not self._compile_distribution:
for v in self.variables:
strategy = self._distribution_strategy
if not strategy.extended.variable_created_in_scope(v):
raise ValueError(
'Variable (%s) was not created in the distribution strategy '
'scope of (%s). It is most likely due to not all layers or '
'the model or optimizer being created outside the distribution '
'strategy scope. Try to make sure your code looks similar '
'to the following.\n'
'with strategy.scope():\n'
' model=_create_model()\n'
' model.compile(...)'% (v, strategy))
@property
def metrics(self):
"""Returns the model's metrics added using `compile`, `add_metric` APIs."""
metrics = []
if self._is_compiled:
metrics += self._compile_stateful_metric_functions
return metrics + super(Model, self).metrics
@property
def metrics_names(self):
"""Returns the model's display labels for all outputs."""
metrics_names = []
if self._is_compiled:
metrics_names += self._compile_metrics_names # Includes names of losses.
# Add metric names from layers.
for layer in self.layers:
metrics_names += [m.name for m in layer._metrics] # pylint: disable=protected-access
metrics_names += [m.name for m in self._metrics]
return metrics_names
@property
def run_eagerly(self):
"""Settable attribute indicating whether the model should run eagerly.
Running eagerly means that your model will be run step by step,
like Python code. Your model might run slower, but it should become easier
for you to debug it by stepping into individual layer calls.
By default, we will attempt to compile your model to a static graph to
deliver the best execution performance.
Returns:
Boolean, whether the model should run eagerly.
"""
if self._run_eagerly is True and not context.executing_eagerly():
raise ValueError('You can only set `run_eagerly=True` if eager execution '
'is enabled.')
if not self.dynamic:
if self._run_eagerly is None:
return False
else:
return self._run_eagerly
else:
if not context.executing_eagerly():
raise ValueError('Your model contains layers that can only be '
'successfully run in eager execution (layers '
'constructed with `dynamic=True`). '
'You must enable eager execution with '
'`tf.enable_eager_execution()`.')
if self._run_eagerly is False:
# TODO(fchollet): consider using py_func to enable this.
raise ValueError('Your model contains layers that can only be '
'successfully run in eager execution (layers '
'constructed with `dynamic=True`). '
'You cannot set `run_eagerly=False`.')
return context.executing_eagerly()
@run_eagerly.setter
def run_eagerly(self, value):
self._run_eagerly = value
def fit(self,
x=None,
y=None,
batch_size=None,
epochs=1,
verbose=1,
callbacks=None,
validation_split=0.,
validation_data=None,
shuffle=True,
class_weight=None,
sample_weight=None,
initial_epoch=0,
steps_per_epoch=None,
validation_steps=None,
validation_freq=1,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
**kwargs):
"""Trains the model for a fixed number of epochs (iterations on a dataset).
Arguments:
x: Input data. It could be:
- A Numpy array (or array-like), or a list of arrays
(in case the model has multiple inputs).
- A TensorFlow tensor, or a list of tensors
(in case the model has multiple inputs).
- A dict mapping input names to the corresponding array/tensors,
if the model has named inputs.
- A `tf.data` dataset or a dataset iterator. Should return a tuple
of either `(inputs, targets)` or
`(inputs, targets, sample_weights)`.
- A generator or `keras.utils.Sequence` returning `(inputs, targets)`
or `(inputs, targets, sample weights)`.
y: Target data. Like the input data `x`,
it could be either Numpy array(s) or TensorFlow tensor(s).
It should be consistent with `x` (you cannot have Numpy inputs and
tensor targets, or inversely). If `x` is a dataset, dataset
iterator, generator, or `keras.utils.Sequence` instance, `y` should
not be specified (since targets will be obtained from `x`).
batch_size: Integer or `None`.
Number of samples per gradient update.
If unspecified, `batch_size` will default to 32.
Do not specify the `batch_size` if your data is in the
form of symbolic tensors, dataset, dataset iterators,
generators, or `keras.utils.Sequence` instances (since they generate
batches).
epochs: Integer. Number of epochs to train the model.
An epoch is an iteration over the entire `x` and `y`
data provided.
Note that in conjunction with `initial_epoch`,
`epochs` is to be understood as "final epoch".
The model is not trained for a number of iterations
given by `epochs`, but merely until the epoch
of index `epochs` is reached.
verbose: Integer. 0, 1, or 2. Verbosity mode.
0 = silent, 1 = progress bar, 2 = one line per epoch.
callbacks: List of `keras.callbacks.Callback` instances.
List of callbacks to apply during training.
See `tf.keras.callbacks`.
validation_split: Float between 0 and 1.
Fraction of the training data to be used as validation data.
The model will set apart this fraction of the training data,
will not train on it, and will evaluate
the loss and any model metrics
on this data at the end of each epoch.
The validation data is selected from the last samples
in the `x` and `y` data provided, before shuffling. This argument is
not supported when `x` is a dataset, dataset iterator, generator or
`keras.utils.Sequence` instance.
validation_data: Data on which to evaluate
the loss and any model metrics at the end of each epoch.
The model will not be trained on this data.
`validation_data` will override `validation_split`.
`validation_data` could be:
- tuple `(x_val, y_val)` of Numpy arrays or tensors
- tuple `(x_val, y_val, val_sample_weights)` of Numpy arrays
- dataset or a dataset iterator
For the first two cases, `batch_size` must be provided.
For the last case, `validation_steps` must be provided.
shuffle: Boolean (whether to shuffle the training data
before each epoch) or str (for 'batch').
'batch' is a special option for dealing with the
limitations of HDF5 data; it shuffles in batch-sized chunks.
Has no effect when `steps_per_epoch` is not `None`.
class_weight: Optional dictionary mapping class indices (integers)
to a weight (float) value, used for weighting the loss function
(during training only).
This can be useful to tell the model to
"pay more attention" to samples from
an under-represented class.
sample_weight: Optional Numpy array of weights for
the training samples, used for weighting the loss function
(during training only). You can either pass a flat (1D)
Numpy array with the same length as the input samples
(1:1 mapping between weights and samples),
or in the case of temporal data,
you can pass a 2D array with shape
`(samples, sequence_length)`,
to apply a different weight to every timestep of every sample.
In this case you should make sure to specify
`sample_weight_mode="temporal"` in `compile()`. This argument is not
supported when `x` is a dataset, dataset iterator, generator, or
`keras.utils.Sequence` instance, instead provide the sample_weights
as the third element of `x`.
initial_epoch: Integer.
Epoch at which to start training
(useful for resuming a previous training run).
steps_per_epoch: Integer or `None`.
Total number of steps (batches of samples)
before declaring one epoch finished and starting the
next epoch. When training with input tensors such as
TensorFlow data tensors, the default `None` is equal to
the number of samples in your dataset divided by
the batch size, or 1 if that cannot be determined.
validation_steps: Only relevant if `validation_data` is provided and
is a dataset or dataset iterator. Total number of steps (batches of
samples) to draw before stopping when performing validation
at the end of every epoch.
validation_freq: Only relevant if validation data is provided. Integer
or `collections.Container` instance (e.g. list, tuple, etc.). If an
integer, specifies how many training epochs to run before a new
validation run is performed, e.g. `validation_freq=2` runs
validation every 2 epochs. If a Container, specifies the epochs on
which to run validation, e.g. `validation_freq=[1, 2, 10]` runs
validation at the end of the 1st, 2nd, and 10th epochs.
max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
input only. Maximum size for the generator queue.
If unspecified, `max_queue_size` will default to 10.
workers: Integer. Used for generator or `keras.utils.Sequence` input
only. Maximum number of processes to spin up
when using process-based threading. If unspecified, `workers`
will default to 1. If 0, will execute the generator on the main
thread.
use_multiprocessing: Boolean. Used for generator or
`keras.utils.Sequence` input only. If `True`, use process-based
threading. If unspecified, `use_multiprocessing` will default to
`False`. Note that because this implementation relies on
multiprocessing, you should not pass non-picklable arguments to
the generator as they can't be passed easily to children processes.
**kwargs: Used for backwards compatibility.
Returns:
A `History` object. Its `History.history` attribute is
a record of training loss values and metrics values
at successive epochs, as well as validation loss values
and validation metrics values (if applicable).
Raises:
RuntimeError: If the model was never compiled.
ValueError: In case of mismatch between the provided input data
and what the model expects.
"""
# Legacy support
if 'nb_epoch' in kwargs:
logging.warning(
'The `nb_epoch` argument in `fit` '
'has been renamed `epochs`.')
epochs = kwargs.pop('nb_epoch')
if kwargs:
raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))
# When the model expects dictionary inputs (i.e. FeatureColumn-based
# models), set run_eagerly to True as there's no support for graph
# functions.
training_utils.set_run_eagerly_for_dict_structure(self, x)
# Case 1: distribution strategy.
if self._distribution_strategy:
if training_utils.should_run_multi_worker():
# Multi-Worker mode runs the Keras training loop on multiple
# servers via the Distribute Coordinator.
def _worker_fn(_):
"""Run training inside the distributed coordinator."""
self._configure_distributed_session()
return training_distributed.fit_distributed(
self,
x=x,
y=y,
batch_size=batch_size,
epochs=epochs,
verbose=verbose,
callbacks=callbacks,
validation_split=validation_split,
validation_data=validation_data,
shuffle=shuffle,
class_weight=class_weight,
sample_weight=sample_weight,
initial_epoch=initial_epoch,
steps_per_epoch=steps_per_epoch,
validation_steps=validation_steps)
# Independent worker only for now.
return dc.run_distribute_coordinator(
_worker_fn,
self._distribution_strategy,
mode=dc.CoordinatorMode.INDEPENDENT_WORKER)
else:
self._configure_distributed_session()
return training_distributed.fit_distributed(
self,
x=x,
y=y,
batch_size=batch_size,
epochs=epochs,
verbose=verbose,
callbacks=callbacks,
validation_split=validation_split,
validation_data=validation_data,
shuffle=shuffle,
class_weight=class_weight,
sample_weight=sample_weight,
initial_epoch=initial_epoch,
steps_per_epoch=steps_per_epoch,
validation_steps=validation_steps)
batch_size = self._validate_or_infer_batch_size(
batch_size, steps_per_epoch, x)
# Case 2: generator-like. Input is Python generator, or Sequence object,
# or a non-distributed Dataset or iterator in eager execution.
if data_utils.is_generator_or_sequence(x):
training_utils.check_generator_arguments(
y, sample_weight, validation_split=validation_split)
return self.fit_generator(
x,
steps_per_epoch=steps_per_epoch,
epochs=epochs,
verbose=verbose,
callbacks=callbacks,
validation_data=validation_data,
validation_steps=validation_steps,
validation_freq=validation_freq,
class_weight=class_weight,
max_queue_size=max_queue_size,
workers=workers,
use_multiprocessing=use_multiprocessing,
shuffle=shuffle,
initial_epoch=initial_epoch)
if training_utils.is_eager_dataset_or_iterator(x):
# Make sure that y, sample_weights, validation_split are not passed.
training_utils.validate_dataset_input(x, y, sample_weight,
validation_split)
if (isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2))
and shuffle):
training_utils.verify_dataset_shuffled(x)
return self.fit_generator(
x,
steps_per_epoch=steps_per_epoch,
epochs=epochs,
verbose=verbose,
callbacks=callbacks,
validation_data=validation_data,
validation_steps=validation_steps,
validation_freq=validation_freq,
class_weight=class_weight,
workers=0,
shuffle=shuffle,
initial_epoch=initial_epoch)
# Case 3: Symbolic tensors or Numpy array-like.
# This includes Datasets and iterators in graph mode (since they
# generate symbolic tensors).
x, y, sample_weights = self._standardize_user_data(
x,
y,
sample_weight=sample_weight,
class_weight=class_weight,
batch_size=batch_size,
check_steps=True,
steps_name='steps_per_epoch',
steps=steps_per_epoch,
validation_split=validation_split,
shuffle=shuffle)
# Prepare validation data.
if validation_data:
val_x, val_y, val_sample_weights = self._unpack_validation_data(
validation_data)
val_x, val_y, val_sample_weights = self._standardize_user_data(
val_x,
val_y,
sample_weight=val_sample_weights,
batch_size=batch_size,
steps=validation_steps,
steps_name='validation_steps')
elif validation_split and 0. < validation_split < 1.:
if training_utils.has_symbolic_tensors(x):
raise ValueError('If your data is in the form of symbolic tensors, '
'you cannot use `validation_split`.')
if hasattr(x[0], 'shape'):
split_at = int(x[0].shape[0] * (1. - validation_split))
else:
split_at = int(len(x[0]) * (1. - validation_split))
x, val_x = (slice_arrays(x, 0, split_at), slice_arrays(x, split_at))
y, val_y = (slice_arrays(y, 0, split_at), slice_arrays(y, split_at))
sample_weights, val_sample_weights = (slice_arrays(
sample_weights, 0, split_at), slice_arrays(sample_weights, split_at))
elif validation_steps:
val_x = []
val_y = []
val_sample_weights = []
else:
val_x = None
val_y = None
val_sample_weights = None
if self.run_eagerly:
return training_generator.fit_generator(
self, (x, y, sample_weights),
steps_per_epoch=steps_per_epoch,
batch_size=batch_size,
epochs=epochs,
verbose=verbose,
callbacks=callbacks,
validation_data=validation_data,
validation_steps=validation_steps,
validation_freq=validation_freq,
workers=0,
shuffle=shuffle,
initial_epoch=initial_epoch,
steps_name='steps_per_epoch')
else:
return training_arrays.fit_loop(
self,
x,
y,
sample_weights=sample_weights,
batch_size=batch_size,
epochs=epochs,
verbose=verbose,
callbacks=callbacks,
val_inputs=val_x,
val_targets=val_y,
val_sample_weights=val_sample_weights,
shuffle=shuffle,
initial_epoch=initial_epoch,
steps_per_epoch=steps_per_epoch,
validation_steps=validation_steps,
validation_freq=validation_freq,
steps_name='steps_per_epoch')
def evaluate(self,
x=None,
y=None,
batch_size=None,
verbose=1,
sample_weight=None,
steps=None,
callbacks=None,
max_queue_size=10,
workers=1,
use_multiprocessing=False):
"""Returns the loss value & metrics values for the model in test mode.
Computation is done in batches.
Arguments:
x: Input data. It could be:
- A Numpy array (or array-like), or a list of arrays
(in case the model has multiple inputs).
- A TensorFlow tensor, or a list of tensors
(in case the model has multiple inputs).
- A dict mapping input names to the corresponding array/tensors,
if the model has named inputs.
- A `tf.data` dataset or a dataset iterator.
- A generator or `keras.utils.Sequence` instance.
y: Target data. Like the input data `x`,
it could be either Numpy array(s) or TensorFlow tensor(s).
It should be consistent with `x` (you cannot have Numpy inputs and
tensor targets, or inversely).
If `x` is a dataset, dataset iterator, generator or
`keras.utils.Sequence` instance, `y` should not be specified (since
targets will be obtained from the iterator/dataset).
batch_size: Integer or `None`.