Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect gradient for ctc_loss on GPU when using logit_length #41280

Closed
pvanhaes opened this issue Jul 10, 2020 · 21 comments
Closed

Incorrect gradient for ctc_loss on GPU when using logit_length #41280

pvanhaes opened this issue Jul 10, 2020 · 21 comments
Assignees
Labels
comp:gpu GPU related issues comp:ops OPs related issues stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author TF 2.2 Issues related to TF 2.2 type:bug Bug

Comments

@pvanhaes
Copy link

pvanhaes commented Jul 10, 2020

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Debian 9.12 (TF2.2 DeepLearning image on GCP)
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device: N/A
  • TensorFlow installed from (source or binary): Preinstalled
  • TensorFlow version (use command below): v2.2.0-0-g2b96f36 2.2.0-dlenv
  • Python version: 3.7.6
  • Bazel version (if compiling from source):
  • GCC/Compiler version (if compiling from source):
  • CUDA/cuDNN version: V10.1.243
  • GPU model and memory: Nvidia tesla P100

Describe the current behavior

I have experienced inconsistencies in the computation of the gradient of tf.nn.ctc_loss between the CPU and GPU implementations when the logit_length argument contains something else than [num_frames]*batch_size.
Mostly I observe that the gradient relative to logits for the GPU implementation does not contain zeros after the end of the sequence as given by logit_length. Whereas this is the case for the CPU implementation which seems to work correctly.

I have noticed that the unit tests for this op do not test this case in particular (see https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/kernel_tests/ctc_loss_op_test.py#L993).

Standalone code to reproduce the issue

import tensorflow as tf

use_logits_lengths = True

batch_size = 8
num_labels = 27
max_labels_length = 32
max_logits_length = 128

labels = []
labels_lengths = []
logits = []
logits_lengths = []
for i in range(batch_size):
    labels_lengths.append(tf.random.uniform([], 1, max_labels_length, tf.int32))
    labels.extend(tf.random.uniform([labels_lengths[-1]], 0, num_labels-1, tf.int32))

    # I multiply label_length by 2 to make sure there are enough frames
    logits_lengths.append(tf.random.uniform([], labels_lengths[-1].numpy()*2, max_logits_length+1, tf.int32))

labels = tf.RaggedTensor.from_row_lengths(labels, labels_lengths).to_sparse()
labels_lengths = tf.concat(labels_lengths, 0)
logits = tf.random.uniform([batch_size, max_logits_length, num_labels])
logits_lengths = tf.concat(logits_lengths, 0)
logits_lengths_full = tf.constant([max_logits_length]*batch_size)

def ctc_compare_cpu_gpu(logits_lengths):
    print("logits_lengths", logits_lengths.numpy())

    with tf.device("/gpu:0"):
        with tf.GradientTape() as t:
            t.watch(logits)
            gpu_loss = tf.nn.ctc_loss(labels, logits, labels_lengths, logits_lengths, logits_time_major=False, blank_index=-1)
        gpu_grad = t.gradient(gpu_loss, [logits])[0]

    with tf.device("/cpu:0"):
        with tf.GradientTape() as t:
            t.watch(logits)
            cpu_loss = tf.nn.ctc_loss(labels, logits, labels_lengths, logits_lengths, logits_time_major=False, blank_index=-1)
        cpu_grad = t.gradient(cpu_loss, [logits])[0]

    print("Max loss error", tf.math.abs(gpu_loss - cpu_loss).numpy().max())
    print("Max grad error", tf.math.abs(gpu_grad - cpu_grad).numpy().max())
    print()
    return cpu_loss, gpu_loss, cpu_grad, gpu_grad

ctc_compare_cpu_gpu(logits_lengths_full)
ctc_compare_cpu_gpu(logits_lengths)

Output:

logits_lengths [128 128 128 128 128 128 128 128]
Max loss error 0.00012207031
Max grad error 0.00014734268

logits_lengths [ 70  86  22  74 112 121 103 123]
Max loss error 6.1035156e-05
Max grad error 0.9669469
@pvanhaes pvanhaes added the type:bug Bug label Jul 10, 2020
@Saduf2019 Saduf2019 added TF 2.2 Issues related to TF 2.2 comp:gpu GPU related issues labels Jul 10, 2020
@Saduf2019
Copy link
Contributor

@pvanhaes
Please let us know if this gist confirms your issue.

@Saduf2019 Saduf2019 added the stat:awaiting response Status - Awaiting response from author label Jul 10, 2020
@pvanhaes
Copy link
Author

@Saduf2019
Indeed this is displaying the same issue.

@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label Jul 12, 2020
@jvishnuvardhan jvishnuvardhan added the comp:ops OPs related issues label Jul 13, 2020
@jvishnuvardhan jvishnuvardhan added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Nov 25, 2020
@sanjoy
Copy link
Contributor

sanjoy commented Apr 14, 2021

@kaixih I believe you added the CTC loss GPU implementation in #32302, can you please take a look to see if this issue is related?

@rmothukuru rmothukuru added stat:awaiting response Status - Awaiting response from author and removed stat:awaiting tensorflower Status - Awaiting response from tensorflower labels May 4, 2021
@rmothukuru rmothukuru assigned sanjoy and rmothukuru and unassigned kaixih and jvishnuvardhan May 4, 2021
@rmothukuru
Copy link
Contributor

@pvanhaes,
Can you please respond to the above comment? Thanks!

@pvanhaes
Copy link
Author

pvanhaes commented May 4, 2021

As far as I can tell, yes? It seems to be related since the problem occurs only when using the cudnn ctc loss implementation.
I looked into pytorch to see if they have the same issue. They have test cases where logit_length is not uniform but have disabled them due to flakiness, however they seem to have correct gradients.

@rmothukuru rmothukuru removed their assignment May 4, 2021
@rmothukuru rmothukuru added stat:awaiting tensorflower Status - Awaiting response from tensorflower and removed stat:awaiting response Status - Awaiting response from author labels May 4, 2021
@kaixih
Copy link
Contributor

kaixih commented May 4, 2021

I tried the script on V100 a couple of times and I can see the flakiness:
Run 1:

Max loss error 0.00021362305
Max grad error 0.00022548437

logits_lengths [  8 108  86  90  66  53 110  97]
Max loss error 9.1552734e-05
Max grad error 0.00022548437

Run X:

Max loss error 9.1552734e-05
Max grad error 9.518862e-05

logits_lengths [ 61  88  79  14  42 112  95  60]
Max loss error 6.1035156e-05
Max grad error 0.55553436

Looking into it.

@kaixih
Copy link
Contributor

kaixih commented May 8, 2021

I think the issue is the cudnn doesn't zero out the grads if it exceeds the sequence length. So, if the grads array happens to contain some large numbers, you will encounter the reported high error. One workaround is to explicitly apply the mask like below, where I use the sequence_mask() to generate a mask based on logits_lengths and then zero out the unwanted gradients.

Also, I will file a bug towards our cudnn team to fix this issue.

import tensorflow as tf

tf.random.set_seed(1)
use_logits_lengths = True

batch_size = 8
num_labels = 27
max_labels_length = 32
max_logits_length = 128
#batch_size = 4
#num_labels = 6
#max_labels_length = 32
#max_logits_length = 64

labels = []
labels_lengths = []
logits = []
logits_lengths = []
for i in range(batch_size):
    labels_lengths.append(tf.random.uniform([], 1, max_labels_length, tf.int32))
    labels.extend(tf.random.uniform([labels_lengths[-1]], 0, num_labels-1, tf.int32))

    # I multiply label_length by 2 to make sure there are enough frames
    logits_lengths.append(tf.random.uniform([], labels_lengths[-1].numpy()*2, max_logits_length+1, tf.int32))

labels = tf.RaggedTensor.from_row_lengths(labels, labels_lengths).to_sparse()
labels_lengths = tf.concat(labels_lengths, 0)

logits_lengths = tf.concat(logits_lengths, 0)
logits_lengths_full = tf.constant([max_logits_length]*batch_size)

logits = tf.random.uniform([batch_size, max_logits_length, num_labels])

logit_mask = tf.sequence_mask(logits_lengths, max_logits_length,
                              tf.dtypes.float32)
logit_mask = tf.expand_dims(logit_mask, axis=2)
#print("XXX", logit_mask)

def ctc_compare_cpu_gpu(logits_lengths, mask=None):

    print("logits_lengths", logits_lengths.numpy())
    print("labels_lengths", labels_lengths.numpy())

    with tf.device("/gpu:0"):
        with tf.GradientTape() as t:
            t.watch(logits)
            gpu_loss = tf.nn.ctc_loss(labels, logits, labels_lengths, logits_lengths, logits_time_major=False, blank_index=-1)
        gpu_grad = t.gradient(gpu_loss, [logits])[0]
        if mask is not None:
          gpu_grad = gpu_grad * mask

    with tf.device("/cpu:0"):
        with tf.GradientTape() as t:
            t.watch(logits)
            cpu_loss = tf.nn.ctc_loss(labels, logits, labels_lengths, logits_lengths, logits_time_major=False, blank_index=-1)
        cpu_grad = t.gradient(cpu_loss, [logits])[0]

    print("Max loss error", tf.math.abs(gpu_loss - cpu_loss).numpy().max())
    print("Max grad error", tf.math.abs(gpu_grad - cpu_grad).numpy().max())
    return cpu_loss, gpu_loss, cpu_grad, gpu_grad

ctc_compare_cpu_gpu(logits_lengths_full)

ctc_compare_cpu_gpu(logits_lengths, mask=logit_mask)
#ctc_compare_cpu_gpu(logits_lengths)

@pvanhaes
Copy link
Author

pvanhaes commented May 8, 2021

Hi @kaixih
Indeed, I remember having tried that at the time of discovering this bug. However I still ended up using the cpu loss anyway due to a very unstable training. I wasnt able to explain why but I easily encountered NaN gradients and/or loss.

@kaixih
Copy link
Contributor

kaixih commented May 10, 2021

Not sure the NaN issue is caused by those "unused and not correctly initialized" gradients output by cudnn backend. Do you mean you still hit the NaN issue even after manually applying the masks over gradients returned by the ctc loss call on GPU?

@pvanhaes
Copy link
Author

What happened was that even when masking the gradients (using tf.where and not multiplication by zeros) I was unable to fully train a model and always ended up with NaN weights. Maybe I had too large gradients instead of NaNs, but in any case the behavior was different than with the CPU loss.
I'll see if I get the chance to try that again soon.

@kaixih
Copy link
Contributor

kaixih commented May 10, 2021

Thanks. It sounds like a numeric precision issue. If possible, could you also give a shot with "TF_CUDNN_DETERMINISTIC=1" which will force TF to use a deterministic CTC algorithm (However, this would require the label size under 256.). By default, TF uses a non-deterministic algorithm.

@AveryLiu
Copy link

AveryLiu commented May 24, 2021

Hi, I encountered the same problem where the CPU version of the ctc_loss works fine, but the GPU version gives the NAN. I've tried setting the TF_CUDNN_DETERMINISTIC=1 with bach_size=512 but the issue persists. However, setting batch size=128 or 256 seems to fix the issue.

@kaixih
Copy link
Contributor

kaixih commented May 24, 2021

Hi @AveryLiu That is interesting. The deterministic algorithm shouldn't affect the batch size. Are you working on the variable sequence lengths or fixed sequence lengths? If it is the variable sequence lengths, maybe it is a fluke when batch size is 128 or 256.

@AveryLiu
Copy link

Hi @AveryLiu That is interesting. The deterministic algorithm shouldn't affect the batch size. Are you working on the variable sequence lengths or fixed sequence lengths? If it is the variable sequence lengths, maybe it is a fluke when batch size is 128 or 256.

I'm using variable sequence length. Indeed setting batch_size to 128 or 256 does not solve the problem. It does not give me NANs, but the gradients seem incorrect and the loss stagnates (CPU version is ok). I am not sure how to apply the gradient masking mentioned above since I am using a fully capsulated Keras model. Another thing I observed is that the model can be trained with data where all label length is 1.

@sushreebarsa
Copy link
Contributor

Was able to replicate the issue in TF v2.5,please find the gist here..Thanks !

@gadagashwini
Copy link
Contributor

Hi @pvanhaes, I tried to replicate the issue with tf-nightly-2.11.0-dev20220801 version. I am getting different error, could you confirm the issue.Please find the gist. Thank you!

@gadagashwini gadagashwini added stat:awaiting response Status - Awaiting response from author and removed stat:awaiting tensorflower Status - Awaiting response from tensorflower labels Aug 2, 2022
@google-ml-butler
Copy link

This issue has been automatically marked as stale because it has no recent activity. It will be closed if no further activity occurs. Thank you.

@google-ml-butler google-ml-butler bot added the stale This label marks the issue/pr stale - to be closed automatically if no activity label Aug 9, 2022
@google-ml-butler
Copy link

Closing as stale. Please reopen if you'd like to work on this further.

@google-ml-butler
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

@gadagashwini
Copy link
Contributor

Hi @pvanhaes, You need set GPU growth to True. Below code snippet worked.

import tensorflow as tf
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
use_logits_lengths = True

batch_size = 8
num_labels = 27
max_labels_length = 32
max_logits_length = 128

labels = []
labels_lengths = []
logits = []
logits_lengths = []
for i in range(batch_size):
    labels_lengths.append(tf.random.uniform([], 1, max_labels_length, tf.int32))
    labels.extend(tf.random.uniform([labels_lengths[-1]], 0, num_labels-1, tf.int32))

    # I multiply label_length by 2 to make sure there are enough frames
    logits_lengths.append(tf.random.uniform([], labels_lengths[-1].numpy()*2, max_logits_length+1, tf.int32))

labels = tf.RaggedTensor.from_row_lengths(labels, labels_lengths).to_sparse()
labels_lengths = tf.concat(labels_lengths, 0)
logits = tf.random.uniform([batch_size, max_logits_length, num_labels])
logits_lengths = tf.concat(logits_lengths, 0)
logits_lengths_full = tf.constant([max_logits_length]*batch_size)

def ctc_compare_cpu_gpu(logits_lengths):
    print("logits_lengths", logits_lengths.numpy())

    with tf.device("/gpu:0"):
        with tf.GradientTape() as t:
            t.watch(logits)
            gpu_loss = tf.nn.ctc_loss(labels, logits, labels_lengths, logits_lengths, logits_time_major=False, blank_index=-1)
        gpu_grad = t.gradient(gpu_loss, [logits])[0]

    with tf.device("/cpu:0"):
        with tf.GradientTape() as t:
            t.watch(logits)
            cpu_loss = tf.nn.ctc_loss(labels, logits, labels_lengths, logits_lengths, logits_time_major=False, blank_index=-1)
        cpu_grad = t.gradient(cpu_loss, [logits])[0]

    print("Max loss error", tf.math.abs(gpu_loss - cpu_loss).numpy().max())
    print("Max grad error", tf.math.abs(gpu_grad - cpu_grad).numpy().max())
    print()
    return cpu_loss, gpu_loss, cpu_grad, gpu_grad

ctc_compare_cpu_gpu(logits_lengths_full)
ctc_compare_cpu_gpu(logits_lengths)

Output

logits_lengths [128 128 128 128 128 128 128 128]
Max loss error 9.1552734e-05
Max grad error 9.608269e-05

logits_lengths [126  94  77  82 108 109  86 116]
Max loss error 6.1035156e-05
Max grad error 0.89445907

(<tf.Tensor: shape=(8,), dtype=float32, numpy=
 array([361.34167, 247.30183, 191.76134, 240.63629, 264.55078, 274.72452,
        223.27164, 348.48816], dtype=float32)>,
 <tf.Tensor: shape=(8,), dtype=float32, numpy=
 array([361.34167, 247.3019 , 191.76135, 240.63626, 264.5508 , 274.7245 ,
        223.27162, 348.48813], dtype=float32)>,
 <tf.Tensor: shape=(8, 128, 27), dtype=float32, numpy=
 array([[[ 0.02538534,  0.05628331,  0.06113573, ...,  0.02440587,
           0.03686846, -0.74192816],
         [ 0.02119422,  0.0558586 ,  0.03859304, ...,  0.04988675,
           0.02495128, -0.57147044],
         [-0.01486497,  0.02755027,  0.04740864, ...,  0.05236494,
           0.02235019, -0.43538034],
         ...,
         [ 0.03877452,  0.05716185,  0.02642595, ...,  0.04564801,
           0.03043702, -0.8778108 ],
         [ 0.        ,  0.        ,  0.        , ...,  0.        ,
           0.        ,  0.        ],
         [ 0.        ,  0.        ,  0.        , ...,  0.        ,
           0.        ,  0.        ]],
 
        [[ 0.02517861,  0.02451471,  0.02614532, ...,  0.02736092,
           0.04163093, -0.60451835],
         [ 0.0282798 ,  0.03323127,  0.02769315, ...,  0.02504516,
           0.04138051, -0.4912519 ],
         [ 0.0264154 ,  0.05226399,  0.04047335, ...,  0.02718884,
           0.02691971, -0.34927228],
         ...,
         [ 0.        ,  0.        ,  0.        , ...,  0.        ,
           0.        ,  0.        ],
         [ 0.        ,  0.        ,  0.        , ...,  0.        ,
           0.        ,  0.        ],
         [ 0.        ,  0.        ,  0.        , ...,  0.        ,
           0.        ,  0.        ]],
 
        [[ 0.02442619, -0.46186063,  0.02670945, ...,  0.05209699,
           0.0565046 , -0.4266413 ],
         [ 0.04467556, -0.42586833,  0.05540843, ...,  0.04710019,
           0.05417868, -0.33000907],
         [ 0.05996322, -0.33492774,  0.04255395, ...,  0.05255849,
           0.04373875, -0.28247663],
         ...,
         [ 0.        ,  0.        ,  0.        , ...,  0.        ,
           0.        ,  0.        ],
         [ 0.        ,  0.        ,  0.        , ...,  0.        ,
           0.        ,  0.        ],
         [ 0.        ,  0.        ,  0.        , ...,  0.        ,
           0.        ,  0.        ]],
 
        ...,
 
        [[ 0.03733097,  0.04586417,  0.05981163, ...,  0.05459556,
           0.03668549, -0.6331417 ],
         [ 0.03733912,  0.0331376 ,  0.02435465, ...,  0.05194985,
           0.03071492, -0.45605576],
         [ 0.039736  ,  0.03019123,  0.04211826, ...,  0.04825637,
           0.03321901, -0.3475042 ],
         ...,
         [ 0.        ,  0.        ,  0.        , ...,  0.        ,
           0.        ,  0.        ],
         [ 0.        ,  0.        ,  0.        , ...,  0.        ,
           0.        ,  0.        ],
         [ 0.        ,  0.        ,  0.        , ...,  0.        ,
           0.        ,  0.        ]],
 
        [[ 0.02234682,  0.03743339,  0.02483355, ...,  0.04442452,
           0.05282348, -0.6971684 ],
         [ 0.03536738,  0.05733758,  0.05181414, ...,  0.05734098,
           0.03816937, -0.53875005],
         [ 0.04522794,  0.03714109,  0.04891694, ...,  0.04756849,
           0.02380415, -0.29162395],
         ...,
         [ 0.        ,  0.        ,  0.        , ...,  0.        ,
           0.        ,  0.        ],
         [ 0.        ,  0.        ,  0.        , ...,  0.        ,
           0.        ,  0.        ],
         [ 0.        ,  0.        ,  0.        , ...,  0.        ,
           0.        ,  0.        ]],
 
        [[-0.04192416,  0.0368094 ,  0.03975597, ...,  0.03538512,
           0.02966222, -0.85794085],
         [-0.12186902,  0.031844  ,  0.05078887, ...,  0.03326574,
           0.03228034, -0.7990319 ],
         [-0.20600453,  0.06134344,  0.02815821, ...,  0.0276357 ,
           0.0273573 , -0.7342564 ],
         ...,
         [ 0.        ,  0.        ,  0.        , ...,  0.        ,
           0.        ,  0.        ],
         [ 0.        ,  0.        ,  0.        , ...,  0.        ,
           0.        ,  0.        ],
         [ 0.        ,  0.        ,  0.        , ...,  0.        ,
           0.        ,  0.        ]]], dtype=float32)>,
 <tf.Tensor: shape=(8, 128, 27), dtype=float32, numpy=
 array([[[ 0.02538534,  0.05628331,  0.06113573, ...,  0.02440587,
           0.03686845, -0.7419282 ],
         [ 0.02119419,  0.0558586 ,  0.03859304, ...,  0.04988675,
           0.02495128, -0.5714753 ],
         [-0.01486521,  0.02755027,  0.04740864, ...,  0.05236493,
           0.02235018, -0.43537524],
         ...,
         [ 0.03877452,  0.05716185,  0.02642595, ...,  0.04564801,
           0.03043702, -0.8777982 ],
         [ 0.04676461,  0.05559509,  0.03807037, ...,  0.04894428,
           0.02546765,  0.03845723],
         [ 0.04474259,  0.04894428,  0.02546765, ...,  0.        ,
           0.        ,  0.04527081]],
 
        [[ 0.02517861,  0.02451471,  0.02614532, ...,  0.02736092,
           0.04163093, -0.60455686],
         [ 0.02827979,  0.03323127,  0.02769314, ...,  0.02504516,
           0.04138051, -0.49128848],
         [ 0.0264154 ,  0.05226399,  0.04047334, ...,  0.02718884,
           0.02691971, -0.3492974 ],
         ...,
         [ 0.03921108,  0.04774981,  0.03753965, ...,  0.02316646,
           0.02603946,  0.04136527],
         [ 0.02690112,  0.02592757,  0.04045248, ...,  0.02984707,
           0.02707733,  0.02219072],
         [ 0.        ,  0.        ,  0.        , ...,  0.        ,
           0.        ,  0.        ]],
 
        [[ 0.02442619, -0.46186855,  0.02670945, ...,  0.05209699,
           0.0565046 , -0.42664865],
         [ 0.04467556, -0.42587164,  0.05540843, ...,  0.04710019,
           0.05417868, -0.33001226],
         [ 0.05996322, -0.33493617,  0.04255395, ...,  0.05255849,
           0.04373875, -0.28247923],
         ...,
         [ 0.04819704,  0.03707863,  0.0289698 , ...,  0.04979009,
           0.03335232,  0.04069154],
         [ 0.03642191,  0.0523469 ,  0.04264299, ...,  0.0456151 ,
           0.03836384,  0.05885974],
         [ 0.        ,  0.        ,  0.        , ...,  0.        ,
           0.        ,  0.        ]],
 
        ...,
 
        [[ 0.03733097,  0.04586417,  0.05981163, ...,  0.05459555,
           0.03668549, -0.63312095],
         [ 0.03733912,  0.0331376 ,  0.02435465, ...,  0.05194985,
           0.03071492, -0.4560436 ],
         [ 0.039736  ,  0.03019122,  0.04211826, ...,  0.04825637,
           0.03321901, -0.3474995 ],
         ...,
         [ 0.04013193,  0.03555685,  0.03492175, ...,  0.02668005,
           0.05200338,  0.0318239 ],
         [ 0.03211536,  0.04654953,  0.04517095, ...,  0.03997811,
           0.03028551,  0.04155886],
         [ 0.        ,  0.        ,  0.        , ...,  0.        ,
           0.        ,  0.        ]],
 
        [[ 0.02234682,  0.03743339,  0.02483355, ...,  0.04442453,
           0.05282348, -0.6971574 ],
         [ 0.03536739,  0.05733759,  0.05181414, ...,  0.05734098,
           0.03816937, -0.53873646],
         [ 0.04522794,  0.03714109,  0.04891693, ...,  0.04756848,
           0.02380415, -0.29161617],
         ...,
         [ 0.04204896,  0.03710284,  0.04811438, ...,  0.05168063,
           0.0527318 ,  0.04406216],
         [ 0.03108212,  0.04031079,  0.04655052, ...,  0.03853708,
           0.05555028,  0.03550655],
         [ 0.        ,  0.        ,  0.        , ...,  0.        ,
           0.        ,  0.        ]],
 
        [[-0.0419212 ,  0.03680939,  0.03975596, ...,  0.03538512,
           0.02966222, -0.8579134 ],
         [-0.12186177,  0.031844  ,  0.05078886, ...,  0.0332658 ,
           0.03228033, -0.7990284 ],
         [-0.20600203,  0.06134345,  0.02815821, ...,  0.02763586,
           0.02735731, -0.734236  ],
         ...,
         [ 0.04372887,  0.03230944,  0.03206547, ...,  0.0366188 ,
           0.03010188,  0.02637064],
         [ 0.04893318,  0.0248911 ,  0.04472371, ...,  0.05552451,
           0.03580674,  0.03061763],
         [ 0.        ,  0.        ,  0.        , ...,  0.        ,
           0.        ,  0.        ]]], dtype=float32)>)

@Mddct
Copy link

Mddct commented Dec 15, 2022

Hi @AveryLiu That is interesting. The deterministic algorithm shouldn't affect the batch size. Are you working on the variable sequence lengths or fixed sequence lengths? If it is the variable sequence lengths, maybe it is a fluke when batch size is 128 or 256.

I'm using variable sequence length. Indeed setting batch_size to 128 or 256 does not solve the problem. It does not give me NANs, but the gradients seem incorrect and the loss stagnates (CPU version is ok). I am not sure how to apply the gradient masking mentioned above since I am using a fully capsulated Keras model. Another thing I observed is that the model can be trained with data where all label length is 1.

any update on this? tf2.11 seems gradients not equal when using sparsetensor

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:gpu GPU related issues comp:ops OPs related issues stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author TF 2.2 Issues related to TF 2.2 type:bug Bug
Projects
None yet
Development

No branches or pull requests