Skip to content

Conversation

duncanriach
Copy link
Contributor

@duncanriach duncanriach commented May 14, 2021

This PR adds and tests deterministic forward and backward operation of tf.nn.softmax_cross_entropy_with_logits when running on a GPU.

Note that there are changes and enhancements to the existing tests that may be obscured by the restructuring of the test files.

Thanks to @reedwm for providing support and guidance on this PR, including looking into the arithmetic equivalence of the forward and backward operation of the python-level solution.

Note that a naive implementation of softmax followed by cross-entropy is not as numerically stable as the version implemented here (and in the existing Eigen-based/C-level implementation) in which the log in the cross-entropy function is moved back into the softmax, changing it into a log-softmax and changing the cross-entropy function into a dot-product. log-softmax does not demand as large dynamic ranges as softmax.

Note that the following tests do not pass on this deterministic implementation (and have been disabled):

  • Backprop to logits when there is only a single class (the forward path passes). See testSingleClass.
  • Backprop to logits when labels are broadcast (the forward path passes). See testLabelsBroadcast.

I have not yet been able to determine the reason for this, and I don't know if it's because the existing functionality is incorrect or if the new, deterministic functionality is incorrect. For the single class case, for example, it seems to me that the correct gradients should all be zero (which is what the new, deterministic implementation provides). It seems as though the above two use cases (single class and broadcast labels) would rarely be used; it's not obvious to me what the applications of these use cases would be, and these functionalities are also not documented. I have added TODO comments for me to look into this more deeply. @reedwm, feel free to explore.

UPDATE: After further investigation, it has been revealed that the gradients only mismatch between the nondeterministic and deterministic implementations when the labels vector is not a valid probability distribution, as required (but not enforced) by the API. See this comment for more information.

This PR is related to RFC: Enabling Determinism in TensorFlow. For status and history of GPU-determinism for this op, see here.

cc @sanjoy @nluehr

@google-ml-butler google-ml-butler bot added the size:L CL Change Size: Large label May 14, 2021
@google-cla google-cla bot added the cla: yes label May 14, 2021
@duncanriach duncanriach force-pushed the softmax-xent-gpu-determinism branch from e6c4756 to abfd464 Compare May 14, 2021 04:56
@gbaned gbaned self-assigned this May 14, 2021
@gbaned gbaned requested a review from sanjoy May 14, 2021 06:54
"2-dimensional, or broadcasted to be "
"2-dimensional"));

if (std::is_same<Device, GPUDevice>::value) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically, this op is still nondeterministic. It's still likely used in practice with the non-Python APIs (e.g. with the C API or C++ API). So I would keep the error message, but add a comment stating the Python API does not use this op when determinism is enabled.

I don't think we ever have to bother fixing the op though. C++ users can always manually use the nonfused version as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. The next commit will put the exception-throwing back (with your suggested comment), as well as the tests (but running on the gen_nn_ops version).

deterministic_ops = os.getenv('TF_DETERMINISTIC_OPS', '0')
return deterministic_ops == '1' or deterministic_ops == 'true'

def _coreOp(self, features, labels):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think instead of having a _coreOp function, we should just test the nn_ops version instead of the gen_nn_ops version for most tests. This avoids the awkwardness of having to create a Python implementation that has the same interface as the C++ op in XentDeterministicTest.

For test methods that are specific to the gen_nn_ops version, like testRankTooLarge, you can define a XentOpTest class subclassing XentOpTestBase and put them there. That way, you don't have to override methods in XentDeterministicTest to do nothing.

Once we have a flag to enable determinism instead of the environmental variable, we can have just a single test file and a testWithDeterminismOnAndOff test decorator, which will simply things. But we're not quite there yet.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think instead of having a _coreOp function, we should just test the nn_ops version instead of the gen_nn_ops version for most tests. This avoids the awkwardness of having to create a Python implementation that has the same interface as the C++ op in XentDeterministicTest.

Sounds good. I was trying to minimize changes and be done quicker. What you propose is a good technical debt minimization step.

For test methods that are specific to the gen_nn_ops version, like testRankTooLarge, you can define a XentOpTest class subclassing XentOpTestBase and put them there. That way, you don't have to override methods in XentDeterministicTest to do nothing.

Yep. Sounds good.

Once we have a flag to enable determinism instead of the environmental variable, we can have just a single test file and a testWithDeterminismOnAndOff test decorator, which will simply things. But we're not quite there yet.

Absolutely. The (temporary) environment variable implementation makes organization of test files and classes unusually cumbersome. Looking forward to the ensuing tidiness.

def _testXentWrapper(self, np_features, np_labels, dim=-1, use_gpu=False):
np_loss, _ = self._npXent(np_features, np_labels, dim=dim)
with self.cached_session(), test_util.device(use_gpu):
# Even in eager mode, the above line will be able to pin ops to CPU.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this mean? What is wrong with passing use_gpu to cached_session, as was done before?

Copy link
Contributor Author

@duncanriach duncanriach May 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I'm mistaken, but from my experience of writing, running, and checking tests (and from looking at the test_util code), my understanding is that if you pass use_gpu=False into self.session() or self.cached_session() and run it on a GPU in eager mode, the ops will not be pinned to the CPU. This is because self.session() and self.cached_session() are only relevant in graph mode (they do nothing in eager mode). When we write test code that can be run in either graph or eager mode, there needs to be both a self.session() or self.cached_session() context and a separate, nested test_util.device() context.

The existing code is currently only running in eager mode, and, the way it's written, each case will be run twice on GPU if there is a GPU present (this is clearly not the intention of the test writer). The change in the PR makes it run each case on both CPU and GPU if there is a GPU present, but does not make it less compatible with graph mode (in case we want to run this test in graph mode).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The next commit will remove the self.cached_session() and the comment on the next line. None of that is needed.

self.assertAllEqual(result_a, result_b)

def _testBackward(self, labels_not_logits=False):
for use_gpu in [False, True]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our CI runs tests both with and without a GPU, so explicitly running on the CPU is not necessary.

Copy link
Contributor Author

@duncanriach duncanriach May 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm aware that all the tests get run on machines with and without GPUs and that, therefore, the tests can be written so that each case will run on a GPU if a GPU is present and on a CPU if a GPU is not present. The reason I write tests like this, and (I presume) why others do (including the existing op test), is that it simplifies the development process: I can develop for both CPU and GPU on one machine with one build-test container.

Should I change my development practice to open both a GPU container and a non-GPU container on my machine, with the source code mounted into both, and then execute the build/run in both containers in parallel? The development procedure is already very cumbersome and slow; this would double the amount of build/re-build time and things to keep track of.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The next commit will remove this loop. It makes sense to remove it after development, when it is no longer needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The next commit also removes all the extra cycles that were in the pre-existing test code to run test cases on both the CPU and GPU on a machine with a GPU (and which ran the same cases twice on machines without GPUs).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can run with CUDA_VISIBLE_DEVICES= (setting the env var to an empty string) if you want to test without a GPU. Admittedly, this is a pain since you cannot run both with and without the GPU in a single bazel test command.

/CC @sanjoy would it be feasible to have the non-GPU versions of test run with CUDA_VISIBLE_DEVICES= automatically?

labels, logits = self._generateInputs(dtype, seed=456)
output_shape = labels.shape[0]

def gradients(seed=789):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No reason for seed to have a default value here

Copy link
Contributor Author

@duncanriach duncanriach May 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. It will be removed by the next commit.

seed = 456 + trial
labels_grad_a, logits_grad_a = gradients(seed=seed)
labels_grad_b, logits_grad_b = gradients(seed=seed)
if labels_not_logits:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of having this parameter, i would just assert both labels and logits.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. That's how this test started out. It evolved so that each op-path is covered independently (with code re-use but less efficient use of compute cycles). For one thing, this made it easy for me to confirm (without modification to the tests) that the tests will catch nondeterminism in each path independently. With these tests, for example, I disabled the determinism solution and confirmed that both testBackwardLogits and testBackwardLabels failed. I believe that tests quit on the first assert, which makes it impossible to fully validate their intention without modifying them (albeit, just the addition or subtraction of a #).

I'm considering simplifying this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The next commit will remove this complexity.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to test multiple things independently (allowing the second to run even if the first fails), consider using self.subtest. I don't think it's necessary in this case though.

@duncanriach duncanriach requested a review from reedwm May 19, 2021 00:46
@gbaned gbaned requested review from reedwm and removed request for reedwm May 20, 2021 13:53
@sanjoy sanjoy removed their request for review May 26, 2021 06:36
class XentOpDeterminismExceptionsTest(test.TestCase):
"""Test d9m-unimplemented exceptions from SoftmaxCrossEntropyWithLogits.
Test that tf.errors.UnimplementedError is thrown or not thrown, as
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This issue isn't introduced by this PR, but it seems this test only tests that UnimplementedError is thrown. It does not test that UnimplementedError is not thrown, contrary to the comment "...or not thrown...".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch. This comment will be adjusted by the next commit.


def _opDeterminismEnabled(self):
deterministic_ops = os.getenv('TF_DETERMINISTIC_OPS', '0')
return deterministic_ops == '1' or deterministic_ops == 'true'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change to deterministic_ops in ('1', 'true')

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lovely simplification. Thank you. This will be included in the next commit.

@test_util.run_in_graph_and_eager_modes
def testForward(self):
with self.session(), test_util.force_cpu():
with self.cached_session(), test_util.device(use_gpu=True):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think test_util.device is unnecessary, since the test will use the GPU by default if available.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, and I've thought about this. use_gpu=True explicitly pins the ops to the GPU, if there is a GPU. I understand that the expected, default behavior is for the ops to auto-pin to the GPU (if there is one), but I've wondered if we should rely on that. Considering again ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The next commit will remove almost all instances of use_gpu=True, in xent_op_test_base.py, xent_op_test.py, and xent_op_deterministic_test.py.

TODO(duncanriach): Identify the source of the difference in gradients for
this case.
"""
self._testSingleClass(test_backprop=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How incorrect are the gradients here and the next test? I am slightly worried there is a correctness problem, although the single class case is seems fairly unimportant.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, it is concerning, and the gradients are very different (if I remember correctly). I was already intending to change these tests so that they do check expected values for both deterministic and nondeterministic cases. The next commit will include that, so we'll have documentation in the test of the difference between the gradients for deterministic vs nondeterministic.

Copy link
Contributor Author

@duncanriach duncanriach May 28, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After further investigation, it has been revealed that the gradients only mismatch between the nondeterministic and deterministic implementations when the labels vector is not a valid probability distribution, as required (but not enforced) by the API. Below are the docstring notes from xent_op_deterministic_test.py in its current state (after the most recent commit).

testSingleClass:

The most recent commit also adds the fourth minibatch item: labels=[1.], logits=[1.].

The deterministic implementation does not produce the gradients expected by
the original test (for the nondeterministic functionality) when the labels
vector is not a valid probability distribution.

labels: [[-1.], [0.], [1.], [1.]]
logits: [[1.], [-1.], [0.], [1.]]

               nondeterministic               deterministic
dloss/dlogits: [[2.0], [1.0], [0.0], [0.0]]   [[0.0], [0.0], [0.0], [0.0]]

Note that only the second two label vectors are valid probability
distributions (as required by the API) and that the gradient matches for
those cases.

testLabelsBroadcast:

The most recent commit also adds the third minibatch item: labels=[0.25], logits=[1., 2., 3., 4.].

The deterministic implementation does not produce the gradients expected by
the original test (for the nondeterministic functionality) when the labels
vector (after broadcasting) is not a valid probability distribution.

labels: [[0.], [2.], [0.25]]
logits: [[1., 1., 1., 1.],
         [1., 2., 3., 4.],
         [1., 2., 3., 4.]]

dloss/dlogits (nondeterministic):
    [[ 0.25 ,  0.25 ,  0.25 ,  0.25 ],
     [-1.968, -1.913, -1.763, -1.355],
     [-0.218, -0.163, -0.013,  0.394]]

dloss/dlogits (determinsitic):
    [[ 0.   ,  0.   ,  0.   ,  0.   ],
     [-1.743, -1.303, -0.105,  3.150],
     [-0.218, -0.163, -0.013,  0.394]]

Note that neither of the first two broadcast label vectors is a valid
probability distribution (as required by the API) and that these are the
cases that yield different gradients for nondeterministic vs determinsitic
implementations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I propose that we remove these illegal label cases from the tests, replacing them with legal cases, thereby making the tests run the same way on both the original implementation and the deterministic implementation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am slightly worried some people might accidentally rely on the current behavior when given illegal label cases. Also in practice, the labels might not add up exactly to 1 due to rounding errors. I think we should keep it as is for now, with the TODO to investigate this. I'll later try to think if there is a valid interpretation of softmax-cross-entropy when the labels do not add to 1.

@duncanriach duncanriach requested a review from reedwm May 28, 2021 00:10
@duncanriach duncanriach force-pushed the softmax-xent-gpu-determinism branch from 5240ac7 to 05e70a4 Compare May 28, 2021 01:34
@duncanriach duncanriach force-pushed the softmax-xent-gpu-determinism branch from 05e70a4 to 8fb4285 Compare May 28, 2021 05:16
@google-ml-butler google-ml-butler bot added kokoro:force-run Tests on submitted change ready to pull PR ready for merge process labels Jun 2, 2021
@reedwm
Copy link
Contributor

reedwm commented Jun 2, 2021

Sorry for long periods between reviews, I had taken some time off.

@kokoro-team kokoro-team removed the kokoro:force-run Tests on submitted change label Jun 2, 2021
@duncanriach
Copy link
Contributor Author

Sorry for long periods between reviews, I had taken some time off.

No problem, @reedwm. I hope you had a relaxing and enjoyable break.

@google-ml-butler google-ml-butler bot removed the ready to pull PR ready for merge process label Jun 2, 2021
@duncanriach duncanriach requested a review from reedwm June 2, 2021 23:01
@google-ml-butler google-ml-butler bot added kokoro:force-run Tests on submitted change ready to pull PR ready for merge process labels Jun 2, 2021
@kokoro-team kokoro-team removed the kokoro:force-run Tests on submitted change label Jun 2, 2021
@pkanwar23 pkanwar23 added ready to pull PR ready for merge process and removed ready to pull PR ready for merge process labels Jun 4, 2021
@copybara-service copybara-service bot merged commit 8c2b12a into tensorflow:master Jun 4, 2021
@duncanriach duncanriach deleted the softmax-xent-gpu-determinism branch June 7, 2021 22:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes ready to pull PR ready for merge process size:L CL Change Size: Large
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants