Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SyncBatchNormalization layer segfaults on multi-worker with NCCL #41113

Closed
MinasTyuru opened this issue Jul 5, 2020 · 20 comments
Closed

SyncBatchNormalization layer segfaults on multi-worker with NCCL #41113

MinasTyuru opened this issue Jul 5, 2020 · 20 comments
Assignees
Labels
comp:dist-strat Distribution Strategy related issues comp:keras Keras related issues TF 2.2 Issues related to TF 2.2 type:bug Bug

Comments

@MinasTyuru
Copy link

MinasTyuru commented Jul 5, 2020

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 14.04 (in a Docker container, on an 18.04 host)
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device: N/A
  • TensorFlow installed from (source or binary): Source
  • TensorFlow version (use command below): 2.2.0
  • Python version: 3.6.8
  • Bazel version (if compiling from source): 0.24.1-3.0
  • GCC/Compiler version (if compiling from source): 4.8.5
  • CUDA/cuDNN version: 10.0
  • GPU model and memory: Nvidia TITAN X 11GB

Describe the current behavior
When training models with the tf.keras.layers.experimental.SyncBatchNormalization layer, and using tf.distribute.experimental.MultiWorkerMirroredStrategy to train across multiple workers with tf.distribute.experimental.CollectiveCommunication.NCCL communication, the model trains for some amount of time (e.g. several thousand steps), then crashes with a segfault.

Describe the expected behavior
The model should train without segfaulting.

Standalone code to reproduce the issue
An example is below. Please note that this code must run on multiple workers. The TF_CONFIG environment variable must be set appropriately for your specific multi-worker configuration.

import tensorflow as tf
from tensorflow import keras

def get_dataset():
    x = tf.zeros([10], dtype=tf.float32)
    x = tf.data.Dataset.from_tensors(x)

    y = tf.constant([5])
    y = tf.data.Dataset.from_tensor_slices(y)

    dataset = tf.data.Dataset.zip((x, y))
    dataset = dataset.batch(1)
    dataset = dataset.repeat()
    return dataset

def main():
    # NOTE: You must set os.environ["TF_CONFIG"] as appropriate
    strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy(tf.distribute.experimental.CollectiveCommunication.NCCL)
    # assert strategy.num_replicas_in_sync == 2

    # Create dataset
    dataset = get_dataset()

    with strategy.scope():
        # Construct model
        model = keras.Sequential(
            layers=[
                tf.keras.layers.experimental.SyncBatchNormalization(),
                tf.keras.layers.Dense(1),
            ]
        )
        model.compile(
            optimizer=keras.optimizers.Adam(),
            loss=keras.losses.MeanSquaredError(),
        )

    model.fit(x=dataset, steps_per_epoch=10 ** 6, epochs=10 ** 3)


if __name__ == "__main__":
    main()

This is reproducible across a wide array of contexts, for example a Keras model, an Estimator model, different GPU types, etc.

Other info / logs
I used gdb to inspect a coredump from the crashed process. The backtrace is:

#0  0x00007f1d68cef711 in tensorflow::NcclReducer::Run(std::function<void (tensorflow::Status const&)>) (this=this@entry=0x7f18cc011af0, 
    done=...) at external/org_tensorflow/tensorflow/core/kernels/collective_nccl_reducer.cc:185
#1  0x00007f1d71797ca6 in tensorflow::BaseCollectiveExecutor::<lambda()>::operator()(void) const (__closure=<optimized out>)
    at external/org_tensorflow/tensorflow/core/common_runtime/base_collective_executor.cc:276
#2  0x00007f1d71797efe in std::_Function_handler<void(), tensorflow::BaseCollectiveExecutor::ExecuteAsync(tensorflow::OpKernelContext*, const tensorflow::CollectiveParams&, const string&, tensorflow::StatusCallback)::<lambda()> >::_M_invoke(const std::_Any_data &) (__functor=...)
    at external/gcc_7_4/usr/include/c++/7/bits/std_function.h:316
#3  0x00007f1d669a0e08 in std::function<void ()>::operator()() const (this=this@entry=0x7f185e7fbe60)
    at external/gcc_7_4/usr/include/c++/7/bits/std_function.h:706
#4  0x00007f1d71cc4f44 in tensorflow::UnboundedWorkQueue::PooledThreadFunc (this=0x20758660)
    at external/org_tensorflow/tensorflow/core/platform/default/unbounded_work_queue.cc:99
#5  0x00007f1d71cc5004 in tensorflow::UnboundedWorkQueue::<lambda()>::operator() (__closure=<optimized out>)
    at external/org_tensorflow/tensorflow/core/platform/default/unbounded_work_queue.cc:68
#6  std::_Function_handler<void(), tensorflow::UnboundedWorkQueue::Schedule(tensorflow::UnboundedWorkQueue::WorkFunction)::<lambda()> >::_M_invoke(const std::_Any_data &) (__functor=...) at external/gcc_7_4/usr/include/c++/7/bits/std_function.h:316
#7  0x00007f1d669a0e08 in std::function<void ()>::operator()() const (this=<optimized out>)
    at external/gcc_7_4/usr/include/c++/7/bits/std_function.h:706
#8  0x00007f1d71d136dd in std::__invoke_impl<void, std::function<void ()>>(std::__invoke_other, std::function<void ()>&&) (__f=...)
    at external/gcc_7_4/usr/include/c++/7/bits/invoke.h:60
#9  std::__invoke<std::function<void ()>>(std::function<void ()>&&) (__fn=...) at external/gcc_7_4/usr/include/c++/7/bits/invoke.h:95
#10 std::thread::_Invoker<std::tuple<std::function<void ()> > >::_M_invoke<0ul>(std::_Index_tuple<0ul>) (this=<optimized out>)
    at external/gcc_7_4/usr/include/c++/7/thread:234
#11 std::thread::_Invoker<std::tuple<std::function<void ()> > >::operator()() (this=<optimized out>)
    at external/gcc_7_4/usr/include/c++/7/thread:243
#12 std::thread::_State_impl<std::thread::_Invoker<std::tuple<std::function<void ()> > > >::_M_run() (this=<optimized out>)
    at external/gcc_7_4/usr/include/c++/7/thread:186
#13 0x00007f1d2eb72ae0 in ?? () from /usr/lib/x86_64-linux-gnu/libstdc++.so.6
#14 0x00007f1da0c12184 in start_thread (arg=0x7f185e7fc700) at pthread_create.c:312
#15 0x00007f1d9fdef03d in clone () at ../sysdeps/unix/sysv/linux/x86_64/clone.S:111

Disassembling the function showed that this was the offending instruction:

   0x00007f1d68cef707 <+2651>:  jg     0x7f1d68cf0642 <tensorflow::NcclReducer::Run(std::function<void (tensorflow::Status const&)>)+6550>
   0x00007f1d68cef70d <+2657>:  mov    0x18(%rbx),%rax
=> 0x00007f1d68cef711 <+2661>:  mov    0x8(%rax),%rdi
   0x00007f1d68cef715 <+2665>:  mov    (%rdi),%rax
   0x00007f1d68cef718 <+2668>:  mov    0x20(%rbx),%rsi

And printing out the registers shows rax 0x0 0, so some sort of pointer is set to 0. It therefore looks like there is some sort of null pointer exception in line 185 of collective_nccl_reducer.cc, which I believe is the line col_ctx_->col_exec->UnblockDependencies(*col_params_);. I don't have any idea why it would segfault there, however. The same line appears shortly above on line 176, so it's strange that it would segfault the second time.

Also, a log is attached here, however it is not very interesting as it just runs for a while and then segfaults.

@ravikyram
Copy link
Contributor

ravikyram commented Jul 6, 2020

@MinasTyuru

I have tried in colab with TF version 2.2 and i am seeing the error message(ModuleNotFoundError: No module named 'tflight').Request you to share colab link or simple standaalone code with supporting files to reproduce the issue in our environment.It helps us in localizing the issue faster.Thanks!
`

@ravikyram ravikyram added the stat:awaiting response Status - Awaiting response from author label Jul 6, 2020
@MinasTyuru
Copy link
Author

@ravikyram Oops, sorry about that. I removed the line that was the problem and it should work now. Please keep in mind that the code must run on multiple workers to reproduce the issue. As far as I know, this is not possible on Google Colab.

@ravikyram ravikyram added comp:dist-strat Distribution Strategy related issues comp:keras Keras related issues TF 2.2 Issues related to TF 2.2 and removed stat:awaiting response Status - Awaiting response from author labels Jul 7, 2020
@ravikyram ravikyram assigned gowthamkpr and unassigned ravikyram Jul 7, 2020
@gowthamkpr gowthamkpr assigned nikitamaia and unassigned gowthamkpr Jul 7, 2020
@nikitamaia
Copy link
Member

Hi @MinasTyuru, thanks for providing very detailed debugging information. BTW you can use virtual GPUs in colab to simulate a multiworker environment, but NCCL is not supported in this case, so probably not helpful for debugging this particular issue.

Can you confirm that you do not see this issue if the CollectiveCommunication is set to AUTO? Additionally, can you try and run on TF nightly and let me know what happens?

@nikitamaia nikitamaia added the stat:awaiting response Status - Awaiting response from author label Jul 8, 2020
@MinasTyuru
Copy link
Author

@nikitamaia Thank you for the information! Yes, that's correct. Using AUTO or RING seems to work fine. Running on TF nightly may be a bit tricky on our infrastructure, I will get back to you on that.

@MinasTyuru
Copy link
Author

@nikitamaia Hi Nikita, I was able to reproduce on TF-nightly 2.4.0.dev20200708. When using AUTO it works fine, but when using NCCL it immediately segfaults. Attached are the failed log with NCCL and the successful log with AUTO.

@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label Jul 10, 2020
@nikitamaia nikitamaia added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Jul 12, 2020
@guptapriya
Copy link
Contributor

@MinasTyuru thank you for the report. It seems to be the case from the title, but I would like to confirm: does this issue only show up when using SyncBatchNormalization layer? (And goes away when not using it, or using a regular batch norm).

Getting a segfault is terrible, we will look into it.

We have used SyncBatchNormalization for some models before with multi worker, so I am trying to see what's different here since your repro is quite simple. Would you mind trying one thing: Can you make the batch size bigger? Currently seems 1. Usually we try to split the batch across the replicas but since batch size of 1 cannot be split, the other replicas will end up getting empty batches. I am wondering if that's what is triggering this. Although, it would not explain why it works for some steps and then fails..

@MinasTyuru
Copy link
Author

After some more investigation, it looks like we can reproduce the issue in some contexts but not others, so it's possible that it is some sort of configuration issue. I thought that I was able to reproduce on TF-nightly, but I believe that was due to a mistake I made when installing CUDA, and now I think it may work fine on TF-nightly. It may be related to how we build TensorFlow from source or something of that nature. I will investigate some more and let you know if it seems to be a problem on our end or not.

@guptapriya
Copy link
Contributor

thanks for the update @MinasTyuru .
@dubey is looking into the segfault more - he thinks there is a possible race condition which might be leading to the seg fault and he is looking into a possible fix.

Re: nightly, yes, your logs indicated that the issue with the nightly was that it was not able to find the GPUs and hence failed right away. I would guess that once you fix that issue, it may probably run into the segfault at some point later like with 2.2.

@MinasTyuru
Copy link
Author

MinasTyuru commented Jul 14, 2020

Yes, we are seeing something similar. When I used a debugger, it changed the behavior. For example, it seems like adding a breakpoint after line 176 of collective_nccl_reducer.cc, which introduces a very small delay, causes TensorFlow to immediately segfault. In this case, it seems that the col_ctx_ variable is set to a null pointer by bfc_allocator.cc. Perhaps this memory is inappropriately freed before it is finished being used? Is this behavior consistent with your hypothesis of the race condition? I'll post some more logs in ~12 hours.

We tried to reproduce on TF2.2 and TF-nightly, and on a clean build (i.e. installing TensorFlow from pip, etc.), it seems like it trains for a long time without any problems. But in our production environment (which changes some of the dependencies in the TensorFlow build and other build differences), we encounter the segfault, so perhaps subtle timing differences are triggering the underlying issue.

@MinasTyuru
Copy link
Author

Attached are a stacktrace from when col_ctx_ is set to 0 by the memory allocator, and the assembly around that line. Thanks again!

@google-ml-butler
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

@dubey
Copy link
Member

dubey commented Jul 15, 2020

@MinasTyuru I just submitted a change that should help with this issue. Feel free to reopen if you encounter the segfault again.

@guoshimin
Copy link
Contributor

Should col_impl get a similar treatment too? If col_ctx could be freed prematurely, then so could col_impl.

@MinasTyuru
Copy link
Author

MinasTyuru commented Jul 16, 2020

Thanks Ayush! I will test this out and see if it works.

Shimin, it looks to me like col_impl is a member of the col_ctx object and therefore does not need to be managed separately (it will be allocated and freed along with col_ctx). Oops, I meant col_params_, never mind.

@MinasTyuru
Copy link
Author

MinasTyuru commented Jul 16, 2020

@dubey It seems like the change fixes the col_ctx_ problem, however I'm still experiencing segfaults. I now observe that col_ctx_ is nonzero pointer, however col_ctx_->col_exec is a zero pointer, like so:

(gdb) bt
#0  0x00007fffbf8911a7 in tensorflow::NcclReducer::Run(std::function<void (tensorflow::Status const&)>) (this=0x7ffd9800f360, done=...)
    at external/org_tensorflow/tensorflow/core/kernels/collective_nccl_reducer.cc:185
#1  0x00007fffc83399e4 in tensorflow::BaseCollectiveExecutor::<lambda()>::operator() (__closure=<optimized out>)
    at external/org_tensorflow/tensorflow/core/common_runtime/base_collective_executor.cc:275
#2  std::_Function_handler<void(), tensorflow::BaseCollectiveExecutor::ExecuteAsync(tensorflow::OpKernelContext*, const tensorflow::CollectiveParams&, const string&, tensorflow::StatusCallback)::<lambda()> >::_M_invoke(const std::_Any_data &) (__functor=...)
    at external/gcc_7_4/usr/include/c++/7/bits/std_function.h:316
#3  0x00007fffbd5425c8 in std::function<void ()>::operator()() const (this=this@entry=0x7ffabcffce60)
    at external/gcc_7_4/usr/include/c++/7/bits/std_function.h:706
#4  0x00007fffc88674de in tensorflow::UnboundedWorkQueue::PooledThreadFunc (this=0x1ed20c40)
    at external/org_tensorflow/tensorflow/core/platform/default/unbounded_work_queue.cc:99
...
(gdb) p col_ctx_
$1 = std::shared_ptr<tensorflow::CollectiveContext> (use count 33194781, weak count -92074895) = {get() = 0x7ffd00555043}
(gdb) p col_params_
$2 = (const tensorflow::CollectiveParams *) 0x7fffc838ca46
     <std::_Function_handler<void(), tensorflow::(anonymous namespace)::ExecutorState::Process(tensorflow::(anonymous namespace)::ExecutorState::TaggedNode, tensorflow::int64)::<lambda()> >::_M_invoke(const std::_Any_data &)>
(gdb) p col_ctx_->col_exec
$3 = (tensorflow::CollectiveExecutor *) 0x0

I think that maybe CollectiveExecutor needs to be a shared pointer as well? And, as Shimin said, possibly col_impl also.

Also I think I don't have the required permissions to reopen this issue, so I'm just commenting on it.

@dubey
Copy link
Member

dubey commented Jul 16, 2020

Thanks for the update. Let me follow up internally.

@dubey dubey reopened this Jul 16, 2020
@dubey dubey self-assigned this Jul 16, 2020
@tensorflowbutler tensorflowbutler removed the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Jul 18, 2020
@MinasTyuru
Copy link
Author

Shimin implemented the change that he suggested with col_impl:

diff --git a/tensorflow/core/common_runtime/base_collective_executor.cc b/tensorflow/core/common_runtime/base_collective_executor.cc
index 0eebab9cf53..4d95059d018 100644
--- a/tensorflow/core/common_runtime/base_collective_executor.cc
+++ b/tensorflow/core/common_runtime/base_collective_executor.cc
@@ -246,20 +246,20 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
                           col_params.is_source))
                             ? &ctx->input(0)
                             : nullptr;
-  CollectiveImplementationInterface* col_impl = nullptr;
-  Status status = CreateCollective(col_params, &col_impl);
+  CollectiveImplementationInterface* col_impl_ptr = nullptr;
+  Status status = CreateCollective(col_params, &col_impl_ptr);
   if (!status.ok()) {
     done_safe(status);
-    TF_DCHECK_EQ(nullptr, col_impl);
+    TF_DCHECK_EQ(nullptr, col_impl_ptr);
     return;
   }
+  auto col_impl = std::shared_ptr<CollectiveImplementationInterface>(col_impl_ptr);
   auto col_ctx = std::make_shared<CollectiveContext>(
       this, dev_mgr_, ctx, CtxParams(ctx), col_params, exec_key, step_id_,
       input, output);
   status = col_impl->InitializeCollectiveContext(col_ctx);
   if (!status.ok()) {
     done_safe(status);
-    delete col_impl;
     return;
   }
   // Run on an unbounded work queue that can handle blocking work so as to not
@@ -274,7 +274,6 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
         profiler::TraceMeLevel::kInfo);
     col_impl->Run([col_impl, col_ctx, done_safe](const Status& s) {
       done_safe(s);
-      delete col_impl;
     });
   });
 }

I'm not sure what the relationship is between col_impl and the col_exec variable that seemed to be null previously, but it seems to fix the segfault. I'm doing some additional testing to be sure and will keep you posted.

@dubey
Copy link
Member

dubey commented Jul 20, 2020

Thanks, I was working on a similar change internally, this time with a unit test that can reproduce the issue. Yes please do keep me posted.

@MinasTyuru
Copy link
Author

Apologies for the delay, it looks like everything works without segfaulting now.

@google-ml-butler
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:dist-strat Distribution Strategy related issues comp:keras Keras related issues TF 2.2 Issues related to TF 2.2 type:bug Bug
Projects
None yet
Development

No branches or pull requests

8 participants