Skip to content
Permalink
 
 
Cannot retrieve contributors at this time
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Linear Estimators."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import six
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator.canned import head as head_lib
from tensorflow.python.estimator.canned import optimizers
from tensorflow.python.feature_column import feature_column as feature_column_lib
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.losses import losses
from tensorflow.python.summary import summary
from tensorflow.python.training import ftrl
from tensorflow.python.util.tf_export import tf_export
# The default learning rate of 0.2 is a historical artifact of the initial
# implementation, but seems a reasonable choice.
_LEARNING_RATE = 0.2
def _get_default_optimizer(feature_columns):
learning_rate = min(_LEARNING_RATE, 1.0 / math.sqrt(len(feature_columns)))
return ftrl.FtrlOptimizer(learning_rate=learning_rate)
def _compute_fraction_of_zero(cols_to_vars):
"""Given a linear cols_to_vars dict, compute the fraction of zero weights.
Args:
cols_to_vars: A dictionary mapping FeatureColumns to lists of tf.Variables
like one returned from feature_column_lib.linear_model.
Returns:
The fraction of zeros (sparsity) in the linear model.
"""
all_weight_vars = []
for var_or_var_list in cols_to_vars.values():
# Skip empty-lists associated with columns that created no Variables.
if var_or_var_list:
all_weight_vars += [
array_ops.reshape(var, [-1]) for var in var_or_var_list
]
return nn.zero_fraction(array_ops.concat(all_weight_vars, axis=0))
def _linear_logit_fn_builder(units, feature_columns):
"""Function builder for a linear logit_fn.
Args:
units: An int indicating the dimension of the logit layer.
feature_columns: An iterable containing all the feature columns used by
the model.
Returns:
A logit_fn (see below).
"""
def linear_logit_fn(features):
"""Linear model logit_fn.
Args:
features: This is the first item returned from the `input_fn`
passed to `train`, `evaluate`, and `predict`. This should be a
single `Tensor` or `dict` of same.
Returns:
A `Tensor` representing the logits.
"""
cols_to_vars = {}
logits = feature_column_lib.linear_model(
features=features,
feature_columns=feature_columns,
units=units,
cols_to_vars=cols_to_vars)
bias = cols_to_vars.pop('bias')
if units > 1:
summary.histogram('bias', bias)
else:
# If units == 1, the bias value is a length-1 list of a scalar Tensor,
# so we should provide a scalar summary.
summary.scalar('bias', bias[0][0])
summary.scalar('fraction_of_zero_weights',
_compute_fraction_of_zero(cols_to_vars))
return logits
return linear_logit_fn
def _linear_model_fn(features, labels, mode, head, feature_columns, optimizer,
partitioner, config):
"""A model_fn for linear models that use a gradient-based optimizer.
Args:
features: dict of `Tensor`.
labels: `Tensor` of shape `[batch_size, logits_dimension]`.
mode: Defines whether this is training, evaluation or prediction.
See `ModeKeys`.
head: A `Head` instance.
feature_columns: An iterable containing all the feature columns used by
the model.
optimizer: string, `Optimizer` object, or callable that defines the
optimizer to use for training. If `None`, will use a FTRL optimizer.
partitioner: Partitioner for variables.
config: `RunConfig` object to configure the runtime settings.
Returns:
An `EstimatorSpec` instance.
Raises:
ValueError: mode or params are invalid, or features has the wrong type.
"""
if not isinstance(features, dict):
raise ValueError('features should be a dictionary of `Tensor`s. '
'Given type: {}'.format(type(features)))
optimizer = optimizers.get_optimizer_instance(
optimizer or _get_default_optimizer(feature_columns),
learning_rate=_LEARNING_RATE)
num_ps_replicas = config.num_ps_replicas if config else 0
partitioner = partitioner or (
partitioned_variables.min_max_variable_partitioner(
max_partitions=num_ps_replicas,
min_slice_size=64 << 20))
with variable_scope.variable_scope(
'linear',
values=tuple(six.itervalues(features)),
partitioner=partitioner):
logit_fn = _linear_logit_fn_builder(
units=head.logits_dimension, feature_columns=feature_columns)
logits = logit_fn(features=features)
return head.create_estimator_spec(
features=features,
mode=mode,
labels=labels,
optimizer=optimizer,
logits=logits)
@tf_export('estimator.LinearClassifier')
class LinearClassifier(estimator.Estimator):
"""Linear classifier model.
Train a linear model to classify instances into one of multiple possible
classes. When number of possible classes is 2, this is binary classification.
Example:
```python
categorical_column_a = categorical_column_with_hash_bucket(...)
categorical_column_b = categorical_column_with_hash_bucket(...)
categorical_feature_a_x_categorical_feature_b = crossed_column(...)
# Estimator using the default optimizer.
estimator = LinearClassifier(
feature_columns=[categorical_column_a,
categorical_feature_a_x_categorical_feature_b])
# Or estimator using the FTRL optimizer with regularization.
estimator = LinearClassifier(
feature_columns=[categorical_column_a,
categorical_feature_a_x_categorical_feature_b],
optimizer=tf.train.FtrlOptimizer(
learning_rate=0.1,
l1_regularization_strength=0.001
))
# Or estimator with warm-starting from a previous checkpoint.
estimator = LinearClassifier(
feature_columns=[categorical_column_a,
categorical_feature_a_x_categorical_feature_b],
warm_start_from="/path/to/checkpoint/dir")
# Input builders
def input_fn_train: # returns x, y (where y represents label's class index).
...
def input_fn_eval: # returns x, y (where y represents label's class index).
...
estimator.train(input_fn=input_fn_train)
estimator.evaluate(input_fn=input_fn_eval)
estimator.predict(input_fn=input_fn_predict)
```
Input of `train` and `evaluate` should have following features,
otherwise there will be a `KeyError`:
* if `weight_column` is not `None`, a feature with
`key=weight_column` whose value is a `Tensor`.
* for each `column` in `feature_columns`:
- if `column` is a `SparseColumn`, a feature with `key=column.name`
whose `value` is a `SparseTensor`.
- if `column` is a `WeightedSparseColumn`, two features: the first with
`key` the id column name, the second with `key` the weight column name.
Both features' `value` must be a `SparseTensor`.
- if `column` is a `RealValuedColumn`, a feature with `key=column.name`
whose `value` is a `Tensor`.
Loss is calculated by using softmax cross entropy.
@compatibility(eager)
Estimators are not compatible with eager execution.
@end_compatibility
"""
def __init__(self,
feature_columns,
model_dir=None,
n_classes=2,
weight_column=None,
label_vocabulary=None,
optimizer='Ftrl',
config=None,
partitioner=None,
warm_start_from=None,
loss_reduction=losses.Reduction.SUM):
"""Construct a `LinearClassifier` estimator object.
Args:
feature_columns: An iterable containing all the feature columns used by
the model. All items in the set should be instances of classes derived
from `FeatureColumn`.
model_dir: Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into a estimator
to continue training a previously saved model.
n_classes: number of label classes. Default is binary classification.
Note that class labels are integers representing the class index (i.e.
values from 0 to n_classes-1). For arbitrary label values (e.g. string
labels), convert to class indices first.
weight_column: A string or a `_NumericColumn` created by
`tf.feature_column.numeric_column` defining feature column representing
weights. It is used to down weight or boost examples during training. It
will be multiplied by the loss of the example. If it is a string, it is
used as a key to fetch weight tensor from the `features`. If it is a
`_NumericColumn`, raw tensor is fetched by key `weight_column.key`,
then weight_column.normalizer_fn is applied on it to get weight tensor.
label_vocabulary: A list of strings represents possible label values. If
given, labels must be string type and have any value in
`label_vocabulary`. If it is not given, that means labels are
already encoded as integer or float within [0, 1] for `n_classes=2` and
encoded as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 .
Also there will be errors if vocabulary is not provided and labels are
string.
optimizer: An instance of `tf.Optimizer` used to train the model. Defaults
to FTRL optimizer.
config: `RunConfig` object to configure the runtime settings.
partitioner: Optional. Partitioner for input layer.
warm_start_from: A string filepath to a checkpoint to warm-start from, or
a `WarmStartSettings` object to fully configure warm-starting. If the
string filepath is provided instead of a `WarmStartSettings`, then all
weights and biases are warm-started, and it is assumed that vocabularies
and Tensor names are unchanged.
loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how
to reduce training loss over batch. Defaults to `SUM`.
Returns:
A `LinearClassifier` estimator.
Raises:
ValueError: if n_classes < 2.
"""
if n_classes == 2:
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint: disable=protected-access
weight_column=weight_column,
label_vocabulary=label_vocabulary,
loss_reduction=loss_reduction)
else:
head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access
n_classes, weight_column=weight_column,
label_vocabulary=label_vocabulary,
loss_reduction=loss_reduction)
def _model_fn(features, labels, mode, config):
"""Call the defined shared _linear_model_fn."""
return _linear_model_fn(
features=features,
labels=labels,
mode=mode,
head=head,
feature_columns=tuple(feature_columns or []),
optimizer=optimizer,
partitioner=partitioner,
config=config)
super(LinearClassifier, self).__init__(
model_fn=_model_fn,
model_dir=model_dir,
config=config,
warm_start_from=warm_start_from)
@tf_export('estimator.LinearRegressor')
class LinearRegressor(estimator.Estimator):
"""An estimator for TensorFlow Linear regression problems.
Train a linear regression model to predict label value given observation of
feature values.
Example:
```python
categorical_column_a = categorical_column_with_hash_bucket(...)
categorical_column_b = categorical_column_with_hash_bucket(...)
categorical_feature_a_x_categorical_feature_b = crossed_column(...)
estimator = LinearRegressor(
feature_columns=[categorical_column_a,
categorical_feature_a_x_categorical_feature_b])
# Or estimator with warm-starting from a previous checkpoint.
estimator = LinearRegressor(
feature_columns=[categorical_column_a,
categorical_feature_a_x_categorical_feature_b],
warm_start_from="/path/to/checkpoint/dir")
# Input builders
def input_fn_train: # returns x, y
...
def input_fn_eval: # returns x, y
...
estimator.train(input_fn=input_fn_train)
estimator.evaluate(input_fn=input_fn_eval)
estimator.predict(input_fn=input_fn_predict)
```
Input of `train` and `evaluate` should have following features,
otherwise there will be a KeyError:
* if `weight_column` is not `None`:
key=weight_column, value=a `Tensor`
* for column in `feature_columns`:
- if isinstance(column, `SparseColumn`):
key=column.name, value=a `SparseTensor`
- if isinstance(column, `WeightedSparseColumn`):
{key=id column name, value=a `SparseTensor`,
key=weight column name, value=a `SparseTensor`}
- if isinstance(column, `RealValuedColumn`):
key=column.name, value=a `Tensor`
Loss is calculated by using mean squared error.
@compatibility(eager)
Estimators are not compatible with eager execution.
@end_compatibility
"""
def __init__(self,
feature_columns,
model_dir=None,
label_dimension=1,
weight_column=None,
optimizer='Ftrl',
config=None,
partitioner=None,
warm_start_from=None,
loss_reduction=losses.Reduction.SUM):
"""Initializes a `LinearRegressor` instance.
Args:
feature_columns: An iterable containing all the feature columns used by
the model. All items in the set should be instances of classes derived
from `FeatureColumn`.
model_dir: Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into a estimator
to continue training a previously saved model.
label_dimension: Number of regression targets per example. This is the
size of the last dimension of the labels and logits `Tensor` objects
(typically, these have shape `[batch_size, label_dimension]`).
weight_column: A string or a `_NumericColumn` created by
`tf.feature_column.numeric_column` defining feature column representing
weights. It is used to down weight or boost examples during training. It
will be multiplied by the loss of the example. If it is a string, it is
used as a key to fetch weight tensor from the `features`. If it is a
`_NumericColumn`, raw tensor is fetched by key `weight_column.key`,
then weight_column.normalizer_fn is applied on it to get weight tensor.
optimizer: An instance of `tf.Optimizer` used to train the model. Defaults
to FTRL optimizer.
config: `RunConfig` object to configure the runtime settings.
partitioner: Optional. Partitioner for input layer.
warm_start_from: A string filepath to a checkpoint to warm-start from, or
a `WarmStartSettings` object to fully configure warm-starting. If the
string filepath is provided instead of a `WarmStartSettings`, then all
weights and biases are warm-started, and it is assumed that vocabularies
and Tensor names are unchanged.
loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how
to reduce training loss over batch. Defaults to `SUM`.
"""
head = head_lib._regression_head_with_mean_squared_error_loss( # pylint: disable=protected-access
label_dimension=label_dimension, weight_column=weight_column,
loss_reduction=loss_reduction)
def _model_fn(features, labels, mode, config):
"""Call the defined shared _linear_model_fn."""
return _linear_model_fn(
features=features,
labels=labels,
mode=mode,
head=head,
feature_columns=tuple(feature_columns or []),
optimizer=optimizer,
partitioner=partitioner,
config=config)
super(LinearRegressor, self).__init__(
model_fn=_model_fn,
model_dir=model_dir,
config=config,
warm_start_from=warm_start_from)