Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gRPC: terminate called after throwing an instance of 'std::bad_alloc' #31245

Closed
karkadad opened this issue Aug 1, 2019 · 15 comments
Closed

gRPC: terminate called after throwing an instance of 'std::bad_alloc' #31245

karkadad opened this issue Aug 1, 2019 · 15 comments
Assignees
Labels
comp:apis Highlevel API related issues stat:awaiting response Status - Awaiting response from author TF 1.12 Issues related to TF 1.12 type:support Support issues

Comments

@karkadad
Copy link

karkadad commented Aug 1, 2019

System information
• OS Platform and Distribution: CentOS Linux release 7.4.1708 (Core)
• TensorFlow version : Tensorflow-1.12.0 built from source

Describe the problem:
Running a model parallel implementation results in the following error:

  what():  std::bad_alloc
  Aborted

The code runs fine on a single node i.e. with a single worker, but distributing across 2 nodes/2 workers results in the above error. Suspect it is related to grpc.

Source code / logs
On the chief worker:

import tensorflow as tf
from time import time

def conv_op(inputs, kernel_, name):
    with tf.variable_scope(name) as scope:
        conv = tf.nn.conv3d(inputs, kernel_, [1, 1, 1, 1, 1], dilations=[1, 1, 1, 1, 1], padding='SAME')
    return conv

def inference_withconv(inputs, kernel_):
    with tf.device('/job:worker/task:{}'.format(0)):
        conv1_woc = conv_op(inputs, kernel_, 'conv1_woc')
    with tf.device('/job:worker/task:{}'.format(1)):
        conv2_woc = conv_op(conv1_woc, kernel_, 'conv2_woc')
    with tf.device('/job:worker/task:{}'.format(0)):
        convadd_results = tf.math.add(conv1_woc, conv2_woc)
        return convadd_results

def run_benchmark():
    image_shape = (1,1024,1024,1024,1)
    kernel_1_shape = (5,5,5,1,1)
    bias_shape = (1)

    dummy_image = tf.truncated_normal(
       image_shape,
       dtype=tf.float32,
       mean=0,
       stddev=1,
       name='3Dconv_image_1')

    dummy_kernel_1 = tf.truncated_normal(
       kernel_1_shape,
       dtype=tf.float32,
       mean=0,
       stddev=1,
       name='3Dconv_kernel_1')

    image_shape = (1, 1024, 1024, 1024, 1)
    image_init = tf.placeholder(tf.float32, shape=image_shape, name='input_1')

    res_ = inference_withconv(image_init, dummy_kernel_1)

    # Define the cluster spec
    cluster_spec = tf.train.ClusterSpec({'worker' : [('<ip_address_1>' + ":" + '2222'), ('<ip_address_2>' + ":" + '2222')]})

    task_id=0 # Chief worker
    server_config = tf.ConfigProto(inter_op_parallelism_threads=2, intra_op_parallelism_threads=20)
    server = tf.train.Server(cluster_spec, job_name='worker', task_index=task_id, config=server_config)

    session_config = tf.ConfigProto(
      inter_op_parallelism_threads=2,
      intra_op_parallelism_threads=20)

    with tf.Session(server.target, config=session_config) as sess:
        sess.run(tf.initialize_all_variables())
        image_, kernel_1 = sess.run([dummy_image, dummy_kernel_1])
        infer_results_ = sess.run(res_, feed_dict={'input_1:0': image_, '3Dconv_kernel_1:0': kernel_1})

if __name__ == '__main__':
    run_benchmark()

On the non-chief worker (on another node with a different IP address):

import tensorflow as tf

# Define the cluster spec
cluster_spec = tf.train.ClusterSpec({'worker' : [('<ip_address_1>' + ":" + '2222'), ('<ip_address_2>' + ":" + '2222')]})

task_id=1 # Non-chief worker
server_config = tf.ConfigProto(inter_op_parallelism_threads=2, intra_op_parallelism_threads=20)
server = tf.train.Server(cluster_spec, job_name='worker', task_index=task_id, config=server_config)

server.join()


@oanush oanush self-assigned this Aug 2, 2019
@oanush oanush added TF 1.12 Issues related to TF 1.12 comp:apis Highlevel API related issues type:support Support issues labels Aug 2, 2019
@oanush
Copy link

oanush commented Aug 2, 2019

@karkadad ,
Can you please go through the link of similar issue.Thanks!

@oanush oanush added the stat:awaiting response Status - Awaiting response from author label Aug 2, 2019
@karkadad
Copy link
Author

karkadad commented Aug 2, 2019

This is not a memory allocation issue and therefore not related to this issue: #9487. Consider the memory profile for a single node run:
MemProfile
There are no issues running on a single node/single worker and as you can see, there is plenty of memory available so it doesn't run out of memory (The total memory available is ~384 GB and the max utilization is ~208GB) .
The issue exists only for the model parallel version and scaling on 2 nodes using gRPC.
Also, if it helps, there are no issues when running the code for smaller image sizes (1, 512, 512, 512, 1). It's only when the image size is increased, for ex: (1,1024,1024,1024,1) , that this error is encountered:
terminate called after throwing an instance of 'std::bad_alloc'
what(): std::bad_alloc
Aborted

@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label Aug 3, 2019
@oanush oanush assigned ymodak and unassigned oanush Aug 6, 2019
@ymodak
Copy link
Contributor

ymodak commented Aug 7, 2019

Since you are able to execute code on parallel version by using lower image size and encounter error for higher image sizes looks like its a failure to allocate memory.

@ymodak ymodak added the stat:awaiting response Status - Awaiting response from author label Aug 8, 2019
@ymodak
Copy link
Contributor

ymodak commented Aug 13, 2019

Automatically closing due to lack of recent activity. Please update the issue when new information becomes available, and we will reopen the issue. Thanks!

@ymodak ymodak closed this as completed Aug 13, 2019
@schoi-habana
Copy link

schoi-habana commented Aug 17, 2019

@ymodak Can we reopen this issue? I took over this work from @karkadad and I'm facing this issue.

With the same size of the image (1,900,900,900,1), a single node can complete the run, but the parallel version with 2 nodes couldn't because of std::bad_alloc. I think this means that there's enough memory in 1 node and the error happens somewhere in the communication.

@ymodak ymodak reopened this Aug 19, 2019
@mrry
Copy link
Contributor

mrry commented Aug 20, 2019

Can you run the code under gdb and let us know what the stack trace is for the std::bad_alloc exception? It's likely that the communication layers are creating additional backing buffers, but seeing the actual source of the error will help us work out what in TensorFlow we can change to fix the problem.

@schoi-habana
Copy link

schoi-habana commented Aug 20, 2019

#0 0x00007ffff71281f7 in raise () from /lib64/libc.so.6
#1 0x00007ffff71298e8 in abort () from /lib64/libc.so.6
#2 0x00007fffed24a3df in __gnu_cxx::__verbose_terminate_handler () at /opt/conda/conda-bld/compilers_linux-64_1534514838838/work/.build/x86_64-conda_cos6-linux-gnu/src/gcc/libstdc++-v3/libsupc++/vterminate.cc:95
#3 0x00007fffed248b16 in __cxxabiv1::__terminate (handler=) at /opt/conda/conda-bld/compilers_linux-64_1534514838838/work/.build/x86_64-conda_cos6-linux-gnu/src/gcc/libstdc++-v3/libsupc++/eh_terminate.cc:47
#4 0x00007fffed248b4c in std::terminate () at /opt/conda/conda-bld/compilers_linux-64_1534514838838/work/.build/x86_64-conda_cos6-linux-gnu/src/gcc/libstdc++-v3/libsupc++/eh_terminate.cc:57
#5 0x00007fffed248d28 in __cxxabiv1::__cxa_throw (obj=0x7ffcc40c9e60, tinfo=0x7fffed2ddab0 , dest=0x7fffed2477c0 std::bad_alloc::~bad_alloc())
at /opt/conda/conda-bld/compilers_linux-64_1534514838838/work/.build/x86_64-conda_cos6-linux-gnu/src/gcc/libstdc++-v3/libsupc++/eh_throw.cc:95
#6 0x00007fffed2490c1 in operator new (sz=18446744069414584364)
at /opt/conda/conda-bld/compilers_linux-64_1534514838838/work/.build/x86_64-conda_cos6-linux-gnu/build/build-cc-gcc-final/x86_64-conda_cos6-linux-gnu/libstdc++-v3/include/bits/exception.h:63
#7 0x00007ffec6848b18 in tensorflow::grpc::EncodeTensorToByteBuffer(bool, tensorflow::Tensor const&, grpc::ByteBuffer*) ()
from /miniconda3/envs/tf2.0_2.7/lib/python2.7/site-packages/tensorflow/python/_pywrap_tensorflow_internal.so
#8 0x00007ffec684367f in tensorflow::GrpcWorker::GrpcRecvTensorAsync(tensorflow::CallOptions*, tensorflow::RecvTensorRequest const*, grpc::ByteBuffer*, std::function<void (tensorflow::Status const&)>)::{lambda(tensorflow::Status const&, tensorflow::Rendezvous::Args const&, tensorflow::Rendezvous::Args const&, tensorflow::Tensor const&, bool)#2}::operator()(tensorflow::Status const&, tensorflow::Rendezvous::Args const&, tensorflow::Rendezvous::Args const&, tensorflow::Tensor const&, bool) const [clone .constprop.369] () from /miniconda3/envs/tf2.0_2.7/lib/python2.7/site-packages/tensorflow/python/_pywrap_tensorflow_internal.so
#9 0x00007ffec6854742 in std::_Function_handler<void (tensorflow::Status const&, tensorflow::Rendezvous::Args const&, tensorflow::Rendezvous::Args const&, tensorflow::Tensor const&, bool), std::_Bind<tensorflow::BaseRendezvousMgr::RecvLocalAsync(long long, tensorflow::Rendezvous::ParsedKey const&, std::function<void (tensorflow::Status const&, tensorflow::Rendezvous::Args const&, tensorflow::Rendezvous::Args const&, tensorflow::Tensor const&, bool)>)::{lambda(std::function<void (tensorflow::Status const&, tensorflow::Rendezvous::Args const&, tensorflow::Rendezvous::Args const&, tensorflow::Tensor const&, bool)>, tensorflow::Status const&, tensorflow::Rendezvous::Args const&, tensorflow::Rendezvous::Args const&, tensorflow::Tensor const&, bool)#1} (std::function<void (tensorflow::Status const&, tensorflow::Rendezvous::Args const&, tensorflow::Rendezvous::Args const&, tensorflow::Tensor const&, bool)>, std::_Placeholder<1>, tensorflow::BaseRendezvousMgr::RecvLocalAsync(long long, tensorflow::Rendezvous::ParsedKey const&, std::function<void (tensorflow::Status const&, tensorflow::Rendezvous::Args const&, tensorflow::Rendezvous::Args const&, tensorflow::Tensor const&, bool)>)::{lambda(std::function<void (tensorflow::Status const&, tensorflow::Rendezvous::Args const&, tensorflow::Rendezvous::Args const&, tensorflow::Tensor const&, bool)>, tensorflow::Status const&, tensorflow::Rendezvous::Args const&, tensorflow::Rendezvous::Args const&, tensorflow::Tensor const&, bool)#1}<2>, tensorflow::BaseRendezvousMgr::RecvLocalAsync(long long, tensorflow::Rendezvous::ParsedKey const&, std::function<void (tensorflow::Status const&, tensorflow::Rendezvous::Args const&, tensorflow::Rendezvous::Args const&, tensorflow::Tensor const&, bool)>)::{lambda(std::function<void (tensorflow::Status const&, tensorflow::Rendezvous::Args const&, tensorflow::Rendezvous::Args const&, tensorflow::Tensor const&, bool)>, tensorflow::Status const&, tensorflow::Rendezvous::Args const&, tensorflow::Rendezvous::Args const&, tensorflow::Tensor const&, bool)#1}<3>, tensorflow::BaseRendezvousMgr::RecvLocalAsync(long long, tensorflow::Rendezvous::ParsedKey const&, std::function<void (tensorflow::Status const&, tensorflow::Rendezvous::Args const&, tensorflow::Rendezvous::Args const&, tensorflow::Tensor const&, bool)>)::{lambda(std::function<void (tensorflow::Status const&, tensorflow::Rendezvous::Args const&, tensorflow::Rendezvous::Args const&, tensorflow::Tensor const&, bool)>, tensorflow::Status const&, tensorflow::Rendezvous::Args const&, tensorflow::Rendezvous::Args const&, tensorflow::Tensor const&, bool)#1}<4>, tensorflow::BaseRendezvousMgr::RecvLocalAsync(long long, tensorflow::Rendezvous::ParsedKey const&, std::function<void (tensorflow::Status const&, tensorflow::Rendezvous::Args const&, tensorflow::Rendezvous::Args const&, tensorflow::Tensor const&, bool)>)::{lambda(std::function<void (tensorflow::Status const&, tensorflow::Rendezvous::Args const&, tensorflow::Rendezvous::Args const&, tensorflow::Tensor const&, bool)>, tensorflow::Status const&, tensorflow::Rendezvous::Args const&, tensorflow::Rendezvous::Args const&, tensorflow::Tensor const&, bool)#1}<5>)> >::_M_invoke(std::_Any_data const&, tensorflow::Status const&, tensorflow::Rendezvous::Args const&, tensorflow::Rendezvous::Args const&, tensorflow::Tensor const&, bool&&) ()
from /miniconda3/envs/tf2.0_2.7/lib/python2.7/site-packages/tensorflow/python/_pywrap_tensorflow_internal.so
#10 0x00007ffec2594eb0 in tensorflow::LocalRendezvousImpl::Send(tensorflow::Rendezvous::ParsedKey const&, tensorflow::Rendezvous::Args const&, tensorflow::Tensor const&, bool) ()
from /miniconda3/envs/tf2.0_2.7/lib/python2.7/site-packages/tensorflow/python/../libtensorflow_framework.so
#11 0x00007ffec6858f1e in tensorflow::BaseRemoteRendezvous::Send(tensorflow::Rendezvous::ParsedKey const&, tensorflow::Rendezvous::Args const&, tensorflow::Tensor const&, bool) ()
from /miniconda3/envs/tf2.0_2.7/lib/python2.7/site-packages/tensorflow/python/_pywrap_tensorflow_internal.so
#12 0x00007ffec6e9999a in tensorflow::SendOp::Compute(tensorflow::OpKernelContext*) ()
from /miniconda3/envs/tf2.0_2.7/lib/python2.7/site-packages/tensorflow/python/_pywrap_tensorflow_internal.so
#13 0x00007ffec2764d31 in tensorflow::(anonymous namespace)::ExecutorState::Process(tensorflow::(anonymous namespace)::ExecutorState::TaggedNode, long long) ()
from /miniconda3/envs/tf2.0_2.7/lib/python2.7/site-packages/tensorflow/python/../libtensorflow_framework.so
#14 0x00007ffec2758d65 in std::_Function_handler<void (), std::_Bind<void (tensorflow::(anonymous namespace)::ExecutorState::(tensorflow::(anonymous namespace)::ExecutorState, tensorflow::(anonymous namespace)::ExecutorState::TaggedNode, long long))(tensorflow::(anonymous namespace)::ExecutorState::TaggedNode, long long)> >::_M_invoke(std::_Any_data const&) ()
from /miniconda3/envs/tf2.0_2.7/lib/python2.7/site-packages/tensorflow/python/../libtensorflow_framework.so
#15 0x00007ffec2ba31d2 in Eigen::NonBlockingThreadPoolTempltensorflow::thread::EigenEnvironment::WorkerLoop(int) ()
from /miniconda3/envs/tf2.0_2.7/lib/python2.7/site-packages/tensorflow/python/../libtensorflow_framework.so
#16 0x00007ffec2ba0fb7 in std::_Function_handler<void (), tensorflow::thread::EigenEnvironment::CreateThread(std::function<void ()>)::{lambda()#1}>::_M_invoke(std::_Any_data const&) ()
from /miniconda3/envs/tf2.0_2.7/lib/python2.7/site-packages/tensorflow/python/../libtensorflow_framework.so
#17 0x00007fffed264408 in std::execute_native_thread_routine (__p=0x55555afb59d0) at /opt/conda/conda-bld/compilers_linux-64_1534514838838/work/.build/x86_64-conda_cos6-linux-gnu/src/gcc/libstdc++-v3/src/c++11/thread.cc:80
#18 0x00007ffff7bc6e25 in start_thread () from /lib64/libpthread.so.0
#19 0x00007ffff71eb34d in clone () from /lib64/libc.so.6

@mrry Thanks for jumping on this issue. Do you think this back trace is helpful? Please do let me know if you need any further info.

@mrry
Copy link
Contributor

mrry commented Aug 20, 2019

Thank you! This back trace is perfect, and I think it gives us enough information to find a fix. I'll ping this thread when we have more information.

@ymodak ymodak removed their assignment Aug 21, 2019
@schoi-habana
Copy link

yes, this is still an issue. @mrry Any updates?

@schoi-habana
Copy link

Still waiting. Any help would be appreciated.

@nwatab
Copy link

nwatab commented Oct 9, 2019

I saw the same error on CPU, and reducing batch_size addressed the problem.

@schoi-habana
Copy link

@asterisk37n Thanks for your comment. Unfortunately, we are facing this issue with batch_size 1 because the input size is really big.

@tensorflowbutler
Copy link
Member

We are closing this issue for now due to lack of activity. Please comment if this is still an issue for you. Thanks!

@yy2261
Copy link

yy2261 commented Jun 28, 2020

I am facing the same issue. When I tried to train my input embedding with shape [20K, 128] the training phase looked nice, but when the embedding shape was changed into [20M, 128] the std::bad_alloc error occured in the parameter server. At that time the memory used was less than 50% of total.
Are there any updates about this issue?

@tishion
Copy link

tishion commented May 26, 2021

#6 0x00007fffed2490c1 in operator new (sz=18446744069414584364)
The call stack showed it was requesting to alloc a memory buffer with 18446744069414584364 bytes (1844674406G). of cause it failed.

I think there is a bug in the method
tensorflow::grpc::EncodeTensorToByteBuffer(bool, tensorflow::Tensor const&, grpc::ByteBuffer*) () when calculate the required buffer size.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:apis Highlevel API related issues stat:awaiting response Status - Awaiting response from author TF 1.12 Issues related to TF 1.12 type:support Support issues
Projects
None yet
Development

No branches or pull requests

9 participants