Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Java] Support addition of gradient operations in a graph #20133

Merged
merged 4 commits into from
Jun 28, 2018

Conversation

karllessard
Copy link
Contributor

This calls the C-api TF_AddGradients method through a new JNI binding for adding gradient nodes to a graph. It also includes an AddGradients wrapper for invoking this operation smoothly while building a graph using the new Java Ops API.

@qlzh727 qlzh727 requested a review from asimshankar June 20, 2018 03:24
@qlzh727 qlzh727 self-assigned this Jun 20, 2018
@qlzh727 qlzh727 added the awaiting review Pull request awaiting review label Jun 20, 2018
Copy link
Contributor

@asimshankar asimshankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a bunch!

* }</pre>
*/
@Operator
public class AddGradients implements Op, Iterable<Operand<?>> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense for this to be called Gradients instead of AddGradients? (We don't call the other Op implementations AddConstant or AddMatMul)

* @param dx
* @return the partial derivatives {@code dy} with the size of {@code x}
*/
public Output<?>[] addGradients(Output<?>[] y, Output<?>[] x, Output<?>[] dx) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The common case would be a single y and null for dx, so should we add another method for that common case? Something like:

public Output<?>[] addGradients(Output<?> y, Output<?> ...x)

And similarly for the Op implementation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds very good to me. Would it matter if we do that modification only in the Op class? I kind of see that this kind of interface optimization is the responsibility of the Ops API layer while the core classes focuses more on the implementation details.

I also worked on another Op that adds gradients nodes to the graph and immediately apply the descent on the input tensors instead of doing this manually in n steps, should I also go ahead with this one (perhaps in another PR)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't feel terribly strongly, but have a mild preference to include the simple override in theGraph class as well.

Is the second Op you mentioned basically the equivalent of Optimizer.minimize() in Python? We can talk about that in a follow up, but as you suggested, not do that in this PR.

* If {@code dx} is null, the implementation will use dx of {@code OnesLike} for all
* shapes in {@code y}.
*
* @param y
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fill in the documentation for these arguments?

long[] dxHandles = null;
int[] dxIndices = null;

for (int i = 0; i < y.length; ++i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the body of this function be enclosed in a:

Reference ref = ref();
try {
  // The main body of this method
} finally {
  ref.close();
}

so that it remains thread-safe (i.e., a concurrent call to Graph.close() won't mess things up by rendering the elements of ?Handles invalid)

// e.g. given xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain
// dy = [dy0Handle, dy1Handle, ..., dy0Index, dy1Index, ...]
long[] dyHandlesAndIndices;
synchronized (nativeHandleLock) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we do the ref() block as suggested above, then we don't need this but instead would use ref.nativeHandle() for the call to addGradients

@Test
public void addGradientsComputationOpsToGraph() {
try (Graph g = new Graph()) {
Output<Integer> a = TestUtil.constant(g, "A", new int[][] {{1},{2}});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use Float instead of Integer as integer gradients are a bit iffy (in fact, we don't backprop through integer tensors in Python anymore - f637506, yes, yes, I know, it would be great if we were able to share that logic more easily :) But for now, let's at least have the example use a more realistic Float?

Output<Integer> ab = TestUtil.matmul(g, "AxB", a, b, false, false);
Output<Integer> abc = TestUtil.matmul(g, "AxBxC", ab, c, false, false);

Output<?>[] grad = g.addGradients(new Output<?>[] {abc}, new Output<?>[] {b, c}, null);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we improve test coverage here, perhaps by adding to SessionTest.java so that we also test the additional arguments (like dys)

Copy link
Contributor

@asimshankar asimshankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update, apologies for the delay. A few more minor things, otherwise looks great!

.fetch(grads[0])
.fetch(grads[1])
.run();

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To encourage best practices, should we also call .close() on the elements of outputs?


Output<?>[] grads = g.addGradients(toArray(y0, y1), toArray(x), null);

List<Tensor<?>> outputs = s.runner()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here:

try (Tensor<?> t = s.runner()....get(0)) {
  assertEquals(114.0f, t.floatValue(), 0.0f);
}

?

* @param dx if not null, the partial derivatives of some loss function {@code L} w.r.t. {@code y}
* @return the partial derivatives {@code dy} with the size of {@code x}
*/
public Output<?>[] addGradients(Output<?>[] y, Output<?>[] x, Output<?>[] dx) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it really ?? Don't all of the tensors have to have the same type? In which case, should this be:

public Output<T>[] addGradients(Output<T>[] y, Output<T>[] x, Output<T>[] dx)

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure about this, you probably know better than me: is it possible to have a graph with variables of different types? If it is guaranteed that could never happen, I'll gladly remove those wildcards to enforce type-safety.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm...it actually is, since one could have say tf.casts on the path between x and y, something like:

import tensorflow as tf

x = tf.placeholder(tf.float64)
y = tf.square(tf.cast(x, tf.float32))
dy = tf.gradients(y, x)[0]
print(x.dtype, y.dtype, dy.dtype)

So apologies, ignore my comment :)


Output<?>[] grads = g.addGradients(toArray(y), toArray(x), toArray(dx));

List<Tensor<?>> outputs = s.runner()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto about try-with-resources for the returned Tensor.

@@ -36,7 +36,7 @@
.<T>output(0);
}

public static Output<?> addN(Graph g, Output<?>... inputs) {
public static <T> Output<T> addN(Graph g, Output<?>... inputs) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the argument be Output<T> instead of Output<?> also?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, the compiler complains with a warning when using a parameterized type for a varargs parameter.

// returned array contains both op handles and output indices, in pair
jlongArray dy_handles_and_indices = env->NewLongArray(nx << 1);
jlong* dy_elems = env->GetLongArrayElements(dy_handles_and_indices, nullptr);
for (int i = 0, j = nx; i < nx; ++i, ++j) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In practice this probably won't matter at all since dy_elems will be a pretty small array, but did you consider encoding it as [handle0, index0, handle1, index1, ...] instead of [handle0, handle1, ..., index0, index1, ...]? The former would be more cache friendly :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ironically, I was using the format you proposed in my previous commit. I switched to the new one lately because:

  • that could be the new 'standard' for returning more than one value in a JNI binding, e.g. if we had 4 arrays to return, each of different length, the previous format won't work while the second will
  • we might find some optimization later to simply split the array in two instead of copying its elements one by one
  • I personally found it more elegant to iterate using two iterators at the same time :)

I don't remind reverting to the previous version if you want to.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's an internal detail, so we can easily switch it around as we go along. I don't feel very strongly, but am slightly partial to the previous version. Your call.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so just for the sake of keeping that PR continuing smoothly, let's keep it that way for now

@@ -0,0 +1,153 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why training/? Seems like gradients could be a top level operation too? One can use gradients for things other than training :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem. Right now, the Op classification is still obscure to me, I'm taking some guesses here and there but we will probably need to settle this at some point?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, the Python code was recently updated to generate symbols in finer grained namespaces (c1ff116#diff-d75a97d0f69d6f87ab79be6e2423d87b). So perhaps we could use the same for Java (without the baggage of backward compatibility. For example, Acos can be just in the math namespace and not in the top level one).

If you're enthusiastic, I'd be more than happy to see a PR adding Java API defs :)

@asimshankar asimshankar added awaiting testing (then merge) kokoro:force-run Tests on submitted change and removed awaiting review Pull request awaiting review labels Jun 28, 2018
@kokoro-team kokoro-team removed the kokoro:force-run Tests on submitted change label Jun 28, 2018
@qlzh727 qlzh727 added the kokoro:force-run Tests on submitted change label Jun 28, 2018
@kokoro-team kokoro-team removed the kokoro:force-run Tests on submitted change label Jun 28, 2018
@qlzh727 qlzh727 added the kokoro:force-run Tests on submitted change label Jun 28, 2018
@kokoro-team kokoro-team removed the kokoro:force-run Tests on submitted change label Jun 28, 2018
@asimshankar asimshankar added the kokoro:force-run Tests on submitted change label Jun 28, 2018
@kokoro-team kokoro-team removed the kokoro:force-run Tests on submitted change label Jun 28, 2018
@qlzh727
Copy link
Member

qlzh727 commented Jun 28, 2018

Windows build failure seems to be a flaky. Submitting now.

@qlzh727 qlzh727 merged commit 9752b11 into tensorflow:master Jun 28, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants