diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 8b323c5e..905581a0 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -54,6 +54,12 @@ via the `-e` flag: `tox -e py27`. After adding dependencies to `setup.py`, run tox with the `--recreate` flag to update the environments' dependencies. +If you prefer to run tests directly with pytest, ensure you manually install the test dependencies. Within the lucid repository, run: + +``` +pip install .[test] +``` + #### During Development If you'd like to develop using [TDD](https://en.wikipedia.org/wiki/Test-driven_development), we recommend calling the tests you're currently working on [using `pytest` directly](https://docs.pytest.org/en/latest/usage.html), e.g. `python -m pytest tests/path/to/your/test.py`. Please don't forget to run all tests using `tox` before submitting a PR, though! diff --git a/lucid/__init__.py b/lucid/__init__.py index 2a3a685d..c8696a7a 100644 --- a/lucid/__init__.py +++ b/lucid/__init__.py @@ -21,10 +21,15 @@ """ import logging +import warnings logging.basicConfig(level=logging.WARN) del logging +# silence unnecessarily loud TF warnings +warnings.filterwarnings("ignore", category=DeprecationWarning, module="tensorflow") +warnings.filterwarnings("ignore", module="tensorflow.core.platform.cpu_feature_guard") + # Lucid uses a fixed random seed for reproducability. Use to seed sources of randomness. seed = 0 diff --git a/lucid/misc/io/loading.py b/lucid/misc/io/loading.py index 26dd214a..f9f799b4 100644 --- a/lucid/misc/io/loading.py +++ b/lucid/misc/io/loading.py @@ -32,6 +32,7 @@ from google.protobuf.message import DecodeError from lucid.misc.io.reading import read_handle +from lucid import modelzoo # create logger with module name, e.g. lucid.misc.io.reading @@ -90,8 +91,17 @@ def _load_text(handle, split=False, encoding="utf-8"): def _load_graphdef_protobuf(handle, **kwargs): """Load GraphDef from a binary proto file.""" - del kwargs - return tf.GraphDef.FromString(handle.read()) + # as_graph_def + graph_def = tf.GraphDef.FromString(handle.read()) + + # check if this is a lucid-saved model + # metadata = modelzoo.util.extract_metadata(graph_def) + # if metadata is not None: + # url = handle.name + # return modelzoo.vision_base.Model.load_from_metadata(url, metadata) + + # else return a normal graph_def + return graph_def loaders = { diff --git a/lucid/misc/io/saving.py b/lucid/misc/io/saving.py index b3bf61fd..fb9fe227 100644 --- a/lucid/misc/io/saving.py +++ b/lucid/misc/io/saving.py @@ -112,6 +112,14 @@ def save_txt(object, handle, **kwargs): handle.write(line) +def save_pb(object, handle, **kwargs): + try: + handle.write(object.SerializeToString()) + except AttributeError as e: + warnings.warn("`save_protobuf` failed for object {}. Re-raising original exception.".format(object)) + raise e + + savers = { ".png": save_img, ".jpg": save_img, @@ -120,6 +128,7 @@ def save_txt(object, handle, **kwargs): ".npz": save_npz, ".json": save_json, ".txt": save_txt, + ".pb": save_pb, } diff --git a/lucid/misc/io/writing.py b/lucid/misc/io/writing.py index 75fd29ab..6ac9bab5 100644 --- a/lucid/misc/io/writing.py +++ b/lucid/misc/io/writing.py @@ -30,7 +30,8 @@ def _supports_make_dirs(path): """Whether this path implies a storage system that supports and requires intermediate directories to be created explicitly.""" - return not path.startswith("/bigstore") + prefixes = ["/bigstore", "gs://"] + return not any(path.startswith(prefix) for prefix in prefixes) def _supports_binary_writing(path): diff --git a/lucid/modelzoo/other_models/InceptionV1.py b/lucid/modelzoo/other_models/InceptionV1.py index 20cc8613..c416c605 100644 --- a/lucid/modelzoo/other_models/InceptionV1.py +++ b/lucid/modelzoo/other_models/InceptionV1.py @@ -54,7 +54,7 @@ class InceptionV1(Model): dataset = 'ImageNet' image_shape = [224, 224, 3] image_value_range = (-117, 255-117) - input_name = 'input:0' + input_name = 'input' def post_import(self, scope): _populate_inception_bottlenecks(scope) diff --git a/lucid/modelzoo/util.py b/lucid/modelzoo/util.py index b85fa3e1..86cd91f3 100644 --- a/lucid/modelzoo/util.py +++ b/lucid/modelzoo/util.py @@ -18,13 +18,18 @@ from __future__ import absolute_import, division, print_function import tensorflow as tf +import json from google.protobuf.message import DecodeError import logging +import warnings +from collections import defaultdict +from itertools import chain # create logger with module name, e.g. lucid.misc.io.reading log = logging.getLogger(__name__) from lucid.misc.io import load +from lucid.misc.io.saving import NumpyJSONEncoder def load_text_labels(labels_path): @@ -51,3 +56,92 @@ def forget_xy(t): """ shape = (t.shape[0], None, None, t.shape[3]) return tf.placeholder_with_default(t, shape) + + +def frozen_default_graph_def(input_node_names, output_node_names): + """Return frozen and simplified graph_def of default graph.""" + + sess = tf.get_default_session() + input_graph_def = tf.get_default_graph().as_graph_def() + + pruned_graph = tf.graph_util.remove_training_nodes( + input_graph_def, protected_nodes=(output_node_names + input_node_names) + ) + pruned_graph = tf.graph_util.extract_sub_graph(pruned_graph, output_node_names) + + # remove explicit device assignments + for node in pruned_graph.node: + node.device = "" + + all_variable_names = [v.op.name for v in tf.global_variables()] + output_graph_def = tf.graph_util.convert_variables_to_constants( + sess=sess, + input_graph_def=pruned_graph, + output_node_names=output_node_names, + variable_names_whitelist=all_variable_names, + ) + + return output_graph_def + + +metadata_node_name = "lucid_metadata_json" + +def infuse_metadata(graph_def, info): + """Embed meta data as a string constant in a TF graph. + + This function takes info, converts it into json, and embeds + it in graph_def as a constant op called `__lucid_metadata_json`. + """ + temp_graph = tf.Graph() + with temp_graph.as_default(): + tf.constant(json.dumps(info, cls=NumpyJSONEncoder), name=metadata_node_name) + meta_node = temp_graph.as_graph_def().node[0] + graph_def.node.extend([meta_node]) + + +def extract_metadata(graph_def): + """Attempt to extract meta data hidden in graph_def. + + Looks for a `__lucid_metadata_json` constant string op. + If present, extract it's content and convert it from json to python. + If not, returns None. + """ + meta_matches = [n for n in graph_def.node if n.name==metadata_node_name] + if meta_matches: + assert len(meta_matches) == 1, "found more than 1 lucid metadata node!" + meta_tensor = meta_matches[0].attr['value'].tensor + return json.loads(meta_tensor.string_val[0]) + else: + return None + + +# TODO: merge with pretty_graph's Graph class. Until then, only use this internally +class GraphDefHelper(object): + """Allows constant time lookups of graphdef nodes by common properties.""" + + def __init__(self, graph_def): + self.graph_def = graph_def + self.by_op = defaultdict(list) + self.by_name = dict() + self.by_input = defaultdict(list) + for node in graph_def.node: + self.by_op[node.op].append(node) + assert node.name not in self.by_name # names should be unique I guess? + self.by_name[node.name] = node + for input_name in node.input: + self.by_input[input_name].append(node) + + + def neighborhood(self, node, degree=4): + """Am I really handcoding graph traversal please no""" + assert self.by_name[node.name] == node + already_visited = frontier = set([node.name]) + for _ in range(degree): + neighbor_names = set() + for node_name in frontier: + outgoing = set(n.name for n in self.by_input[node_name]) + incoming = set(self.by_name[node_name].input) + neighbor_names |= incoming | outgoing + frontier = neighbor_names - already_visited + already_visited |= neighbor_names + return [self.by_name[name] for name in already_visited] diff --git a/lucid/modelzoo/vision_base.py b/lucid/modelzoo/vision_base.py index 9061253a..a4af82c0 100644 --- a/lucid/modelzoo/vision_base.py +++ b/lucid/modelzoo/vision_base.py @@ -18,13 +18,14 @@ from os import path import warnings import logging +from itertools import chain import tensorflow as tf import numpy as np -from lucid.modelzoo.util import load_graphdef, forget_xy +from lucid.modelzoo import util as model_util from lucid.modelzoo.aligned_activations import get_aligned_activations as _get_aligned_activations -from lucid.misc.io import load +from lucid.misc.io import load, save import lucid.misc.io.showing as showing # ImageNet classes correspond to WordNet Synsets. @@ -94,7 +95,7 @@ def name(cls): class Model(with_metaclass(ModelPropertiesMetaClass, object)): - """Model allows importing pre-trained models.""" + """Model allows using pre-trained models.""" model_path = None labels_path = None @@ -107,6 +108,14 @@ class Model(with_metaclass(ModelPropertiesMetaClass, object)): _synsets = None _graph_def = None + # Avoid pickling the in-memory graph_def. + _blacklist = ['_graph_def'] + def __getstate__(self): + return {k: v for k, v in self.__dict__.items() if k not in self._blacklist} + + def __setstate__(self, state): + self.__dict__.update(state) + @property def labels(self): if not hasattr(self, 'labels_path') or self.labels_path is None: @@ -136,7 +145,7 @@ def name(self): @property def graph_def(self): if not self._graph_def: - self._graph_def = load_graphdef(self.model_path) + self._graph_def = model_util.load_graphdef(self.model_path) return self._graph_def def load_graphdef(self): @@ -157,7 +166,7 @@ def create_input(self, t_input=None, forget_xy_shape=True): if len(t_prep_input.shape) == 3: t_prep_input = tf.expand_dims(t_prep_input, 0) if forget_xy_shape: - t_prep_input = forget_xy(t_prep_input) + t_prep_input = model_util.forget_xy(t_prep_input) if hasattr(self, "is_BGR") and self.is_BGR is True: t_prep_input = tf.reverse(t_prep_input, [-1]) lo, hi = self.image_value_range @@ -194,45 +203,154 @@ def get_layer(self, name): layer_names = str([l.name for l in self.layers]) raise KeyError(key_error_message.format(name, layer_names)) + @staticmethod + def suggest_save_args(graph_def=None): + if graph_def is None: + graph_def = tf.get_default_graph().as_graph_def() + gdhelper = model_util.GraphDefHelper(graph_def) + inferred_info = dict.fromkeys(("input_name", "image_shape", "output_names", "image_value_range")) + node_shape = lambda n: [dim.size for dim in n.attr['shape'].shape.dim] + potential_input_nodes = gdhelper.by_op["Placeholder"] + output_nodes = [node.name for node in gdhelper.by_op["Softmax"]] + + if len(potential_input_nodes) == 1: + input_node = potential_input_nodes[0] + input_dtype = tf.dtypes.as_dtype(input_node.attr['dtype'].type) + if input_dtype.is_floating: + input_name = input_node.name + print("Inferred: input_name = {} (because it was the only Placeholder in the graph_def)".format(input_name)) + inferred_info["input_name"] = input_name + else: + print("Warning: found a single Placeholder, but its dtype is {}. Lucid's parameterizations can only replace float dtypes. We're now scanning to see if you maybe divide this placeholder by 255 to get a float later in the graph...".format(str(input_node.attr['dtype']).strip())) + neighborhood = gdhelper.neighborhood(input_node, degree=5) + divs = [n for n in neighborhood if n.op == "RealDiv"] + consts = [n for n in neighborhood if n.op == "Const"] + magic_number_present = any(255 in c.attr['value'].tensor.int_val for c in consts) + if divs and magic_number_present: + if len(divs) == 1: + input_name = divs[0].name + print("Guessed: input_name = {} (because it's the only division by 255 near the only placeholder)".format(input_name)) + inferred_info["input_name"] = input_name + image_value_range = (0,1) + print("Guessed: image_value_range = {} (because you're dividing by 255 near the only placeholder)".format(image_value_range)) + inferred_info["image_value_range"] = (0,1) + else: + warnings.warn("Could not infer input_name because there were multiple division ops near your the only placeholder. Candidates include: {}".format([n.name for n in divs])) + else: + warnings.warn("Could not infer input_name because there were multiple or no Placeholders.") + + if inferred_info["input_name"] is not None: + input_node = gdhelper.by_name[inferred_info["input_name"]] + shape = node_shape(input_node) + if len(shape) in (3,4): + if len(shape) == 4: + shape = shape[1:] + if -1 not in shape: + print("Inferred: image_shape = {}".format(shape)) + inferred_info["image_shape"] = shape + if inferred_info["image_shape"] is None: + warnings.warn("Could not infer image_shape.") + + if output_nodes: + print("Inferred: output_names = {} (because those are all the Softmax ops)".format(output_nodes)) + inferred_info["output_names"] = output_nodes + else: + warnings.warn("Could not infer output_names.") + + report = [] + report.append("# Please sanity check all inferred values before using this code.") + report.append("Incorrect `image_value_range` is the most common cause of feature visualization bugs! Most methods will fail silently with incorrect visualizations!") + report.append("Model.save(") + + suggestions = { + "input_name" : 'input', + "image_shape" : [224, 224, 3], + "output_names": ['logits'], + "image_value_range": "[-1, 1], [0, 1], [0, 255], or [-117, 138]" + } + for key, value in inferred_info.items(): + if value is not None: + report.append(" {}={!r},".format(key, value)) + else: + report.append(" {}=_, # TODO (eg. {!r})".format(key, suggestions[key])) + report.append(" )") + + print("\n".join(report)) + return inferred_info + + + @staticmethod + def save(save_url, input_name, output_names, image_shape, image_value_range): + metadata = { + "input_name" : input_name, + "image_shape" : image_shape, + "image_value_range": image_value_range, + } + + graph_def = model_util.frozen_default_graph_def([input_name], output_names) + model_util.infuse_metadata(graph_def, metadata) + save(graph_def, save_url) + + @staticmethod + def load(graphdef_url): + graph_def = load(graphdef_url) + metadata = model_util.extract_metadata(graph_def) + if metadata: + return Model.load_from_metadata(graphdef_url, metadata) + else: + raise ValueError("Model.load was called on a GraphDef ({}) that does not contain Lucid's metadata node. Model.load only works for models saved via Model.save. For the graphdef you're trying to load, you will need to provide custom metadata; see Model.load_from_metadata()".format(graphdef_url)) + + @staticmethod + def load_from_metadata(model_url, metadata): + class DynamicModel(Model): + model_path = model_url + input_name = metadata["input_name"] + image_shape = metadata["image_shape"] + image_value_range = metadata["image_value_range"] + return DynamicModel() + + @staticmethod + def load_from_manifest(manifest_url): + try: + manifest = load(manifest_url) + except Exception as e: + raise ValueError("Could not find manifest.json file in dir {}. Error: {}".format(manifest_url, e)) -class SerializedModel(Model): - """Allows importing various types of serialized models from a directory. + if manifest.get('type', 'frozen') == 'frozen': + manifest_folder = path.dirname(manifest_url) + return FrozenGraphModel(manifest_folder, manifest) + else: + raise NotImplementedError("SerializedModel Manifest type '{}' has not been implemented!".format(manifest.get('type'))) - (Currently only supports frozen graph models and relies on manifest.json file. - In the future we may want to support automatically detecting the type and - support loading more ways of saving models: tf.SavedModel, metagraphs, etc.) - """ + +class SerializedModel(Model): @classmethod def from_directory(cls, model_path, manifest_path=None): - + warnings.warn("SerializedModel is deprecated. Please use Model.load_from_manifest instead.", DeprecationWarning) if manifest_path is None: manifest_path = path.join(model_path, 'manifest.json') - - try: - manifest = load(manifest_path) - except Exception as e: - raise ValueError("Could not find manifest.json file in dir {}. Error: {}".format(model_path, e)) - - if manifest.get('type', 'frozen') == 'frozen': - return FrozenGraphModel(model_path, manifest) - else: # TODO: add tf.SavedModel support, etc - raise NotImplementedError("SerializedModel Manifest type '{}' has not been implemented!".format(manifest.get('type'))) + return Model.load_from_manifest(manifest_path) class FrozenGraphModel(SerializedModel): + _mandatory_properties = ['model_path', 'image_value_range', 'input_name', 'image_shape'] + def __init__(self, model_directory, manifest): self.manifest = manifest + + for mandatory_key in self._mandatory_properties: + # TODO: consider if we can tell you the path of the faulty manifest here + assert mandatory_key in manifest.keys(), "Mandatory property '{}' was not defined in json manifest.".format(mandatory_key) + for key, value in manifest.items(): + setattr(self, key, value) + model_path = manifest.get('model_path', 'graph.pb') if model_path.startswith("./"): # TODO: can we be less specific here? self.model_path = path.join(model_directory, model_path[2:]) else: self.model_path = model_path - self.labels_path = manifest.get('labels_path', None) - self.image_value_range = manifest.get('image_value_range') - self.image_shape = manifest.get('image_shape') - self.input_name = manifest.get('input_name') layers_or_layer_names = manifest.get('layers') if len(layers_or_layer_names) > 0: diff --git a/tests/conftest.py b/tests/conftest.py index 3bb66bdd..588234be 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,14 +1,18 @@ import pytest - -from lucid.modelzoo.vision_models import InceptionV1 +import tensorflow as tf @pytest.fixture -def inceptionv1(): - model = InceptionV1() - model.load_graphdef() - return model - +def minimodel(): + def inner(input=None, shape=(16,16,3)): + """Constructs a tiny graph containing one each of a typical input + (tf.placegholder), variable and typical output (softmax) nodes.""" + if input is None: + input = tf.placeholder(tf.float32, shape=shape, name="input") + w = tf.Variable(0.1, name="variable") + logits = tf.reduce_mean(w*input, name="output", axis=(0,1)) + return tf.nn.softmax(logits) + return inner # Add support for a slow tests marker: diff --git a/tests/fixtures/graphdef.pb b/tests/fixtures/graphdef.pb new file mode 100644 index 00000000..0b67e3e9 Binary files /dev/null and b/tests/fixtures/graphdef.pb differ diff --git a/tests/fixtures/minigraph.pb b/tests/fixtures/minigraph.pb new file mode 100644 index 00000000..e0285652 Binary files /dev/null and b/tests/fixtures/minigraph.pb differ diff --git a/tests/misc/io/test_loading.py b/tests/misc/io/test_loading.py index d033e7e1..3102e7a2 100644 --- a/tests/misc/io/test_loading.py +++ b/tests/misc/io/test_loading.py @@ -73,3 +73,9 @@ def test_load_json_with_file_handle(): with io.open(path, 'r') as handle: dictionary = load(handle) assert "key" in dictionary + + +def test_load_protobuf(): + path = "./tests/fixtures/graphdef.pb" + graphdef = load(path) + assert "int_val: 42" in repr(graphdef) diff --git a/tests/misc/io/test_saving.py b/tests/misc/io/test_saving.py index 7a90aa55..3c25f7a9 100644 --- a/tests/misc/io/test_saving.py +++ b/tests/misc/io/test_saving.py @@ -6,6 +6,7 @@ from lucid.misc.io.saving import save import os.path import io +import tensorflow as tf dictionary = {"key": "value"} @@ -104,3 +105,13 @@ def test_save_named_handle(): def test_unknown_extension(): with pytest.raises(ValueError): save({}, "test.unknown") + + +def test_save_protobuf(): + path = "./tests/fixtures/graphdef.pb" + _remove(path) + with tf.Graph().as_default() as graph: + a = tf.Variable(42) + graphdef = a.graph.as_graph_def() + save(graphdef, path) + assert os.path.isfile(path) diff --git a/tests/modelzoo/test_InceptionV1.py b/tests/modelzoo/test_InceptionV1.py index ccac0ab8..b283a35f 100644 --- a/tests/modelzoo/test_InceptionV1.py +++ b/tests/modelzoo/test_InceptionV1.py @@ -21,14 +21,12 @@ @pytest.mark.slow def test_InceptionV1_model_download(): model = InceptionV1() - model.load_graphdef() assert model.graph_def is not None @pytest.mark.slow def test_InceptionV1_graph_import(): model = InceptionV1() - model.load_graphdef() model.import_graph() nodes = tf.get_default_graph().as_graph_def().node node_names = set(node.name for node in nodes) diff --git a/tests/modelzoo/test_saveload.py b/tests/modelzoo/test_saveload.py new file mode 100644 index 00000000..4c45ba96 --- /dev/null +++ b/tests/modelzoo/test_saveload.py @@ -0,0 +1,20 @@ +import pytest +import tensorflow as tf + +from lucid.modelzoo.vision_base import Model +from lucid.modelzoo.vision_models import AlexNet + + +shape = (16,16,3) + +def test_Model_save(minimodel): + with tf.Session().as_default() as sess: + _ = minimodel() + sess.run(tf.global_variables_initializer()) + path = "./tests/fixtures/minigraph.pb" + Model.save(path, "input", ["output"], shape, [0,1]) + +def test_Model_load(): + path = "./tests/fixtures/minigraph.pb" + model = Model.load(path) + assert all(str(shape[i]) in repr(model.graph_def) for i in range(len(shape))) diff --git a/tests/modelzoo/test_vision_base.py b/tests/modelzoo/test_vision_base.py new file mode 100644 index 00000000..7b93e48b --- /dev/null +++ b/tests/modelzoo/test_vision_base.py @@ -0,0 +1,61 @@ +import pytest +import tensorflow as tf + +from lucid.modelzoo.vision_base import Model +from lucid.modelzoo.vision_models import AlexNet, InceptionV1, InceptionV3_slim, ResnetV1_50_slim + + +def test_suggest_save_args_happy_path(capsys, minimodel): + path = "./tests/fixtures/minigraph.pb" + + with tf.Graph().as_default() as graph, tf.Session() as sess: + _ = minimodel() + sess.run(tf.global_variables_initializer()) + + # ask for suggested arguments + inferred = Model.suggest_save_args() + # they should be both printed... + captured = capsys.readouterr().out # captures stdout + names = ["input_name", "image_shape", "output_names"] + assert all(name in captured for name in names) + #...and returned + + # check that these inferred values work + inferred.update(image_value_range=(0,1)) + Model.save(path, **inferred) + loaded_model = Model.load(path) + assert "0.100" in repr(loaded_model.graph_def) + + +def test_suggest_save_args_int_input(capsys, minimodel): + with tf.Graph().as_default() as graph, tf.Session() as sess: + image_t = tf.placeholder(tf.uint8, shape=(32, 32, 3), name="input") + input_t = tf.math.divide(image_t, tf.constant(255, dtype=tf.uint8), name="divide") + _ = minimodel(input_t) + sess.run(tf.global_variables_initializer()) + + # ask for suggested arguments + inferred = Model.suggest_save_args() + captured = capsys.readouterr().out # captures stdout + assert "DT_UINT8" in captured + assert inferred["input_name"] == "divide" + + +@pytest.mark.parametrize("model_class", [AlexNet, InceptionV1, InceptionV3_slim, ResnetV1_50_slim]) +def test_suggest_save_args_existing_graphs(capsys, model_class): + graph_def = model_class().graph_def + + if model_class == InceptionV1: # has flexible input shape, can't be inferred + with pytest.warns(UserWarning): + inferred = Model.suggest_save_args(graph_def) + else: + inferred = Model.suggest_save_args(graph_def) + + assert model_class.input_name == inferred["input_name"] + + if model_class != InceptionV1: + assert model_class.image_shape == inferred["image_shape"] + + layer_names = [layer.name for layer in model_class.layers] + for output_name in list(inferred["output_names"]): + assert output_name in layer_names diff --git a/tests/modelzoo/test_vision_models.py b/tests/modelzoo/test_vision_models.py index fa6caeb4..f55ceced 100644 --- a/tests/modelzoo/test_vision_models.py +++ b/tests/modelzoo/test_vision_models.py @@ -45,6 +45,7 @@ def test_consistent_namespaces(): assert difference in ('Model', 'Layer') or difference.startswith("__") +@pytest.mark.slow @pytest.mark.parametrize("name,model_class", models_map.items()) def test_model_properties(name, model_class): assert hasattr(model_class, "model_path") @@ -73,7 +74,6 @@ def test_model_layers_shapes(model_class): name = model_class.__name__ scope = "TestLucidModelzoo" model = model_class() - model.load_graphdef() with tf.Graph().as_default() as graph: model.import_graph(scope=scope) for layer in model.layers: diff --git a/tests/optvis/test_integration.py b/tests/optvis/test_integration.py index 8a5b6154..07e9b7de 100644 --- a/tests/optvis/test_integration.py +++ b/tests/optvis/test_integration.py @@ -4,12 +4,14 @@ import tensorflow as tf from lucid.optvis import objectives, param, render, transform +from lucid.modelzoo.vision_models import InceptionV1 @pytest.mark.slow @pytest.mark.parametrize("decorrelate", [True, False]) @pytest.mark.parametrize("fft", [True, False]) -def test_integration(decorrelate, fft, inceptionv1): +def test_integration(decorrelate, fft): + inceptionv1 = InceptionV1() obj = objectives.neuron("mixed3a_pre_relu", 0) param_f = lambda: param.image(16, decorrelate=decorrelate, fft=fft) rendering = render.render_vis( diff --git a/tests/optvis/test_objectives.py b/tests/optvis/test_objectives.py index 1408dd15..87155d1a 100644 --- a/tests/optvis/test_objectives.py +++ b/tests/optvis/test_objectives.py @@ -5,12 +5,20 @@ import tensorflow as tf import numpy as np from lucid.optvis import objectives, param, render, transform +from lucid.modelzoo.vision_models import InceptionV1 np.random.seed(42) NUM_STEPS = 3 + +@pytest.fixture +def inceptionv1(): + return InceptionV1() + + + def assert_gradient_ascent(objective, model, batch=None, alpha=False, shape=None): with tf.Graph().as_default() as graph, tf.Session() as sess: shape = shape or [1, 32, 32, 3] diff --git a/tests/recipes/activation_atlas.py b/tests/recipes/activation_atlas.py index 4d98dc59..74dfa55f 100644 --- a/tests/recipes/activation_atlas.py +++ b/tests/recipes/activation_atlas.py @@ -12,7 +12,6 @@ @pytest.mark.skip(reason="takes too long to complete on CI") def test_activation_atlas(): model = AlexNet() - model.load_graphdef() layer = model.layers[1] atlas = activation_atlas(model, layer, number_activations=subset) save(atlas, "tests/recipes/results/activation_atlas/atlas.jpg") @@ -21,11 +20,9 @@ def test_activation_atlas(): @pytest.mark.skip(reason="takes too long to complete on CI") def test_aligned_activation_atlas(): model1 = AlexNet() - model1.load_graphdef() layer1 = model1.layers[1] model2 = InceptionV1() - model2.load_graphdef() layer2 = model2.layers[8] # mixed4d atlasses = aligned_activation_atlas( diff --git a/tox.ini b/tox.ini index 3e70b4dc..20b1c56c 100644 --- a/tox.ini +++ b/tox.ini @@ -5,10 +5,10 @@ envlist = py{27,36} deps = tensorflow .[test] -commands = coverage run --source lucid --omit lucid/scratch/*,lucid/recipes/*,lucid/misc/gl/* -m py.test {posargs} +commands = coverage run --source lucid --omit lucid/scratch/*,lucid/recipes/*,lucid/misc/gl/* -m py.test --run-slow {posargs} [pytest] -addopts = --verbose --run-slow +addopts = --verbose testpaths = ./tests/ [flake8]