## Calculate Cross Entropy Loss manually

In [1]:
import torch
import torch.nn.functional as F

# Suppose we have prediction
prediction = torch.tensor([[0.3, 0.2, 0.4, 0.1], [0.4, 0.1, 0.3, 0.2], [0.3, 0.5, 0.1, 0.2]])

# Suppose we have target indexes
target = torch.tensor([2, 0, 1])

# Apply softmax to the prediction to get probabilities
proba_sofmax = F.softmax(prediction, dim=1)

# Select the probabilities corresponding to the true classes
selected_proba = proba_sofmax[range(len(target)), target]

# Compute the negative log likelihood loss
manual_cross_entropy_loss = -torch.log(selected_proba).mean()

# Print out results
print("prediction: {} \n".format(prediction))
print("target: {} \n".format(target))
print("proba_sofmax: {}\n".format(proba_sofmax))
print("selected_proba: {}\n".format(selected_proba))
print("manual_cross_entropy_loss: {}\n".format(manual_cross_entropy_loss))

prediction: tensor([[0.3000, 0.2000, 0.4000, 0.1000],
        [0.4000, 0.1000, 0.3000, 0.2000],
        [0.3000, 0.5000, 0.1000, 0.2000]]) 

target: tensor([2, 0, 1]) 

proba_sofmax: tensor([[0.2612, 0.2363, 0.2887, 0.2138],
        [0.2887, 0.2138, 0.2612, 0.2363],
        [0.2535, 0.3096, 0.2075, 0.2294]])

selected_proba: tensor([0.2887, 0.2887, 0.3096])

manual_cross_entropy_loss: 1.2191709280014038



## Calculate Cross Entropy Loss by nn.CrossEntropyLoss()

In [2]:
import torch
import torch.nn as nn

# Define the loss function
loss_fn = nn.CrossEntropyLoss()

# Suppose we have prediction
prediction = torch.tensor([[0.3, 0.2, 0.4, 0.1], [0.4, 0.1, 0.3, 0.2], [0.3, 0.5, 0.1, 0.2]])

# Suppose we have target indexes
target = torch.tensor([2, 0, 1])

# Compute the cross entropy loss
loss = loss_fn(prediction, target)

# Print out results
print("prediction: {} \n".format(prediction))
print("target: {} \n".format(target))
print("cross_entropy_loss: {} \n".format(loss))

prediction: tensor([[0.3000, 0.2000, 0.4000, 0.1000],
        [0.4000, 0.1000, 0.3000, 0.2000],
        [0.3000, 0.5000, 0.1000, 0.2000]]) 

target: tensor([2, 0, 1]) 

cross_entropy_loss: 1.2191709280014038 

