-
Notifications
You must be signed in to change notification settings - Fork 94
/
estimators.py
1904 lines (1659 loc) · 76.3 KB
/
estimators.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TF Lattice canned estimators implement typical monotonic model architectures.
You can use TFL canned estimators to easily construct commonly used monotonic
model architectures. To construct a TFL canned estimator, construct a model
configuration from `tfl.configs` and pass it to the canned estimator
constructor. To use automated quantile calculation, canned estimators also
require passing a *feature_analysis_input_fn* which is similar to the one used
for training, but with a single epoch or a subset of the data. To create a
Crystals ensemble model using `tfl.configs.CalibratedLatticeEnsembleConfig`, you
will also need to provide a *prefitting_input_fn* to the estimator constructor.
```python
feature_columns = ...
model_config = tfl.configs.CalibratedLatticeConfig(...)
feature_analysis_input_fn = create_input_fn(num_epochs=1, ...)
train_input_fn = create_input_fn(num_epochs=100, ...)
estimator = tfl.estimators.CannedClassifier(
feature_columns=feature_columns,
model_config=model_config,
feature_analysis_input_fn=feature_analysis_input_fn)
estimator.train(input_fn=train_input_fn)
```
Supported models are defined in `tfl.configs`. Each model architecture can be
used for:
* **Classification** using `tfl.estimators.CannedClassifier` with standard
classification head (softmax cross-entropy loss).
* **Regression** using `tfl.estimators.CannedRegressor` with standard
regression head (squared loss).
* **Custom head** using `tfl.estimators.CannedEstimator` with any custom head
and loss.
This module also provides `tfl.estimators.get_model_graph` as a mechanism to
extract abstract model graphs and layer parameters from saved models. The
resulting graph (not a TF graph) can be used by the `tfl.visualization` module
for plotting and other visualization and analysis.
```python
model_graph = estimators.get_model_graph(saved_model_path)
visualization.plot_feature_calibrator(model_graph, "feature_name")
visualization.plot_all_calibrators(model_graph)
visualization.draw_model_graph(model_graph)
```
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import copy
import json
import os
import re
import time
from . import categorical_calibration_layer
from . import configs
from . import kronecker_factored_lattice_layer as kfll
from . import lattice_layer
from . import linear_layer
from . import model_info
from . import premade
from . import premade_lib
from . import pwl_calibration_layer
from . import rtl_layer
from absl import logging
import numpy as np
import six
import tensorflow as tf
from tensorflow.python.feature_column import feature_column as fc # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.feature_column import feature_column_v2 as fc2 # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.keras.utils import losses_utils # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.training import training_util # pylint: disable=g-direct-tensorflow-import
from tensorflow_estimator.python.estimator import estimator as estimator_lib
from tensorflow_estimator.python.estimator.canned import optimizers
from tensorflow_estimator.python.estimator.head import binary_class_head
from tensorflow_estimator.python.estimator.head import multi_class_head
from tensorflow_estimator.python.estimator.head import regression_head
# TODO: support multi dim inputs.
# TODO: support multi dim output.
# TODO: add linear layer regularizers.
# TODO: add examples in docs.
# TODO: make _REPEATED_PAIR_DISCOUNT_IN_CRYSTALS_SCORE config param
# Feed and fetch names for the model.
FEATURES_SCOPE = 'features'
OUTPUT_NAME = 'output'
# File to store and load feature keypoints.
_KEYPOINTS_FILE = 'keypoints.json'
# File to store and load lattice ensemble structure.
_ENSEMBLE_STRUCTURE_FILE = 'ensemble_structure.json'
# Name for label keypoints in keypoints file.
_LABEL_FEATURE_NAME = '__label__'
# Pooling interval and maximum wait time for workers waiting for files.
_MAX_WAIT_TIME = 2400
_POLL_INTERVAL_SECS = 10
class WaitTimeOutError(Exception):
"""Timeout error when waiting for a file."""
pass
def _poll_for_file(filename):
"""Waits and polls for a file until it exists."""
start = time.time()
while not tf.io.gfile.exists(filename):
time.sleep(_POLL_INTERVAL_SECS)
if time.time() - start > _MAX_WAIT_TIME:
raise WaitTimeOutError('Waiting for file {} timed-out'.format(filename))
def transform_features(features, feature_columns=None):
"""Parses the input features using the given feature columns.
This function can be used to parse input features when constructing a custom
estimator. When using this function, you will not need to wrap categorical
features with dense feature embeddings, and the resulting tensors will not be
concatenated, making it easier to use the features in the calibration layers.
Args:
features: A dict from feature names to tensors.
feature_columns: A list of FeatureColumn objects to be used for parsing. If
not provided, the input features are assumed to be already parsed.
Returns:
collections.OrderedDict mapping feature names to parsed tensors.
"""
with tf.name_scope('transform'):
if feature_columns:
parsed_features = collections.OrderedDict()
for feature_column in feature_columns:
# pylint: disable=protected-access
if (isinstance(feature_column, fc._DenseColumn) or
isinstance(feature_column, fc2.DenseColumn)):
parsed_features[
feature_column.name] = feature_column._transform_feature(features)
elif (isinstance(feature_column, fc._CategoricalColumn) or
isinstance(feature_column, fc2.CategoricalColumn)):
if feature_column.num_oov_buckets:
# If oov buckets are used, missing values are assigned to the last
# oov bucket.
default_value = feature_column.num_buckets - 1
else:
default_value = feature_column.default_value
parsed_features[feature_column.name] = tf.reshape(
tf.sparse.to_dense(
sp_input=feature_column._transform_feature(features),
default_value=default_value),
shape=[-1, 1])
else:
raise ValueError(
'Unsupported feature_column: {}'.format(feature_column))
# pylint: enable=protected-access
else:
parsed_features = collections.OrderedDict(features)
for name, tensor in parsed_features.items():
if len(tensor.shape) == 1:
parsed_features[name] = tf.expand_dims(tensor, 1)
elif len(tensor.shape) > 2 or tensor.shape[1] != 1:
raise ValueError('Only 1-d inputs are supported: {}'.format(tensor))
with tf.name_scope(FEATURES_SCOPE):
for name, tensor in parsed_features.items():
parsed_features[name] = tf.identity(parsed_features[name], name=name)
return parsed_features
def _materialize_locally(tensors, max_elements=1e6):
"""Materialize the given tensors locally, during initialization.
Assumes non-distributed environment (uses SingularMonitoredSession).
Args:
tensors: A dict of name to feed tensors to be materialized.
max_elements: Data is read and accmulated from tensors until end-of-input is
reached or when we have at least max_elements collected.
Returns:
Materialized tensors as dict.
"""
# tf.compat.v1.train.SingularMonitoredSession silently catches
# tf.errors.OutOfRangeError, and we want to expose it to detect end of the
# data from the given feed tensors.
with tf.compat.v1.train.SingularMonitoredSession() as sess:
splits = []
count = 0
try:
while count < max_elements:
materialized_tensors = sess.run(tensors)
values = list(materialized_tensors.values())
if not values:
break
count += len(values[0])
splits.append(materialized_tensors)
except (tf.errors.OutOfRangeError, StopIteration):
pass
concatenated_tensors = {}
for k in tensors:
concatenated_tensors[k] = np.concatenate(
[split[k] for split in splits if split[k].size > 0])
return concatenated_tensors
def _finalize_keypoints(model_config, config, feature_columns,
feature_analysis_input_fn, logits_output):
"""Calculates and sets keypoints for input and output calibration.
Input and label keypoints are calculated, stored in a file and also set in the
model_config to be used for model construction.
Args:
model_config: Model config to be updated.
config: A `tf.RunConfig` to indicate if worker is chief.
feature_columns: A list of FeatureColumn's to use for feature parsing.
feature_analysis_input_fn: An input_fn used to collect feature statistics.
logits_output: A boolean indicating if model outputs logits.
Raises:
ValueError: If keypoints mode is invalid.
"""
if not feature_analysis_input_fn:
return
keypoints_filename = os.path.join(config.model_dir, _KEYPOINTS_FILE)
if ((config is None or config.is_chief) and
not tf.io.gfile.exists(keypoints_filename)):
with tf.Graph().as_default():
features, label = feature_analysis_input_fn()
features = transform_features(features, feature_columns)
features[_LABEL_FEATURE_NAME] = label
features = _materialize_locally(features)
feature_keypoints = {}
for feature_name, feature_values in six.iteritems(features):
feature_values = feature_values.flatten()
if feature_name == _LABEL_FEATURE_NAME:
# Default feature_values to [0, ... n_class-1] if string label.
if label.dtype == tf.string:
feature_values = np.arange(len(set(feature_values)))
num_keypoints = model_config.output_calibration_num_keypoints
keypoints = model_config.output_initialization
clip_min = model_config.output_min
clip_max = model_config.output_max
default_value = None
else:
feature_config = model_config.feature_config_by_name(feature_name)
if feature_config.num_buckets:
# Skip categorical features.
continue
num_keypoints = feature_config.pwl_calibration_num_keypoints
keypoints = feature_config.pwl_calibration_input_keypoints
clip_min = feature_config.pwl_calibration_clip_min
clip_max = feature_config.pwl_calibration_clip_max
default_value = feature_config.default_value
# Remove default values before calculating stats.
feature_values = feature_values[feature_values != default_value]
if np.isnan(feature_values).any():
raise ValueError(
'NaN values were observed for numeric feature `{}`. '
'Consider replacing the values in transform or input_fn.'.format(
feature_name))
# Before calculating keypoints, clip values as requested.
# Add min and max to the value list to make sure min/max in values match
# the requested range.
if clip_min is not None:
feature_values = np.maximum(feature_values, clip_min)
feature_values = np.append(feature_values, clip_min)
if clip_max is not None:
feature_values = np.minimum(feature_values, clip_max)
feature_values = np.append(feature_values, clip_max)
# Remove duplicate values before calculating stats.
feature_values = np.unique(feature_values)
if isinstance(keypoints, str):
if keypoints == 'quantiles':
if (feature_name != _LABEL_FEATURE_NAME and
feature_values.size < num_keypoints):
logging.info(
'Not enough unique values observed for feature `%s` to '
'construct %d keypoints for pwl calibration. Using %d unique '
'values as keypoints.', feature_name, num_keypoints,
feature_values.size)
num_keypoints = feature_values.size
quantiles = np.quantile(
feature_values,
np.linspace(0., 1., num_keypoints),
interpolation='nearest')
feature_keypoints[feature_name] = [float(x) for x in quantiles]
elif keypoints == 'uniform':
linspace = np.linspace(
np.min(feature_values), np.max(feature_values), num_keypoints)
feature_keypoints[feature_name] = [float(x) for x in linspace]
else:
raise ValueError(
'Invalid keypoint generation mode: {}'.format(keypoints))
else:
# Keypoints are explicitly provided in the config.
feature_keypoints[feature_name] = [float(x) for x in keypoints]
# Save keypoints to file as the chief worker.
tmp_keypoints_filename = keypoints_filename + 'tmp'
with tf.io.gfile.GFile(tmp_keypoints_filename, 'w') as keypoints_file:
keypoints_file.write(json.dumps(feature_keypoints, indent=2))
tf.io.gfile.rename(tmp_keypoints_filename, keypoints_filename)
else:
# Non-chief workers read the keypoints from file.
_poll_for_file(keypoints_filename)
with tf.io.gfile.GFile(keypoints_filename) as keypoints_file:
feature_keypoints = json.loads(keypoints_file.read())
if _LABEL_FEATURE_NAME in feature_keypoints:
output_init = feature_keypoints.pop(_LABEL_FEATURE_NAME)
if logits_output and isinstance(model_config.output_initialization, str):
# If model is expected to produce logits, initialize linearly in the
# range [-2, 2], ignoring the label distribution.
model_config.output_initialization = [
float(x) for x in np.linspace(
-2, 2, model_config.output_calibration_num_keypoints)
]
else:
model_config.output_initialization = output_init
for feature_name, keypoints in feature_keypoints.items():
model_config.feature_config_by_name(
feature_name).pwl_calibration_input_keypoints = keypoints
def _fix_ensemble_for_2d_constraints(model_config, feature_names):
"""Fixes 2d constraint violations by adding missing features to some lattices.
Some 2d shape constraints require lattices from ensemble to either contain
both constrained features or none of them, e.g. trapezoid trust constraint
requires a lattice that has the "conditional" feature to include the "main"
feature.
Args:
model_config: Model config to be updated.
feature_names: List of feature names.
"""
must_include_features = collections.defaultdict(set)
for feature_name in feature_names:
feature_config = model_config.feature_config_by_name(feature_name)
for trust_config in feature_config.reflects_trust_in or []:
if trust_config.trust_type == 'trapezoid':
must_include_features[feature_name].add(trust_config.feature_name)
for dominance_config in feature_config.dominates or []:
must_include_features[dominance_config.feature_name].add(feature_name)
fixed_lattices = []
for idx, lattice in enumerate(model_config.lattices):
fixed_lattice = set()
for feature_name in lattice:
fixed_lattice.add(feature_name)
fixed_lattice.update(must_include_features[feature_name])
assert len(lattice) <= len(fixed_lattice)
fixed_lattices.append(list(fixed_lattice))
if len(lattice) < len(fixed_lattice):
logging.info(
'Fixed 2d constraint violations in lattices[%d]. Lattice rank '
'increased from %d to %d.', idx, len(lattice), len(fixed_lattice))
model_config.lattices = fixed_lattices
def _set_crystals_lattice_ensemble(model_config, feature_names, label_dimension,
feature_columns, head, prefitting_input_fn,
prefitting_optimizer, prefitting_steps,
config, dtype):
"""Sets the lattice ensemble in model_config using the crystals algorithm."""
if prefitting_input_fn is None:
raise ValueError('prefitting_input_fn must be set for crystals models')
# Get prefitting model config.
prefitting_model_config = premade_lib.construct_prefitting_model_config(
model_config, feature_names)
def prefitting_model_fn(features, labels, mode, config):
return _calibrated_lattice_ensemble_model_fn(
features=features,
labels=labels,
label_dimension=label_dimension,
feature_columns=feature_columns,
mode=mode,
head=head,
model_config=prefitting_model_config,
optimizer=prefitting_optimizer,
config=config,
dtype=dtype)
config = tf.estimator.RunConfig(
keep_checkpoint_max=1,
save_summary_steps=0,
save_checkpoints_steps=10000000,
tf_random_seed=config.tf_random_seed if config is not None else 42)
logging.info('Creating the prefitting estimator.')
prefitting_estimator = tf.estimator.Estimator(
model_fn=prefitting_model_fn, config=config)
logging.info('Training the prefitting estimator.')
prefitting_estimator.train(
input_fn=prefitting_input_fn, steps=prefitting_steps)
premade_lib.set_crystals_lattice_ensemble(
model_config=model_config,
prefitting_model_config=prefitting_model_config,
prefitting_model=prefitting_estimator,
feature_names=feature_names)
logging.info('Finished training the prefitting estimator.')
# Cleanup model_dir since we might be reusing it for the main estimator.
# Note that other workers are blocked until model structure file is
# generated by the chief worker, so modifying files here should be safe.
remove_list = [
os.path.join(prefitting_estimator.model_dir, 'graph.pbtxt'),
os.path.join(prefitting_estimator.model_dir, 'checkpoint'),
]
remove_list.extend(
tf.io.gfile.glob(prefitting_estimator.latest_checkpoint() + '*'))
for file_path in remove_list:
tf.io.gfile.remove(file_path)
def _finalize_model_structure(model_config, label_dimension, feature_columns,
head, prefitting_input_fn, prefitting_optimizer,
prefitting_steps, model_dir, config,
warm_start_from, dtype):
"""Sets up the lattice ensemble in model_config with requested algorithm."""
if (not isinstance(model_config, configs.CalibratedLatticeEnsembleConfig) or
isinstance(model_config.lattices, list)):
return
# TODO: If warmstarting, look for the previous ensemble file.
if warm_start_from:
raise ValueError('Warm starting lattice ensembles without explicitly '
'defined lattices is not supported yet.')
if feature_columns:
feature_names = [feature_column.name for feature_column in feature_columns]
else:
feature_names = [
feature_config.name for feature_config in model_config.feature_configs
]
if model_config.lattice_rank > len(feature_names):
raise ValueError(
'lattice_rank {} cannot be larger than the number of features: {}'
.format(model_config.lattice_rank, feature_names))
if model_config.num_lattices * model_config.lattice_rank < len(feature_names):
raise ValueError(
'Model with {}x{}d lattices is not large enough for all features: {}'
.format(model_config.num_lattices, model_config.lattice_rank,
feature_names))
ensemble_structure_filename = os.path.join(model_dir,
_ENSEMBLE_STRUCTURE_FILE)
if ((config is None or config.is_chief) and
not tf.io.gfile.exists(ensemble_structure_filename)):
if model_config.lattices not in ['random', 'crystals', 'rtl_layer']:
raise ValueError('Unsupported ensemble structure: {}'.format(
model_config.lattices))
if model_config.lattices == 'random':
premade_lib.set_random_lattice_ensemble(model_config, feature_names)
elif model_config.lattices == 'crystals':
_set_crystals_lattice_ensemble(
feature_names=feature_names,
label_dimension=label_dimension,
feature_columns=feature_columns,
head=head,
model_config=model_config,
prefitting_input_fn=prefitting_input_fn,
prefitting_optimizer=prefitting_optimizer,
prefitting_steps=prefitting_steps,
config=config,
dtype=dtype)
if (model_config.fix_ensemble_for_2d_constraints and
model_config.lattices != 'rtl_layer'):
# Note that we currently only support monotonicity and bound constraints
# for RTL.
_fix_ensemble_for_2d_constraints(model_config, feature_names)
# Save lattices to file as the chief worker.
tmp_ensemble_structure_filename = ensemble_structure_filename + 'tmp'
with tf.io.gfile.GFile(tmp_ensemble_structure_filename,
'w') as ensemble_structure_file:
ensemble_structure_file.write(json.dumps(model_config.lattices, indent=2))
tf.io.gfile.rename(tmp_ensemble_structure_filename,
ensemble_structure_filename)
else:
# Non-chief workers read the lattices from file.
_poll_for_file(ensemble_structure_filename)
with tf.io.gfile.GFile(
ensemble_structure_filename) as ensemble_structure_file:
model_config.lattices = json.loads(ensemble_structure_file.read())
logging.info('Finalized model structure: %s', str(model_config.lattices))
def _verify_config(model_config, feature_columns):
"""Verifies that the config is setup correctly and ready for model_fn."""
if feature_columns:
feature_configs = [
model_config.feature_config_by_name(feature_column.name)
for feature_column in feature_columns
]
else:
feature_configs = model_config.feature_configs or []
for feature_config in feature_configs:
if not feature_config.num_buckets:
if (not np.iterable(feature_config.pwl_calibration_input_keypoints) or
any(not isinstance(x, float)
for x in feature_config.pwl_calibration_input_keypoints)):
raise ValueError(
'Input keypoints are invalid for feature {}: {}'.format(
feature_config.name,
feature_config.pwl_calibration_input_keypoints))
if (not np.iterable(model_config.output_initialization) or any(
not isinstance(x, float) for x in model_config.output_initialization)):
raise ValueError('Output initilization is invalid: {}'.format(
model_config.output_initialization))
def _update_by_feature_columns(model_config, feature_columns):
"""Updates a model config with the given feature columns."""
for feature_column in feature_columns or []:
feature_config = model_config.feature_config_by_name(feature_column.name)
# pylint: disable=protected-access
if (isinstance(feature_column, fc._DenseColumn) or
isinstance(feature_column, fc2.DenseColumn)):
feature_config.default_value = feature_column.default_value
elif (isinstance(feature_column, fc._VocabularyListCategoricalColumn) or
isinstance(feature_column, fc2.VocabularyListCategoricalColumn)):
feature_config.vocabulary_list = feature_column.vocabulary_list
feature_config.num_buckets = feature_column.num_buckets
if feature_column.num_oov_buckets:
# A positive num_oov_buckets can not be specified with default_value.
# See tf.feature_column.categorical_column_with_vocabulary_list.
feature_config.default_value = None
else:
# We add a bucket at the end for the default_value, since num_buckets
# does not include the default value (but includes oov buckets).
feature_config.default_value = feature_column.default_value
feature_config.num_buckets += 1
else:
raise ValueError('Unsupported feature_column: {}'.format(feature_column))
# pylint: enable=protected-access
# Change categorical monotonicities to indices.
premade_lib.set_categorical_monotonicities(model_config.feature_configs)
def _calibrated_lattice_ensemble_model_fn(features, labels, label_dimension,
feature_columns, mode, head,
model_config, optimizer, config,
dtype):
"""Calibrated Lattice Ensemble Model."""
del config
if label_dimension != 1:
raise ValueError('Only 1-dimensional output is supported.')
# Get input tensors and corresponding feature configs.
transformed_features = transform_features(features, feature_columns)
feature_names = list(transformed_features.keys())
input_tensors = [
transformed_features[feature_name] for feature_name in feature_names
]
# Reconstruct feature_config in order of feature_names
feature_configs = [
model_config.feature_config_by_name(feature_name)
for feature_name in feature_names
]
del model_config.feature_configs[:]
model_config.feature_configs.extend(feature_configs)
training = (mode == tf.estimator.ModeKeys.TRAIN)
model = premade.CalibratedLatticeEnsemble(
model_config=model_config, dtype=dtype)
logits = tf.identity(
model(input_tensors, training=training), name=OUTPUT_NAME)
if training:
optimizer = optimizers.get_optimizer_instance_v2(optimizer)
optimizer.iterations = training_util.get_or_create_global_step()
else:
optimizer = None
return head.create_estimator_spec(
features=features,
mode=mode,
labels=labels,
optimizer=optimizer,
logits=logits,
trainable_variables=model.trainable_variables,
update_ops=model.updates,
regularization_losses=model.losses or None)
def _calibrated_lattice_model_fn(features, labels, label_dimension,
feature_columns, mode, head, model_config,
optimizer, config, dtype):
"""Calibrated Lattice Model."""
del config
if label_dimension != 1:
raise ValueError('Only 1-dimensional output is supported.')
# Get input tensors and corresponding feature configs.
transformed_features = transform_features(features, feature_columns)
feature_names = list(transformed_features.keys())
input_tensors = [
transformed_features[feature_name] for feature_name in feature_names
]
# Reconstruct feature_config in order of feature_names
feature_configs = [
model_config.feature_config_by_name(feature_name)
for feature_name in feature_names
]
del model_config.feature_configs[:]
model_config.feature_configs.extend(feature_configs)
training = (mode == tf.estimator.ModeKeys.TRAIN)
model = premade.CalibratedLattice(model_config=model_config, dtype=dtype)
logits = tf.identity(
model(input_tensors, training=training), name=OUTPUT_NAME)
if training:
optimizer = optimizers.get_optimizer_instance_v2(optimizer)
optimizer.iterations = training_util.get_or_create_global_step()
else:
optimizer = None
return head.create_estimator_spec(
features=features,
mode=mode,
labels=labels,
optimizer=optimizer,
logits=logits,
trainable_variables=model.trainable_variables,
update_ops=model.updates,
regularization_losses=model.losses or None)
def _calibrated_linear_model_fn(features, labels, label_dimension,
feature_columns, mode, head, model_config,
optimizer, config, dtype):
"""Calibrated Linear Model."""
del config
if label_dimension != 1:
raise ValueError('Only 1-dimensional output is supported.')
# Get input tensors and corresponding feature configs.
transformed_features = transform_features(features, feature_columns)
feature_names = list(transformed_features.keys())
input_tensors = [
transformed_features[feature_name] for feature_name in feature_names
]
# Reconstruct feature_config in order of feature_names
feature_configs = [
model_config.feature_config_by_name(feature_name)
for feature_name in feature_names
]
del model_config.feature_configs[:]
model_config.feature_configs.extend(feature_configs)
training = (mode == tf.estimator.ModeKeys.TRAIN)
model = premade.CalibratedLinear(model_config=model_config, dtype=dtype)
logits = tf.identity(
model(input_tensors, training=training), name=OUTPUT_NAME)
if training:
optimizer = optimizers.get_optimizer_instance_v2(optimizer)
optimizer.iterations = training_util.get_or_create_global_step()
return head.create_estimator_spec(
features=features,
mode=mode,
labels=labels,
optimizer=optimizer,
logits=logits,
trainable_variables=model.trainable_variables,
update_ops=model.updates,
regularization_losses=model.losses or None)
def _get_model_fn(label_dimension, feature_columns, head, model_config,
optimizer, dtype):
"""Returns the model_fn for the given model_config."""
if isinstance(model_config, configs.CalibratedLatticeConfig):
def calibrated_lattice_model_fn(features, labels, mode, config):
return _calibrated_lattice_model_fn(
features=features,
labels=labels,
label_dimension=label_dimension,
feature_columns=feature_columns,
mode=mode,
head=head,
model_config=model_config,
optimizer=optimizer,
config=config,
dtype=dtype)
return calibrated_lattice_model_fn
elif isinstance(model_config, configs.CalibratedLinearConfig):
def calibrated_linear_model_fn(features, labels, mode, config):
return _calibrated_linear_model_fn(
features=features,
labels=labels,
label_dimension=label_dimension,
feature_columns=feature_columns,
mode=mode,
head=head,
model_config=model_config,
optimizer=optimizer,
config=config,
dtype=dtype)
return calibrated_linear_model_fn
if isinstance(model_config, configs.CalibratedLatticeEnsembleConfig):
def calibrated_lattice_ensemble_model_fn(features, labels, mode, config):
return _calibrated_lattice_ensemble_model_fn(
features=features,
labels=labels,
label_dimension=label_dimension,
feature_columns=feature_columns,
mode=mode,
head=head,
model_config=model_config,
optimizer=optimizer,
config=config,
dtype=dtype)
return calibrated_lattice_ensemble_model_fn
else:
raise ValueError('Unsupported model type: {}'.format(type(model_config)))
class CannedEstimator(estimator_lib.EstimatorV2):
"""An estimator for TensorFlow lattice models.
Creates an estimator with a custom head for the model architecutre specified
by the `model_config`, which should be one of those defined in `tfl.configs`.
Calculation of feature quantiles for input keypoint initialization is done
using `feature_analysis_input_fn`. If this auxiliary input fn is not provided,
all keypoint values should be explicitly provided via the `model_config`.
Example:
```python
model_config = tfl.configs.CalibratedLatticeConfig(...)
feature_analysis_input_fn = create_input_fn(num_epochs=1, ...)
train_input_fn = create_input_fn(num_epochs=100, ...)
head = ...
estimator = tfl.estimators.CannedEstimator(
feature_columns=feature_columns,
model_config=model_config,
feature_analysis_input_fn=feature_analysis_input_fn
head=head)
estimator.train(input_fn=train_input_fn)
```
"""
def __init__(self,
head,
model_config,
feature_columns,
feature_analysis_input_fn=None,
prefitting_input_fn=None,
model_dir=None,
label_dimension=1,
optimizer='Adagrad',
prefitting_optimizer='Adagrad',
prefitting_steps=None,
config=None,
warm_start_from=None,
dtype=tf.float32):
"""Initializes a `CannedEstimator` instance.
Args:
head: A `_Head` instance constructed with a method such as
`tf.contrib.estimator.multi_label_head`.
model_config: Model configuration object describing model architecutre.
Should be one of the model configs in `tfl.configs`.
feature_columns: An iterable containing all the feature columns used by
the model.
feature_analysis_input_fn: An input_fn used to calculate statistics about
features and labels in order to setup calibration keypoint and values.
prefitting_input_fn: An input_fn used in the pre fitting stage to estimate
non-linear feature interactions. Required for crystals models.
Prefitting typically uses the same dataset as the main training, but
with fewer epochs.
model_dir: Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into a estimator to
continue training a previously saved model.
label_dimension: Number of regression targets per example. This is the
size of the last dimension of the labels and logits `Tensor` objects
(typically, these have shape `[batch_size, label_dimension]`).
optimizer: An instance of `tf.Optimizer` used to train the model. Can also
be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or
callable. Defaults to Adagrad optimizer.
prefitting_optimizer: An instance of `tf.Optimizer` used to train the
model during the pre-fitting stage. Can also be a string (one of
'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to
Adagrad optimizer.
prefitting_steps: Number of steps for which to pretraing train the model
during the prefitting stage. If None, train forever or train until
prefitting_input_fn generates the tf.errors.OutOfRange error or
StopIteration exception.
config: `RunConfig` object to configure the runtime settings.
warm_start_from: A string filepath to a checkpoint to warm-start from, or
a `WarmStartSettings` object to fully configure warm-starting. If the
string filepath is provided instead of a `WarmStartSettings`, then all
weights are warm-started, and it is assumed that vocabularies and Tensor
names are unchanged.
dtype: dtype of layers used in the model.
"""
config = estimator_lib.maybe_overwrite_model_dir_and_session_config(
config, model_dir)
model_dir = config.model_dir
model_config = copy.deepcopy(model_config)
_update_by_feature_columns(model_config, feature_columns)
_finalize_keypoints(
model_config=model_config,
config=config,
feature_columns=feature_columns,
feature_analysis_input_fn=feature_analysis_input_fn,
logits_output=True)
_verify_config(model_config, feature_columns)
_finalize_model_structure(
label_dimension=label_dimension,
feature_columns=feature_columns,
head=head,
model_config=model_config,
prefitting_input_fn=prefitting_input_fn,
prefitting_optimizer=prefitting_optimizer,
prefitting_steps=prefitting_steps,
model_dir=model_dir,
config=config,
warm_start_from=warm_start_from,
dtype=dtype)
model_fn = _get_model_fn(
label_dimension=label_dimension,
feature_columns=feature_columns,
head=head,
model_config=model_config,
optimizer=optimizer,
dtype=dtype)
super(CannedEstimator, self).__init__(
model_fn=model_fn,
model_dir=model_dir,
config=config,
warm_start_from=warm_start_from)
class CannedClassifier(estimator_lib.EstimatorV2):
"""Canned classifier for TensorFlow lattice models.
Creates a classifier for the model architecutre specified by the
`model_config`, which should be one of those defined in `tfl.configs`.
Calclulation of feature quantiles for input keypoint initialization is done
using `feature_analysis_input_fn`. If this auxiliary input fn is not provided,
all keypoint values should be explicitly provided via the `model_config`.
Training loss is softmax cross-entropy as defined for the default
TF classificaiton head.
Example:
```python
model_config = tfl.configs.CalibratedLatticeConfig(...)
feature_analysis_input_fn = create_input_fn(num_epochs=1, ...)
train_input_fn = create_input_fn(num_epochs=100, ...)
estimator = tfl.estimators.CannedClassifier(
feature_columns=feature_columns,
model_config=model_config,
feature_analysis_input_fn=feature_analysis_input_fn)
estimator.train(input_fn=train_input_fn)
```
"""
def __init__(self,
model_config,
feature_columns,
feature_analysis_input_fn=None,
prefitting_input_fn=None,
model_dir=None,
n_classes=2,
weight_column=None,
label_vocabulary=None,
optimizer='Adagrad',
prefitting_optimizer='Adagrad',
prefitting_steps=None,
config=None,
warm_start_from=None,
loss_reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
loss_fn=None,
dtype=tf.float32):
"""Initializes a `CannedClassifier` instance.
Args:
model_config: Model configuration object describing model architecutre.
Should be one of the model configs in `tfl.configs`.
feature_columns: An iterable containing all the feature columns used by
the model.
feature_analysis_input_fn: An input_fn used to calculate statistics about
features and labels in order to setup calibration keypoint and values.
prefitting_input_fn: An input_fn used in the pre fitting stage to estimate
non-linear feature interactions. Required for crystals models.
Prefitting typically uses the same dataset as the main training, but
with fewer epochs.
model_dir: Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into a estimator to
continue training a previously saved model.
n_classes: Number of label classes. Defaults to 2, namely binary
classification. Must be > 1.
weight_column: A string or a `_NumericColumn` created by
`tf.feature_column.numeric_column` defining feature column representing
weights. It is used to down weight or boost examples during training. It
will be multiplied by the loss of the example. If it is a string, it is
used as a key to fetch weight tensor from the `features`. If it is a
`_NumericColumn`, raw tensor is fetched by key `weight_column.key`, then
weight_column.normalizer_fn is applied on it to get weight tensor.
label_vocabulary: A list of strings represents possible label values. If
given, labels must be string type and have any value in
`label_vocabulary`. If it is not given, that means labels are already
encoded as integer or float within [0, 1] for `n_classes=2` and encoded
as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 . Also
there will be errors if vocabulary is not provided and labels are
string.
optimizer: An instance of `tf.Optimizer` used to train the model. Can also
be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or
callable. Defaults to Adagrad optimizer.
prefitting_optimizer: An instance of `tf.Optimizer` used to train the
model during the pre-fitting stage. Can also be a string (one of
'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to
Adagrad optimizer.
prefitting_steps: Number of steps for which to pretraing train the model
during the prefitting stage. If None, train forever or train until
prefitting_input_fn generates the tf.errors.OutOfRange error or
StopIteration exception.
config: `RunConfig` object to configure the runtime settings.
warm_start_from: A string filepath to a checkpoint to warm-start from, or
a `WarmStartSettings` object to fully configure warm-starting. If the
string filepath is provided instead of a `WarmStartSettings`, then all
weights are warm-started, and it is assumed that vocabularies and Tensor
names are unchanged.
loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how
to reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`.
loss_fn: Optional loss function.
dtype: dtype of layers used in the model.
"""
config = estimator_lib.maybe_overwrite_model_dir_and_session_config(
config, model_dir)
model_dir = config.model_dir
if n_classes == 2:
head = binary_class_head.BinaryClassHead(
weight_column=weight_column,
label_vocabulary=label_vocabulary,
loss_reduction=loss_reduction,