Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] [tf.learn] API restructure #2551

Closed
wants to merge 12 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 13 additions & 0 deletions tensorflow/contrib/learn/BUILD
Expand Up @@ -196,6 +196,19 @@ py_test(
],
)

py_test(
name = "classifier_test",
size = "small",
srcs = ["python/learn/estimators/classifier_test.py"],
srcs_version = "PY2AND3",
deps = [
":learn",
"//tensorflow:tensorflow_py",
"//tensorflow/python:framework_test_lib",
],
)


py_test(
name = "dnn_linear_combined_test",
size = "medium",
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/contrib/learn/python/learn/estimators/__init__.py
Expand Up @@ -23,6 +23,7 @@
from tensorflow.contrib.learn.python.learn.estimators.autoencoder import TensorFlowDNNAutoencoder
from tensorflow.contrib.learn.python.learn.estimators.base import TensorFlowBaseTransformer
from tensorflow.contrib.learn.python.learn.estimators.base import TensorFlowEstimator
from tensorflow.contrib.learn.python.learn.estimators.classifier import Classifier
from tensorflow.contrib.learn.python.learn.estimators.dnn import DNNClassifier
from tensorflow.contrib.learn.python.learn.estimators.dnn import DNNRegressor
from tensorflow.contrib.learn.python.learn.estimators.dnn import TensorFlowDNNClassifier
Expand All @@ -41,3 +42,4 @@
from tensorflow.contrib.learn.python.learn.estimators.rnn import TensorFlowRNNClassifier
from tensorflow.contrib.learn.python.learn.estimators.rnn import TensorFlowRNNRegressor
from tensorflow.contrib.learn.python.learn.estimators.run_config import RunConfig
from tensorflow.contrib.learn.python.learn.estimators._sklearn import NotFittedError
7 changes: 3 additions & 4 deletions tensorflow/contrib/learn/python/learn/estimators/_sklearn.py
Expand Up @@ -129,7 +129,7 @@ class _TransformerMixin():
"""Mixin class for all transformer estimators."""


class _NotFittedError(ValueError, AttributeError):
class NotFittedError(ValueError, AttributeError):
"""Exception class to raise if estimator is used before fitting.

This class inherits from both ValueError and AttributeError to help with
Expand Down Expand Up @@ -175,7 +175,7 @@ def _train_test_split(*args, **options):
train_size = 0.75
elif train_size is None:
train_size = 1 - test_size
train_size *= args[0].shape[0]
train_size = int(train_size * args[0].shape[0])

np.random.seed(random_state)
indices = np.random.permutation(args[0].shape[0])
Expand All @@ -199,14 +199,13 @@ def _train_test_split(*args, **options):
try:
from sklearn.utils.validation import NotFittedError
except ImportError:
NotFittedError = _NotFittedError
pass
else:
# Naive implementations of sklearn classes and functions.
BaseEstimator = _BaseEstimator
ClassifierMixin = _ClassifierMixin
RegressorMixin = _RegressorMixin
TransformerMixin = _TransformerMixin
NotFittedError = _NotFittedError
accuracy_score = _accuracy_score
log_loss = None
mean_squared_error = _mean_squared_error
Expand Down
169 changes: 120 additions & 49 deletions tensorflow/contrib/learn/python/learn/estimators/base.py
Expand Up @@ -22,17 +22,24 @@

import json
import os
import types

import six
from six import string_types

import numpy as np

from tensorflow.contrib import framework as contrib_framework
from tensorflow.contrib import layers
from tensorflow.contrib.learn.python.learn.estimators import _sklearn
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators._sklearn import NotFittedError
from tensorflow.contrib.learn.python.learn.io.data_feeder import setup_train_data_feeder
from tensorflow.contrib.learn.python.learn.utils import checkpoints

from tensorflow.python.framework import ops
from tensorflow.python.ops import constant_op
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging


def _write_with_backup(filename, content):
Expand All @@ -54,28 +61,6 @@ def _copy_dir(dir_in, dir_out):
gfile.Copy(name_in, name_out, overwrite=True)


def _new_tf_model_fn(model_fn, class_weight):
"""Backward compatibility way of adding class weight and IS_TRAINING.

TODO(ipolosukhin): Remove this function after new layers are available.
Specifically:
* dropout and batch norm should work via update ops.
* class weights should be retrieved from weights column or hparams.

Args:
model_fn: Core model function.
class_weight: Class weight.
Returns:
Model function.
"""
def _model_fn(features, targets, mode):
ops.get_default_graph().add_to_collection('IS_TRAINING', mode == 'train')
if class_weight is not None:
constant_op.constant(class_weight, name='class_weight')
return model_fn(features, targets)
return _model_fn


class TensorFlowEstimator(estimator.Estimator):
"""Base class for all TensorFlow estimators.

Expand Down Expand Up @@ -122,12 +107,17 @@ def __init__(self,
continue_training=False,
config=None,
verbose=1):
self.class_weight = class_weight
self.learning_rate = learning_rate
self.clip_gradients = clip_gradients
if isinstance(optimizer, six.string_types):
if optimizer not in layers.OPTIMIZER_CLS_NAMES:
raise ValueError(
'Optimizer name should be one of [%s], you provided %s.' %
(', '.join(layers.OPTIMIZER_CLS_NAMES), optimizer))
self.optimizer = optimizer
super(TensorFlowEstimator, self).__init__(
model_fn=_new_tf_model_fn(model_fn, class_weight),
classification=n_classes > 1,
learning_rate=learning_rate,
optimizer=optimizer,
clip_gradients=clip_gradients,
model_fn=self._get_model_fn(model_fn),
config=config)
self.n_classes = n_classes
self.batch_size = batch_size
Expand Down Expand Up @@ -275,27 +265,6 @@ def get_tensor(self, name):
"""
return self._graph.get_tensor_by_name(name)

def get_tensor_value(self, name):
"""Returns value of the tensor give by name.

Args:
name: string, name of the tensor.

Returns:
Numpy array - value of the tensor.
"""
if name.endswith(':0'):
name = name[:-2]
return checkpoints.load_variable(self.model_dir, name)

def get_variable_names(self):
"""Returns list of all variable names in this model.

Returns:
List of names.
"""
return [name for name, _ in checkpoints.list_variables(self.model_dir)]

def save(self, path):
"""Saves checkpoints and graph to given path.

Expand Down Expand Up @@ -383,6 +352,41 @@ def restore(cls, path, config=None):
result._restore(path)
return result

def _get_model_fn(self, model_fn):
"""Backward compatibility way of adding class weight and IS_TRAINING.

TODO(ipolosukhin): Remove this function after new layers are available.
Specifically:
* dropout and batch norm should work via update ops.
* class weights should be retrieved from weights column or hparams.

Args:
model_fn: Core model function.
Returns:
Model function.
"""
def _model_fn(features, targets, mode):
ops.get_default_graph().add_to_collection('IS_TRAINING', mode == 'train')
if self.class_weight is not None:
constant_op.constant(self.class_weight, name='class_weight')
predictions, loss = model_fn(features, targets)
if isinstance(self.learning_rate, types.FunctionType):
learning_rate = self.learning_rate(contrib_framework.get_global_step())
else:
learning_rate = self.learning_rate
if isinstance(self.optimizer, types.FunctionType):
optimizer = self.optimizer(learning_rate)
else:
optimizer = self.optimizer
train_op = layers.optimize_loss(
loss,
contrib_framework.get_global_step(),
learning_rate=learning_rate,
optimizer=optimizer,
clip_gradients=self.clip_gradients)
return predictions, loss, train_op
return _model_fn


class TensorFlowBaseTransformer(TensorFlowEstimator, _sklearn.TransformerMixin):
"""TensorFlow Base Transformer class."""
Expand All @@ -400,3 +404,70 @@ def fit(self, X, y=None, monitor=None, logdir=None):
def fit_transform(self, X, y=None, monitor=None, logdir=None):
"""Fit transformer and transform X using trained transformer."""
return self.fit(X, y, monitor=None, logdir=None).transform(X)


class DeprecatedMixin(object):

def __init__(self, *args, **kwargs):
this_class = type(self).__name__
alternative_class = this_class[len('TensorFlow'):]
logging.warn(
"%s class is deprecated. Please consider using %s as an alternative.",
this_class, alternative_class)
# Handle deprecated arguments.
self.__deprecated_n_classes = kwargs.get('n_classes', 0)
if self.__deprecated_n_classes < 1 and 'n_classes' in kwargs:
kwargs.pop('n_classes')
self.batch_size = kwargs.pop('batch_size', 32)
self.steps = kwargs.pop('steps', 200)
if 'optimizer' in kwargs or 'learning_rate' in kwargs:
self.learning_rate = kwargs.pop('learning_rate', 0.1)
self.optimizer = kwargs.pop('optimizer', 'Adagrad')
if 'class_weight' in kwargs:
raise ValueError('Sorry we switched interface for providing class weights. '
'Please use weight column instead which provides more '
'granular control (per example).')
kwargs.pop('clip_gradients', 5.0)
if 'continue_training' in kwargs:
logging.info('continue_training argument in %s is now ignored.' %
this_class)
kwargs.pop('continue_training', False)
super(DeprecatedMixin, self).__init__(*args, **kwargs)

def fit(self, x, y, steps=None, batch_size=None, monitors=None, logdir=None):
if logdir is not None:
self._model_dir = logdir
return super(DeprecatedMixin, self).fit(x=x, y=y, steps=steps or self.steps,
batch_size=batch_size or self.batch_size, monitors=monitors)

def predict(self, x=None, input_fn=None, batch_size=None, outputs=None,
axis=1):
if x is not None:
predict_data_feeder = setup_train_data_feeder(
x, None, n_classes=None,
batch_size=batch_size or self.batch_size,
shuffle=False, epochs=1)
result = super(DeprecatedMixin, self)._infer_model(
input_fn=predict_data_feeder.input_builder,
feed_fn=predict_data_feeder.get_feed_dict_fn(),
outputs=outputs)
else:
result = super(DeprecatedMixin, self)._infer_model(
input_fn=input_fn, outputs=outputs)
if self.__deprecated_n_classes > 1 and axis is not None:
return np.argmax(result, axis)
return result

def predict_proba(self, x=None, input_fn=None, batch_size=None, outputs=None):
return self.predict(x=x, input_fn=input_fn, batch_size=batch_size,
outputs=outputs, axis=None)

def save(self, path):
"""Saves checkpoints and graph to given path.

Args:
path: Folder to save model to.
"""
# Copy model dir into new path.
_copy_dir(self.model_dir, path)

72 changes: 72 additions & 0 deletions tensorflow/contrib/learn/python/learn/estimators/classifier.py
@@ -0,0 +1,72 @@
# pylint: disable=g-bad-file-header
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Classifier class."""

from tensorflow.contrib import metrics as metrics_lib
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn


def get_classifier_metrics(n_classes):
return {('accuracy', 'classes'): metrics_lib.streaming_accuracy}


class Classifier(estimator.Estimator):
"""Classifier single output Estimator.

Given logits generating function, provides class / probabilities heads and
functions to work with them.
"""

CLASS_OUTPUT = 'classes'
PROBABILITY_OUTPUT = 'probabilities'

def __init__(self, model_fn, n_classes, model_dir=None, config=None):
"""

Args:
model_fn: (targets, predictions, mode) -> logits, loss, train_op
"""
self._n_classes = n_classes
self._logits_fn = model_fn
super(Classifier, self).__init__(model_fn=self._classifier_model,
model_dir=model_dir, config=None)

def evaluate(self, x=None, y=None, input_fn=None, batch_size=None,
steps=None, metrics=None):
metrics = metrics or get_classifier_metrics(self._n_classes)
return super(Classifier, self).evaluate(
x=x, y=y, input_fn=input_fn, batch_size=batch_size, steps=steps, metrics=metrics)

def predict(self, x=None, input_fn=None, batch_size=None):
return super(Classifier, self).predict(
x=x, input_fn=input_fn, batch_size=batch_size,
outputs=[self.CLASS_OUTPUT])[self.CLASS_OUTPUT]

def predict_proba(self, x=None, input_fn=None, batch_size=None):
return super(Classifier, self).predict(
x=x, input_fn=input_fn, batch_size=batch_size,
outputs=[self.PROBABILITY_OUTPUT])[self.PROBABILITY_OUTPUT]

def _classifier_model(self, features, targets, mode):
logits, loss, train_op = self._logits_fn(features, targets, mode)
return {
'classes': math_ops.argmax(logits, len(logits.get_shape()) - 1),
'probabilities': nn.softmax(logits)
}, loss, train_op