-
Notifications
You must be signed in to change notification settings - Fork 74.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Java API to get the size of output list operations #9640
Conversation
Can one of the admins verify this patch? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks very much for the PR! Much appreciated.
A couple of minor things, otherwise looks great.
@@ -70,6 +70,24 @@ public int numOutputs() { | |||
} | |||
} | |||
|
|||
/** | |||
* Given the name of an output producing a tensor list, return the | |||
* size of the list. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd suggest elaborating a bit here since the notion of a "name" might not be obvious. How about something like:
/**
* Returns the size of the list of Tensors produced by this operation.
*
* <p/>An Operation has multiple named outputs, each of which produces either
* a single tensor or a list of tensors. This method returns the size of
* the list of tensors for a specific output of the operation.
*
* @param name identifier of the list of tensors (of which there may be many) produced by this operation.
* @returns the size of the list of Tensors produced by this operation.
* @throws ...
public void outputListLength() { | ||
try (Graph g = new Graph()) { | ||
|
||
checkSplit(g, "t1", new int[] {0, 1}, 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be better to have a helper function return the integer value and then the assert
call can be in the test, that way the line number rendered on test failures is more useful and you don't need to use the string names to locate the failure. Something like:
public void outputListLength() {
assertEquals(1, split(new int[]{0, 1}, 1));
assertEquals(2, split(new int[]{0, 1}, 2));
assertEquals(3, split(new int[]{0,1,2}, 3));
}
@Test
public void outputListLengthFailsOnInvalidName() {
...
}
private int split(int[] values, int num_split) {
try (Graph g = new Graph()) {
return g.opBuilder("Split", "Split")
.addInput(TestUtil.constant(g, "split_dim", 0))
.addInput(TestUtil.constant(g, "values", values)
.setAttr("num_split", num_split)
.build()
.outputListLength("output");
}
}
great suggestions @asimshankar - thanks, and I've incorporated them wholesale. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
@tensorflow-jenkins test this please |
…hAddMemcpy… Imported from GitHub PR openxla/xla#9640 …Node1D It seems that setting up the params structure is inconsistent across rocm versions so fall back to more stable 1D variant. Minor cleanup of "unwrapped" hip runtime calls. Should fix openxla/xla#8692 Copybara import of the project: -- 7de30f1c2fa05941dd2fb59d70e43466145a1ad6 by Dragan Mladjenovic <Dragan.Mladjenovic@amd.com>: [ROCm] Replace usages of hipGraphAddMemcpyNode with hipGraphAddMemcpyNode1D It seems that setting up the params structure is inconsistent across rocm versions so fall back to more stable 1D variant. Minor cleanup of "unwrapped" hip runtime calls. Merging this change closes #9640 PiperOrigin-RevId: 609124252
This opens up access to the TF_OperationOutputListLength C API from java, for operations that return output lists.