Skip to content
This repository has been archived by the owner on Apr 10, 2024. It is now read-only.

Commit

Permalink
Merge pull request #152 from tensorflow/save_experiment
Browse files Browse the repository at this point in the history
Experimental API for saving and loading models
  • Loading branch information
ludwigschubert committed Apr 18, 2019
2 parents b223f73 + 3503a5b commit d1a1e2e
Show file tree
Hide file tree
Showing 21 changed files with 396 additions and 46 deletions.
6 changes: 6 additions & 0 deletions CONTRIBUTING.md
Expand Up @@ -54,6 +54,12 @@ via the `-e` flag: `tox -e py27`.
After adding dependencies to `setup.py`, run tox with the `--recreate` flag to
update the environments' dependencies.

If you prefer to run tests directly with pytest, ensure you manually install the test dependencies. Within the lucid repository, run:

```
pip install .[test]
```

#### During Development

If you'd like to develop using [TDD](https://en.wikipedia.org/wiki/Test-driven_development), we recommend calling the tests you're currently working on [using `pytest` directly](https://docs.pytest.org/en/latest/usage.html), e.g. `python -m pytest tests/path/to/your/test.py`. Please don't forget to run all tests using `tox` before submitting a PR, though!
Expand Down
5 changes: 5 additions & 0 deletions lucid/__init__.py
Expand Up @@ -21,10 +21,15 @@
"""

import logging
import warnings

logging.basicConfig(level=logging.WARN)
del logging

# silence unnecessarily loud TF warnings
warnings.filterwarnings("ignore", category=DeprecationWarning, module="tensorflow")
warnings.filterwarnings("ignore", module="tensorflow.core.platform.cpu_feature_guard")

# Lucid uses a fixed random seed for reproducability. Use to seed sources of randomness.
seed = 0

Expand Down
14 changes: 12 additions & 2 deletions lucid/misc/io/loading.py
Expand Up @@ -32,6 +32,7 @@
from google.protobuf.message import DecodeError

from lucid.misc.io.reading import read_handle
from lucid import modelzoo


# create logger with module name, e.g. lucid.misc.io.reading
Expand Down Expand Up @@ -90,8 +91,17 @@ def _load_text(handle, split=False, encoding="utf-8"):

def _load_graphdef_protobuf(handle, **kwargs):
"""Load GraphDef from a binary proto file."""
del kwargs
return tf.GraphDef.FromString(handle.read())
# as_graph_def
graph_def = tf.GraphDef.FromString(handle.read())

# check if this is a lucid-saved model
# metadata = modelzoo.util.extract_metadata(graph_def)
# if metadata is not None:
# url = handle.name
# return modelzoo.vision_base.Model.load_from_metadata(url, metadata)

# else return a normal graph_def
return graph_def


loaders = {
Expand Down
9 changes: 9 additions & 0 deletions lucid/misc/io/saving.py
Expand Up @@ -112,6 +112,14 @@ def save_txt(object, handle, **kwargs):
handle.write(line)


def save_pb(object, handle, **kwargs):
try:
handle.write(object.SerializeToString())
except AttributeError as e:
warnings.warn("`save_protobuf` failed for object {}. Re-raising original exception.".format(object))
raise e


savers = {
".png": save_img,
".jpg": save_img,
Expand All @@ -120,6 +128,7 @@ def save_txt(object, handle, **kwargs):
".npz": save_npz,
".json": save_json,
".txt": save_txt,
".pb": save_pb,
}


Expand Down
3 changes: 2 additions & 1 deletion lucid/misc/io/writing.py
Expand Up @@ -30,7 +30,8 @@
def _supports_make_dirs(path):
"""Whether this path implies a storage system that supports and requires
intermediate directories to be created explicitly."""
return not path.startswith("/bigstore")
prefixes = ["/bigstore", "gs://"]
return not any(path.startswith(prefix) for prefix in prefixes)


def _supports_binary_writing(path):
Expand Down
2 changes: 1 addition & 1 deletion lucid/modelzoo/other_models/InceptionV1.py
Expand Up @@ -54,7 +54,7 @@ class InceptionV1(Model):
dataset = 'ImageNet'
image_shape = [224, 224, 3]
image_value_range = (-117, 255-117)
input_name = 'input:0'
input_name = 'input'

def post_import(self, scope):
_populate_inception_bottlenecks(scope)
Expand Down
94 changes: 94 additions & 0 deletions lucid/modelzoo/util.py
Expand Up @@ -18,13 +18,18 @@
from __future__ import absolute_import, division, print_function

import tensorflow as tf
import json
from google.protobuf.message import DecodeError
import logging
import warnings
from collections import defaultdict
from itertools import chain

# create logger with module name, e.g. lucid.misc.io.reading
log = logging.getLogger(__name__)

from lucid.misc.io import load
from lucid.misc.io.saving import NumpyJSONEncoder


def load_text_labels(labels_path):
Expand All @@ -51,3 +56,92 @@ def forget_xy(t):
"""
shape = (t.shape[0], None, None, t.shape[3])
return tf.placeholder_with_default(t, shape)


def frozen_default_graph_def(input_node_names, output_node_names):
"""Return frozen and simplified graph_def of default graph."""

sess = tf.get_default_session()
input_graph_def = tf.get_default_graph().as_graph_def()

pruned_graph = tf.graph_util.remove_training_nodes(
input_graph_def, protected_nodes=(output_node_names + input_node_names)
)
pruned_graph = tf.graph_util.extract_sub_graph(pruned_graph, output_node_names)

# remove explicit device assignments
for node in pruned_graph.node:
node.device = ""

all_variable_names = [v.op.name for v in tf.global_variables()]
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess=sess,
input_graph_def=pruned_graph,
output_node_names=output_node_names,
variable_names_whitelist=all_variable_names,
)

return output_graph_def


metadata_node_name = "lucid_metadata_json"

def infuse_metadata(graph_def, info):
"""Embed meta data as a string constant in a TF graph.
This function takes info, converts it into json, and embeds
it in graph_def as a constant op called `__lucid_metadata_json`.
"""
temp_graph = tf.Graph()
with temp_graph.as_default():
tf.constant(json.dumps(info, cls=NumpyJSONEncoder), name=metadata_node_name)
meta_node = temp_graph.as_graph_def().node[0]
graph_def.node.extend([meta_node])


def extract_metadata(graph_def):
"""Attempt to extract meta data hidden in graph_def.
Looks for a `__lucid_metadata_json` constant string op.
If present, extract it's content and convert it from json to python.
If not, returns None.
"""
meta_matches = [n for n in graph_def.node if n.name==metadata_node_name]
if meta_matches:
assert len(meta_matches) == 1, "found more than 1 lucid metadata node!"
meta_tensor = meta_matches[0].attr['value'].tensor
return json.loads(meta_tensor.string_val[0])
else:
return None


# TODO: merge with pretty_graph's Graph class. Until then, only use this internally
class GraphDefHelper(object):
"""Allows constant time lookups of graphdef nodes by common properties."""

def __init__(self, graph_def):
self.graph_def = graph_def
self.by_op = defaultdict(list)
self.by_name = dict()
self.by_input = defaultdict(list)
for node in graph_def.node:
self.by_op[node.op].append(node)
assert node.name not in self.by_name # names should be unique I guess?
self.by_name[node.name] = node
for input_name in node.input:
self.by_input[input_name].append(node)


def neighborhood(self, node, degree=4):
"""Am I really handcoding graph traversal please no"""
assert self.by_name[node.name] == node
already_visited = frontier = set([node.name])
for _ in range(degree):
neighbor_names = set()
for node_name in frontier:
outgoing = set(n.name for n in self.by_input[node_name])
incoming = set(self.by_name[node_name].input)
neighbor_names |= incoming | outgoing
frontier = neighbor_names - already_visited
already_visited |= neighbor_names
return [self.by_name[name] for name in already_visited]

0 comments on commit d1a1e2e

Please sign in to comment.