Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load from saved model support #68

Merged
merged 5 commits into from
Mar 11, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions examples/regression_savedmodel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import tensorflow as tf
from tensorflow.python.saved_model.builder import SavedModelBuilder
from tensorflow.python.saved_model.signature_def_utils import build_signature_def
from tensorflow.python.saved_model.signature_constants import REGRESS_METHOD_NAME
from tensorflow.python.saved_model.tag_constants import TRAINING, SERVING
from tensorflow.python.saved_model.utils import build_tensor_info

x = tf.placeholder(tf.float32, name='x')
y = tf.placeholder(tf.float32, name='y')

w = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name='w')
b = tf.Variable(tf.zeros([1]), name='b')
y_hat = w * x + b

loss = tf.reduce_mean(tf.square(y_hat - y))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss, name='train')

init = tf.variables_initializer(tf.global_variables(), name='init')

directory = 'examples/saved-regression-model'
builder = SavedModelBuilder(directory)

with tf.Session(graph=tf.get_default_graph()) as sess:
sess.run(init)

signature_inputs = {
"x": build_tensor_info(x),
"y": build_tensor_info(y)
}
signature_outputs = {
"out": build_tensor_info(y_hat)
}
signature_def = build_signature_def(
signature_inputs, signature_outputs,
REGRESS_METHOD_NAME)
builder.add_meta_graph_and_variables(
sess, [TRAINING, SERVING],
signature_def_map={
REGRESS_METHOD_NAME: signature_def
},
assets_collection=tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS))
builder.save(as_text=False)
100 changes: 100 additions & 0 deletions examples/regression_savedmodel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
extern crate random;
extern crate tensorflow;

use random::Source;
use std::error::Error;
use std::result::Result;
use std::path::Path;
use std::process::exit;
use tensorflow::Code;
use tensorflow::Graph;
use tensorflow::Session;
use tensorflow::SessionOptions;
use tensorflow::Status;
use tensorflow::StepWithGraph;
use tensorflow::Tensor;

fn main() {
// Putting the main code in another function serves two purposes:
// 1. We can use the try! macro.
// 2. We can call exit safely, which does not run any destructors.
exit(match run() {
Ok(_) => 0,
Err(e) => {
println!("{}", e);
1
}
})
}

fn run() -> Result<(), Box<Error>> {
let export_dir = "examples/saved-regression-model"; // y = w * x + b
if !Path::new(export_dir).exists() {
return Err(Box::new(Status::new_set(Code::NotFound,
&format!("Run 'python regression_savedmodel.py' to generate \
{} and try again.",
export_dir))
.unwrap()));
}

// Generate some test data.
let w = 0.1;
let b = 0.3;
let num_points = 100;
let steps = 201;
let mut rand = random::default();
let mut x = Tensor::new(&[num_points as u64]);
let mut y = Tensor::new(&[num_points as u64]);
for i in 0..num_points {
x[i] = (2.0 * rand.read::<f64>() - 1.0) as f32;
y[i] = w * x[i] + b;
}

// Load the saved model exported by regression_savedmodel.py.
let mut graph = Graph::new();
let mut session = Session::from_saved_model(&SessionOptions::new(),
&["train", "serve"],
&mut graph,
export_dir)?;
let op_x = graph.operation_by_name_required("x")?;
let op_y = graph.operation_by_name_required("y")?;
let op_train = graph.operation_by_name_required("train")?;
let op_w = graph.operation_by_name_required("w")?;
let op_b = graph.operation_by_name_required("b")?;

// Train the model (e.g. for fine tuning).
let mut train_step = StepWithGraph::new();
train_step.add_input(&op_x, 0, &x);
train_step.add_input(&op_y, 0, &y);
train_step.add_target(&op_train);
for _ in 0..steps {
try!(session.run(&mut train_step));
}

// Grab the data out of the session.
let mut output_step = StepWithGraph::new();
let w_ix = output_step.request_output(&op_w, 0);
let b_ix = output_step.request_output(&op_b, 0);
try!(session.run(&mut output_step));

// Check our results.
let w_hat: f32 = try!(output_step.take_output(w_ix)).data()[0];
let b_hat: f32 = try!(output_step.take_output(b_ix)).data()[0];
println!("Checking w: expected {}, got {}. {}",
w,
w_hat,
if (w - w_hat).abs() < 1e-3 {
"Success!"
} else {
"FAIL"
});
println!("Checking b: expected {}, got {}. {}",
b,
b_hat,
if (b - b_hat).abs() < 1e-3 {
"Success!"
} else {
"FAIL"
});
Ok(())
}
71 changes: 56 additions & 15 deletions src/session.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use tf;
use libc::c_int;
use libc::{c_char, c_int};
use std::ffi::CString;
use std::marker;
use std::path::Path;
use std::ptr;
use super::Code;
use super::DataType;
Expand Down Expand Up @@ -32,6 +34,45 @@ impl Session {
}
}

/// Loads a session from an exported model.
pub fn from_saved_model<P: AsRef<Path>, Tag: AsRef<str>, Tags: IntoIterator<Item = Tag>>
(options: &SessionOptions,
tags: Tags,
graph: &mut Graph,
export_dir: P)
-> Result<Self> {
let mut status = Status::new();

let export_dir_cstr =
try!(export_dir.as_ref()
.to_str()
.and_then(|s| CString::new(s.as_bytes()).ok())
.ok_or_else(|| invalid_arg!("Invalid export directory path")));

let tags_cstr: Vec<_> = try!(tags.into_iter()
.map(|t| CString::new(t.as_ref()))
.collect::<::std::result::Result<_, _>>()
.map_err(|_| invalid_arg!("Invalid tag name")));
// keeping tags_cstr to retain strings in memory
let tags_ptr: Vec<*const c_char> = tags_cstr.iter().map(|t| t.as_ptr()).collect();

let inner = unsafe {
tf::TF_LoadSessionFromSavedModel(options.inner,
ptr::null(),
export_dir_cstr.as_ptr(),
tags_ptr.as_ptr(),
tags_ptr.len() as c_int,
graph.inner(),
ptr::null_mut(),
status.inner())
};
if inner.is_null() {
Err(status)
} else {
Ok(Session { inner: inner })
}
}

/// Closes the session.
pub fn close(&mut self) -> Result<()> {
let mut status = Status::new();
Expand Down Expand Up @@ -143,19 +184,19 @@ impl<'l> StepWithGraph<'l> {
index: c_int,
tensor: &'l Tensor<T>) {
self.input_ports.push(tf::TF_Output {
oper: operation.inner(),
index: index,
});
oper: operation.inner(),
index: index,
});
self.input_tensors.push(tensor.inner);
}

/// Requests that an output is fetched from the graph after running this step.
/// Returns an index that you can then use to fetch this output from the step after running it.
pub fn request_output(&mut self, operation: &Operation, index: c_int) -> OutputToken {
self.output_ports.push(tf::TF_Output {
oper: operation.inner(),
index: index,
});
oper: operation.inner(),
index: index,
});
self.output_tensors.push(ptr::null_mut());
OutputToken { index: self.output_tensors.len() - 1 }
}
Expand All @@ -172,13 +213,13 @@ impl<'l> StepWithGraph<'l> {
{}",
output_idx,
self.output_tensors.len()))
.unwrap());
.unwrap());
}
if self.output_tensors[output_idx].is_null() {
return Err(Status::new_set(Code::Unavailable,
"Output not available. Either it was already taken, or \
this step has not been sucessfully run yet.")
.unwrap());
.unwrap());
}
let actual_data_type = self.output_data_type(output_idx).unwrap();
if actual_data_type != T::data_type() {
Expand Down Expand Up @@ -260,13 +301,13 @@ mod tests {
let y = {
let mut nd = g.new_operation("Mul", "y").unwrap();
nd.add_input(Output {
operation: &two,
index: 0,
});
operation: &two,
index: 0,
});
nd.add_input(Output {
operation: &x,
index: 0,
});
operation: &x,
index: 0,
});
nd.finish().unwrap()
};
let options = SessionOptions::new();
Expand Down