Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix wrong accuracy fetches when single_session is used. #469

Merged
merged 1 commit into from May 19, 2020
Merged

Fix wrong accuracy fetches when single_session is used. #469

merged 1 commit into from May 19, 2020

Conversation

gyeongchan-yun
Copy link
Contributor

@gyeongchan-yun gyeongchan-yun commented May 14, 2020

Hello,

I'm wondering whether accuracy op is correct or not in 'distributed_all_reduce'.
To support my doubt, the log message is here. Top-5 accuracy is over 1.0!

I run it on 4 servers with 4 GPUs (total GPUs: 16) with the following configuration.
Model: resenet56
Dataset: cifar10 (synthetic)
Mode: BenchmarkMode.TRAIN
SingleSess: True
Variables: distributed_all_reduce
AllReduce: xring

Step Img/sec total_loss top_1_accuracy top_5_accuracy
1 images/sec: 5.6 +/- 0.0 (jitter = 0.0) 3.286 0.445 1.906
10 images/sec: 52.6 +/- 81.8 (jitter = 29.5) 2.580 0.484 1.945
20 images/sec: 99.3 +/- 42.3 (jitter = 29.1) 2.522 0.500 2.164
30 images/sec: 141.0 +/- 28.7 (jitter = 19.9) 2.519 0.398 1.852
40 images/sec: 178.4 +/- 21.7 (jitter = 17.6) 2.511 0.391 1.992
50 images/sec: 212.3 +/- 17.7 (jitter = 24.9) 2.507 0.375 2.047
60 images/sec: 242.8 +/- 14.8 (jitter = 25.6) 2.509 0.312 1.891
70 images/sec: 270.8 +/- 12.7 (jitter = 27.8) 2.503 0.391 2.094
80 images/sec: 297.0 +/- 11.4 (jitter = 29.8) 2.502 0.414 2.008
90 images/sec: 321.3 +/- 10.4 (jitter = 34.8) 2.503 0.391 1.914
100 images/sec: 343.0 +/- 9.5 (jitter = 38.9) 2.505 0.391 1.883

In benchmark_cnn.py,

  def _build_fetches(self, global_step, all_logits, losses, device_grads,
                     enqueue_ops, update_ops, all_accuracy_ops, phase_train):
    """Complete construction of model graph, populating the fetches map."""
    fetches = {}
    if enqueue_ops:
      fetches['enqueue_ops'] = enqueue_ops
    for name, ops in all_accuracy_ops.items():
      if name.startswith(constants.UNREDUCED_ACCURACY_OP_PREFIX):
        key = name[len(constants.UNREDUCED_ACCURACY_OP_PREFIX):]
        fetches[key] = tf.concat(ops, 0)
      else:
        # The problem if single_session is used.
        fetches[name] = tf.reduce_sum(ops) / self.batch_size  
        if self.task_index == 0 and self.params.summary_verbosity >= 1:
          tf.summary.scalar(name, fetches[name])

I think this is correct accuracy_ops if single_session is used.

fetches[name] = (tf.reduce_sum(ops) / self.batch_size * 
                 (self.num_workers  if self.single_session else 1))

Best regards,
Gyeongchan Yun

@googlebot
Copy link

Thanks for your pull request. It looks like this may be your first contribution to a Google open source project (if not, look below for help). Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

📝 Please visit https://cla.developers.google.com/ to sign.

Once you've signed (or fixed any issues), please reply here with @googlebot I signed it! and we'll verify it.


What to do if you already signed the CLA

Individual signers
Corporate signers

ℹ️ Googlers: Go here for more info.

@gyeongchan-yun
Copy link
Contributor Author

@googlebot I signed it!

@googlebot
Copy link

CLAs look good, thanks!

ℹ️ Googlers: Go here for more info.

@@ -2903,7 +2903,8 @@ def _build_fetches(self, global_step, all_logits, losses, device_grads,
key = name[len(constants.UNREDUCED_ACCURACY_OP_PREFIX):]
fetches[key] = tf.concat(ops, 0)
else:
fetches[name] = tf.reduce_sum(ops) / self.batch_size
fetches[name] = (tf.reduce_sum(ops) / self.batch_size *
(self.num_workers if self.single_session else 1))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you want to divide by num_workers here, not multiply. Or alternatively, put parenthesis around self.batch_size * (self.num_workers if self.single_session else 1)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. I want to divide num_workers if single_session is used. Your alternative suggestion is right. 👍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated it.

Best regards,
Gyeongchan Yun

Copy link
Member

@reedwm reedwm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are failing some tests, like testReplicated. To reproduce, run:

python -m unittest -v benchmark_cnn_test.TfCnnBenchmarksTest.testReplicated

The issue is self.single_session is not set during evaluation. Probably the best way to fix this would be to set self.single_session in the constructor:

self.single_session = params.variable_update == 'distributed_all_reduce'

Then, change the if-statement here to check if self.single_session is True instead, and no longer set self.single_session in the _build_graph function.

@gyeongchan-yun
Copy link
Contributor Author

You are failing some tests, like testReplicated. To reproduce, run:

python -m unittest -v benchmark_cnn_test.TfCnnBenchmarksTest.testReplicated

The issue is self.single_session is not set during evaluation. Probably the best way to fix this would be to set self.single_session in the constructor:

self.single_session = params.variable_update == 'distributed_all_reduce'

Then, change the if-statement here to check if self.single_session is True instead, and no longer set self.single_session in the _build_graph function.

Thanks for your kind suggestion!
I've updated it and confirmed passing the test.

Best Regards,
Gyeongchan Yun

@reedwm reedwm merged commit baa105b into tensorflow:master May 19, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants