-
Notifications
You must be signed in to change notification settings - Fork 74k
/
image_preprocessing.py
1308 lines (1125 loc) · 49.1 KB
/
image_preprocessing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras image preprocessing layers."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.engine.input_spec import InputSpec
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import image_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import stateful_random_ops
from tensorflow.python.ops import stateless_random_ops
from tensorflow.python.util.tf_export import keras_export
ResizeMethod = image_ops.ResizeMethod
_RESIZE_METHODS = {
'bilinear': ResizeMethod.BILINEAR,
'nearest': ResizeMethod.NEAREST_NEIGHBOR,
'bicubic': ResizeMethod.BICUBIC,
'area': ResizeMethod.AREA,
'lanczos3': ResizeMethod.LANCZOS3,
'lanczos5': ResizeMethod.LANCZOS5,
'gaussian': ResizeMethod.GAUSSIAN,
'mitchellcubic': ResizeMethod.MITCHELLCUBIC
}
H_AXIS = 1
W_AXIS = 2
def check_fill_mode_and_interpolation(fill_mode, interpolation):
if fill_mode not in {'reflect', 'wrap', 'constant'}:
raise NotImplementedError(
'Unknown `fill_mode` {}. Only `reflect`, `wrap` and '
'`constant` are supported.'.format(fill_mode))
if interpolation not in {'nearest', 'bilinear'}:
raise NotImplementedError('Unknown `interpolation` {}. Only `nearest` and '
'`bilinear` are supported.'.format(interpolation))
@keras_export('keras.layers.experimental.preprocessing.Resizing')
class Resizing(Layer):
"""Image resizing layer.
Resize the batched image input to target height and width. The input should
be a 4-D tensor in the format of NHWC.
Arguments:
height: Integer, the height of the output shape.
width: Integer, the width of the output shape.
interpolation: String, the interpolation method. Defaults to `bilinear`.
Supports `bilinear`, `nearest`, `bicubic`, `area`, `lanczos3`, `lanczos5`,
`gaussian`, `mitchellcubic`
name: A string, the name of the layer.
"""
def __init__(self,
height,
width,
interpolation='bilinear',
name=None,
**kwargs):
self.target_height = height
self.target_width = width
self.interpolation = interpolation
self._interpolation_method = get_interpolation(interpolation)
self.input_spec = InputSpec(ndim=4)
super(Resizing, self).__init__(name=name, **kwargs)
def call(self, inputs):
outputs = image_ops.resize_images_v2(
images=inputs,
size=[self.target_height, self.target_width],
method=self._interpolation_method)
return outputs
def compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
return tensor_shape.TensorShape(
[input_shape[0], self.target_height, self.target_width, input_shape[3]])
def get_config(self):
config = {
'height': self.target_height,
'width': self.target_width,
'interpolation': self.interpolation,
}
base_config = super(Resizing, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@keras_export('keras.layers.experimental.preprocessing.CenterCrop')
class CenterCrop(Layer):
"""Crop the central portion of the images to target height and width.
Input shape:
4D tensor with shape:
`(samples, height, width, channels)`, data_format='channels_last'.
Output shape:
4D tensor with shape:
`(samples, target_height, target_width, channels)`.
If the input height/width is even and the target height/width is odd (or
inversely), the input image is left-padded by 1 pixel.
Arguments:
height: Integer, the height of the output shape.
width: Integer, the width of the output shape.
name: A string, the name of the layer.
"""
def __init__(self, height, width, name=None, **kwargs):
self.target_height = height
self.target_width = width
self.input_spec = InputSpec(ndim=4)
super(CenterCrop, self).__init__(name=name, **kwargs)
def call(self, inputs):
inputs_shape = array_ops.shape(inputs)
img_hd = inputs_shape[H_AXIS]
img_wd = inputs_shape[W_AXIS]
img_hd_diff = img_hd - self.target_height
img_wd_diff = img_wd - self.target_width
checks = []
checks.append(
check_ops.assert_non_negative(
img_hd_diff,
message='The crop height {} should not be greater than input '
'height.'.format(self.target_height)))
checks.append(
check_ops.assert_non_negative(
img_wd_diff,
message='The crop width {} should not be greater than input '
'width.'.format(self.target_width)))
with ops.control_dependencies(checks):
bbox_h_start = math_ops.cast(img_hd_diff / 2, dtypes.int32)
bbox_w_start = math_ops.cast(img_wd_diff / 2, dtypes.int32)
bbox_begin = array_ops.stack([0, bbox_h_start, bbox_w_start, 0])
bbox_size = array_ops.stack(
[-1, self.target_height, self.target_width, -1])
outputs = array_ops.slice(inputs, bbox_begin, bbox_size)
return outputs
def compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
return tensor_shape.TensorShape(
[input_shape[0], self.target_height, self.target_width, input_shape[3]])
def get_config(self):
config = {
'height': self.target_height,
'width': self.target_width,
}
base_config = super(CenterCrop, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@keras_export('keras.layers.experimental.preprocessing.RandomCrop')
class RandomCrop(Layer):
"""Randomly crop the images to target height and width.
This layer will crop all the images in the same batch to the same cropping
location.
By default, random cropping is only applied during training. At inference
time, the images will be first rescaled to preserve the shorter side, and
center cropped. If you need to apply random cropping at inference time,
set `training` to True when calling the layer.
Input shape:
4D tensor with shape:
`(samples, height, width, channels)`, data_format='channels_last'.
Output shape:
4D tensor with shape:
`(samples, target_height, target_width, channels)`.
Arguments:
height: Integer, the height of the output shape.
width: Integer, the width of the output shape.
seed: Integer. Used to create a random seed.
name: A string, the name of the layer.
"""
def __init__(self, height, width, seed=None, name=None, **kwargs):
self.height = height
self.width = width
self.seed = seed
self._rng = make_generator(self.seed)
self.input_spec = InputSpec(ndim=4)
super(RandomCrop, self).__init__(name=name, **kwargs)
def call(self, inputs, training=True):
if training is None:
training = K.learning_phase()
def random_cropped_inputs():
"""Cropped inputs with stateless random ops."""
input_shape = array_ops.shape(inputs)
crop_size = array_ops.stack(
[input_shape[0], self.height, self.width, input_shape[3]])
check = control_flow_ops.Assert(
math_ops.reduce_all(input_shape >= crop_size),
[self.height, self.width])
input_shape = control_flow_ops.with_dependencies([check], input_shape)
limit = input_shape - crop_size + 1
offset = stateless_random_ops.stateless_random_uniform(
array_ops.shape(input_shape),
dtype=crop_size.dtype,
maxval=crop_size.dtype.max,
seed=self._rng.make_seeds()[:, 0]) % limit
return array_ops.slice(inputs, offset, crop_size)
# TODO(b/143885775): Share logic with Resize and CenterCrop.
def resize_and_center_cropped_inputs():
"""Deterministically resize to shorter side and center crop."""
input_shape = array_ops.shape(inputs)
input_height_t = input_shape[H_AXIS]
input_width_t = input_shape[W_AXIS]
ratio_cond = (input_height_t / input_width_t > (self.height / self.width))
# pylint: disable=g-long-lambda
resized_height = tf_utils.smart_cond(
ratio_cond,
lambda: math_ops.cast(self.width * input_height_t / input_width_t,
input_height_t.dtype), lambda: self.height)
resized_width = tf_utils.smart_cond(
ratio_cond, lambda: self.width,
lambda: math_ops.cast(self.height * input_width_t / input_height_t,
input_width_t.dtype))
# pylint: enable=g-long-lambda
resized_inputs = image_ops.resize_images_v2(
images=inputs, size=array_ops.stack([resized_height, resized_width]))
img_hd_diff = resized_height - self.height
img_wd_diff = resized_width - self.width
bbox_h_start = math_ops.cast(img_hd_diff / 2, dtypes.int32)
bbox_w_start = math_ops.cast(img_wd_diff / 2, dtypes.int32)
bbox_begin = array_ops.stack([0, bbox_h_start, bbox_w_start, 0])
bbox_size = array_ops.stack([-1, self.height, self.width, -1])
outputs = array_ops.slice(resized_inputs, bbox_begin, bbox_size)
return outputs
output = tf_utils.smart_cond(training, random_cropped_inputs,
resize_and_center_cropped_inputs)
original_shape = inputs.shape.as_list()
batch_size, num_channels = original_shape[0], original_shape[3]
output_shape = [batch_size] + [self.height, self.width] + [num_channels]
output.set_shape(output_shape)
return output
def compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
return tensor_shape.TensorShape(
[input_shape[0], self.height, self.width, input_shape[3]])
def get_config(self):
config = {
'height': self.height,
'width': self.width,
'seed': self.seed,
}
base_config = super(RandomCrop, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@keras_export('keras.layers.experimental.preprocessing.Rescaling')
class Rescaling(Layer):
"""Multiply inputs by `scale` and adds `offset`.
For instance:
1. To rescale an input in the `[0, 255]` range
to be in the `[0, 1]` range, you would pass `scale=1./255`.
2. To rescale an input in the `[0, 255]` range to be in the `[-1, 1]` range,
you would pass `scale=1./127.5, offset=-1`.
The rescaling is applied both during training and inference.
Input shape:
Arbitrary.
Output shape:
Same as input.
Arguments:
scale: Float, the scale to apply to the inputs.
offset: Float, the offset to apply to the inputs.
name: A string, the name of the layer.
"""
def __init__(self, scale, offset=0., name=None, **kwargs):
self.scale = scale
self.offset = offset
super(Rescaling, self).__init__(name=name, **kwargs)
def call(self, inputs):
dtype = self._compute_dtype
scale = math_ops.cast(self.scale, dtype)
offset = math_ops.cast(self.offset, dtype)
return math_ops.cast(inputs, dtype) * scale + offset
def compute_output_shape(self, input_shape):
return input_shape
def get_config(self):
config = {
'scale': self.scale,
'offset': self.offset,
}
base_config = super(Rescaling, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
HORIZONTAL = 'horizontal'
VERTICAL = 'vertical'
HORIZONTAL_AND_VERTICAL = 'horizontal_and_vertical'
@keras_export('keras.layers.experimental.preprocessing.RandomFlip')
class RandomFlip(Layer):
"""Randomly flip each image horizontally and vertically.
This layer will flip the images based on the `mode` attribute.
During inference time, the output will be identical to input. Call the layer
with `training=True` to flip the input.
Input shape:
4D tensor with shape:
`(samples, height, width, channels)`, data_format='channels_last'.
Output shape:
4D tensor with shape:
`(samples, height, width, channels)`, data_format='channels_last'.
Attributes:
mode: String indicating which flip mode to use. Can be "horizontal",
"vertical", or "horizontal_and_vertical". Defaults to
"horizontal_and_vertical". "horizontal" is a left-right flip and
"vertical" is a top-bottom flip.
seed: Integer. Used to create a random seed.
name: A string, the name of the layer.
"""
def __init__(self,
mode=HORIZONTAL_AND_VERTICAL,
seed=None,
name=None,
**kwargs):
super(RandomFlip, self).__init__(name=name, **kwargs)
self.mode = mode
if mode == HORIZONTAL:
self.horizontal = True
self.vertical = False
elif mode == VERTICAL:
self.horizontal = False
self.vertical = True
elif mode == HORIZONTAL_AND_VERTICAL:
self.horizontal = True
self.vertical = True
else:
raise ValueError('RandomFlip layer {name} received an unknown mode '
'argument {arg}'.format(name=name, arg=mode))
self.seed = seed
self._rng = make_generator(self.seed)
self.input_spec = InputSpec(ndim=4)
def call(self, inputs, training=True):
if training is None:
training = K.learning_phase()
def random_flipped_inputs():
flipped_outputs = inputs
if self.horizontal:
flipped_outputs = image_ops.random_flip_left_right(flipped_outputs,
self.seed)
if self.vertical:
flipped_outputs = image_ops.random_flip_up_down(
flipped_outputs, self.seed)
return flipped_outputs
output = tf_utils.smart_cond(training, random_flipped_inputs,
lambda: inputs)
output.set_shape(inputs.shape)
return output
def compute_output_shape(self, input_shape):
return input_shape
def get_config(self):
config = {
'mode': self.mode,
'seed': self.seed,
}
base_config = super(RandomFlip, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
# TODO(tanzheny): Add examples, here and everywhere.
@keras_export('keras.layers.experimental.preprocessing.RandomTranslation')
class RandomTranslation(Layer):
"""Randomly translate each image during training.
Arguments:
height_factor: a float represented as fraction of value, or a tuple
of size 2 representing lower and upper bound for shifting vertically.
A negative value means shifting image up, while a positive value
means shifting image down. When represented as a single positive float,
this value is used for both the upper and lower bound. For instance,
`height_factor=(-0.2, 0.3)` results in an output shifted by a random
amount in the range [-20%, +30%].
`height_factor=0.2` results in an output height shifted by a random
amount in the range [-20%, +20%].
width_factor: a float represented as fraction of value, or a tuple
of size 2 representing lower and upper bound for shifting horizontally.
A negative value means shifting image left, while a positive value
means shifting image right. When represented as a single positive float,
this value is used for both the upper and lower bound. For instance,
`width_factor=(-0.2, 0.3)` results in an output shifted left by 20%, and
shifted right by 30%.
`width_factor=0.2` results in an output height shifted left or right
by 20%.
fill_mode: Points outside the boundaries of the input are filled according
to the given mode (one of `{'constant', 'reflect', 'wrap'}`).
- *reflect*: `(d c b a | a b c d | d c b a)`
The input is extended by reflecting about the edge of the last pixel.
- *constant*: `(k k k k | a b c d | k k k k)`
The input is extended by filling all values beyond the edge with the
same constant value k = 0.
- *wrap*: `(a b c d | a b c d | a b c d)`
The input is extended by wrapping around to the opposite edge.
interpolation: Interpolation mode. Supported values: "nearest", "bilinear".
seed: Integer. Used to create a random seed.
name: A string, the name of the layer.
Input shape:
4D tensor with shape: `(samples, height, width, channels)`,
data_format='channels_last'.
Output shape:
4D tensor with shape: `(samples, height, width, channels)`,
data_format='channels_last'.
Raise:
ValueError: if either bound is not between [0, 1], or upper bound is
less than lower bound.
"""
def __init__(self,
height_factor,
width_factor,
fill_mode='reflect',
interpolation='bilinear',
seed=None,
name=None,
**kwargs):
self.height_factor = height_factor
if isinstance(height_factor, (tuple, list)):
self.height_lower = height_factor[0]
self.height_upper = height_factor[1]
else:
self.height_lower = -height_factor
self.height_upper = height_factor
if self.height_upper < self.height_lower:
raise ValueError('`height_factor` cannot have upper bound less than '
'lower bound, got {}'.format(height_factor))
if abs(self.height_lower) > 1. or abs(self.height_upper) > 1.:
raise ValueError('`height_factor` must have values between [-1, 1], '
'got {}'.format(height_factor))
self.width_factor = width_factor
if isinstance(width_factor, (tuple, list)):
self.width_lower = width_factor[0]
self.width_upper = width_factor[1]
else:
self.width_lower = -width_factor
self.width_upper = width_factor
if self.width_upper < self.width_lower:
raise ValueError('`width_factor` cannot have upper bound less than '
'lower bound, got {}'.format(width_factor))
if abs(self.width_lower) > 1. or abs(self.width_upper) > 1.:
raise ValueError('`width_factor` must have values between [-1, 1], '
'got {}'.format(width_factor))
check_fill_mode_and_interpolation(fill_mode, interpolation)
self.fill_mode = fill_mode
self.interpolation = interpolation
self.seed = seed
self._rng = make_generator(self.seed)
self.input_spec = InputSpec(ndim=4)
super(RandomTranslation, self).__init__(name=name, **kwargs)
def call(self, inputs, training=True):
if training is None:
training = K.learning_phase()
def random_translated_inputs():
"""Translated inputs with random ops."""
inputs_shape = array_ops.shape(inputs)
batch_size = inputs_shape[0]
h_axis, w_axis = H_AXIS, W_AXIS
img_hd = math_ops.cast(inputs_shape[h_axis], dtypes.float32)
img_wd = math_ops.cast(inputs_shape[w_axis], dtypes.float32)
height_translate = self._rng.uniform(
shape=[batch_size, 1],
minval=self.height_lower,
maxval=self.height_upper,
dtype=dtypes.float32)
height_translate = height_translate * img_hd
width_translate = self._rng.uniform(
shape=[batch_size, 1],
minval=self.width_lower,
maxval=self.width_upper,
dtype=dtypes.float32)
width_translate = width_translate * img_wd
translations = math_ops.cast(
array_ops.concat([width_translate, height_translate], axis=1),
dtype=dtypes.float32)
return transform(
inputs,
get_translation_matrix(translations),
interpolation=self.interpolation,
fill_mode=self.fill_mode)
output = tf_utils.smart_cond(training, random_translated_inputs,
lambda: inputs)
output.set_shape(inputs.shape)
return output
def compute_output_shape(self, input_shape):
return input_shape
def get_config(self):
config = {
'height_factor': self.height_factor,
'width_factor': self.width_factor,
'fill_mode': self.fill_mode,
'interpolation': self.interpolation,
'seed': self.seed,
}
base_config = super(RandomTranslation, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def get_translation_matrix(translations, name=None):
"""Returns projective transform(s) for the given translation(s).
Args:
translations: A matrix of 2-element lists representing [dx, dy] to translate
for each image (for a batch of images).
name: The name of the op.
Returns:
A tensor of shape (num_images, 8) projective transforms which can be given
to `transform`.
"""
with ops.name_scope(name, 'translation_matrix'):
num_translations = array_ops.shape(translations)[0]
# The translation matrix looks like:
# [[1 0 -dx]
# [0 1 -dy]
# [0 0 1]]
# where the last entry is implicit.
# Translation matrices are always float32.
return array_ops.concat(
values=[
array_ops.ones((num_translations, 1), dtypes.float32),
array_ops.zeros((num_translations, 1), dtypes.float32),
-translations[:, 0, None],
array_ops.zeros((num_translations, 1), dtypes.float32),
array_ops.ones((num_translations, 1), dtypes.float32),
-translations[:, 1, None],
array_ops.zeros((num_translations, 2), dtypes.float32),
],
axis=1)
def transform(images,
transforms,
fill_mode='reflect',
interpolation='bilinear',
output_shape=None,
name=None):
"""Applies the given transform(s) to the image(s).
Args:
images: A tensor of shape (num_images, num_rows, num_columns, num_channels)
(NHWC), (num_rows, num_columns, num_channels) (HWC), or (num_rows,
num_columns) (HW). The rank must be statically known (the shape is not
`TensorShape(None)`.
transforms: Projective transform matrix/matrices. A vector of length 8 or
tensor of size N x 8. If one row of transforms is [a0, a1, a2, b0, b1, b2,
c0, c1], then it maps the *output* point `(x, y)` to a transformed *input*
point `(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`, where
`k = c0 x + c1 y + 1`. The transforms are *inverted* compared to the
transform mapping input points to output points. Note that gradients are
not backpropagated into transformation parameters.
fill_mode: Points outside the boundaries of the input are filled according
to the given mode (one of `{'constant', 'reflect', 'wrap'}`).
interpolation: Interpolation mode. Supported values: "nearest", "bilinear".
output_shape: Output dimesion after the transform, [height, width]. If None,
output is the same size as input image.
name: The name of the op.
## Fill mode.
Behavior for each valid value is as follows:
reflect (d c b a | a b c d | d c b a)
The input is extended by reflecting about the edge of the last pixel.
constant (k k k k | a b c d | k k k k)
The input is extended by filling all values beyond the edge with the same
constant value k = 0.
wrap (a b c d | a b c d | a b c d)
The input is extended by wrapping around to the opposite edge.
Input shape:
4D tensor with shape: `(samples, height, width, channels)`,
data_format='channels_last'.
Output shape:
4D tensor with shape: `(samples, height, width, channels)`,
data_format='channels_last'.
Returns:
Image(s) with the same type and shape as `images`, with the given
transform(s) applied. Transformed coordinates outside of the input image
will be filled with zeros.
Raises:
TypeError: If `image` is an invalid type.
ValueError: If output shape is not 1-D int32 Tensor.
"""
with ops.name_scope(name, 'transform'):
if output_shape is None:
output_shape = array_ops.shape(images)[1:3]
if not context.executing_eagerly():
output_shape_value = tensor_util.constant_value(output_shape)
if output_shape_value is not None:
output_shape = output_shape_value
output_shape = ops.convert_to_tensor_v2(
output_shape, dtypes.int32, name='output_shape')
if not output_shape.get_shape().is_compatible_with([2]):
raise ValueError('output_shape must be a 1-D Tensor of 2 elements: '
'new_height, new_width, instead got '
'{}'.format(output_shape))
return image_ops.image_projective_transform_v2(
images,
output_shape=output_shape,
transforms=transforms,
fill_mode=fill_mode.upper(),
interpolation=interpolation.upper())
def get_rotation_matrix(angles, image_height, image_width, name=None):
"""Returns projective transform(s) for the given angle(s).
Args:
angles: A scalar angle to rotate all images by, or (for batches of images) a
vector with an angle to rotate each image in the batch. The rank must be
statically known (the shape is not `TensorShape(None)`).
image_height: Height of the image(s) to be transformed.
image_width: Width of the image(s) to be transformed.
name: The name of the op.
Returns:
A tensor of shape (num_images, 8). Projective transforms which can be given
to operation `image_projective_transform_v2`. If one row of transforms is
[a0, a1, a2, b0, b1, b2, c0, c1], then it maps the *output* point
`(x, y)` to a transformed *input* point
`(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`,
where `k = c0 x + c1 y + 1`.
"""
with ops.name_scope(name, 'rotation_matrix'):
x_offset = ((image_width - 1) - (math_ops.cos(angles) *
(image_width - 1) - math_ops.sin(angles) *
(image_height - 1))) / 2.0
y_offset = ((image_height - 1) - (math_ops.sin(angles) *
(image_width - 1) + math_ops.cos(angles) *
(image_height - 1))) / 2.0
num_angles = array_ops.shape(angles)[0]
return array_ops.concat(
values=[
math_ops.cos(angles)[:, None],
-math_ops.sin(angles)[:, None],
x_offset[:, None],
math_ops.sin(angles)[:, None],
math_ops.cos(angles)[:, None],
y_offset[:, None],
array_ops.zeros((num_angles, 2), dtypes.float32),
],
axis=1)
@keras_export('keras.layers.experimental.preprocessing.RandomRotation')
class RandomRotation(Layer):
"""Randomly rotate each image.
By default, random rotations are only applied during training.
At inference time, the layer does nothing. If you need to apply random
rotations at inference time, set `training` to True when calling the layer.
Input shape:
4D tensor with shape:
`(samples, height, width, channels)`, data_format='channels_last'.
Output shape:
4D tensor with shape:
`(samples, height, width, channels)`, data_format='channels_last'.
Attributes:
factor: a float represented as fraction of 2pi, or a tuple of size
2 representing lower and upper bound for rotating clockwise and
counter-clockwise. A positive values means rotating counter clock-wise,
while a negative value means clock-wise. When represented as a single
float, this value is used for both the upper and lower bound. For
instance, `factor=(-0.2, 0.3)` results in an output
rotation by a random amount in the range `[-20% * 2pi, 30% * 2pi]`.
`factor=0.2` results in an output rotating by a random amount in the range
`[-20% * 2pi, 20% * 2pi]`.
fill_mode: Points outside the boundaries of the input are filled according
to the given mode (one of `{'constant', 'reflect', 'wrap'}`).
- *reflect*: `(d c b a | a b c d | d c b a)`
The input is extended by reflecting about the edge of the last pixel.
- *constant*: `(k k k k | a b c d | k k k k)`
The input is extended by filling all values beyond the edge with the
same constant value k = 0.
- *wrap*: `(a b c d | a b c d | a b c d)`
interpolation: Interpolation mode. Supported values: "nearest", "bilinear".
seed: Integer. Used to create a random seed.
name: A string, the name of the layer.
Input shape:
4D tensor with shape: `(samples, height, width, channels)`,
data_format='channels_last'.
Output shape:
4D tensor with shape: `(samples, height, width, channels)`,
data_format='channels_last'.
Raise:
ValueError: if either bound is not between [0, 1], or upper bound is
less than lower bound.
"""
def __init__(self,
factor,
fill_mode='reflect',
interpolation='bilinear',
seed=None,
name=None,
**kwargs):
self.factor = factor
if isinstance(factor, (tuple, list)):
self.lower = factor[0]
self.upper = factor[1]
else:
self.lower = -factor
self.upper = factor
if self.upper < self.lower:
raise ValueError('Factor cannot have negative values, '
'got {}'.format(factor))
check_fill_mode_and_interpolation(fill_mode, interpolation)
self.fill_mode = fill_mode
self.interpolation = interpolation
self.seed = seed
self._rng = make_generator(self.seed)
self.input_spec = InputSpec(ndim=4)
super(RandomRotation, self).__init__(name=name, **kwargs)
def call(self, inputs, training=True):
if training is None:
training = K.learning_phase()
def random_rotated_inputs():
"""Rotated inputs with random ops."""
inputs_shape = array_ops.shape(inputs)
batch_size = inputs_shape[0]
img_hd = math_ops.cast(inputs_shape[H_AXIS], dtypes.float32)
img_wd = math_ops.cast(inputs_shape[W_AXIS], dtypes.float32)
min_angle = self.lower * 2. * np.pi
max_angle = self.upper * 2. * np.pi
angles = self._rng.uniform(
shape=[batch_size], minval=min_angle, maxval=max_angle)
return transform(
inputs,
get_rotation_matrix(angles, img_hd, img_wd),
fill_mode=self.fill_mode,
interpolation=self.interpolation)
output = tf_utils.smart_cond(training, random_rotated_inputs,
lambda: inputs)
output.set_shape(inputs.shape)
return output
def compute_output_shape(self, input_shape):
return input_shape
def get_config(self):
config = {
'factor': self.factor,
'fill_mode': self.fill_mode,
'interpolation': self.interpolation,
'seed': self.seed,
}
base_config = super(RandomRotation, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@keras_export('keras.layers.experimental.preprocessing.RandomZoom')
class RandomZoom(Layer):
"""Randomly zoom each image during training.
Arguments:
height_factor: a float represented as fraction of value, or a tuple
of size 2 representing lower and upper bound for zooming vertically.
When represented as a single float, this value is used for both the
upper and lower bound. A positive value means zooming out, while a
negative value means zooming in.
For instance, `height_factor=(0.2, 0.3)` result in an output zoomed out
by a random amount in the range [+20%, +30%].
`height_factor=(-0.3, -0.2)` result in an output zoomed in by a random
amount in the range [+20%, +30%].
width_factor: a float represented as fraction of value, or a tuple
of size 2 representing lower and upper bound for zooming horizontally.
When represented as a single float, this value is used for both the
upper and lower bound.
For instance, `width_factor=(0.2, 0.3)` result in an output zooming out
between 20% to 30%.
`width_factor=(-0.3, -0.2)` result in an output zooming in between 20%
to 30%. Defaults to `None`, i.e., zooming vertical and horizontal
directions by preserving the aspect ratio.
fill_mode: Points outside the boundaries of the input are filled according
to the given mode (one of `{'constant', 'reflect', 'wrap'}`).
- *reflect*: `(d c b a | a b c d | d c b a)`
The input is extended by reflecting about the edge of the last pixel.
- *constant*: `(k k k k | a b c d | k k k k)`
The input is extended by filling all values beyond the edge with the
same constant value k = 0.
- *wrap*: `(a b c d | a b c d | a b c d)`
interpolation: Interpolation mode. Supported values: "nearest", "bilinear".
seed: Integer. Used to create a random seed.
name: A string, the name of the layer.
Example:
>>> input_img = np.random.random((32, 224, 224, 3))
>>> layer = tf.keras.layers.experimental.preprocessing.RandomZoom(.5, .2)
>>> out_img = layer(input_img)
>>> out_img.shape
TensorShape([32, 224, 224, 3])
Input shape:
4D tensor with shape:
`(samples, height, width, channels)`, data_format='channels_last'.
Output shape:
4D tensor with shape:
`(samples, height, width, channels)`, data_format='channels_last'.
Raise:
ValueError: if lower bound is not between [0, 1], or upper bound is
negative.
"""
# TODO(b/156526279): Add `fill_value` argument.
def __init__(self,
height_factor,
width_factor=None,
fill_mode='reflect',
interpolation='bilinear',
seed=None,
name=None,
**kwargs):
self.height_factor = height_factor
if isinstance(height_factor, (tuple, list)):
self.height_lower = height_factor[0]
self.height_upper = height_factor[1]
else:
self.height_lower = -height_factor
self.height_upper = height_factor
if abs(self.height_lower) > 1. or abs(self.height_upper) > 1.:
raise ValueError('`height_factor` must have values between [-1, 1], '
'got {}'.format(height_factor))
self.width_factor = width_factor
if width_factor is not None:
if isinstance(width_factor, (tuple, list)):
self.width_lower = width_factor[0]
self.width_upper = width_factor[1]
else:
self.width_lower = -width_factor # pylint: disable=invalid-unary-operand-type
self.width_upper = width_factor
if self.width_lower < -1. or self.width_upper < -1.:
raise ValueError('`width_factor` must have values larger than -1, '
'got {}'.format(width_factor))
check_fill_mode_and_interpolation(fill_mode, interpolation)
self.fill_mode = fill_mode
self.interpolation = interpolation
self.seed = seed
self._rng = make_generator(self.seed)
self.input_spec = InputSpec(ndim=4)
super(RandomZoom, self).__init__(name=name, **kwargs)
def call(self, inputs, training=True):
if training is None:
training = K.learning_phase()
def random_zoomed_inputs():
"""Zoomed inputs with random ops."""
inputs_shape = array_ops.shape(inputs)
batch_size = inputs_shape[0]
img_hd = math_ops.cast(inputs_shape[H_AXIS], dtypes.float32)
img_wd = math_ops.cast(inputs_shape[W_AXIS], dtypes.float32)
height_zoom = self._rng.uniform(
shape=[batch_size, 1],
minval=1. + self.height_lower,
maxval=1. + self.height_upper)
if self.width_factor is not None:
width_zoom = self._rng.uniform(
shape=[batch_size, 1],
minval=1. + self.width_lower,
maxval=1. + self.width_upper)
else:
width_zoom = height_zoom
zooms = math_ops.cast(
array_ops.concat([width_zoom, height_zoom], axis=1),
dtype=dtypes.float32)
return transform(
inputs, get_zoom_matrix(zooms, img_hd, img_wd),
fill_mode=self.fill_mode,
interpolation=self.interpolation)
output = tf_utils.smart_cond(training, random_zoomed_inputs,
lambda: inputs)
output.set_shape(inputs.shape)
return output
def compute_output_shape(self, input_shape):
return input_shape
def get_config(self):
config = {
'height_factor': self.height_factor,
'width_factor': self.width_factor,
'fill_mode': self.fill_mode,
'interpolation': self.interpolation,
'seed': self.seed,
}
base_config = super(RandomZoom, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def get_zoom_matrix(zooms, image_height, image_width, name=None):
"""Returns projective transform(s) for the given zoom(s).
Args:
zooms: A matrix of 2-element lists representing [zx, zy] to zoom
for each image (for a batch of images).
image_height: Height of the image(s) to be transformed.
image_width: Width of the image(s) to be transformed.
name: The name of the op.
Returns:
A tensor of shape (num_images, 8). Projective transforms which can be given
to operation `image_projective_transform_v2`. If one row of transforms is