Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA illegal error access error when running distributed mixed precision #40814

Closed
lminer opened this issue Jun 26, 2020 · 23 comments
Closed

CUDA illegal error access error when running distributed mixed precision #40814

lminer opened this issue Jun 26, 2020 · 23 comments
Assignees
Labels
comp:gpu GPU related issues stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author TF 2.0 Issues relating to TensorFlow 2.0 type:bug Bug

Comments

@lminer
Copy link

lminer commented Jun 26, 2020

Whenever I try to train a model using MirroredStrategy and mixed precision, at an indeterminate time, I get the following error:

./tensorflow/core/kernels/conv_2d_gpu.h:970] Non-OK-status: GpuLaunchKernel( SwapDimension1And2InTensor3UsingTiles<T, kNumThreads, kTileSize, kTileSize, conjugate>, total_tiles_count, kNumThreads, 0, d.stream(), input, input_dims, output) status: Internal: an illegal memory access was encountered
2020-06-25 00:45:27.788127: E tensorflow/stream_executor/cuda/cuda_event.cc:29] Error polling for event status: failed to query event: CUDA_ERROR_ILLEGAL_ADDRESS: an illegal memory access was encountered
2020-06-25 00:45:27.788208: F tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc:273] Unexpected Event status: 1

Unfortunately, I don't have a simply example to reproduce this and can't include my entire code. But maybe other people are having similar issues and can produce a better example.

I'm running tensorflow 2.2.0 on ubuntu 18.04. CUDA 10.1.243, CuDNN 7.6.5 using two RTX 2080 ti cards. I get the same error on a V100.

@ravikyram
Copy link
Contributor

@lminer

Provide the exact sequence of commands / steps that you executed before running into the problem.Thanks!

@ravikyram ravikyram added the stat:awaiting response Status - Awaiting response from author label Jun 26, 2020
@lminer
Copy link
Author

lminer commented Jun 26, 2020

I can't give you all the code, but I use the basic approach below:

policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
tf.keras.mixed_precision.experimental.set_policy(policy)
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
    logits = get_logits()
    model = tf.keras.Model(inputs, logits)


model.fit(X, y)  # X and y are datasets read from tfrecords

I run this from the official tensorflow docker container.

@lminer
Copy link
Author

lminer commented Jun 26, 2020

Is there a flag I can use to get a more detailed stack trace?

@lminer
Copy link
Author

lminer commented Jun 27, 2020

I ran this with cuda-memcheck and the error occured at an earlier point:

========= Internal Memcheck Error: Initialization failed
=========     Saved host backtrace up to driver entry point at error
=========     Host Frame:/usr/lib/x86_64-linux-gnu/libcuda.so.1 [0x1403fc]
=========     Host Frame:/usr/local/cuda-10.1/targets/x86_64-linux/lib/libcufft.so.10 [0x3d7e4a]
=========     Host Frame:/usr/local/cuda-10.1/targets/x86_64-linux/lib/libcufft.so.10 [0x3caf70]
=========     Host Frame:/usr/local/cuda-10.1/targets/x86_64-linux/lib/libcufft.so.10 [0x3d719a]
=========     Host Frame:/usr/local/cuda-10.1/targets/x86_64-linux/lib/libcufft.so.10 [0x3dae9f]
=========     Host Frame:/usr/local/cuda-10.1/targets/x86_64-linux/lib/libcufft.so.10 [0x3db60a]
=========     Host Frame:/usr/local/cuda-10.1/targets/x86_64-linux/lib/libcufft.so.10 [0x3cec3c]
=========     Host Frame:/usr/local/cuda-10.1/targets/x86_64-linux/lib/libcufft.so.10 [0x3bed7e]
=========     Host Frame:/usr/local/cuda-10.1/targets/x86_64-linux/lib/libcufft.so.10 [0x3f022c]
=========     Host Frame:/usr/local/cuda-10.1/targets/x86_64-linux/lib/libcufft.so.10 [0x379a2]
=========     Host Frame:/usr/local/cuda-10.1/targets/x86_64-linux/lib/libcufft.so.10 [0x37fa6]
=========     Host Frame:/usr/local/cuda-10.1/targets/x86_64-linux/lib/libcufft.so.10 [0x39af2]
=========     Host Frame:/usr/local/cuda-10.1/targets/x86_64-linux/lib/libcufft.so.10 [0x37fa6][48/2973]=========     Host Frame:/usr/local/cuda-10.1/targets/x86_64-linux/lib/libcufft.so.10 [0x39af2]
=========     Host Frame:/usr/local/cuda-10.1/targets/x86_64-linux/lib/libcufft.so.10 (cufftXtMakePlanMany + 0x63a) [0x4d0ca]
=========     Host Frame:/usr/local/cuda-10.1/targets/x86_64-linux/lib/libcufft.so.10 (cufftMakePlanMany64 + 0x157) [0x4e087]
=========     Host Frame:/usr/local/cuda-10.1/targets/x86_64-linux/lib/libcufft.so.10 (cufftMakePlanMany + 0x193) [0x4aaf3]
=========     Host Frame:/home/lminer/anaconda3/envs/separate2/lib/python3.7/site-packages/tensorflow/python/../libtensorflow_framework.so.2 (_ZN15stream_executor3gpu11CUDAFftPlan10InitializeEPNS0_11GpuExecutorEPNS_6StreamEiPyS6_yyS6_yyNS_3fft4TypeEiPNS_16ScratchAllocatorE + 0x1e6) [0x15949b6]
=========     Host Frame:/home/lminer/anaconda3/envs/separate2/lib/python3.7/site-packages/tensorflow/python/../libtensorflow_framework.so.2 (_ZN15stream_executor3gpu7CUDAFft37CreateBatchedPlanWithScratchAllocatorEPNS_6StreamEiPyS4_yyS4_yyNS_3fft4TypeEbiPNS_16ScratchAllocatorE + 0xca) [0x159614a]
=========     Host Frame:/home/lminer/anaconda3/envs/separate2/lib/python3.7/site-packages/tensorflow/python/_pywrap_tensorflow_internal.so (_ZN10tensorflow10FFTGPUBase5DoFFTEPNS_15OpKernelContextERKNS_6TensorEPyPS3_ + 0x3bf) [0x5157bff]
=========     Host Frame:/home/lminer/anaconda3/envs/separate2/lib/python3.7/site-packages/tensorflow/python/_pywrap_tensorflow_internal.so (_ZN10tensorflow7FFTBase7ComputeEPNS_15OpKernelContextE + 0x453) [0x50f2fa3]
=========     Host Frame:/home/lminer/anaconda3/envs/separate2/lib/python3.7/site-packages/tensorflow/python/../libtensorflow_framework.so.2 (_ZN10tensorflow13BaseGPUDevice7ComputeEPNS_8OpKernelEPNS_15OpKernelContextE + 0xe6) [0xf385f6]
=========     Host Frame:/home/lminer/anaconda3/envs/separate2/lib/python3.7/site-packages/tensorflow/python/_pywrap_tensorflow_internal.so (_ZN10tensorflow17KernelAndDeviceOp3RunEPNS_19ScopedStepContainerERKNS_15EagerKernelArgsEPSt6vectorINS_6TensorESaIS7_EEPNS_19CancellationManagerERKN4absl8optionalINS_25EagerRemoteFunctionParamsEEE + 0x64f) [0x378dd3f]
=========     Host Frame:/home/lminer/anaconda3/envs/separate2/lib/python3.7/site-packages/tensorflow/python/_pywrap_tensorflow_internal.so (_ZN10tensorflow17KernelAndDeviceOp3RunERKNS_15EagerKernelArgsEPSt6vectorINS_6TensorESaIS5_EEPNS_19CancellationManagerERKN4absl8optionalINS_25EagerRemoteFunctionParamsEEE + 0x2f) [0x378e23f]
=========     Host Frame:/home/lminer/anaconda3/envs/separate2/lib/python3.7/site-packages/tensorflow/python/_pywrap_tensorflow_internal.so (_ZN10tensorflow18EagerKernelExecuteEPNS_12EagerContextERKN4absl13InlinedVectorIPNS_12TensorHandleELm4ESaIS5_EEERKNS2_8optionalINS_25EagerRemoteFunctionParamsEEERKSt10unique_ptrINS_15KernelAndDeviceENS_4core15RefCountDeleterEEPNS_14GraphCollectorEPNS_19CancellationManagerENS2_4SpanIS5_EE + 0x60b) [0x376100b]
=========     Host Frame:/home/lminer/anaconda3/envs/separate2/lib/python3.7/site-packages/tensorflow/python/_pywrap_tensorflow_internal.so (_ZN10tensorflow11ExecuteNode3RunEv + 0x170) [0x3761be0]
=========     Host Frame:/home/lminer/anaconda3/envs/separate2/lib/python3.7/site-packages/tensorflow/python/_pywrap_tensorflow_internal.so (_ZN10tensorflow13EagerExecutor11SyncExecuteEPNS_9EagerNodeE + 0x1b$) [0x3789684]
=========     Host Frame:/home/lminer/anaconda3/envs/separate2/lib/python3.7/site-packages/tensorflow/python/_pywrap_tensorflow_internal.so [0x375bb8c]
=========     Host Frame:/home/lminer/anaconda3/envs/separate2/lib/python3.7/site-packages/tensorflow/python/_pywrap_tensorflow_internal.so (_ZN10tensorflow12EagerExecuteEPNS_14EagerOperationEPPNS_12TensorHandleEPi + 0x2d2) [0x375f1a2]
=========     Host Frame:/home/lminer/anaconda3/envs/separate2/lib/python3.7/site-packages/tensorflow/python/_pywrap_tensorflow_internal.so (_ZN10tensorflow18OperationInterface7ExecuteEPN4absl10FixedArrayISt10unique_ptrI29AbstractTensorHandleInterfaceSt14default_deleteIS4_EELm18446744073709551615ESaIS7_EEEPi +
0x6b) [0x332917b]
=========     Host Frame:/home/lminer/anaconda3/envs/separate2/lib/python3.7/site-packages/tensorflow/python/_pywrap_tensorflow_internal.so (TFE_Execute + 0x91) [0x3317e91]
=========     Host Frame:/home/lminer/anaconda3/envs/separate2/lib/python3.7/site-packages/tensorflow/python/_pywrap_tensorflow_internal.so (_Z24TFE_Py_FastPathExecute_CP7_object + 0x1af1) [0x2fc9931]
=========     Host Frame:/home/lminer/anaconda3/envs/separate2/lib/python3.7/site-packages/tensorflow/python/_pywrap_tfe.so [0x20ef7]
=========     Host Frame:/home/lminer/anaconda3/envs/separate2/lib/python3.7/site-packages/tensorflow/python/_pywrap_tfe.so [0x42c09]
=========     Host Frame:python (_PyMethodDef_RawFastCallKeywords + 0x264) [0x163c34]
=========     Host Frame:python (_PyCFunction_FastCallKeywords + 0x21) [0x163d51]
=========     Host Frame:python (_PyEval_EvalFrameDefault + 0x4ebc) [0x1d00ac]
=========     Host Frame:python (_PyEval_EvalCodeWithName + 0x2f9) [0x1131b9]
=========     Host Frame:python (_PyFunction_FastCallKeywords + 0x387) [0x163437]
=========     Host Frame:python (_PyEval_EvalFrameDefault + 0x14eb) [0x1cc6db]
=========     Host Frame:python (_PyEval_EvalCodeWithName + 0xab8) [0x113978]
=========     Host Frame:python (_PyFunction_FastCallKeywords + 0x387) [0x163437]
=========     Host Frame:python (_PyEval_EvalFrameDefault + 0x4b29) [0x1cfd19]
=========     Host Frame:python (_PyEval_EvalCodeWithName + 0x2f9) [0x1131b9]
=========     Host Frame:python (_PyFunction_FastCallKeywords + 0x387) [0x163437]
=========     Host Frame:python (_PyEval_EvalFrameDefault + 0x14eb) [0x1cc6db]
=========     Host Frame:python (_PyFunction_FastCallKeywords + 0xfb) [0x1631ab]
=========     Host Frame:python (_PyEval_EvalFrameDefault + 0x416) [0x1cb606]
=========     Host Frame:python (_PyEval_EvalCodeWithName + 0x2f9) [0x1131b9]
=========     Host Frame:python (_PyFunction_FastCallKeywords + 0x387) [0x163437]
=========     Host Frame:python (_PyEval_EvalFrameDefault + 0x14eb) [0x1cc6db]
=========     Host Frame:python (_PyEval_EvalCodeWithName + 0x2f9) [0x1131b9]
=========     Host Frame:python (_PyFunction_FastCallKeywords + 0x325) [0x1633d5]
=========     Host Frame:python (_PyEval_EvalFrameDefault + 0x416) [0x1cb606]
=========     Host Frame:python (_PyFunction_FastCallKeywords + 0xfb) [0x1631ab]
=========     Host Frame:python (_PyEval_EvalFrameDefault + 0x6a3) [0x1cb893]
=========
2020-06-27 16:22:40.585959: E tensorflow/stream_executor/cuda/cuda_event.cc:29] Error polling for event
status: failed to query event: CUDA_ERROR_ILLEGAL_INSTRUCTION: an illegal instruction was encountered
2020-06-27 16:22:40.586021: F tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc:273] Unexpected Event
status: 1

@ravikyram ravikyram added comp:gpu GPU related issues TF 2.0 Issues relating to TensorFlow 2.0 labels Jun 28, 2020
@ravikyram ravikyram assigned gowthamkpr and unassigned ravikyram Jun 28, 2020
@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label Jun 30, 2020
@gowthamkpr
Copy link

@lminer Can you please provide us the full reproducible code for us to reproduce this issue. We can't reproduce the issue currently as the code you provided is not enough. Thanks!

@gowthamkpr gowthamkpr added the stat:awaiting response Status - Awaiting response from author label Jul 6, 2020
@ben0it8
Copy link

ben0it8 commented Jul 8, 2020

@lminer did you figure what causes this error?

@lminer
Copy link
Author

lminer commented Jul 8, 2020

@ben0it8 no, and I'm having trouble creating a reproducible example that isn't just my entire code base. Do you have one?

@lminer
Copy link
Author

lminer commented Jul 9, 2020

@gowthamkpr, I have a reproducible example. This will crash if I run: TF_FORCE_GPU_ALLOW_GROWTH=true python fail.py.

Where fail.py is as below. Interestingly, it does not crash with I don't set TF_FORCE_GPU_ALLOW_GROWTH to true. I'm running tensorflow 2.2.0 on ubuntu 18.04. CUDA 10.1.243, CuDNN 7.6.5 using two RTX 2080 ti cards. I get the same error on a V100. This only happens if I enable mixed precision and Mirrored distribute strategy.

import numpy as np
import tensorflow as tf


class Conv2dLayer(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, strides=1, **kwargs):
        super().__init__(**kwargs)
        self.activation = tf.keras.layers.LeakyReLU()
        self.conv = tf.keras.layers.Conv2D(
            filters, kernel_size, strides=strides, padding="same", kernel_initializer="he_normal",
        )
        self.batch_norm = tf.keras.layers.BatchNormalization()
        self.filters = filters
        self.kernel_size = kernel_size
        self.strides = strides

    def call(self, inputs, **kwargs):
        x = self.conv(inputs)
        x = self.activation(x)
        x = self.batch_norm(x)
        return x

    def get_config(self):
        config = super().get_config()
        config["filters"] = self.filters
        config["kernel_size"] = self.kernel_size
        config["strides"] = self.strides
        return config

    def compute_output_shape(self, input_shape):
        return self.conv.compute_output_shape(input_shape)


class UpSampleLayer(tf.keras.layers.Layer):
    def __init__(self, filters, strides=2, **kwargs):
        super().__init__(**kwargs)
        self.dropout = tf.keras.layers.Dropout(0.5)
        self.activation = tf.keras.layers.LeakyReLU()
        self.upconv = tf.keras.layers.Conv2DTranspose(
            filters, 4, strides=strides, padding="same", kernel_initializer="he_normal"
        )
        self.batch_norm = tf.keras.layers.BatchNormalization()
        self.filters = filters
        self.strides = strides

    def call(self, inputs, **kwargs):
        x = self.upconv(inputs)
        x = self.batch_norm(x)
        x = self.dropout(x)
        return self.activation(x)

    def get_config(self):
        config = super().get_config()
        config["filters"] = self.filters
        config["strides"] = self.strides
        return config


class DownsampleBlock(tf.keras.layers.Layer):
    def __init__(self, filters, **kwargs):
        super().__init__(**kwargs)
        self.filters = filters

        self.conv1 = Conv2dLayer(filters, 4)
        self.conv2 = Conv2dLayer(filters, 4)
        self.downsample_conv = Conv2dLayer(filters, 4, strides=2)
        self.dropout = tf.keras.layers.Dropout(0.5)

    def call(self, inputs, **kwargs):
        x = self.conv1(inputs)
        x = self.conv2(x)
        x = self.downsample_conv(x)
        x = self.dropout(x)
        return x

    def get_config(self):
        config = super().get_config()
        config["filters"] = self.filters
        return config


class Unet(tf.keras.models.Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.mask = tf.keras.layers.Activation("relu")
        self.axis = -1
        self.downsample_blocks = []
        self.upsample_blocks = []

        n_maps_list = []

        for i in range(6):
            n_maps = 16 * 2 ** i
            n_maps_list.insert(0, n_maps)
            self.downsample_blocks.append(DownsampleBlock(n_maps))

        for i, n_maps in enumerate(n_maps_list[1:]):
            self.upsample_blocks.append(UpSampleLayer(n_maps, strides=2))
        self.upsample_blocks.append(UpSampleLayer(2, strides=2))

    def call(self, inputs, training=None, mask=None):
        skip_connections = []
        x = inputs
        for downsample_block in self.downsample_blocks:
            x = downsample_block(x)
            skip_connections.insert(0, x)

        x = self.upsample_blocks[0](x)  # no skip connection used for first block
        for upsample_block, h in zip(self.upsample_blocks[1:], skip_connections[1:]):
            x = upsample_block(tf.keras.layers.concatenate([x, h], axis=self.axis))
        return self.mask(x)


def train():
    BATCH_SIZE = 16
    WIDTH = 256
    HEIGHT = 512
    CHANNELS = 2
    policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
    tf.keras.mixed_precision.experimental.set_policy(policy)
    strategy = tf.distribute.MirroredStrategy()
    with strategy.scope():
        model = Unet()
        model.build(input_shape=(None, WIDTH, HEIGHT, CHANNELS))
        model.compile(optimizer="adam", loss="mean_absolute_error")

    examples = np.random.rand(BATCH_SIZE * 20, WIDTH, HEIGHT, CHANNELS)
    target = np.random.rand(BATCH_SIZE * 20, WIDTH, HEIGHT, CHANNELS)

    ds = tf.data.Dataset.from_tensor_slices((examples, target))
    ds = ds.repeat()
    ds = ds.batch(BATCH_SIZE)
    model.fit(ds, steps_per_epoch=1875, epochs=10)


train()

The error is as follows:

2020-07-09 12:29:38.319163: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
2020-07-09 12:29:39.314182: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcuda.so.1                                                           
[96/108]2020-07-09 12:29:39.348856: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1716] Found device 0 with properties:
pciBusID: 0000:0a:00.0 name: GeForce RTX 2080 Ti computeCapability: 7.5
coreClock: 1.635GHz coreCount: 68 deviceMemorySize: 10.76GiB deviceMemoryBandwidth: 573.69GiB/s
2020-07-09 12:29:39.349533: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1716] Found device 1 with properties:
pciBusID: 0000:42:00.0 name: GeForce RTX 2080 Ti computeCapability: 7.5
coreClock: 1.635GHz coreCount: 68 deviceMemorySize: 10.76GiB deviceMemoryBandwidth: 573.69GiB/s
2020-07-09 12:29:39.349555: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
2020-07-09 12:29:39.350837: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcublas.so.10
2020-07-09 12:29:39.351923: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcufft.so.10
2020-07-09 12:29:39.352101: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcurand.so.10
2020-07-09 12:29:39.353272: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcusolver.so.10
2020-07-09 12:29:39.353937: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcusparse.so.10
2020-07-09 12:29:39.356280: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudnn.so.7
2020-07-09 12:29:39.358938: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1858] Adding visible gpu
devices: 0, 1
2020-07-09 12:29:39.359275: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary
is optimized with oneAPI Deep Neural Network Library (oneDNN)to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2020-07-09 12:29:39.383097: I tensorflow/core/platform/profile_utils/cpu_utils.cc:104] CPU Frequency: 3499590000 Hz
2020-07-09 12:29:39.384148: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x55d32e968bc0 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2020-07-09 12:29:39.384186: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
2020-07-09 12:29:39.566332: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1716] Found device 0 with properties:
pciBusID: 0000:0a:00.0 name: GeForce RTX 2080 Ti computeCapability: 7.5
coreClock: 1.635GHz coreCount: 68 deviceMemorySize: 10.76GiB deviceMemoryBandwidth: 573.69GiB/s
2020-07-09 12:29:39.566949: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1716] Found device 1 with properties:
pciBusID: 0000:42:00.0 name: GeForce RTX 2080 Ti computeCapability: 7.5
coreClock: 1.635GHz coreCount: 68 deviceMemorySize: 10.76GiB deviceMemoryBandwidth: 573.69GiB/s
2020-07-09 12:29:39.566982: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
2020-07-09 12:29:39.567005: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcublas.so.10
2020-07-09 12:29:39.567016: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully
 opened dynamic library libcufft.so.10                                                          [48/108]2020-07-09 12:29:39.567027: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcurand.so.10
2020-07-09 12:29:39.567036: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcusolver.so.10
2020-07-09 12:29:39.567045: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcusparse.so.10
2020-07-09 12:29:39.567057: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudnn.so.7
2020-07-09 12:29:39.569326: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1858] Adding visible gpu
devices: 0, 1
2020-07-09 12:29:39.569507: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
2020-07-09 12:29:40.324951: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1257] Device interconnect StreamExecutor with strength 1 edge matrix:
2020-07-09 12:29:40.324995: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1263]      0 1
2020-07-09 12:29:40.325002: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1276] 0:   N N
2020-07-09 12:29:40.325007: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1276] 1:   N N
2020-07-09 12:29:40.327546: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:39] Overriding allow_growth setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.
2020-07-09 12:29:40.327582: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1402] Created TensorFlow
device (/job:localhost/replica:0/task:0/device:GPU:0 with 10066 MB memory) -> physical GPU (device: 0, name: GeForce RTX 2080 Ti, pci bus id: 0000:0a:00.0, compute capability: 7.5)
2020-07-09 12:29:40.329152: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:39] Overriding allow_growth setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.
2020-07-09 12:29:40.329169: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1402] Created TensorFlow
device (/job:localhost/replica:0/task:0/device:GPU:1 with 10064 MB memory) -> physical GPU (device: 1, name: GeForce RTX 2080 Ti, pci bus id: 0000:42:00.0, compute capability: 7.5)
2020-07-09 12:29:40.330787: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x55d35311c4c0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2020-07-09 12:29:40.330807: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): GeForce RTX 2080 Ti, Compute Capability 7.5
2020-07-09 12:29:40.330814: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (1): GeForce RTX 2080 Ti, Compute Capability 7.5
Epoch 1/10
WARNING:tensorflow:From /home/lminer/anaconda3/envs/separate2/lib/python3.7/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.
2020-07-09 12:30:03.105866: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudnn.so.7
2020-07-09 12:30:04.367518: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcublas.so.10
tensor1:73142:73386 [1] NCCL INFO Bootstrap : Using [0]enp8s0:10.10.2.159<0>
tensor1:73142:73386 [1] NCCL INFO NET/Plugin : No plugin found (libnccl-net.so), using internal implementation
tensor1:73142:73386 [1] NCCL INFO NET/IB : No device found.
tensor1:73142:73386 [1] NCCL INFO NET/Socket : Using [0]enp8s0:10.10.2.159<0>
tensor1:73142:73386 [1] NCCL INFO Using network Socket
NCCL version 2.7.3+cudaCUDA_MAJOR.CUDA_MINOR
tensor1:73142:73601 [0] NCCL INFO Channel 00/02 :    0   1
tensor1:73142:73601 [0] NCCL INFO Channel 01/02 :    0   1
tensor1:73142:73601 [0] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 8/8/64
tensor1:73142:73601 [0] NCCL INFO Trees [0] 1/-1/-1->0->-1|-1->0->1/-1/-1 [1] 1/-1/-1->0->-1|-1->0->1/-1/-1
tensor1:73142:73601 [0] NCCL INFO Setting affinity for GPU 0 to ffff
tensor1:73142:73602 [1] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 8/8/64
tensor1:73142:73602 [1] NCCL INFO Trees [0] -1/-1/-1->1->0|0->1->-1/-1/-1 [1] -1/-1/-1->1->0|0->1->-1/-1/-1
tensor1:73142:73602 [1] NCCL INFO Setting affinity for GPU 1 to ffff0000
tensor1:73142:73602 [1] NCCL INFO Could not enable P2P between dev 1(=42000) and dev 0(=a000)
tensor1:73142:73601 [0] NCCL INFO Could not enable P2P between dev 0(=a000) and dev 1(=42000)
tensor1:73142:73601 [0] NCCL INFO Could not enable P2P between dev 0(=a000) and dev 1(=42000)
tensor1:73142:73602 [1] NCCL INFO Could not enable P2P between dev 1(=42000) and dev 0(=a000)
tensor1:73142:73601 [0] NCCL INFO Channel 00 : 0[a000] -> 1[42000] via direct shared memory
tensor1:73142:73602 [1] NCCL INFO Channel 00 : 1[42000] -> 0[a000] via direct shared memory
tensor1:73142:73601 [0] NCCL INFO Could not enable P2P between dev 0(=a000) and dev 1(=42000)
tensor1:73142:73602 [1] NCCL INFO Could not enable P2P between dev 1(=42000) and dev 0(=a000)
tensor1:73142:73601 [0] NCCL INFO Could not enable P2P between dev 0(=a000) and dev 1(=42000)
tensor1:73142:73602 [1] NCCL INFO Could not enable P2P between dev 1(=42000) and dev 0(=a000)
tensor1:73142:73601 [0] NCCL INFO Channel 01 : 0[a000] -> 1[42000] via direct shared memory
tensor1:73142:73602 [1] NCCL INFO Channel 01 : 1[42000] -> 0[a000] via direct shared memory
tensor1:73142:73601 [0] NCCL INFO 2 coll channels, 2 p2p channels, 2 p2p channels per peer
tensor1:73142:73601 [0] NCCL INFO comm 0x55d409f64000 rank 0 nranks 2 cudaDev 0 busId a000 - Init COMPLETE
tensor1:73142:73602 [1] NCCL INFO 2 coll channels, 2 p2p channels, 2 p2p channels per peer
tensor1:73142:73602 [1] NCCL INFO comm 0x55d3fee0e000 rank 1 nranks 2 cudaDev 1 busId 42000 - Init COMPLETE
tensor1:73142:73598 [0] NCCL INFO Launch mode Group/CGMD
 256/1875 [===>..........................] - ETA: 2:30 - loss: 0.45892020-07-09 12:30:29.803542: E tensorflow/stream_executor/cuda/cuda_event.cc:29] Error polling for event status: failed to query event: CUDA_ERROR_ILLEGAL_ADDRESS: an illegal memory access was encountered
2020-07-09 12:30:29.803584: F tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc:220] Unexpected Event
status: 1
2020-07-09 12:30:29.803545: E tensorflow/stream_executor/cuda/cuda_event.cc:29] Error polling for event
status: failed to query event: CUDA_ERROR_ILLEGAL_ADDRESS: an illegal memory access was encountered
2020-07-09 12:30:29.803665: F tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc:220] Unexpected Event
status: 1
./fail.sh: line 7: 73142 Aborted                 (core dumped) NCCL_DEBUG=INFO TF_FORCE_GPU_ALLOW_GROWTH=true LD_LIBRARY_PATH=/usr/local/cuda-10.1/extras/CUPTI/lib64/:$LD_LIBRARY_PATH TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=99999999999999999999999999999999 LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4 /home/lminer/anaconda3/envs/separate2/bin/python fail.py

@gowthamkpr gowthamkpr assigned sanjoy and unassigned gowthamkpr Jul 10, 2020
@gowthamkpr gowthamkpr added stat:awaiting tensorflower Status - Awaiting response from tensorflower and removed stat:awaiting response Status - Awaiting response from author labels Jul 10, 2020
@sanjoy
Copy link
Contributor

sanjoy commented Jul 11, 2020

Unfortunately I was not able to reproduce this on P100 or Titan-V. Can you try running with CUDA_LAUNCH_BLOCKING=1? That will tell us which kernel cause the crash.

@lminer
Copy link
Author

lminer commented Jul 12, 2020

@sanjoy when I run it with that option, the model is loaded into the memory of both GPUs, but only one GPU actually sees any utilization and there is no crash.

@tensorflowbutler tensorflowbutler removed the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Jul 14, 2020
@sanjoy
Copy link
Contributor

sanjoy commented Jul 21, 2020

@dubey Have you seen similar issues before?

@dolhasz
Copy link

dolhasz commented Jul 24, 2020

I get the same error when running multi-gpu training with 2 or 3 RTX 2080Tis. My code is very similar to yours, with the exception that I do not use mixed precision.

@dubey
Copy link
Member

dubey commented Jul 24, 2020

@sanjoy No I haven't seen this issue before.

@dolhasz
Copy link

dolhasz commented Jul 25, 2020

Ok guys, I think I've found a solution, which seems to work for me.

I followed the instructions here: https://github.com/NVIDIA/framework-determinism - I enabled os.environ['TF_CUDNN_DETERMINISTIC']='1'

Then I fixed all the random seeds:

random.seed(42)
np.random.seed(42)
tf.random.set_seed(42)

The model has been running without a hitch for many epochs now. Seems that the non-determinism of some operations might cause these multi-gpu issues. Keep in mind - I don't fully understand WHY this works, just know that it does work for a similar problem. Do let me know if this helps.

Also, keep in mind the instructions here: https://github.com/NVIDIA/framework-determinism are a bit different from the ones I originally used (here; https://stackoverflow.com/questions/50744565/how-to-handle-non-determinism-when-training-on-a-gpu/62712389#62712389). Might be worth trying both sets.

@Kevin0624
Copy link

Kevin0624 commented Aug 31, 2020

I get the same error when running multi-gpu mixed precision training with 2 RTX 2080Ti . Any solutions ?

@ben0it8
Copy link

ben0it8 commented Aug 31, 2020

@dolhasz can you confirm that enforcing determinism indeed solved your issue? also, which layer/op do you suspect to cause the error?

@xiumingzhang
Copy link

I have a very similar setup as @lminer, and Luke's solution of NOT setting TF_FORCE_GPU_ALLOW_GROWTH works for me too. Thanks, @lminer!

@wenxichen
Copy link

Ok guys, I think I've found a solution, which seems to work for me.

I followed the instructions here: https://github.com/NVIDIA/framework-determinism - I enabled os.environ['TF_CUDNN_DETERMINISTIC']='1'

Then I fixed all the random seeds:

random.seed(42)
np.random.seed(42)
tf.random.set_seed(42)

The model has been running without a hitch for many epochs now. Seems that the non-determinism of some operations might cause these multi-gpu issues. Keep in mind - I don't fully understand WHY this works, just know that it does work for a similar problem. Do let me know if this helps.

Also, keep in mind the instructions here: https://github.com/NVIDIA/framework-determinism are a bit different from the ones I originally used (here; https://stackoverflow.com/questions/50744565/how-to-handle-non-determinism-when-training-on-a-gpu/62712389#62712389). Might be worth trying both sets.

This works for me. Thanks @dolhasz !
My set up is CUDA10.1, tensorflow 2.4.1, GTX1080Ti x2. Previously, I ran into this error whenever I use my newly added GPU to train. Interestingly, the newly added GPU has GPU id 0

@rajprasad001
Copy link

I faced a very similar issue.
I was loading the Tensorflow pretrained model at every function call.
I am not sure but for me this was the problem as the GPU was running out of memory.

Loading the model just once solved the issue.

@tilakrayal
Copy link
Contributor

@lminer ,
Can you please try this comment with the latest stable version v2.7 and let us know the issue still persists.Thanks!

@tilakrayal tilakrayal self-assigned this Dec 8, 2021
@tilakrayal tilakrayal added the stat:awaiting response Status - Awaiting response from author label Dec 8, 2021
@google-ml-butler
Copy link

This issue has been automatically marked as stale because it has no recent activity. It will be closed if no further activity occurs. Thank you.

@google-ml-butler google-ml-butler bot added the stale This label marks the issue/pr stale - to be closed automatically if no activity label Dec 15, 2021
@google-ml-butler
Copy link

Closing as stale. Please reopen if you'd like to work on this further.

@google-ml-butler
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:gpu GPU related issues stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author TF 2.0 Issues relating to TensorFlow 2.0 type:bug Bug
Projects
None yet
Development

No branches or pull requests