-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CTCLoss with empty target doesn't work well #18215
Comments
Test data and script found here. Script reproduced below: import sys
import torch
import torch.nn
def run_test(ctc_type, data_path, zero_inf):
test_data = torch.load(data_path)
inp = test_data['inp']
inp_len = torch.tensor((inp.size(0),) * inp.size(1), dtype=torch.int32)
tar = test_data['tar']
tar_len = test_data['tar_len']
if ctc_type == 'cudnn':
inp = inp.cuda().detach()
inp_len = inp_len.cuda()
elif ctc_type == 'plain_cuda':
inp = inp.double().cuda().detach()
inp_len = inp_len.long().cuda()
tar = tar.long().cuda()
tar_len = tar_len.long().cuda()
else:
inp = inp.double().detach()
assert bool(torch.all((inp.exp().sum(dim=-1) - 1).abs() < 1e-5).item())
inp.requires_grad = True
loss_fn = torch.nn.CTCLoss(reduction='none', zero_infinity=zero_inf)
loss = loss_fn(inp, tar, inp_len, tar_len)
loss[-1].backward()
grad_sum = inp.grad.sum()
grad_abs_sum = inp.grad.abs().sum()
print(f'{ctc_type:11} '
f'tar_len: {tar_len.tolist()} '
f'loss: {loss[0].item():.10f}, {loss[1].item():.10f} '
f'grad_sum: {grad_sum.item():.10f} '
f'grad_abs_sum: {grad_abs_sum.item():.10f}')
if __name__ == '__main__':
print('python version:', sys.version)
print('torch version:', torch.__version__)
print('GPU:', torch.cuda.get_device_name())
for i in range(4):
data_path = f'ctc_test_data_{i}.pt'
print()
for zero_inf in [True, False]:
print(f'[{data_path}] zero_inf={zero_inf}')
run_test('cpu', data_path, zero_inf)
run_test('plain_cuda', data_path, zero_inf)
run_test('cudnn', data_path, zero_inf) Running the script we got:
... and then the script hangs. A few problems can be seen from the result (besides the problem mentioned aboved and the problem with CuDNN implementation as noted in #21680):
|
I would expect a nonzero loss when you have target length 0, but negative losses do indicate a problem. |
This looks like an uninitialized variables problem. Here are results from several different runs:
There is nothing stochastic about the tests being run. Yet for test_data_2, every time the CPU answer is different. For test_data_3 where all targets are empty the CPU answer seems better though (at least they are all the same). |
I opened a pull request (#21910) for a fix for the CPU implementation. After the fix, my test for the CPU case looks good now:
If anyone can tell me how I can make custom-built pytorch work in python, without having to recompile the whole thing after every change, I can look into the GPU case too ... |
Thank you! Awesome! CONTRIBUTING.md has some hints at to keep your sanity with re-building. I deviate from that at times, but the documented method is probably best. |
Summary: The bug is that when target_length == 0, there is no preceding BLANK state and the original implementation will lead to out of bound pointer access. Pull Request resolved: #21910 Differential Revision: D15960239 Pulled By: ezyang fbshipit-source-id: 7bbbecb7bf91842735c14265612c7e5049c4d9b3
Summary: Fixes: pytorch/pytorch#18215 at last! Also sprinkle tests... Pull Request resolved: pytorch/pytorch#23298 Differential Revision: D16582145 Pulled By: soumith fbshipit-source-id: bc8b1a629de0c2606e70a2218ccd135f4a9cdc5d
Summary: Fixes: pytorch#18215 at last! Also sprinkle tests... Pull Request resolved: pytorch#23298 Differential Revision: D16582145 Pulled By: soumith fbshipit-source-id: bc8b1a629de0c2606e70a2218ccd135f4a9cdc5d
🐛 Bug
CTCLoss doesn't provide the correct gradient when the target sequence is empty.
To Reproduce
also the default reduction doesn't play well with zero length.
Expected behavior
Compute the proper loss and gradient (which would point in the direction of less "blank").
Acknowledgement
This has been pointed out by Evgeni Kirov, thank you for tracking this down!
The text was updated successfully, but these errors were encountered: