Skip to content
Browse files

Add Model.suggest_save_args heuristic and test for uint8 placeholders

  • Loading branch information...
ludwigschubert committed Apr 18, 2019
1 parent 523d4dd commit 6edbda69aa385359a1fc5e3d89fd57ab39e3ce14
@@ -21,10 +21,15 @@

import logging
import warnings

del logging

# silence unnecessarily loud TF warnings
warnings.filterwarnings("ignore", category=DeprecationWarning, module="tensorflow")
warnings.filterwarnings("ignore", module="tensorflow.core.platform.cpu_feature_guard")

# Lucid uses a fixed random seed for reproducability. Use to seed sources of randomness.
seed = 0

@@ -54,7 +54,7 @@ class InceptionV1(Model):
dataset = 'ImageNet'
image_shape = [224, 224, 3]
image_value_range = (-117, 255-117)
input_name = 'input:0'
input_name = 'input'

def post_import(self, scope):
@@ -22,6 +22,8 @@
from google.protobuf.message import DecodeError
import logging
import warnings
from collections import defaultdict
from itertools import chain

# create logger with module name, e.g.
log = logging.getLogger(__name__)
@@ -111,3 +113,35 @@ def extract_metadata(graph_def):
return json.loads(meta_tensor.string_val[0])
return None

# TODO: merge with pretty_graph's Graph class. Until then, only use this internally
class GraphDefHelper(object):
"""Allows constant time lookups of graphdef nodes by common properties."""

def __init__(self, graph_def):
self.graph_def = graph_def
self.by_op = defaultdict(list)
self.by_name = dict()
self.by_input = defaultdict(list)
for node in graph_def.node:
assert not in self.by_name # names should be unique I guess?
self.by_name[] = node
for input_name in node.input:

def neighborhood(self, node, degree=4):
"""Am I really handcoding graph traversal please no"""
assert self.by_name[] == node
already_visited = frontier = set([])
for _ in range(degree):
neighbor_names = set()
for node_name in frontier:
outgoing = set( for n in self.by_input[node_name])
incoming = set(self.by_name[node_name].input)
neighbor_names |= incoming | outgoing
frontier = neighbor_names - already_visited
already_visited |= neighbor_names
return [self.by_name[name] for name in already_visited]
@@ -18,6 +18,7 @@
from os import path
import warnings
import logging
from itertools import chain

import tensorflow as tf
import numpy as np
@@ -204,30 +205,44 @@ def get_layer(self, name):

def suggest_save_args(graph_def=None):
# TODO: Check with uint8 placeholders
if graph_def is None:
graph_def = tf.get_default_graph().as_graph_def()

gdhelper = model_util.GraphDefHelper(graph_def)
inferred_info = dict.fromkeys(("input_name", "image_shape", "output_names", "image_value_range"))

nodes_of_op = lambda s: [ for n in graph_def.node if n.op == s]
node_by_name = lambda s: [n for n in graph_def.node if == s][0]
node_shape = lambda n: [dim.size for dim in n.attr['shape'].shape.dim]

potential_input_nodes = nodes_of_op("Placeholder")
output_nodes = nodes_of_op("Softmax")
node_shape = lambda n: [dim.size for dim in n.attr['shape'].shape.dim]
potential_input_nodes = gdhelper.by_op["Placeholder"]
output_nodes = [ for node in gdhelper.by_op["Softmax"]]

if len(potential_input_nodes) == 1:
input_name = potential_input_nodes[0]
print("Inferred: input_name = {} (because it was the only Placeholder in the graph_def)".format(input_name))
inferred_info["input_name"] = input_name
input_node = potential_input_nodes[0]
input_dtype = tf.dtypes.as_dtype(input_node.attr['dtype'].type)
if input_dtype.is_floating:
input_name =
print("Inferred: input_name = {} (because it was the only Placeholder in the graph_def)".format(input_name))
inferred_info["input_name"] = input_name
print("Warning: found a single Placeholder, but its dtype is {}. Lucid's parameterizations can only replace float dtypes. We're now scanning to see if you maybe divide this placeholder by 255 to get a float later in the graph...".format(str(input_node.attr['dtype']).strip()))
neighborhood = gdhelper.neighborhood(input_node, degree=5)
divs = [n for n in neighborhood if n.op == "RealDiv"]
consts = [n for n in neighborhood if n.op == "Const"]
magic_number_present = any(255 in c.attr['value'].tensor.int_val for c in consts)
if divs and magic_number_present:
if len(divs) == 1:
input_name = divs[0].name
print("Guessed: input_name = {} (because it's the only division by 255 near the only placeholder)".format(input_name))
inferred_info["input_name"] = input_name
image_value_range = (0,1)
print("Guessed: image_value_range = {} (because you're dividing by 255 near the only placeholder)".format(image_value_range))
inferred_info["image_value_range"] = (0,1)
warnings.warn("Could not infer input_name because there were multiple division ops near your the only placeholder. Candidates include: {}".format([ for n in divs]))
warnings.warn("Could not infer input_name.")
warnings.warn("Could not infer input_name because there were multiple or no Placeholders.")

if inferred_info["input_name"] is not None:
input_node = node_by_name(inferred_info["input_name"])
input_node = gdhelper.by_name[inferred_info["input_name"]]
shape = node_shape(input_node)
if len(shape) in [3,4]:
if len(shape) in (3,4):
if len(shape) == 4:
shape = shape[1:]
if -1 not in shape:
@@ -279,7 +294,10 @@ def save(save_url, input_name, output_names, image_shape, image_value_range):
def load(graphdef_url):
graph_def = load(graphdef_url)
metadata = model_util.extract_metadata(graph_def)
return Model.load_from_metadata(graphdef_url, metadata)
if metadata:
return Model.load_from_metadata(graphdef_url, metadata)
raise ValueError("Model.load was called on a GraphDef ({}) that does not contain Lucid's metadata node. Model.load only works for models saved via For the graphdef you're trying to load, you will need to provide custom metadata; see Model.load_from_metadata()".format(graphdef_url))

def load_from_metadata(model_url, metadata):
BIN +0 Bytes (100%) tests/fixtures/graphdef.pb
Binary file not shown.
BIN +0 Bytes (100%) tests/fixtures/minigraph.pb
Binary file not shown.
@@ -1,9 +1,11 @@
import pytest
import tensorflow as tf

from lucid.modelzoo.vision_base import Model
from lucid.modelzoo.vision_models import AlexNet, InceptionV1, InceptionV3_slim, ResnetV1_50_slim

def test_suggest_save_args(capsys, minimodel):
def test_suggest_save_args_happy_path(capsys, minimodel):
path = "./tests/fixtures/minigraph.pb"

with tf.Graph().as_default() as graph, tf.Session() as sess:
@@ -25,3 +27,35 @@ def test_suggest_save_args(capsys, minimodel):
assert "0.100" in repr(loaded_model.graph_def)

def test_suggest_save_args_int_input(capsys, minimodel):
with tf.Graph().as_default() as graph, tf.Session() as sess:
image_t = tf.placeholder(tf.uint8, shape=(32, 32, 3), name="input")
input_t = tf.math.divide(image_t, tf.constant(255, dtype=tf.uint8), name="divide")
_ = minimodel(input_t)

# ask for suggested arguments
inferred = Model.suggest_save_args()
captured = capsys.readouterr().out # captures stdout
assert "DT_UINT8" in captured
assert inferred["input_name"] == "divide"

@pytest.mark.parametrize("model_class", [AlexNet, InceptionV1, InceptionV3_slim, ResnetV1_50_slim])
def test_suggest_save_args_existing_graphs(capsys, model_class):
graph_def = model_class().graph_def

if model_class == InceptionV1: # has flexible input shape, can't be inferred
with pytest.warns(UserWarning):
inferred = Model.suggest_save_args(graph_def)
inferred = Model.suggest_save_args(graph_def)

assert model_class.input_name == inferred["input_name"]

if model_class != InceptionV1:
assert model_class.image_shape == inferred["image_shape"]

layer_names = [ for layer in model_class.layers]
for output_name in list(inferred["output_names"]):
assert output_name in layer_names

0 comments on commit 6edbda6

Please sign in to comment.
You can’t perform that action at this time.