This error happens to multiple loss functions.
import torch
import torch_xla
import torch.nn as nn
import pdb
device = 'xla:0'
# device='cpu'
# m = nn.LSTM(3, 4, bidirectional=True, num_layers=2).to(device)
m = nn.CTCLoss(14)
a = torch.rand(50, 3, 15).log_softmax(2).to(device)
b = torch.randint(0, 14, (3, 30), device=device)
loss = m(a, b, [50, 50, 50], [30, 25, 20])