Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CTCLoss with empty target doesn't work well #18215

Closed
t-vi opened this issue Mar 20, 2019 · 5 comments
Closed

CTCLoss with empty target doesn't work well #18215

t-vi opened this issue Mar 20, 2019 · 5 comments
Labels
module: bootcamp We plan to do a full writeup on the issue, and then get someone to do it for onboarding module: derivatives Related to derivatives of operators module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@t-vi
Copy link
Collaborator

t-vi commented Mar 20, 2019

🐛 Bug

CTCLoss doesn't provide the correct gradient when the target sequence is empty.

To Reproduce

import torch

probs = torch.randn(2, 2, 3, dtype=torch.double).log_softmax(-1).requires_grad_()
labels = torch.tensor([1, 2])
label_sizes = [2, 0]
sizes = [2, 2]
loss = torch.nn.functional.ctc_loss(probs, labels, sizes, label_sizes, reduction='sum', zero_infinity=True)
loss2 = torch.nn.functional.ctc_loss(probs, labels, sizes, label_sizes, reduction='none', zero_infinity=True)
grad, = torch.autograd.grad(loss, probs)

probs_gpu = probs.detach().cuda().requires_grad_()
loss_gpu = torch.nn.functional.ctc_loss(probs_gpu, labels.cuda(), sizes, label_sizes, reduction='sum', zero_infinity=True)
loss2_gpu = torch.nn.functional.ctc_loss(probs_gpu, labels.cuda(), sizes, label_sizes, reduction='none', zero_infinity=True)
grad_gpu, = torch.autograd.grad(loss_gpu, probs_gpu)

print('loss:', loss, loss_gpu)
print('loss2:', loss2, loss2_gpu)
print('grad:', grad, "\n", grad_gpu)

print("grad_check cpu: ",
      torch.autograd.gradcheck(lambda logits: torch.nn.functional.ctc_loss(logits.log_softmax(-1), labels, sizes, label_sizes, reduction='sum', zero_infinity=True), (torch.randn(2, 2, 3, dtype=torch.double, requires_grad=True),), raise_exception=False))
print("grad_check gpu: ",
      torch.autograd.gradcheck(lambda logits: torch.nn.functional.ctc_loss(logits.log_softmax(-1), labels.cuda(), sizes, label_sizes, reduction='sum', zero_infinity=True), (torch.randn(2, 2, 3, dtype=torch.double, device='cuda', requires_grad=True),), raise_exception=False))

also the default reduction doesn't play well with zero length.

Expected behavior

Compute the proper loss and gradient (which would point in the direction of less "blank").

Acknowledgement

This has been pointed out by Evgeni Kirov, thank you for tracking this down!

@fmassa fmassa added bug todo Not as important as medium or high priority tasks, but we will work on these. labels Mar 25, 2019
@ezyang ezyang added module: nn Related to torch.nn module: derivatives Related to derivatives of operators triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: bootcamp We plan to do a full writeup on the issue, and then get someone to do it for onboarding and removed bug todo Not as important as medium or high priority tasks, but we will work on these. labels Apr 6, 2019
@zh217
Copy link
Contributor

zh217 commented Jun 18, 2019

Test data and script found here. Script reproduced below:

import sys

import torch
import torch.nn

def run_test(ctc_type, data_path, zero_inf):
    test_data = torch.load(data_path)
    inp = test_data['inp']
    inp_len = torch.tensor((inp.size(0),) * inp.size(1), dtype=torch.int32)
    tar = test_data['tar']
    tar_len = test_data['tar_len']

    if ctc_type == 'cudnn':
        inp = inp.cuda().detach()
        inp_len = inp_len.cuda()
    elif ctc_type == 'plain_cuda':
        inp = inp.double().cuda().detach()
        inp_len = inp_len.long().cuda()
        tar = tar.long().cuda()
        tar_len = tar_len.long().cuda()
    else:
        inp = inp.double().detach()

    assert bool(torch.all((inp.exp().sum(dim=-1) - 1).abs() < 1e-5).item())
    inp.requires_grad = True

    loss_fn = torch.nn.CTCLoss(reduction='none', zero_infinity=zero_inf)

    loss = loss_fn(inp, tar, inp_len, tar_len)

    loss[-1].backward()

    grad_sum = inp.grad.sum()
    grad_abs_sum = inp.grad.abs().sum()
    print(f'{ctc_type:11} '
          f'tar_len: {tar_len.tolist()}  '
          f'loss:  {loss[0].item():.10f}, {loss[1].item():.10f}  '
          f'grad_sum: {grad_sum.item():.10f}  '
          f'grad_abs_sum: {grad_abs_sum.item():.10f}')


if __name__ == '__main__':
    print('python version:', sys.version)
    print('torch version:', torch.__version__)
    print('GPU:', torch.cuda.get_device_name())
    for i in range(4):
        data_path = f'ctc_test_data_{i}.pt'
        print()
        for zero_inf in [True, False]:
            print(f'[{data_path}] zero_inf={zero_inf}')
            run_test('cpu', data_path, zero_inf)
            run_test('plain_cuda', data_path, zero_inf)
            run_test('cudnn', data_path, zero_inf)

Running the script we got:

python version: 3.7.2 (default, Jan 10 2019, 07:33:16) 
[GCC 7.3.0]
torch version: 1.2.0.dev20190617
GPU: GeForce RTX 2080 Ti

[ctc_test_data_0.pt] zero_inf=True
cpu         tar_len: [9, 20]  loss:  30.5302109565, 203.0446978194  grad_sum: 0.0000362272  grad_abs_sum: 39.9975735736
plain_cuda  tar_len: [9, 20]  loss:  30.5302109565, 203.0446978194  grad_sum: 0.0000362272  grad_abs_sum: 39.9975735736
cudnn       tar_len: [9, 20]  loss:  30.5302200317, 203.0447387695  grad_sum: nan  grad_abs_sum: nan
[ctc_test_data_0.pt] zero_inf=False
cpu         tar_len: [9, 20]  loss:  30.5302109565, 203.0446978194  grad_sum: 0.0000362272  grad_abs_sum: 39.9975735736
plain_cuda  tar_len: [9, 20]  loss:  30.5302109565, 203.0446978194  grad_sum: 0.0000362272  grad_abs_sum: 39.9975735736
cudnn       tar_len: [9, 20]  loss:  30.5302200317, 203.0447387695  grad_sum: nan  grad_abs_sum: nan

[ctc_test_data_1.pt] zero_inf=True
cpu         tar_len: [9, 16]  loss:  70.8375443172, 47.4589806627  grad_sum: 0.0000198770  grad_abs_sum: 29.1309699646
plain_cuda  tar_len: [9, 16]  loss:  70.8375443172, 47.4589806627  grad_sum: 0.0000198770  grad_abs_sum: 29.1309699646
cudnn       tar_len: [9, 16]  loss:  70.8375549316, 47.4589996338  grad_sum: nan  grad_abs_sum: nan
[ctc_test_data_1.pt] zero_inf=False
cpu         tar_len: [9, 16]  loss:  70.8375443172, 47.4589806627  grad_sum: 0.0000198770  grad_abs_sum: 29.1309699646
plain_cuda  tar_len: [9, 16]  loss:  70.8375443172, 47.4589806627  grad_sum: 0.0000198770  grad_abs_sum: 29.1309699646
cudnn       tar_len: [9, 16]  loss:  70.8375549316, 47.4589996338  grad_sum: nan  grad_abs_sum: nan

[ctc_test_data_2.pt] zero_inf=True
cpu         tar_len: [0, 13]  loss:  -283370451696837273446472351913216827957500774163770208211728206452598691021902455185383562402427213323298073623160868620504739983843715661191102982207577537963450571928728162671698356893809834924626107503807329543018887885814664393278625591652199348165283200514350389395456.0000000000, 51.7903951958  grad_sum: 0.0000134892  grad_abs_sum: 25.1179445297
plain_cuda  tar_len: [0, 13]  loss:  0.0000000000, 51.7903951958  grad_sum: 0.0000134892  grad_abs_sum: 25.1179445297
cudnn       tar_len: [0, 13]  loss:  0.2770059705, 51.7903862000  grad_sum: nan  grad_abs_sum: nan
[ctc_test_data_2.pt] zero_inf=False
cpu         tar_len: [0, 13]  loss:  -394896264825696949569105967955929059680720400074394644640100835601224760994490590630075362547316590191301792896509674486620213409344982968744119015528623767552.0000000000, 51.7903951958  grad_sum: 0.0000134892  grad_abs_sum: 25.1179445297
plain_cuda  tar_len: [0, 13]  loss:  inf, 51.7903951958  grad_sum: nan  grad_abs_sum: nan
cudnn       tar_len: [0, 13]  loss:  0.2770059705, 51.7903862000  grad_sum: nan  grad_abs_sum: nan

[ctc_test_data_3.pt] zero_inf=True
cpu         tar_len: [0, 0]  loss:  -0.6079427618, -0.2247722130  grad_sum: 62.4999991092  grad_abs_sum: 62.4999991092

... and then the script hangs.

A few problems can be seen from the result (besides the problem mentioned aboved and the problem with CuDNN implementation as noted in #21680):

  • the CPU implementation does not respect zero_infinity when target is empty (see the huge loss in test 2 with zero_info=True);
  • the non-CuDNN CUDA implementation will hang when all targets are of length 0.

@t-vi
Copy link
Collaborator Author

t-vi commented Jun 18, 2019

I would expect a nonzero loss when you have target length 0, but negative losses do indicate a problem.

@zh217
Copy link
Contributor

zh217 commented Jun 18, 2019

This looks like an uninitialized variables problem.

Here are results from several different runs:

[ctc_test_data_2.pt] zero_inf=True
cpu         tar_len: [0, 13]  loss:  0.2770090103, 51.7903951958  grad_sum: 0.0000134892  grad_abs_sum: 25.1179445297
plain_cuda  tar_len: [0, 13]  loss:  0.0000000000, 51.7903951958  grad_sum: 0.0000134892  grad_abs_sum: 25.1179445297
cudnn       tar_len: [0, 13]  loss:  0.2770059705, 51.7903862000  grad_sum: nan  grad_abs_sum: nan
[ctc_test_data_2.pt] zero_inf=False
cpu         tar_len: [0, 13]  loss:  0.2770090103, 51.7903951958  grad_sum: 0.0000134892  grad_abs_sum: 25.1179445297
plain_cuda  tar_len: [0, 13]  loss:  inf, 51.7903951958  grad_sum: nan  grad_abs_sum: nan
cudnn       tar_len: [0, 13]  loss:  0.2770059705, 51.7903862000  grad_sum: nan  grad_abs_sum: nan

[ctc_test_data_3.pt] zero_inf=True
cpu         tar_len: [0, 0]  loss:  -0.6079427618, -0.2247722130  grad_sum: 62.4999991092  grad_abs_sum: 62.4999991092
[ctc_test_data_2.pt] zero_inf=True
cpu         tar_len: [0, 13]  loss:  -0.5642039131, 51.7903951958  grad_sum: 0.0000134892  grad_abs_sum: 25.1179445297
plain_cuda  tar_len: [0, 13]  loss:  0.0000000000, 51.7903951958  grad_sum: 0.0000134892  grad_abs_sum: 25.1179445297
cudnn       tar_len: [0, 13]  loss:  0.2770059705, 51.7903862000  grad_sum: nan  grad_abs_sum: nan
[ctc_test_data_2.pt] zero_inf=False
cpu         tar_len: [0, 13]  loss:  0.2770090103, 51.7903951958  grad_sum: 0.0000134892  grad_abs_sum: 25.1179445297
plain_cuda  tar_len: [0, 13]  loss:  inf, 51.7903951958  grad_sum: nan  grad_abs_sum: nan
cudnn       tar_len: [0, 13]  loss:  0.2770059705, 51.7903862000  grad_sum: nan  grad_abs_sum: nan

[ctc_test_data_3.pt] zero_inf=True
cpu         tar_len: [0, 0]  loss:  -0.6079427618, -0.2247722130  grad_sum: 62.4999991092  grad_abs_sum: 62.4999991092
[ctc_test_data_2.pt] zero_inf=True
cpu         tar_len: [0, 13]  loss:  -283370451696837273446472351913216827957500774163770208211728206452598691021902455185383562402427213323298073623160868620504739983843715661191102982207577537963450571928728162671698356893809834924626107503807329543018887885814664393278625591652199348165283200514350389395456.0000000000, 51.7903951958  grad_sum: 0.0000134892  grad_abs_sum: 25.1179445297
plain_cuda  tar_len: [0, 13]  loss:  0.0000000000, 51.7903951958  grad_sum: 0.0000134892  grad_abs_sum: 25.1179445297
cudnn       tar_len: [0, 13]  loss:  0.2770059705, 51.7903862000  grad_sum: nan  grad_abs_sum: nan
[ctc_test_data_2.pt] zero_inf=False
cpu         tar_len: [0, 13]  loss:  -394896264825696949569105967955929059680720400074394644640100835601224760994490590630075362547316590191301792896509674486620213409344982968744119015528623767552.0000000000, 51.7903951958  grad_sum: 0.0000134892  grad_abs_sum: 25.1179445297
plain_cuda  tar_len: [0, 13]  loss:  inf, 51.7903951958  grad_sum: nan  grad_abs_sum: nan
cudnn       tar_len: [0, 13]  loss:  0.2770059705, 51.7903862000  grad_sum: nan  grad_abs_sum: nan

[ctc_test_data_3.pt] zero_inf=True
cpu         tar_len: [0, 0]  loss:  -0.6079427618, -0.2247722130  grad_sum: 62.4999991092  grad_abs_sum: 62.4999991092

There is nothing stochastic about the tests being run. Yet for test_data_2, every time the CPU answer is different.

For test_data_3 where all targets are empty the CPU answer seems better though (at least they are all the same).

zh217 added a commit to zh217/pytorch that referenced this issue Jun 18, 2019
@zh217
Copy link
Contributor

zh217 commented Jun 18, 2019

I opened a pull request (#21910) for a fix for the CPU implementation. After the fix, my test for the CPU case looks good now:

python version: 3.7.2 (default, Jan 10 2019, 07:33:16) 
[GCC 7.3.0]
torch version: 1.2.0a0+08a0ac8
GPU: GeForce RTX 2080 Ti

[ctc_test_data_0.pt] zero_inf=True
cpu         tar_len: [9, 20]  loss:  30.5302109565, 203.0446978194  grad_sum: 0.0000362272  grad_abs_sum: 39.9975735736
[ctc_test_data_0.pt] zero_inf=False
cpu         tar_len: [9, 20]  loss:  30.5302109565, 203.0446978194  grad_sum: 0.0000362272  grad_abs_sum: 39.9975735736

[ctc_test_data_1.pt] zero_inf=True
cpu         tar_len: [9, 16]  loss:  70.8375443172, 47.4589806627  grad_sum: 0.0000198770  grad_abs_sum: 29.1309699646
[ctc_test_data_1.pt] zero_inf=False
cpu         tar_len: [9, 16]  loss:  70.8375443172, 47.4589806627  grad_sum: 0.0000198770  grad_abs_sum: 29.1309699646

[ctc_test_data_2.pt] zero_inf=True
cpu         tar_len: [0, 13]  loss:  0.2770090103, 51.7903951958  grad_sum: 0.0000134892  grad_abs_sum: 25.1179445297
[ctc_test_data_2.pt] zero_inf=False
cpu         tar_len: [0, 13]  loss:  0.2770090103, 51.7903951958  grad_sum: 0.0000134892  grad_abs_sum: 25.1179445297

[ctc_test_data_3.pt] zero_inf=True
cpu         tar_len: [0, 0]  loss:  0.0852044187, 0.4683749676  grad_sum: -0.0000008908  grad_abs_sum: 0.9290851062
[ctc_test_data_3.pt] zero_inf=False
cpu         tar_len: [0, 0]  loss:  0.0852044187, 0.4683749676  grad_sum: -0.0000008908  grad_abs_sum: 0.9290851062

If anyone can tell me how I can make custom-built pytorch work in python, without having to recompile the whole thing after every change, I can look into the GPU case too ...

@t-vi
Copy link
Collaborator Author

t-vi commented Jun 18, 2019

Thank you! Awesome! CONTRIBUTING.md has some hints at to keep your sanity with re-building. I deviate from that at times, but the documented method is probably best.

facebook-github-bot pushed a commit that referenced this issue Jun 24, 2019
Summary:
The bug is that when target_length == 0, there is no preceding BLANK state and the original implementation will lead to out of bound pointer access.
Pull Request resolved: #21910

Differential Revision: D15960239

Pulled By: ezyang

fbshipit-source-id: 7bbbecb7bf91842735c14265612c7e5049c4d9b3
t-vi added a commit to t-vi/pytorch that referenced this issue Jul 24, 2019
zdevito pushed a commit to zdevito/ATen that referenced this issue Jul 31, 2019
Summary:
Fixes: pytorch/pytorch#18215 at last!

Also sprinkle tests...
Pull Request resolved: pytorch/pytorch#23298

Differential Revision: D16582145

Pulled By: soumith

fbshipit-source-id: bc8b1a629de0c2606e70a2218ccd135f4a9cdc5d
ssnl pushed a commit to ssnl/pytorch that referenced this issue Aug 2, 2019
Summary:
Fixes: pytorch#18215 at last!

Also sprinkle tests...
Pull Request resolved: pytorch#23298

Differential Revision: D16582145

Pulled By: soumith

fbshipit-source-id: bc8b1a629de0c2606e70a2218ccd135f4a9cdc5d
soumith pushed a commit that referenced this issue Aug 2, 2019
Summary:
Fixes: #18215 at last!

Also sprinkle tests...
Pull Request resolved: #23298

Differential Revision: D16582145

Pulled By: soumith

fbshipit-source-id: bc8b1a629de0c2606e70a2218ccd135f4a9cdc5d
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: bootcamp We plan to do a full writeup on the issue, and then get someone to do it for onboarding module: derivatives Related to derivatives of operators module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants