In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [31]:
logits: torch.Tensor = torch.randn(2, 3, 4) # 假设batch_size=2，seq_len=3，单词表数=4
print(logits, "#logits")
# cross entropy 要求batch_size, C, d1...
logits = logits.transpose(1,2)
# 每一个样本（大小为2的batch）的每个位置上都有一个word在单词表中的index
label = torch.randint(0, 4, (2,3)) # index ~ [0, C-1]
print(label, "# label")
# 计算交叉熵loss，每个句子的每个单词求一个交叉熵，所有单词加起来求一个平均， reduction='mean'
loss = F.cross_entropy(logits, label)
print(loss, "# mean loss")
loss =  F.cross_entropy(logits, label, reduction='none')
print(loss, "# loss reduction=none")

tensor([[[-1.3811e-01, -5.3651e-02,  1.6222e+00, -3.7813e-01],
         [-1.7034e+00,  2.9650e-01,  1.1526e-03, -2.5275e-01],
         [-3.9760e-01, -1.6542e-01,  2.4655e-01,  1.2284e-01]],

        [[-7.0630e-02,  8.6487e-01,  3.7207e-01,  5.2766e-01],
         [-7.1508e-01, -6.3658e-01, -2.1381e-01, -4.3843e-02],
         [-1.1693e+00,  2.1271e-02,  3.1061e-01, -9.8025e-01]]]) #logits
tensor([[0, 0, 1],
        [0, 0, 2]]) # label
tensor(1.8467) # mean loss
tensor([[2.1620, 2.8988, 1.5340],
        [1.9351, 1.7386, 0.8116]]) # loss reduction=none


In [33]:
# 出现了padding情况，第一个句子长度是2，因此需要mask掉
tgt_len = torch.Tensor([2,3]).to(torch.int32)
print(tgt_len, "# tgt_len")
mask = torch.cat([torch.unsqueeze(F.pad(torch.ones(L), (0, max(tgt_len) - L)), 0) for L in tgt_len],0)
print(mask, "# mask")
loss =  F.cross_entropy(logits, label, reduction='none') * mask
print(loss, "# mask loss")


tensor([2, 3], dtype=torch.int32) # tgt_len
tensor([[1., 1., 0.],
        [1., 1., 1.]]) # mask
tensor([[2.1620, 2.8988, 0.0000],
        [1.9351, 1.7386, 0.8116]]) # mask loss


In [36]:
# 使用ignore_index，默认是-100
label[0, 2] = -100
print(label, "# with ignore_index label")
loss =  F.cross_entropy(logits, label, reduction='none')
print(loss, "# loss ignore_index，默认是-100，自动mask操作")


tensor([[   0,    0, -100],
        [   0,    0,    2]]) # with ignore_index label
tensor([[2.1620, 2.8988, 0.0000],
        [1.9351, 1.7386, 0.8116]]) # loss ignore_index，默认是-100，自动mask操作
