-
Notifications
You must be signed in to change notification settings - Fork 74k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Java] Support addition of gradient operations in a graph #20133
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a bunch!
* }</pre> | ||
*/ | ||
@Operator | ||
public class AddGradients implements Op, Iterable<Operand<?>> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense for this to be called Gradients
instead of AddGradients
? (We don't call the other Op
implementations AddConstant
or AddMatMul
)
* @param dx | ||
* @return the partial derivatives {@code dy} with the size of {@code x} | ||
*/ | ||
public Output<?>[] addGradients(Output<?>[] y, Output<?>[] x, Output<?>[] dx) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The common case would be a single y
and null
for dx
, so should we add another method for that common case? Something like:
public Output<?>[] addGradients(Output<?> y, Output<?> ...x)
And similarly for the Op
implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds very good to me. Would it matter if we do that modification only in the Op
class? I kind of see that this kind of interface optimization is the responsibility of the Ops API layer while the core classes focuses more on the implementation details.
I also worked on another Op
that adds gradients nodes to the graph and immediately apply the descent on the input tensors instead of doing this manually in n
steps, should I also go ahead with this one (perhaps in another PR)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't feel terribly strongly, but have a mild preference to include the simple override in theGraph
class as well.
Is the second Op
you mentioned basically the equivalent of Optimizer.minimize()
in Python? We can talk about that in a follow up, but as you suggested, not do that in this PR.
* If {@code dx} is null, the implementation will use dx of {@code OnesLike} for all | ||
* shapes in {@code y}. | ||
* | ||
* @param y |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fill in the documentation for these arguments?
long[] dxHandles = null; | ||
int[] dxIndices = null; | ||
|
||
for (int i = 0; i < y.length; ++i) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should the body of this function be enclosed in a:
Reference ref = ref();
try {
// The main body of this method
} finally {
ref.close();
}
so that it remains thread-safe (i.e., a concurrent call to Graph.close()
won't mess things up by rendering the elements of ?Handles
invalid)
// e.g. given xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain | ||
// dy = [dy0Handle, dy1Handle, ..., dy0Index, dy1Index, ...] | ||
long[] dyHandlesAndIndices; | ||
synchronized (nativeHandleLock) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we do the ref()
block as suggested above, then we don't need this but instead would use ref.nativeHandle()
for the call to addGradients
@Test | ||
public void addGradientsComputationOpsToGraph() { | ||
try (Graph g = new Graph()) { | ||
Output<Integer> a = TestUtil.constant(g, "A", new int[][] {{1},{2}}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use Float
instead of Integer
as integer gradients are a bit iffy (in fact, we don't backprop through integer tensors in Python anymore - f637506, yes, yes, I know, it would be great if we were able to share that logic more easily :) But for now, let's at least have the example use a more realistic Float
?
Output<Integer> ab = TestUtil.matmul(g, "AxB", a, b, false, false); | ||
Output<Integer> abc = TestUtil.matmul(g, "AxBxC", ab, c, false, false); | ||
|
||
Output<?>[] grad = g.addGradients(new Output<?>[] {abc}, new Output<?>[] {b, c}, null); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we improve test coverage here, perhaps by adding to SessionTest.java
so that we also test the additional arguments (like dys
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update, apologies for the delay. A few more minor things, otherwise looks great!
.fetch(grads[0]) | ||
.fetch(grads[1]) | ||
.run(); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To encourage best practices, should we also call .close()
on the elements of outputs
?
|
||
Output<?>[] grads = g.addGradients(toArray(y0, y1), toArray(x), null); | ||
|
||
List<Tensor<?>> outputs = s.runner() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here:
try (Tensor<?> t = s.runner()....get(0)) {
assertEquals(114.0f, t.floatValue(), 0.0f);
}
?
* @param dx if not null, the partial derivatives of some loss function {@code L} w.r.t. {@code y} | ||
* @return the partial derivatives {@code dy} with the size of {@code x} | ||
*/ | ||
public Output<?>[] addGradients(Output<?>[] y, Output<?>[] x, Output<?>[] dx) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it really ?
? Don't all of the tensors have to have the same type? In which case, should this be:
public Output<T>[] addGradients(Output<T>[] y, Output<T>[] x, Output<T>[] dx)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wasn't sure about this, you probably know better than me: is it possible to have a graph with variables of different types? If it is guaranteed that could never happen, I'll gladly remove those wildcards to enforce type-safety.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm...it actually is, since one could have say tf.cast
s on the path between x
and y
, something like:
import tensorflow as tf
x = tf.placeholder(tf.float64)
y = tf.square(tf.cast(x, tf.float32))
dy = tf.gradients(y, x)[0]
print(x.dtype, y.dtype, dy.dtype)
So apologies, ignore my comment :)
|
||
Output<?>[] grads = g.addGradients(toArray(y), toArray(x), toArray(dx)); | ||
|
||
List<Tensor<?>> outputs = s.runner() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto about try-with-resources for the returned Tensor
.
@@ -36,7 +36,7 @@ | |||
.<T>output(0); | |||
} | |||
|
|||
public static Output<?> addN(Graph g, Output<?>... inputs) { | |||
public static <T> Output<T> addN(Graph g, Output<?>... inputs) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should the argument be Output<T>
instead of Output<?>
also?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately, the compiler complains with a warning when using a parameterized type for a varargs parameter.
// returned array contains both op handles and output indices, in pair | ||
jlongArray dy_handles_and_indices = env->NewLongArray(nx << 1); | ||
jlong* dy_elems = env->GetLongArrayElements(dy_handles_and_indices, nullptr); | ||
for (int i = 0, j = nx; i < nx; ++i, ++j) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In practice this probably won't matter at all since dy_elems
will be a pretty small array, but did you consider encoding it as [handle0, index0, handle1, index1, ...]
instead of [handle0, handle1, ..., index0, index1, ...]
? The former would be more cache friendly :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ironically, I was using the format you proposed in my previous commit. I switched to the new one lately because:
- that could be the new 'standard' for returning more than one value in a JNI binding, e.g. if we had 4 arrays to return, each of different length, the previous format won't work while the second will
- we might find some optimization later to simply split the array in two instead of copying its elements one by one
- I personally found it more elegant to iterate using two iterators at the same time :)
I don't remind reverting to the previous version if you want to.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's an internal detail, so we can easily switch it around as we go along. I don't feel very strongly, but am slightly partial to the previous version. Your call.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, so just for the sake of keeping that PR continuing smoothly, let's keep it that way for now
@@ -0,0 +1,153 @@ | |||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why training/
? Seems like gradients
could be a top level operation too? One can use gradients for things other than training :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No problem. Right now, the Op
classification is still obscure to me, I'm taking some guesses here and there but we will probably need to settle this at some point?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, the Python code was recently updated to generate symbols in finer grained namespaces (c1ff116#diff-d75a97d0f69d6f87ab79be6e2423d87b). So perhaps we could use the same for Java (without the baggage of backward compatibility. For example, Acos
can be just in the math
namespace and not in the top level one).
If you're enthusiastic, I'd be more than happy to see a PR adding Java API defs :)
8965aed
to
b7baff7
Compare
Windows build failure seems to be a flaky. Submitting now. |
This calls the C-api
TF_AddGradients
method through a new JNI binding for adding gradient nodes to a graph. It also includes anAddGradients
wrapper for invoking this operation smoothly while building a graph using the new Java Ops API.