Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Branch 124956736 #2886

Merged
merged 13 commits into from Jun 15, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
43 changes: 43 additions & 0 deletions tensorflow/contrib/ios_examples/README.md
@@ -0,0 +1,43 @@
# TensorFlow iOS Examples

This folder contains examples of how to build applications for iOS devices using TensorFlow.

## Building the Examples

- You'll need Xcode 7.3 or later, with the command-line tools installed.

- Follow the instructions at [tensorflow/contrib/makefile](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/makefile) to compile a static library containing the core TensorFlow code.

- Download [Inception v1](https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip), and extract the label and graph files into the data folders inside both the simple and camera examples.

- Load the Xcode project inside the `simple` subfolder, and press Command-R to build and run it on the simulator or your connected device.

- You should see a single-screen app with a "Run Model" button. Tap that, and you should see some debug output appear below indicating that the example Grace Hopper image has been analyzed, with a military uniform recognized.

- Once that's been successfully run, make sure you have a real device connected and open up the Xcode project in the camera subfolder. Once you build and run that, you should get a live camera view that you can point at objects to get real-time recognition results.

## Troubleshooting

If you're hitting problems, here's a checklist of common things to investigate:

- Make sure that you've run the `download_dependencies.sh` and `compile_ios_protobuf.sh` scripts before you run `compile_ios_tensorflow`.

- Check that you have version 7.3 of Xcode.

- If there are Eigen errors, look inside the build settings of your Xcode project. In the Search Paths section, you'll see an Eigen include directory that changes with each version of the framework. You may need to update this to may the version in your tensorflow/contrib/makefile/downloads folder.

- If there's a complaint about no Session's registered, that means that the C++ global constructors that TensorFlow relies on for registration haven't been linked in properly. You'll have to make sure your project uses force_load, as described below.

## Creating your Own App

You'll need to update various settings in your app to link against TensorFlow. You can view them in the example projects, but here's a full rundown:

- The `compile_ios_tensorflow.sh' script builds a universal static library in tensorflow/contrib/makefile/gen/lib/libtensorflow-core.a. You'll need to add this to your linking build stage, and in Search Paths add tensorflow/contrib/makefile/gen/lib to the Library Search Paths setting.

- You'll also need to add libprotobuf.a and libprotobuf-lite.a from tensorflow/contrib/makefile/gen/protobuf_ios/lib to your Build Stages and Library Search Paths.

- The Header Search paths needs to contain the root folder of tensorflow, tensorflow/contrib/makefile/downloads/protobuf/src, tensorflow/contrib/makefile/downloads, tensorflow/contrib/makefile/downloads/eigen-eigen-<current Eigen hash>, and tensorflow/contrib/makefile/gen/proto.

- In the Linking section, you need to add `-force_load` followed by the path to the TensorFlow static library in the Other Linker Flags section. This ensures that the global C++ objects that are used to register important classes inside the library are not stripped out. To the linker, they can appear unused because no other code references the variables, but in fact their constructors have the important side effect of registering the class.

- The library doesn't currently support bitcode, so you'll need to disable that in your project settings.
116 changes: 86 additions & 30 deletions tensorflow/contrib/layers/python/layers/feature_column.py
Expand Up @@ -261,9 +261,14 @@ def to_weighted_sum(self,
weight_collections=None,
trainable=True):
return _create_embedding_lookup(
input_tensor, self.length, num_outputs,
_add_variable_collection(weight_collections), 0., self.combiner,
trainable, self.name + "_weights")
input_tensor=input_tensor,
vocab_size=self.length,
dimension=num_outputs,
weight_collections=_add_variable_collection(weight_collections),
initializer=init_ops.zeros_initializer,
combiner=self.combiner,
trainable=trainable,
name=self.name + "_weights")


class _SparseColumnIntegerized(_SparseColumn):
Expand Down Expand Up @@ -441,7 +446,7 @@ def sparse_column_with_keys(column_name,

class _EmbeddingColumn(_FeatureColumn, collections.namedtuple(
"_EmbeddingColumn",
["sparse_id_column", "dimension", "combiner", "stddev"])):
["sparse_id_column", "dimension", "combiner", "initializer"])):
"""Represents an embedding column.

Args:
Expand All @@ -455,15 +460,27 @@ class _EmbeddingColumn(_FeatureColumn, collections.namedtuple(
* "mean": do l1 normalization on features in the column
* "sqrtn": do l2 normalization on features in the column
For more information: `tf.embedding_lookup_sparse`.
stddev: the standard deviation to be used in embedding initialization.
Default is 1/sqrt(sparse_id_column.length).
initializer: A variable initializer function to be used in embedding
variable initialization. If not specified, defaults to
`tf.truncated_normal_initializer` with mean 0.0 and standard deviation
1/sqrt(sparse_id_column.length).
"""

def __new__(cls, sparse_id_column, dimension, combiner="mean", stddev=None):
if stddev is None:
def __new__(cls,
sparse_id_column,
dimension,
combiner="mean",
initializer=None):
if initializer is not None and not callable(initializer):
raise ValueError("initializer must be callable if specified.")
if initializer is None:
stddev = 1 / math.sqrt(sparse_id_column.length)
# TODO(b/25671353): Better initial value?
initializer = init_ops.truncated_normal_initializer(mean=0.0,
stddev=stddev)
return super(_EmbeddingColumn, cls).__new__(cls, sparse_id_column,
dimension, combiner, stddev)
dimension, combiner,
initializer)

@property
def name(self):
Expand All @@ -481,7 +498,22 @@ def config(self):
@property
def key(self):
"""Returns a string which will be used as a key when we do sorting."""
return "{}".format(self)
fields_values = []
# pylint: disable=protected-access
for k, v in self._asdict().items():
if k == "initializer":
# Excludes initializer from the key since we don't support allowing
# users to specify different initializers for the same embedding column.
# Special treatment is needed since the default str form of a
# function contains its address, which could introduce non-determinism
# in sorting.
continue
fields_values.append("{}={}".format(k, v))
# pylint: enable=protected-access

# This is effectively the same format as str(self), except with our special
# treatment.
return "_EmbeddingColumn(%s)" % ", ".join(fields_values)

def insert_transformed_feature(self, columns_to_tensors):
self.sparse_id_column.insert_transformed_feature(columns_to_tensors)
Expand All @@ -492,9 +524,14 @@ def to_dnn_input_layer(self,
weight_collections=None,
trainable=True):
output, _ = _create_embedding_lookup(
input_tensor, self.length, self.dimension,
_add_variable_collection(weight_collections), self.stddev,
self.combiner, trainable, self.name + "_weights")
input_tensor=input_tensor,
vocab_size=self.length,
dimension=self.dimension,
weight_collections=_add_variable_collection(weight_collections),
initializer=self.initializer,
combiner=self.combiner,
trainable=trainable,
name=self.name + "_weights")
return output

# pylint: disable=unused-argument
Expand All @@ -507,7 +544,10 @@ def to_weighted_sum(self,
"Please use sparse_column.".format(self))


def embedding_column(sparse_id_column, dimension, combiner="mean", stddev=None):
def embedding_column(sparse_id_column,
dimension,
combiner="mean",
initializer=None):
"""Creates an _EmbeddingColumn.

Args:
Expand All @@ -521,13 +561,15 @@ def embedding_column(sparse_id_column, dimension, combiner="mean", stddev=None):
* "mean": do l1 normalization
* "sqrtn": do l2 normalization
For more information: `tf.embedding_lookup_sparse`.
stddev: the standard deviation to be used in embedding initialization.
Default is 1/sqrt(sparse_id_column.length).
initializer: A variable initializer function to be used in embedding
variable initialization. If not specified, defaults to
`tf.truncated_normal_initializer` with mean 0.0 and standard deviation
1/sqrt(sparse_id_column.length).

Returns:
An _EmbeddingColumn.
"""
return _EmbeddingColumn(sparse_id_column, dimension, combiner, stddev)
return _EmbeddingColumn(sparse_id_column, dimension, combiner, initializer)


class _RealValuedColumn(_FeatureColumn, collections.namedtuple(
Expand Down Expand Up @@ -788,9 +830,14 @@ def to_weighted_sum(self,
vocab_size = self.length * self.source_column.dimension

return _create_embedding_lookup(
sparse_id_values, vocab_size, num_outputs,
_add_variable_collection(weight_collections), 0., "sum",
trainable, self.name + "_weights")
input_tensor=sparse_id_values,
vocab_size=vocab_size,
dimension=num_outputs,
weight_collections=_add_variable_collection(weight_collections),
initializer=init_ops.zeros_initializer,
combiner="sum",
trainable=trainable,
name=self.name + "_weights")


def bucketized_column(source_column, boundaries):
Expand Down Expand Up @@ -944,9 +991,14 @@ def to_weighted_sum(self,
weight_collections=None,
trainable=True):
return _create_embedding_lookup(
input_tensor, self.length, num_outputs,
_add_variable_collection(weight_collections), -1, self.combiner,
trainable, self.name + "_weights")
input_tensor=input_tensor,
vocab_size=self.length,
dimension=num_outputs,
weight_collections=_add_variable_collection(weight_collections),
initializer=init_ops.zeros_initializer,
combiner=self.combiner,
trainable=trainable,
name=self.name + "_weights")


def crossed_column(columns, hash_bucket_size, combiner="sum"):
Expand Down Expand Up @@ -1120,16 +1172,17 @@ def partitioner(vocab_size, embed_dim):


def _create_embedding_lookup(input_tensor, vocab_size, dimension,
weight_collections, stddev, combiner, trainable,
name):
weight_collections, initializer, combiner,
trainable, name):
"""Creates embedding variable and does a lookup.

Args:
input_tensor: A tensor which should contain sparse id to look up.
vocab_size: An integer specifying the vocabulary size.
dimension: An integer specifying the embedding vector dimension.
weight_collections: List of graph collections to which weights are added.
stddev: the standard deviation to be used in embedding initialization.
initializer: A variable initializer function to be used in embedding
variable initialization.
combiner: A string specifying how to reduce if the sparse column is
multivalent. Currently "mean", "sqrtn" and "sum" are supported:
* "sum": do not normalize features in the column
Expand All @@ -1142,14 +1195,17 @@ def _create_embedding_lookup(input_tensor, vocab_size, dimension,

Returns:
A Tensor with shape [batch_size, dimension] and embedding Variable.

Raises:
ValueError: If initializer is None or not callable.
"""
slicing = _max_size_embedding_partitioner()(vocab_size, dimension)
logging.info("Slicing=%s for name=%s, vocab_size=%d, embed_dim=%d",
str(slicing), name, vocab_size, dimension)
if stddev > 0:
initializer = init_ops.truncated_normal_initializer(stddev=stddev)
else:
initializer = init_ops.zeros_initializer
if not initializer:
raise ValueError("initializer must be defined.")
if not callable(initializer):
raise ValueError("initializer must be callable.")
embeddings = partitioned_variables.create_partitioned_variables(
shape=[vocab_size, dimension],
slicing=slicing,
Expand Down
23 changes: 21 additions & 2 deletions tensorflow/contrib/layers/python/layers/feature_column_ops.py
Expand Up @@ -86,14 +86,15 @@ def input_from_feature_columns(columns_to_tensors,
Raises:
ValueError: if FeatureColumn cannot be consumed by a neural network.
"""

check_feature_columns(feature_columns)
with variable_scope.variable_op_scope(columns_to_tensors.values(), name,
'input_from_feature_columns'):
output_tensors = []
transformer = _Transformer(columns_to_tensors)
if weight_collections:
weight_collections = list(set(list(weight_collections) +
[ops.GraphKeys.VARIABLES]))

for column in sorted(set(feature_columns), key=lambda x: x.key):
transformed_tensor = transformer.transform(column)
output_tensors.append(column.to_dnn_input_layer(
Expand Down Expand Up @@ -162,6 +163,7 @@ def weighted_sum_from_feature_columns(columns_to_tensors,
Raises:
ValueError: if FeatureColumn cannot be used for linear predictions.
"""
check_feature_columns(feature_columns)
with variable_scope.variable_op_scope(columns_to_tensors.values(), name,
'weighted_sum_from_feature_columns'):
output_tensors = []
Expand Down Expand Up @@ -225,7 +227,7 @@ def parse_feature_columns_from_examples(serialized,
Returns:
A `dict` mapping FeatureColumn to `Tensor` and `SparseTensor` values.
"""

check_feature_columns(feature_columns)
columns_to_tensors = parsing_ops.parse_example(
serialized=serialized,
features=fc.create_feature_spec_for_parsing(feature_columns),
Expand Down Expand Up @@ -278,6 +280,23 @@ def infer_real_valued_columns(features):
return feature_columns


def check_feature_columns(feature_columns):
"""Checks the validity of the set of FeatureColumns.

Args:
feature_columns: A set of instances or subclasses of FeatureColumn.

Raises:
ValueError: If there are duplicate feature column keys.
"""
seen_keys = set()
for f in feature_columns:
key = f.key
if key in seen_keys:
raise ValueError('Duplicate feature column key found: %s' % key)
seen_keys.add(key)


class _Transformer(object):
"""Handles all the transformations defined by FeatureColumn if needed.

Expand Down
Expand Up @@ -18,6 +18,7 @@
from __future__ import division
from __future__ import print_function

import numpy as np
import tensorflow as tf

from tensorflow.contrib.layers.python.layers import feature_column_ops
Expand Down Expand Up @@ -256,6 +257,48 @@ def testEmbeddingColumn(self):
tf.initialize_all_variables().run()
self.assertAllEqual(output.eval().shape, [2, 10])

def testEmbeddingColumnWithInitializer(self):
hashed_sparse = tf.contrib.layers.sparse_column_with_hash_bucket("wire", 10)
wire_tensor = tf.SparseTensor(values=["omar", "stringer", "marlo"],
indices=[[0, 0], [1, 0], [1, 1]],
shape=[2, 2])
features = {"wire": wire_tensor}
init_value = 133.7
embeded_sparse = tf.contrib.layers.embedding_column(
hashed_sparse,
10, initializer=tf.constant_initializer(init_value))
output = tf.contrib.layers.input_from_feature_columns(features,
[embeded_sparse])

with self.test_session():
tf.initialize_all_variables().run()
output_eval = output.eval()
self.assertAllEqual(output_eval.shape, [2, 10])
self.assertAllClose(output_eval, np.tile(init_value, [2, 10]))

def testEmbeddingColumnWithMultipleInitializers(self):
hashed_sparse = tf.contrib.layers.sparse_column_with_hash_bucket("wire", 10)
wire_tensor = tf.SparseTensor(values=["omar", "stringer", "marlo"],
indices=[[0, 0], [1, 0], [1, 1]],
shape=[2, 2])
features = {"wire": wire_tensor}
embedded_sparse = tf.contrib.layers.embedding_column(
hashed_sparse,
10,
initializer=tf.truncated_normal_initializer(mean=42,
stddev=1337))
embedded_sparse_alternate = tf.contrib.layers.embedding_column(
hashed_sparse,
10,
initializer=tf.truncated_normal_initializer(mean=1337,
stddev=42))

# Makes sure that trying to use different initializers with the same
# embedding column explicitly fails.
with self.assertRaises(ValueError):
tf.contrib.layers.input_from_feature_columns(
features, [embedded_sparse, embedded_sparse_alternate])

def testSparseColumn(self):
hashed_sparse = tf.contrib.layers.sparse_column_with_hash_bucket("wire", 10)
wire_tensor = tf.SparseTensor(values=["omar", "stringer", "marlo"],
Expand Down Expand Up @@ -294,7 +337,9 @@ def testAllColumns(self):
indices=[[0, 0], [1, 0], [2, 0]],
shape=[3, 1])
}
embeded_sparse = tf.contrib.layers.embedding_column(hashed_sparse, 10)
embeded_sparse = tf.contrib.layers.embedding_column(
hashed_sparse,
10, initializer=tf.constant_initializer(133.7))
output = tf.contrib.layers.input_from_feature_columns(
features, [real_valued, bucket, embeded_sparse])
with self.test_session():
Expand Down