Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load from saved model support #68

Merged
merged 5 commits into from
Mar 11, 2017
Merged

Conversation

Enet4
Copy link
Contributor

@Enet4 Enet4 commented Mar 9, 2017

This PR exposes TensorFlow's native capability of loading saved model bundles from a directory.

  • Added TF_LoadSessionFromSavedModel function to tensorflow-sys;
  • Added function Session::from_saved_model, which provides a safe generic API with the minimum arguments required. In the future, one might consider adding alternative functions that would let users specify run_options and retrieve meta_graph_def.

This feature will hopefully make loading pre-trained models in Rust more accessible. Please let me know if you would like a complete example or an integration test. Some of the examples found in this repository should be easily adjusted to test this feature.

Example of use:

let mut graph = Graph::new();
println!("Loading session...");
let sess = Session::from_saved_model(
    &SessionOptions::new(), &["serve"],
    &mut graph,
    "test-model").unwrap();

println!("Loaded model with {} operations.", graph.operation_iter().count());

- add TF_LoadSessionFromSavedModel function to tensorflow-sys
- add function in Session that provides safe loading from a saved model
src/session.rs Outdated
use std::ptr;
use std::result::Result as StdResult;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm I find this confusing, but if @adamcrume is good with it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I just needed a hint on line 56 to pluck the result out, and this approach sounded elegant to me. We could consider something else though...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In other places we're naming std::result::Result directly without importing it, like so:

pub type Result<T> = std::result::Result<T, Status>;

I'd prefer that, just for consistency. On a side note, I regret adding that type alias. I assumed it was an accepted pattern because of std::fmt::Result, std::io::Result, and std::thread::Result, but having multiple types with the same name (just in different modules) causes no end of headaches.

src/session.rs Outdated
let tags_ptr: Vec<*const c_char> = tags_cstr.iter().map(|t| t.as_ptr()).collect();

let inner = unsafe {
tf::TF_LoadSessionFromSavedModel(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

formatting looks a bit off here, can you run it through rustfmt please?

@daschl
Copy link
Contributor

daschl commented Mar 9, 2017

@Enet4 I'd love to see at least an example on this, since we are lacking them anyways and it also helps other people onboard more quickly

@Enet4
Copy link
Contributor Author

Enet4 commented Mar 9, 2017

@daschl All right, I have formatted the code (my bad!) and added an example based on regression, which was already available. Still, I'm open to file renames or other tweaks.

@daschl
Copy link
Contributor

daschl commented Mar 9, 2017

@Enet4 very cool, thanks! Of course @adamcrume has the final say on this ;)

Copy link
Contributor

@adamcrume adamcrume left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly looks good, with a few tweaks. Please run the code through rustfmt. Also, thanks for the example code; we can always use more examples and more tests.

src/session.rs Outdated
use std::ptr;
use std::result::Result as StdResult;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In other places we're naming std::result::Result directly without importing it, like so:

pub type Result<T> = std::result::Result<T, Status>;

I'd prefer that, just for consistency. On a side note, I regret adding that type alias. I assumed it was an accepted pattern because of std::fmt::Result, std::io::Result, and std::thread::Result, but having multiple types with the same name (just in different modules) causes no end of headaches.

src/session.rs Outdated
.to_str()
.and_then(|s| CString::new(s.as_bytes()).ok())
.ok_or_else(|| {
Status::new_set(Code::InvalidArgument, "Invalid export directory path").unwrap()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use the invalid_arg! macro for this.

src/session.rs Outdated
.map(|t| CString::new(t.as_ref()))
.collect::<StdResult<_, _>>()
.map_err(|_| {
Status::new_set(Code::InvalidArgument, "Invalid tag name").unwrap()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

invalid_arg! again.

src/session.rs Outdated
let inner = unsafe {
tf::TF_LoadSessionFromSavedModel(options.inner,
ptr::null(),
export_dir_cstr.to_bytes_with_nul().as_ptr() as
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return value of to_bytes_with_nul doesn't live long enough. You can just use export_dir_cstr.as_ptr(), since it already guarantees a null terminator.

- use invalid_arg!
- remove std result alias
- fix getting char pointer to export_dir
@Enet4
Copy link
Contributor Author

Enet4 commented Mar 10, 2017

I still wonder how I got the tag conversion right and then screwed up on the export_dir string... 🤔 Nevertheless, the changes were made. :)

@adamcrume adamcrume merged commit 8bd62df into tensorflow:master Mar 11, 2017
@adamcrume
Copy link
Contributor

Thanks!

@Enet4 Enet4 deleted the export-saved-model branch March 13, 2017 18:18
@jhseu
Copy link

jhseu commented Mar 15, 2017

Note that we're preferring to add a SavedModelBundle wrapper for the return value in other languages. You need the MetaGraphDef to extract out the signatures in the SavedModel.

@Enet4
Copy link
Contributor Author

Enet4 commented Mar 15, 2017

@jhseu Admittedly, I knew that at least the Java bindings would be doing it with a bundle (tensorflow/tensorflow#7134), but I had found no reason to replicate that design here, at the time. But given that the meta-graph is still unreachable, I agree that we should still seek to improve this saved model API. Also, a proper MetaGraph abstraction would probably be nicer than just retrieving a byte buffer.

@Enet4 Enet4 mentioned this pull request Jul 3, 2017
ramon-garcia pushed a commit to ramon-garcia/tensorflow-rust that referenced this pull request May 20, 2023
* added new row indexer for parquet data frame

* updated all tests and code to use DateTimeOffset

* added logical JSON type

* added new dataset handling of rows through pivoting

* Update PlainValuesReader.cs

* built more single responsibility around ParquetReader type to ensure efficient deallocation of resources using IDisposable

* updated reader to look at nulls

* added branches to set type IList as either nullable or non-nullable and done this against the required attribute on the column header

* moved BigDecimal to own file
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants