# KL Divergence Loss

In [1]:
import torch
from torch import nn
import torch.nn.functional as F

In [14]:
kl_loss = nn.KLDivLoss(reduction="batchmean")

In [30]:
input = torch.log(torch.FloatTensor([[1 / 3, 1 / 3, 1 / 3]]))
print(input.shape, input)
target = torch.FloatTensor([[9 / 25, 12 / 25, 4 / 25]])
print(target.shape)

torch.Size([1, 3]) tensor([[-1.0986, -1.0986, -1.0986]])
torch.Size([1, 3])


In [31]:
kl_loss(input, target)

tensor(0.0853)

# KL Loss & CE Loss

In [32]:
kl_loss = nn.KLDivLoss(reduction="batchmean")
ce_loss = nn.CrossEntropyLoss()

In [33]:
input = torch.randn(3, 5, requires_grad=True)
input

tensor([[-0.2886, -1.1185, -0.3406, -0.2233, -1.4214],
        [ 0.7044, -0.1608,  1.4654,  1.3905,  0.5196],
        [-0.8406,  1.6600,  1.9194,  0.8524,  0.8311]], requires_grad=True)

In [35]:
target = torch.empty(3, dtype=torch.long).random_(5)
target

tensor([2, 4, 1])

In [43]:
ce_value = ce_loss(input, target)
ce_value

tensor(1.5334, grad_fn=<NllLossBackward0>)

In [44]:
input_log_softmax = F.log_softmax(input, dim=1)
input_log_softmax

tensor([[-1.3284, -2.1583, -1.3804, -1.2631, -2.4612],
        [-1.8530, -2.7181, -1.0920, -1.1668, -2.0378],
        [-3.6826, -1.1819, -0.9225, -1.9895, -2.0108]],
       grad_fn=<LogSoftmaxBackward0>)

In [45]:
target_shaped = torch.tensor([
    [0, 0, 1, 0, 0],
    [0, 0, 0, 0, 1],
    [0, 1, 0, 0, 0]
], dtype=torch.float)

In [46]:
ce_loss(input, target_shaped)

tensor(1.5334, grad_fn=<DivBackward1>)

In [47]:
kl_loss(input_log_softmax, target_shaped)

tensor(1.5334, grad_fn=<DivBackward0>)

KL散度和交叉熵损失函数之前是差一个target分布的熵，一个input分布的熵，然后两个分布的交叉熵，然后两个分布的KL散度，然后两个分布的交叉熵减去两个分布的KL散度，然后两个分布的交叉熵减去两个分布的KL散度的结果是一样的。

In [48]:
kl_loss = torch.nn.KLDivLoss(reduction='none')
ce_loss = torch.nn.CrossEntropyLoss(reduction='none')

In [57]:
input = torch.tensor([[-0.1, 0.2, -0.4, 0.3]], dtype=torch.float)
target = torch.tensor([[-0.7, 0.1, -0.1, 0.1]], dtype=torch.float)

In [58]:
kl_output = kl_loss(F.log_softmax(input, dim=1), F.softmax(target, dim=1))
print(kl_output)
print(kl_output.mean())
print(kl_output.sum())

tensor([[-0.0635,  0.0116,  0.1097, -0.0190]])
tensor(0.0097)
tensor(0.0389)


In [59]:
F.log_softmax(input, dim=1)

tensor([[-1.5222, -1.2222, -1.8222, -1.1222]])

In [79]:
p = F.softmax(input, dim=1)
log_p = torch.log(p)
p, log_p

(tensor([[0.2182, 0.2946, 0.1617, 0.3255]]),
 tensor([[-1.5222, -1.2222, -1.8222, -1.1222]]))

In [80]:
q = F.softmax(target, dim=1)
log_q = torch.log(q)
q, log_q

(tensor([[0.1375, 0.3060, 0.2505, 0.3060]]),
 tensor([[-1.9842, -1.1842, -1.3842, -1.1842]]))

In [83]:
q * (log_q - log_p)

tensor([[-0.0635,  0.0116,  0.1097, -0.0190]])