-
Notifications
You must be signed in to change notification settings - Fork 427
Eager API support #326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Eager API support #326
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I haven't been able to review everything, yet, but I figured I'd respond with the comments I had. This is a very large change, and it'd be easier to review if broken up into smaller PRs.
src/eager/mod.rs
Outdated
status.inner(), | ||
); | ||
} | ||
if status.is_ok() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This if/else can be simplified to just status.into_result()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed as you pointed out.
/// }; | ||
/// ``` | ||
#[derive(Debug)] | ||
pub struct TensorHandle<'a> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any reason not to parameterize this based on the tensor type, i.e. TensorHandle<'a, T: TensorType>
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If TensorHandle were parameterized, all return types from raw operations also needed to be parameterized, but it was difficult to write such a generic code for me.
In more detail, the output tensor type depends on input type for raw operation functions, which makes difficult to parameterize at compile time.
|
||
/// Return a number of dimensions. | ||
/// This function will block till the operation that produces the TensorHandle has completed. | ||
pub fn num_dims(&self) -> Result<i32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These should use u64
instead of i64
for consistency with Tensor
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I can leave the return type as is, since the following code uses the same type (while it is referred to as c_int
). This code is in graph.rs.
/// Returns the number of dimensions of the Tensor referenced by `output`.
///
/// If the number of dimensions in the shape is unknown, returns -1.
///
/// Returns an error if:
///
/// * `output` is not in `graph`.
pub fn num_dims<I: Into<Output>>(&self, output: I) -> Result<c_int> {
src/eager/mod.rs
Outdated
unsafe { DataType::from_c(tf::TFE_TensorHandleDataType(self.inner)) } | ||
} | ||
|
||
/// Return a number of dimensions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: "the" number of dimensions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
src/eager/mod.rs
Outdated
} | ||
} | ||
|
||
/// Return a number of elements |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: "the" number of elements, and missing a period.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
src/eager/mod.rs
Outdated
|
||
let h = raw_ops::add(&ctx, x, y).unwrap(); | ||
let z: Result<Tensor<i32>> = h.resolve(); | ||
assert!(z.is_ok()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The unwrap()
would fail the test anyway if it weren't ok
, so this assertion isn't really needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
src/eager/mod.rs
Outdated
let z: Tensor<i32> = h.resolve().unwrap(); | ||
assert_eq!(z[0], 8i32); | ||
|
||
let h1 = z.clone().to_handle(&ctx).unwrap(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is z
being cloned before calling to_handle
? to_handle
doesn't consume the tensor, unless I'm misunderstanding something.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is because I first write the ToHandle trait to consume the input Tensor, but switched not to do so later on, I think... I'll check it out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation does not require the clone
as you pointed out. The clone
s were removed.
src/eager/mod.rs
Outdated
let z: Tensor<String> = h.resolve().unwrap(); | ||
assert_eq!(z.len(), 1); | ||
assert_eq!(z[0].len(), 32); | ||
assert_eq!(z[0], "This a sample text for unittest.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These three assertions could be simplified to something like:
assert_eq!(&z[..], &["This is a sample text for unittest.".to_string()]);
src/eager/mod.rs
Outdated
|
||
let devices = ctx.device_list().unwrap(); | ||
assert!(devices.len() > 0); | ||
for d in devices.iter() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for d in &devices {
would be more idiomatic.
src/eager/mod.rs
Outdated
let t = Tensor::from(0); | ||
let h = t.to_handle(&ctx).unwrap(); | ||
let v: Result<Tensor<f32>> = h.resolve(); | ||
assert!(v.is_err()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is redundant with the assert!(false)
below.
Thank you for your so much feedbacks. I had to say that I am very sorry that I did not add an explanation in this pull request. After checking your comments, I will break down this large PR into small pieces. |
I'd like to mention this PR in the contributor summit. Are you fine with that? |
Of course it's fine! Thank you. |
I think I've fixed most of the things that can be easily fixed, so I'm going to divide the PR into smaller parts based on the content here. |
Closes since this is progressing in the following smaller PRs: |
resolves #321