Skip to content
This repository has been archived by the owner on Apr 10, 2024. It is now read-only.

Experimental API for saving and loading models #152

Merged
merged 11 commits into from Apr 18, 2019
Merged

Conversation

colah
Copy link
Contributor

@colah colah commented Mar 27, 2019

The biggest pain point of using lucid seems to be importing models. Most users wish to visualize their own models, and need to get them into a format lucid can use. Unfortunately, this presently involves several steps, which can be unintuitive:

  1. Save a graph
  2. Convert it into a frozen graph
  3. Write a Lucid modelzoo.Model class, filling in values like image_value_range

This PR proposes an alternate API where there is only one, clearly defined step to preparing your model for use in Lucid. This import path is only for TensorFlow users.

We assume that the user can construct an inference graph of their model. This should be easy for anyone training models (because they will have one for tracking accuracy) and for anyone using a model for inference.

A this point, the user simply calls the save_model() function.

# Run this code with inference graph in default graph and session
save_model(
    save_path    = 'gs://.../test.pb',  # Local paths are also fine!
    input_name   = 'input',
    output_names = ['prob/Softmax'],
    image_shape  = [224, 224, 3],
    image_value_range = [0,1]
  )

If the user successfully does this with the correct arguments, they should be done. All metadata is baked into the save model (more on this later), meaning it never needs to be specified again.

To use a model in lucid, the user simple does:

model = load_model('gs://.../test.pb')

And they're ready to go!

Suggesting save code for users

The above import path is still unnecessarily painful, since we can often infer many of the arguments to save_model() with high confidence.

I considered just making save_model() have optional arguments and try to infer unspecified arguments if possible. However, this struck me as an API that would lead to user confusion, since we can't always infer these arguments, aren't completely certain, and some arguments just can't be inferred.

Instead, I went with a suggest_save_code() function, which is simply invoked:

# With inference graph in default graph
suggest_save_code()

If our heuristics can determine arguments, it will print out something like the following:

>>> suggest_save_code()

# Infered: input_name = input  (only Placeholder)
# Infered: image_shape = [224, 224, 3]
# Infered: output_names = ['prob/Softmax']  (Softmax ops)

# Sanity check all inferred values before using this code!
save_model(
    save_path    = 'gs://save/model.pb', # TODO: replace
    input_name   = 'input',
    output_names = ['prob/Softmax'],
    image_shape  = [224, 224, 3],
    image_value_range =                  # TODO (eg. [0, 1], [0, 255], [-117, 138] )
    # WARNING: Incorrect `image_value_range` is the most common cause of feature 
    #     visualization bugs! It will fail silently with incorrect visualizations!
  )

In other cases, when it can't be inferred, the output will be something like this, giving the user a n empty template to fill in.

>>> suggest_save_code()

# Sanity check all inferred values before using this code!
save_model(
    save_path    = 'gs://save/model.pb', # TODO: replace
    input_name   =   ,                   # TODO (eg. 'input' )
    output_names = [ ],                  # TODO (eg. ['logits'] )
    image_shape  =   ,                   # TODO (eg. [224, 224, 3] )
    image_value_range =                  # TODO (eg. [0, 1], [0, 255], [-117, 138] )
    # WARNING: Incorrect `image_value_range` is the most common cause of feature 
    #     visualization bugs! It will fail silently with incorrect visualizations!
  )

Input Range

My biggest concern with this API -- and any other I can think of -- is misspecified image_value_ranges. Unlike other errors, this will not cause visualization to fail. Instead, it will cause bad or incorrect visualizations to be produced, failing silently.

There isn't any reasonable way for us to catch these errors. (My best thought would be to test accuracy on ImageNet for different common ranges.)

At the moment, trying to warn users that this is a common silent failure mode seems like the best bet.

Storing Metadata

This API depends on us being able to save metadata along with the graph. This is critical to making it into "one save step" instead of two separate steps of saving and then later specifying metadata on import.

For this code, I do something kind of evil to accomplish this. I inject a tf.constant op into the graph named lucid_meta_json, containing a json blob of metadata. On import, we can detect and extract this node to recover meta data. This is a completely legal TensorFlow graph! But not really being used as intended...

In the future, we might be able to switch over to SavedModel and use their actual supported mechanisms for specifying metadata.

@coveralls
Copy link

coveralls commented Mar 28, 2019

Pull Request Test Coverage Report for Build 489

  • 146 of 170 (85.88%) changed or added relevant lines in 7 files are covered.
  • 2 unchanged lines in 1 file lost coverage.
  • Overall coverage increased (+0.9%) to 77.715%

Changes Missing Coverage Covered Lines Changed/Added Lines %
lucid/modelzoo/util.py 52 53 98.11%
lucid/misc/io/saving.py 3 6 50.0%
lucid/modelzoo/vision_base.py 82 102 80.39%
Files with Coverage Reduction New Missed Lines %
lucid/modelzoo/vision_base.py 2 70.0%
Totals Coverage Status
Change from base Build 465: 0.9%
Covered Lines: 1639
Relevant Lines: 2109

💛 - Coveralls

@jacobhilton
Copy link
Contributor

As a novice lucid user I tried this out and saving my model was pretty straightforward. I've not yet tried doing anything with the saved model, but nothing went obviously wrong.

I did find the suggest_save_code interface to be a little strange. More natural to me would be a function called something like save_model_try_infer, which is a wrapper around save_model that does its best to infer any unspecified arguments, raising an exception if there is something that it can't infer.

A minor issue I had is that json.dumps doesn't support numpy dtypes (try json.dumps(np.array([1])[0]) for example), which is a bit of a shame but probably not worth worrying about.

@colah
Copy link
Contributor Author

colah commented Apr 5, 2019

Thanks for the feedback, Jacob!

I did find the suggest_save_code interface to be a little strange. More natural to me would be a function called something like save_model_try_infer, which is a wrapper around save_model that does its best to infer any unspecified arguments, raising an exception if there is something that it can't infer.

Yeah, it's totally weird.

I think my main concern about the save_model_try_infer API is that our heuristics may be wrong. That makes me kind of want to get explicit user sign-off on inferred things, if that makes sense? But maybe that's naive of me, and users will just OK whatever it suggests without looking.

A minor issue I had is that json.dumps doesn't support numpy dtypes (try json.dumps(np.array([1])[0]) for example), which is a bit of a shame but probably not worth worrying about.

Thanks for raising this! Will fix. :)

@jacobhilton
Copy link
Contributor

Another way of doing it that avoids generating code would be to have a suggest_save_kwargs function that returns a dict that you can pass as kwargs to save_model. I can see the benefits of the current interface though.

@hegman12
Copy link

hegman12 commented Apr 5, 2019

@colah For me, this API looks very good improvement from old method of loading. Thanks for this! Here are my thoughts :)

  1. From Usability point of view:

This API looks much easier to use than the previous method of importing custom model! Thanks for making this so simple!

import lucid
<Code to train your model>
.
.
.
#save model
lucid.save_model(....)

Then

#load model
lucid.load_model(...)
  1. From Use case point of view:
  • I have access to the code to train the model

This is very simple and just save_model and load_model will take care of most of the work! Makes life easier!

  • I got the model from internet as as inference graph(pb). This model is not there in model zoo.

If i understand correctly, we have to follow the old method to import this model or write a script to load this model and save it with save_model ? Because this graph will not have lucid_meta_json node so cannot be used with load_model API. I think this could be improvement point in future to make the model loading more easier. My thought is load_model may not be right place to provide this functionality and new API may be required.

  • I need to visualize a model which is in modelzoo
    It was always easy to do this! Just use the existing API to import and load.

Overall, this change is very encouraging to try out my own model and visualize with lucid.

@jacobhilton
Copy link
Contributor

I had an issue where my input was originally a uint8 placeholder - talked to @colah about this.

@ludwigschubert
Copy link
Contributor

@colah I took a shot at how I'd integrate this into our existing API.
You can look at the tests I wrote for usage examples, but the gist is:

path = ...
inferred = Model.suggest_save_args()  # throws if one arg can't be inferred
Model.save(path, **inferred, image_value_range=(0,1))
reloaded_model = Model.load(path)

This unifies the APIs for loading to the following static methods:

Model.load(graphdef_url)
Model.load_from_metadata(graphdef_url, metadata)
Model.load_from_manifest(manifest_url)

I'll still fix Python 2 and remove some commented out code, so please don't merge yet. :-)

Ludwig Schubert added 2 commits April 12, 2019 18:17
@ludwigschubert
Copy link
Contributor

@jacobhilton @colah what was the suggested solution in your case? I'll happily write a test for this case, I imagine that might be common. :-)

@ludwigschubert
Copy link
Contributor

ludwigschubert commented Apr 15, 2019

@colah from 523d4dd Model.suggest_save_args() no longer throws. Instead we write to stdout as in your original proposal. Output:

Inferred: input_name = input (because it was the only Placeholder in the graph_def)
Inferred: image_shape = [16, 16, 3]
Inferred: output_names = ['Softmax']  (because those are all the Softmax ops)
# Please sanity check all inferred values before using this code!
Model.save(
    input_name='input',
    image_shape=[16, 16, 3],
    output_names=['Softmax'],
    image_value_range=_,   # TODO (eg. '[-1, 1], [0, 1], [0, 255], or [-117, 138]')
  )

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants