/
keras_support.py
2259 lines (1908 loc) · 85.5 KB
/
keras_support.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""*Experimental* support for running Keras models on the TPU.
To use, wrap your model with the `keras_support.tpu_model` function.
Example usage:
```
image = tf.keras.layers.Input(shape=(28, 28, 3), name='image')
c1 = tf.keras.layers.Conv2D(filters=16, kernel_size=(3, 3))( image)
flattened = tf.keras.layers.Flatten()(c1)
logits = tf.keras.layers.Dense(10, activation='softmax')(flattened)
model = tf.keras.Model(inputs=[image], outputs=[logits])
resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu=tpu_name)
strategy = keras_support.TPUDistributionStrategy(resolver)
model = keras_support.tpu_model(model, strategy=strategy)
# Only TF optimizers are currently supported.
model.compile(optimizer=tf.compat.v1.train.AdamOptimizer(), ...)
# `images` and `labels` should be Numpy arrays. Support for tensor input
# (e.g. datasets) is planned.
model.fit(images, labels)
```
"""
# pylint: disable=protected-access
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import collections
import contextlib
import re
import sys
import time
import numpy as np
import six
from tensorflow.contrib.cluster_resolver.python.training import tpu_cluster_resolver as tpu_cluster_resolver_lib
from tensorflow.contrib.tpu.python.ops import tpu_ops
from tensorflow.contrib.tpu.python.tpu import keras_tpu_variables
from tensorflow.contrib.tpu.python.tpu import tpu
from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.contrib.tpu.python.tpu import tpu_optimizer
from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf.tpu import compilation_result_pb2 as tpu_compilation_result
from tensorflow.python import tf2
from tensorflow.python.client import session as tf_session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import context
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import callbacks as cbks
from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.keras import models
from tensorflow.python.keras import optimizers as keras_optimizers
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.engine import training_arrays
from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.layers import embeddings
from tensorflow.python.keras.utils.generic_utils import make_batches
from tensorflow.python.keras.utils.generic_utils import slice_arrays
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.deprecation import deprecated
# TODO(b/114775106): temporary shim to optionally initialize the TPU
# This increases the odds our session is initialized, but shouldn't be needed.
_TEST_REWRITE_OP = None
def _maybe_initialize_tpu(session):
"""Initialize the TPU if it has not already been initialized."""
global _TEST_REWRITE_OP
try:
# Try to use cached version to avoid another ground of graph optimization.
test_rewrite_op = _TEST_REWRITE_OP
if (test_rewrite_op is None or
test_rewrite_op[0].graph != ops.get_default_graph()):
def test_op():
return constant_op.constant(1) + constant_op.constant(1)
test_rewrite_op = tpu.rewrite(test_op)
_TEST_REWRITE_OP = test_rewrite_op
session.run(test_rewrite_op)
except errors.FailedPreconditionError as _:
session.run(tpu.initialize_system())
@contextlib.contextmanager
def _tpu_session_context():
"""Initialize the TPU and cleans cache entries for bad sessions."""
try:
_maybe_initialize_tpu(K.get_session())
yield
except (errors.FailedPreconditionError, errors.AbortedError) as e:
K.clear_session()
raise Exception("""
An error occurred connecting or initializing your TPU.
The session has been reset. re-run keras_to_tpu_model to create a new session.
""" + str(e))
def setup_tpu_session(cluster_resolver):
"""Construct or return a `tf.Session` connected to the given cluster."""
master = cluster_resolver.master()
# Use the existing session if we're already connected to this TPU
# N.B K.get_session() is a non-trivial operation, and may fail if the remote
# session has been reset.
try:
default_session = K.get_session()
if (default_session._target == master and
getattr(default_session, '_tpu_initialized', None)):
return
except errors.AbortedError as _:
# We lost the remote session and need to re-initialize.
logging.warning('Lost remote session: creating a new session.')
cluster_spec = cluster_resolver.cluster_spec()
config = config_pb2.ConfigProto(isolate_session_state=True)
if cluster_spec:
config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
tpu_session = tf_session.Session(target=master, config=config)
tpu_session.run(tpu.initialize_system())
tpu_session._tpu_initialized = True
# N.B. We have to call `K.set_session()` AND set our session as the
# TF default. `K.get_session()` surprisingly does not return the value
# supplied by K.set_session otherwise.
K.set_session(tpu_session)
try:
from scipy.sparse import issparse # pylint: disable=g-import-not-at-top
except ImportError:
issparse = None
def get_tpu_system_metadata(tpu_cluster_resolver):
"""Retrieves TPU system metadata given a TPUClusterResolver."""
master = tpu_cluster_resolver.master()
# pylint: disable=protected-access
cluster_spec = tpu_cluster_resolver.cluster_spec()
cluster_def = cluster_spec.as_cluster_def() if cluster_spec else None
tpu_system_metadata = (
tpu_system_metadata_lib._query_tpu_system_metadata(
master, cluster_def=cluster_def, query_topology=False))
return tpu_system_metadata
class TPUDistributionStrategy(object):
"""The strategy to run Keras model on TPU."""
def __init__(self, tpu_cluster_resolver=None, using_single_core=False):
"""Construct a TPUDistributionStrategy.
Args:
tpu_cluster_resolver: Any instance of `TPUClusterResolver`. If None, will
create one with '' as master address.
using_single_core: Bool. This is the debugging option, which might be
removed in future once the model replication functionality is mature
enough. If `False` (default behavior), the system automatically finds
the best configuration, in terms of number of TPU cores, for the model
replication, typically using all available TPU cores. If overwrites as
`True`, force the model replication using single core, i.e., no
replication.
Raises:
Exception: No TPU Found on the given worker.
"""
if tf2.enabled():
raise RuntimeError(
'Keras support is now deprecated in support of TPU Strategy. '
'Please follow the distribution strategy guide on tensorflow.org '
'to migrate to the 2.0 supported version.')
else:
logging.warning(
'Keras support is now deprecated in support of TPU Strategy. '
'Please follow the distribution strategy guide on tensorflow.org '
'to migrate to the 2.0 supported version.')
if tpu_cluster_resolver is None:
tpu_cluster_resolver = tpu_cluster_resolver_lib.TPUClusterResolver('')
metadata = get_tpu_system_metadata(tpu_cluster_resolver)
self._tpu_metadata = metadata
self._tpu_cluster_resolver = tpu_cluster_resolver
self._num_cores = 1 if using_single_core else metadata.num_cores
# Walk device list to identify TPU worker for enqueue/dequeue operations.
worker_re = re.compile('/job:([^/]+)')
for device in metadata.devices:
if 'TPU:0' in device.name:
self._worker_name = worker_re.search(device.name).group(1)
return
raise Exception('No TPU found on given worker.')
def _make_assignment_for_model(self, cpu_model):
"""Makes a `TPUAssignment` for the passed in `cpu_model`."""
num_cores = self._num_cores
if num_cores > 1 and cpu_model.stateful:
logging.warning(
'Model replication does not currently support stateful models. '
'Degrading to a single core.')
num_cores = 1
return TPUAssignment(worker_name=self._worker_name, num_cores=num_cores)
class TPUAssignment(object):
"""This is object holding TPU resources assignment for the concrete model.
`TPUDistributionStrategy` is responsible to create the instance of
`TPUAssignment`, so, it can dynamically adjust the `num_cores` to use based on
model and input batch sizes.
"""
def __init__(self, worker_name, num_cores):
self._worker_name = worker_name
self._num_cores = num_cores
@property
def worker_name(self):
return self._worker_name
@property
def num_towers(self):
# TODO(xiejw): Support automatically assign num_cores based on inputs.
return self._num_cores
class TPUEmbedding(embeddings.Embedding):
"""TPU compatible embedding layer.
The default Keras layer is not TPU compatible. This layer is a drop-in
replacement: it has the same behavior and will work on CPU and GPU devices.
"""
def build(self, input_shape):
if input_shape[0] is None:
raise ValueError(
'TPUEmbeddings must have a fixed input_length or input shape.')
return super(TPUEmbedding, self).build(input_shape)
def call(self, inputs):
if K.dtype(inputs) != 'int32':
inputs = math_ops.cast(inputs, 'int32')
inputs = array_ops.one_hot(inputs, self.input_dim)
return math_ops.tensordot(inputs, self.embeddings, 1)
def _cross_replica_concat(tensor, core_id, num_cores, name):
"""Concatenate `tensor` across cores.
Args:
tensor: The tensor to be concatenated. Must be [int32 and float32].
core_id: Tensor indicating the current TPU core.
num_cores: Python int. The total number of TPU cores in the system.
name: The string name to print for debugging.
Returns:
The same concatenated Tensor on each core.
"""
input_dtype = tensor.dtype
if input_dtype not in [dtypes.bfloat16, dtypes.float32, dtypes.int32]:
raise TypeError('For model replication, only (bfloat16, float32 and int32) '
'is supported for model outputs and targets. Got {} for '
'{}.'.format(input_dtype, name))
batch_size = tensor.shape[0]
mask = math_ops.cast(
math_ops.equal(np.arange(num_cores, dtype=np.int32), core_id),
dtypes.float32)
mask = array_ops.reshape(mask, [num_cores] + [1] * tensor.shape.ndims)
result = mask * math_ops.cast(tensor, dtypes.float32)
local_tensor_with_holes = array_ops.reshape(result,
[-1] + result.shape.as_list()[2:])
concat_tensor = tpu_ops.cross_replica_sum(local_tensor_with_holes)
concat_tensor.set_shape((num_cores * batch_size,) + tuple(tensor.shape[1:]))
if concat_tensor != input_dtype:
concat_tensor = math_ops.cast(concat_tensor, input_dtype)
return concat_tensor
class KerasCrossShardOptimizer(keras_optimizers.Optimizer):
"""An optimizer that averages gradients across TPU shards."""
def __init__(self, opt, name='KerasCrossShardOptimizer'):
"""Construct a new cross-shard optimizer.
Args:
opt: An existing `Optimizer` to encapsulate.
name: Optional name prefix for the operations created when applying
gradients. Defaults to "KerasCrossShardOptimizer".
Raises:
ValueError: If reduction is not a valid cross-shard reduction.
"""
super(KerasCrossShardOptimizer, self).__init__()
self._name = name
self._opt = opt
logging.info('KerasCrossShard: %s %s', self._opt, self._opt.weights)
def get_updates(self, loss, params):
self._opt.get_gradients = self.get_gradients
return self._opt.get_updates(loss, params)
def get_gradients(self, loss, params):
num_shards = tpu_function.get_tpu_context().number_of_shards
grads = super(KerasCrossShardOptimizer, self).get_gradients(loss, params)
return [tpu_ops.cross_replica_sum(grad) / num_shards for grad in grads]
def get_weights(self):
return self._opt.get_weights()
def get_config(self):
return self._opt.get_config()
# Defer remaining operations to the underlying optimizer
def __getattr__(self, key):
return getattr(self._opt, key)
class TPUModelOp(
collections.namedtuple('TPUModelOp', [
'compile_op', 'execute_op', 'infeed_tensors', 'infeed_op', 'outfeed_op'
])):
pass
def _valid_name(tensor_name):
"""Return a valid tensor name (strips '/', ':', etc)."""
return re.sub('[^a-zA-Z0-9_-]+', '', tensor_name)
def _replicated_optimizer(opt):
"""Wrap the optimizer `opt` with CrossShardOptimizer if applicable."""
# Always wrap `opt` with CrossShardOptimizer, even if we are running on a
# single core. This ensures Keras properly tracks and initializes optimizer
# variables.
if isinstance(opt, keras_optimizers.TFOptimizer):
return tpu_optimizer.CrossShardOptimizer(opt.optimizer)
else:
return KerasCrossShardOptimizer(opt)
def _clone_optimizer(optimizer, config=None, worker_name=None):
"""Returns a cloned optimizer with the provided optimizer.config or config."""
if not isinstance(optimizer, keras_optimizers.Optimizer):
# In the first call to tpu_model(model), Keras may not have wrapped the TF
# optimizer in the TFOptimizer helper, e.g., the given model isn't compiled
# or optimizer isn't set, and later generated tpu_model compiles with a TF
# optimizer.
return optimizer
if isinstance(optimizer, keras_optimizers.TFOptimizer):
return keras_optimizers.TFOptimizer(optimizer.optimizer)
if config is None:
config = optimizer.get_config()
logging.info('Cloning %s %s', optimizer.__class__.__name__, config)
with ops.device(
'%s/device:CPU:0' % ('/job:%s' % worker_name if worker_name else '')):
# Explicitly put optimizer parameter variables on TPU worker.
return optimizer.__class__.from_config(config)
class TPURewriteContext(object):
"""Prepare the environment for a Keras model during `tpu.rewrite`.
This overrides the default placeholder behaviour to instead refer to a preset
input mapping. Placeholders are unsupported in TPU compiled code, and must
be replaced with explicit inputs or values from the infeed queue.
Instead of explicitly threading inputs all the way through the Keras codebase,
we override the behavior of the placeholder while compiling and inject the
Tensors from the infeed in place of the placeholder.
Similarly, as we compile a new sub-graph for each unique shape and execution
mode, we need to override the behavior of an embedded `name_scope` call in
the base Keras layer code. This allows us to re-use the same weights across
many compiles and share a single session/graph.
"""
def __init__(self, input_map):
self._input_map = input_map
self._default_placeholder = None
self._default_name_scope = None
def __enter__(self):
def _placeholder(dtype, shape=None, name=None): # pylint: disable=unused-argument
logging.info('Remapping placeholder for %s', name)
if name in self._input_map:
return self._input_map[name]
else:
logging.info('Default: %s', name)
return self._default_placeholder(dtype, shape, name)
def _name_scope(name, default_name=None, values=None):
caller_frame = sys._getframe().f_back
caller_obj = caller_frame.f_locals.get('self')
if (caller_obj is not None and
isinstance(caller_obj, base_layer.Layer) and name is not None):
return variable_scope.variable_scope(
name, default_name, values, reuse=variable_scope.AUTO_REUSE)
return self._default_name_scope(name, default_name, values)
self._default_placeholder = array_ops.placeholder
self._default_name_scope = ops.name_scope
self._default_make_variable = base_layer_utils.make_variable
self._default_random_normal = random_ops.random_normal
self._default_qr = gen_linalg_ops.qr
array_ops.placeholder = _placeholder
# Replace random_ops.random_normal with a dummy function because
# `random_normal` isn't yet implemented on the TPU. Because these
# initialized values are overwritten by the CPU values, this is okay.
def random_normal(shape,
mean=0.0,
stddev=1.0,
dtype=dtypes.float32,
seed=None,
name=None):
del mean
del stddev
del seed
return array_ops.zeros(shape, dtype=dtype, name=name)
random_ops.random_normal = random_normal
# Replace gen_linalg_ops.qr because QR decomposition is not yet implemented.
# TODO(saeta): Remove qr override once we confirm the qr implementation is
# ok.
# pylint: disable=redefined-builtin
def qr(input, full_matrices=False, name=None):
"""Dummy implementation of qr decomposition."""
del full_matrices # TODO(saeta): Properly handle the full matrix case.
input_shape = input.shape
if len(input_shape) < 2:
raise ValueError('Invalid shape passed to qr: %s' % input_shape)
p = min(input_shape[-1], input_shape[-2])
if len(input_shape) == 2:
q = array_ops.zeros((p, p), name=name)
r = array_ops.zeros(input_shape, name=name)
return (r, q)
elif len(input_shape) == 3:
n = input_shape[0]
q = array_ops.zeros((n, p, p), name=name)
r = array_ops.zeros(input_shape, name=name)
return (r, q)
else:
raise ValueError('Invalid shape passed to qr: %s' % input_shape)
gen_linalg_ops.qr = qr
ops.name_scope = _name_scope
base_layer_utils.make_variable = variable_scope.get_variable
logging.info('Overriding default placeholder.')
return
def __exit__(self, exc_type, exc_val, exc_tb):
array_ops.placeholder = self._default_placeholder
ops.name_scope = self._default_name_scope
base_layer_utils.make_variable = self._default_make_variable
random_ops.random_normal = self._default_random_normal
gen_linalg_ops.qr = self._default_qr
class SizedInfeed(
collections.namedtuple('SizedInfeed',
['sharded_infeed_tensors', 'infeed_ops'])):
"""Represents an instantiation of the infeed ops for a concrete input shape.
sharded_infeed_tensors: A data structure of Tensors used to represent the
placeholder tensors that must be fed when using feed_dicts.
infeed_ops: the set of ops that will be run to drive infeed for a single step.
"""
pass
class TPUInfeedInstance(object):
"""TPUInfeedInstance represents the logic to manage feeding in a single step.
See the comments on the `TPUInfeedManager` for a description for how infeed
is managed.
"""
@abc.abstractmethod
def make_input_specs(self, input_tensors):
"""Constructs the infeed_specs for the given Infeed instance.
Args:
input_tensors: The inputs to the model.
Returns:
A list of
"""
pass
def make_feed_dict(self, tpu_model_op):
"""Constructs a feed_dict for this instance, given the tpu_model_op.
Args:
tpu_model_op: A `TPUModelOp` representing the TPU Model for this
instance's input spec.
Returns:
A dictionary to use as the feed_dict of a `session.run` call.
"""
pass
@six.add_metaclass(abc.ABCMeta)
class TPUInfeedManager(object):
"""TPUInfeedManager manages the data infeeding of data to a TPU computation.
Because there are multiple data sources (e.g. in-memory NumPy arrays,
`tf.data.Dataset`s), we abstract the different logic behind a single
interface: the `TPUInfeedManager`.
(1) A `TPUFunction` is called with a set of inputs. Based on the inputs,
`TPUFunction` retrieves the corresponding `TPUInfeedManager` (or constructs a
new one if required).
(2) The `TPUFunction` calls `make_infeed_instance` on the `TPUInfeedManager`
which returns a `TPUInfeedInstance`.
(3) The `TPUFunction` checks in the shape cache for a pre-compiled instance of
the model based on the returned `input_specs` from `TPUInfeedInstance`.
(4) [Optional.] If the model has not already been instantiated for the given
input spec, the `TPUFunction` compiles the model for the input spec (using the
`TPUInfeedManager`).
(5) The `TPUInfeedInstance` constructs the session.run's feed_dict given the
compiled model instance corresponding to its shape.
"""
@abc.abstractmethod
def make_infeed_instance(self, inputs):
"""Given a single step's input, construct a `TPUInfeedInstance`.
Args:
inputs: The inputs to a given step.
Returns:
A subclass of `TPUInfeedInstance`.
"""
pass
@abc.abstractmethod
def build_infeed_from_input_specs(self, input_specs, execution_mode):
"""For a given input specification (size, type), construct the infeed ops.
This is called only once for a given input specification and builds the
graph ops. It does not have a pointer to the actual infeed data.
Args:
input_specs: TODO(saeta): Document me!
execution_mode: TODO(saeta): Document me!
Returns:
A `SizedInfeed` instance.
"""
pass
class TPUNumpyInfeedManager(TPUInfeedManager):
"""TPU Infeed manager for Numpy inputs."""
class NumpyInfeedInstance(TPUInfeedInstance):
"""Infeed instance for Numpy inputs."""
def __init__(self, sharded_inputs):
self._sharded_inputs = sharded_inputs
def make_input_specs(self, input_tensors):
# Compute an input specification (used to generate infeed enqueue and
# dequeue operations). We use the shape from our input array and the
# dtype from our model. A user may pass in a float64 for a float32
# input: for model compatibility we still must generate a float32 infeed.
input_specs = []
# We use the shape and dtype from the first shard to compute the input
# metadata (`input_specs`); all replicas have the same type and shape.
for tensor, ary in zip(input_tensors, self._sharded_inputs[0]):
input_specs.append(
tensor_spec.TensorSpec(ary.shape, tensor.dtype,
_valid_name(tensor.name)))
return input_specs
def make_feed_dict(self, tpu_model_op):
infeed_dict = {}
for infeed_tensors, inputs in zip(tpu_model_op.infeed_tensors,
self._sharded_inputs):
for tensor, value in zip(infeed_tensors, inputs):
infeed_dict[tensor] = value
return infeed_dict
def __init__(self, tpu_assignment):
self._tpu_assignment = tpu_assignment
def _split_tensors(self, inputs):
"""Split input data across shards.
Each input is sliced along the batch axis.
Args:
inputs: List of Numpy arrays to run on the TPU.
Returns:
List of lists containing the input to feed to each TPU shard.
"""
if self._tpu_assignment.num_towers == 1:
return [inputs]
batch_size = inputs[0].shape[0]
assert batch_size % self._tpu_assignment.num_towers == 0, (
'batch_size must be divisible by the number of TPU cores in use (%s '
'vs %s)' % (batch_size, self._tpu_assignment.num_towers))
shard_size = batch_size // self._tpu_assignment.num_towers
input_list = []
for index in range(self._tpu_assignment.num_towers):
shard_inputs = [
x[index * shard_size:(index + 1) * shard_size] for x in inputs
]
input_list.append(shard_inputs)
return input_list
def make_infeed_instance(self, inputs):
sharded_inputs = self._split_tensors(inputs)
return self.NumpyInfeedInstance(sharded_inputs)
def build_infeed_from_input_specs(self, input_specs, execution_mode):
infeed_op = []
shard_infeed_tensors = []
for shard_id in range(self._tpu_assignment.num_towers):
with ops.device(
'/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
infeed_tensors = []
with ops.device('/device:TPU:%d' % shard_id):
for spec in input_specs:
# Construct placeholders for each of the inputs.
infeed_tensors.append(
array_ops.placeholder(
dtype=spec.dtype,
shape=spec.shape,
name='infeed-enqueue-%s-%d' % (spec.name, shard_id)))
shard_infeed_tensors.append(infeed_tensors)
infeed_op.append(
tpu_ops.infeed_enqueue_tuple(
infeed_tensors, [spec.shape for spec in input_specs],
name='infeed-enqueue-%s-%d' % (execution_mode, shard_id),
device_ordinal=shard_id))
return SizedInfeed(
infeed_ops=infeed_op, sharded_infeed_tensors=shard_infeed_tensors)
class TPUDatasetInfeedManager(TPUInfeedManager):
"""Manages infeed for a `tf.data.Dataset` into a TPU computation.
"""
class DatasetInfeedInstance(TPUInfeedInstance):
"""An instance of the TPU infeed."""
def __init__(self, input_specs):
self._input_specs = input_specs
def make_input_specs(self, input_tensors):
# TODO(saeta): Do error checking here!
return self._input_specs
def make_feed_dict(self, tpu_model_op):
# TODO(saeta): Verify tpu_model_op is as expected!
return {}
# pylint: disable=redefined-outer-name
def __init__(self, dataset, tpu_assignment, mode):
"""Constructs a TPUDatasetInfeedManager.
Args:
dataset: A `tf.data.Dataset` to infeed.
tpu_assignment: The `TPUAssignment` used to configure the
Keras TPU model.
mode: ModeKeys enum.
"""
self._verify_dataset_shape(dataset)
self._dataset = dataset
self._tpu_assignment = tpu_assignment
dataset_output_shapes = dataset_ops.get_legacy_output_shapes(dataset)
dummy_x_shape = dataset_output_shapes[0].as_list()
dummy_x_shape[0] *= tpu_assignment.num_towers
dummy_y_shape = dataset_output_shapes[1].as_list()
dummy_y_shape[0] *= tpu_assignment.num_towers
self._iterator = dataset_ops.make_initializable_iterator(dataset)
K.get_session().run(self._iterator.initializer)
self._get_next_ops = []
ctrl_deps = []
for i in range(tpu_assignment.num_towers):
with ops.control_dependencies(ctrl_deps): # Ensure deterministic
# TODO(saeta): Ensure correct placement!
get_next_op = self._iterator.get_next()
self._get_next_ops.append(get_next_op)
ctrl_deps.extend(get_next_op)
# Use dummy numpy inputs for the rest of Keras' shape checking. We
# intercept them when building the model.
dataset_output_types = dataset_ops.get_legacy_output_types(dataset)
self._dummy_x = np.zeros(
dummy_x_shape, dtype=dataset_output_types[0].as_numpy_dtype)
self._dummy_y = np.zeros(
dummy_y_shape, dtype=dataset_output_types[1].as_numpy_dtype)
input_specs = []
iterator_output_shapes = dataset_ops.get_legacy_output_shapes(
self._iterator)
iterator_output_types = dataset_ops.get_legacy_output_types(self._iterator)
if isinstance(iterator_output_shapes, tuple):
assert isinstance(iterator_output_types, tuple)
assert len(iterator_output_shapes) == len(iterator_output_types)
for i in range(len(iterator_output_shapes)):
spec = tensor_spec.TensorSpec(iterator_output_shapes[i],
iterator_output_types[i])
input_specs.append(spec)
elif isinstance(iterator_output_shapes, tensor_shape.TensorShape):
spec = tensor_spec.TensorSpec(iterator_output_shapes,
iterator_output_types)
input_specs.append(spec)
# Pre-process the inputs and get_next_ops before caching.
input_specs, self._get_next_ops = (
_inject_tpu_inputs_for_dataset(
tpu_assignment, mode, input_specs, self._get_next_ops))
self._infeed_instance = self.DatasetInfeedInstance(input_specs)
def _verify_dataset_shape(self, dataset):
"""Verifies a dataset is of an appropriate shape for TPUs."""
dataset_output_shapes = dataset_ops.get_legacy_output_shapes(dataset)
dataset_output_classes = dataset_ops.get_legacy_output_classes(dataset)
if not isinstance(dataset, dataset_ops.DatasetV2):
raise ValueError('The function passed as the `x` parameter did not '
'return a `tf.data.Dataset`.')
if not isinstance(dataset_output_classes, tuple):
raise ValueError('The dataset must return a tuple of tf.Tensors, '
'instead it returns: %s' % dataset_output_classes)
if len(dataset_output_classes) != 2:
raise ValueError('The dataset must return a 2-element tuple, got '
'%s output classes instead.' % (dataset_output_classes,))
for i, cls in enumerate(dataset_output_classes):
if cls != ops.Tensor:
raise ValueError('The dataset returned a non-Tensor type (%s) at '
'index %d.' % (cls, i))
for i, shape in enumerate(dataset_output_shapes):
if not shape:
raise ValueError('The dataset returns a scalar tensor in '
'tuple index %d. Did you forget to batch? '
'(Output shapes: %s).' % (i, dataset_output_shapes))
for j, dim in enumerate(shape):
if dim.value is None:
if j == 0:
hint = (' Hint: did you use `ds.batch(BATCH_SIZE, '
'drop_remainder=True)`?')
else:
hint = ''
raise ValueError(
'The Keras-TPU integration for `tf.data` '
'currently requires static shapes. The provided '
'dataset only has a partially defined shape. '
'(Dimension %d of output tensor %d is not statically known '
'for output shapes: %s.%s)' % (j, i, dataset_output_shapes, hint))
@property
def dummy_x(self):
return self._dummy_x
@property
def dummy_y(self):
return self._dummy_y
def make_infeed_instance(self, inputs):
# TODO(saeta): Verify inputs is as expected.
return self._infeed_instance
def build_infeed_from_input_specs(self, input_specs, execution_mode):
shard_infeed_tensors = self._get_next_ops
assert len(shard_infeed_tensors) == self._tpu_assignment.num_towers
infeed_ops = []
for shard_id in range(self._tpu_assignment.num_towers):
with ops.device(
'/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
infeed_ops.append(
tpu_ops.infeed_enqueue_tuple(
shard_infeed_tensors[shard_id],
[spec.shape for spec in input_specs],
name='infeed-enqueue-%s-%d' % (execution_mode, shard_id),
device_ordinal=shard_id))
return SizedInfeed(
infeed_ops=infeed_ops, sharded_infeed_tensors=shard_infeed_tensors)
def _inject_tpu_inputs_for_dataset(tpu_assignment, mode,
input_specs, get_next_ops):
"""Append core information to the set of dataset inputs."""
# This is used during compilation to identify the current TPU core and enable
# concatenation operations across cores.
if mode not in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL]:
return input_specs, get_next_ops
# Dataset inputs operate on per core basis.
per_core_batch_size = input_specs[0].shape.as_list()[0]
# Insert, at head, the tensor for core_id.
assert len(get_next_ops) == tpu_assignment.num_towers
for i in range(tpu_assignment.num_towers):
core_id_constant = constant_op.constant(
np.array([i] * per_core_batch_size).astype('int32'),
dtype=dtypes.int32,
name='cord_id_constant')
get_next_ops[i] = [core_id_constant] + list(get_next_ops[i])
# Insert the input spec at head also.
input_specs = [tensor_spec.TensorSpec([per_core_batch_size], dtypes.int32)
] + input_specs
return input_specs, get_next_ops
def _inject_tpu_inputs_for_infeed(tpu_assignment, mode,
core_id_place_holder, input_tensors, inputs):
"""Append core information to the set of inputs."""
# This is used during compilation to identify the current TPU core and enable
# concatenation operations across cores.
if mode not in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL]:
return input_tensors, inputs
# Puts a place holder in input spec.
input_tensors = [core_id_place_holder] + input_tensors
# Now fill the core id. For `num_cores` = 2, `batch_size` = 8, we fill the
# core id inputs as [0, 0, 0, 0, 1, 1, 1, 1], so each core sees its core id
# (duplicated).
num_cores = tpu_assignment.num_towers
per_core_batch_size = inputs[0].shape[0] // num_cores
core_ids = np.arange(num_cores).repeat(per_core_batch_size)
inputs = [core_ids] + inputs
return input_tensors, inputs
def _read_tpu_coreid_from_infeed(mode, infeed_tensors):
"""Popping out the core ids from infeed."""
if mode not in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL]:
return None, infeed_tensors
if len(infeed_tensors) <= 1:
raise RuntimeError(
'The infeed tensors on TPU core has only {} tensors. '
'This is not expected. Please report a bug.\nTensors: {}'.format(
len(infeed_tensors), infeed_tensors))
core_id = infeed_tensors[0][0] # Pop out the scalar version.
rest = infeed_tensors[1:]
return core_id, rest
class TPUFunction(object):
"""K.function compatible interface for invoking a TPU compiled function.
Recompilation is triggered on-demand for each set of new inputs shapes: the
results are cached for future execution. We expect most computations will
be dominated by a standard batch-size, followed by a straggler batch for
the end of training or evaluation.
All `inputs` and `outputs` will be loaded via the infeed and outfeed queues
instead of being injected as `feed_dict` items or fetches.
"""
def __init__(self, model, execution_mode, tpu_assignment):
self.model = model
self.execution_mode = execution_mode
self._tpu_assignment = tpu_assignment
self._compilation_cache = {}
self._cloned_model = None
self._cloned_optimizer = None
# Create a placeholder for the TPU core ID. Cache the placeholder to avoid
# modifying the graph for every batch.
self._core_id_place_holder = array_ops.placeholder(
dtype=dtypes.int32, shape=[1], name='core_id')
def _specialize_model(self, input_specs, infeed_manager):
"""Specialize `self.model` (a Keras model) for the given input shapes."""
# Re-create our input and output layers inside our subgraph. They will be
# attached to the true computation when we clone our model in `tpu_fn`.
K.set_learning_phase(self.execution_mode == model_fn_lib.ModeKeys.TRAIN)
# functools.partial and callable objects are not supported by tpu.rewrite
def _model_fn():
"""Compute fit/eval/predict for the TPU."""
is_training = self.execution_mode == model_fn_lib.ModeKeys.TRAIN
is_test = self.execution_mode == model_fn_lib.ModeKeys.EVAL
is_predict = self.execution_mode == model_fn_lib.ModeKeys.PREDICT
# During train/eval, we infeed our features as well as labels.
if is_training or is_test:
infeed_layers = self.model._input_layers + self.model._output_layers
else:
infeed_layers = self.model._input_layers
# Generate our infeed operation to read features & labels.
infeed_tensors = tpu_ops.infeed_dequeue_tuple(
dtypes=[spec.dtype for spec in input_specs],
shapes=[spec.shape for spec in input_specs],
name='infeed-%s' % self.execution_mode)
core_id, infeed_tensors = (
_read_tpu_coreid_from_infeed(
mode=self.execution_mode, infeed_tensors=infeed_tensors))
assert len(infeed_tensors) == len(infeed_layers), (
'Infeed inputs did not match model: %s vs %s' % (infeed_layers,
infeed_tensors))
tpu_targets = []
tpu_input_map = {}
# Sort infeed outputs into inputs and labels for calling our Keras model.
for tensor, layer in zip(infeed_tensors, infeed_layers):
if layer in self.model._input_layers:
tpu_input_map[layer.name] = tensor
if layer in self.model._output_layers:
tpu_targets.append(tensor)
# Clone our CPU model, running within the TPU device context.
#
# We use the id of the original model as a key to avoid weight collisions
# (if a user re-runs the same model multiple times, in e.g. Colab).
with TPURewriteContext(tpu_input_map):
with variable_scope.variable_scope('tpu_%s' % id(self.model)):
with keras_tpu_variables.replicated_scope(
self._tpu_assignment.num_towers):
if not self._cloned_optimizer:
self._cloned_optimizer = _clone_optimizer(
self.model.cpu_optimizer,
worker_name=self._tpu_assignment.worker_name)
self._cloned_model = models.clone_model(self.model)
# When running on more than one core, concatenate outputs at the end
# of processing. In backprop stage, the gradients will be
# calculated according to the local inputs as gradient of
# cross-replica-concat being zero for any outputs other than those