Skip to content
This repository was archived by the owner on Apr 10, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions lucid/misc/io/showing.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,48 @@ def textured_mesh(mesh, texture, background='0xffffff'):
background = background,
)
_display_html(code)


def _strip_consts(graph_def, max_const_size=32):
"""Strip large constant values from graph_def.

This is mostly a utility function for graph(), and also originates here:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/tutorials/deepdream/deepdream.ipynb
"""
strip_def = tf.GraphDef()
for n0 in graph_def.node:
n = strip_def.node.add()
n.MergeFrom(n0)
if n.op == 'Const':
tensor = n.attr['value'].tensor
size = len(tensor.tensor_content)
if size > max_const_size:
tensor.tensor_content = tf.compat.as_bytes("<stripped %d bytes>"%size)
return strip_def


def graph(graph_def, max_const_size=32):
"""Visualize a TensorFlow graph.

This function was originally found in this notebook (also Apache licensed):
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/tutorials/deepdream/deepdream.ipynb
"""
if hasattr(graph_def, 'as_graph_def'):
graph_def = graph_def.as_graph_def()
strip_def = _strip_consts(graph_def, max_const_size=max_const_size)
code = """
<script>
function load() {{
document.getElementById("{id}").pbtxt = {data};
}}
</script>
<link rel="import" href="https://tensorboard.appspot.com/tf-graph-basic.build.html" onload=load()>
<div style="height:600px">
<tf-graph-basic id="{id}"></tf-graph-basic>
</div>
""".format(data=repr(str(strip_def)), id='graph'+str(np.random.rand()))

iframe = """
<iframe seamless style="width:100%; height:620px; border: none;" srcdoc="{}"></iframe>
""".format(code.replace('"', '&quot;'))
_display_html(iframe)
8 changes: 8 additions & 0 deletions lucid/modelzoo/vision_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import tensorflow as tf
from lucid.modelzoo.util import load_text_labels, load_graphdef, forget_xy
from lucid.misc.io import load
import lucid.misc.io.showing as showing

class Model(object):
"""Base pretrained model importer."""
Expand Down Expand Up @@ -55,6 +56,8 @@ def create_input(self, t_input=None, forget_xy_shape=True):

def import_graph(self, t_input=None, scope='import', forget_xy_shape=True):
"""Import model GraphDef into the current graph."""
if self.graph_def is None:
raise Exception("Model.import_graph(): Must load graph def before importing it.")
graph = tf.get_default_graph()
assert graph.unique_name(scope, False) == scope, (
'Scope "%s" already exists. Provide explicit scope names when '
Expand All @@ -63,6 +66,11 @@ def import_graph(self, t_input=None, scope='import', forget_xy_shape=True):
tf.import_graph_def(
self.graph_def, {self.input_name: t_prep_input}, name=scope)
self.post_import(scope)

def show_graph(self):
if self.graph_def is None:
raise Exception("Model.show_graph(): Must load graph def before showing it.")
showing.graph(self.graph_def)


class SerializedModel(Model):
Expand Down