Experimental API for saving and loading models #152
Conversation
Pull Request Test Coverage Report for Build 489
💛 - Coveralls |
As a novice lucid user I tried this out and saving my model was pretty straightforward. I've not yet tried doing anything with the saved model, but nothing went obviously wrong. I did find the A minor issue I had is that |
Thanks for the feedback, Jacob!
Yeah, it's totally weird. I think my main concern about the
Thanks for raising this! Will fix. :) |
Another way of doing it that avoids generating code would be to have a |
@colah For me, this API looks very good improvement from old method of loading. Thanks for this! Here are my thoughts :)
This API looks much easier to use than the previous method of importing custom model! Thanks for making this so simple!
Then
This is very simple and just
If i understand correctly, we have to follow the old method to import this model or write a script to load this model and save it with
Overall, this change is very encouraging to try out my own model and visualize with lucid. |
I had an issue where my input was originally a uint8 placeholder - talked to @colah about this. |
@colah I took a shot at how I'd integrate this into our existing API. path = ...
inferred = Model.suggest_save_args() # throws if one arg can't be inferred
Model.save(path, **inferred, image_value_range=(0,1))
reloaded_model = Model.load(path) This unifies the APIs for loading to the following static methods:
I'll still fix Python 2 and remove some commented out code, so please don't merge yet. :-) |
Also runs slow tests for coverage again.
@jacobhilton @colah what was the suggested solution in your case? I'll happily write a test for this case, I imagine that might be common. :-) |
@colah from 523d4dd
|
The biggest pain point of using lucid seems to be importing models. Most users wish to visualize their own models, and need to get them into a format lucid can use. Unfortunately, this presently involves several steps, which can be unintuitive:
modelzoo.Model
class, filling in values likeimage_value_range
This PR proposes an alternate API where there is only one, clearly defined step to preparing your model for use in Lucid. This import path is only for TensorFlow users.
We assume that the user can construct an inference graph of their model. This should be easy for anyone training models (because they will have one for tracking accuracy) and for anyone using a model for inference.
A this point, the user simply calls the
save_model()
function.If the user successfully does this with the correct arguments, they should be done. All metadata is baked into the save model (more on this later), meaning it never needs to be specified again.
To use a model in lucid, the user simple does:
And they're ready to go!
Suggesting save code for users
The above import path is still unnecessarily painful, since we can often infer many of the arguments to
save_model()
with high confidence.I considered just making
save_model()
have optional arguments and try to infer unspecified arguments if possible. However, this struck me as an API that would lead to user confusion, since we can't always infer these arguments, aren't completely certain, and some arguments just can't be inferred.Instead, I went with a
suggest_save_code()
function, which is simply invoked:If our heuristics can determine arguments, it will print out something like the following:
In other cases, when it can't be inferred, the output will be something like this, giving the user a n empty template to fill in.
Input Range
My biggest concern with this API -- and any other I can think of -- is misspecified
image_value_range
s. Unlike other errors, this will not cause visualization to fail. Instead, it will cause bad or incorrect visualizations to be produced, failing silently.There isn't any reasonable way for us to catch these errors. (My best thought would be to test accuracy on ImageNet for different common ranges.)
At the moment, trying to warn users that this is a common silent failure mode seems like the best bet.
Storing Metadata
This API depends on us being able to save metadata along with the graph. This is critical to making it into "one save step" instead of two separate steps of saving and then later specifying metadata on import.
For this code, I do something kind of evil to accomplish this. I inject a
tf.constant
op into the graph namedlucid_meta_json
, containing a json blob of metadata. On import, we can detect and extract this node to recover meta data. This is a completely legal TensorFlow graph! But not really being used as intended...In the future, we might be able to switch over to
SavedModel
and use their actual supported mechanisms for specifying metadata.