## Cross-entropy Loss in PyTorch and its relation with Softmax 

# [Link to my Youtube Video Explaining this whole Notebook](https://www.youtube.com/watch?v=h3M3Ob0zgkc&list=PLxqBkZuBynVRMORlFw95iNTp9aZzmmz4Y&index=7)

[![Imgur](https://imgur.com/r2Qhgo3.png)](https://www.youtube.com/watch?v=h3M3Ob0zgkc&list=PLxqBkZuBynVRMORlFw95iNTp9aZzmmz4Y&index=7)


### Cross-entropy is a function that compares two probability distributions.

![](assets/2022-02-21-19-07-53.png)

The key thing to pay attention to is that cross-entropy is a function that takes, as input, two probability distributions: q and p and returns a value that is minimal when q and p are equal. q represents an estimated distribution, and p represents a true distribution.

In the context of ML classification we know the actual label of the training data, so the true/target distribution, p, has a probability of 1 for the true label and 0 elsewhere, i.e. p is a one-hot vector.

------------------

#### Softmax is a non-linear activation function

![](assets/2022-02-21-19-08-58.png)

Softmax is not a loss function. It has a very specific task: It is used for multi-class classification to normalize the scores for the given classes. By doing so we get probabilities for each class that sum up to 1.

### In PyTorch Softmax is combined with Cross-Entropy-Loss to calculate the `nn.CrossEntropyLoss` of a model.


The documentation of [nn.CrossEntropyLoss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html#torch.nn.CrossEntropyLoss) says,

#### This criterion combines `nn.LogSoftmax()` and `nn.NLLLoss()` in one single class.

#### So DO NOT use `nn.Softmax()` Function in the output layer for a neural net when using `nn.CrossEntropyLoss` as a loss function.



In [1]:
import torch
import torch.nn as nn

In [2]:
output = torch.randn(3, 5, requires_grad=True)

output

tensor([[-0.7055,  2.2556,  0.4626, -0.1980,  1.0355],
        [-0.7134, -1.1858,  0.6833,  0.3093, -1.6893],
        [-0.8832,  0.5251,  1.1047,  1.4685, -0.9776]], requires_grad=True)

In [3]:
ce_loss_pytorch = nn.CrossEntropyLoss()

targets = torch.argmax(output, dim=1)

targets

tensor([1, 2, 3])

In [4]:
loss_pytorch = ce_loss_pytorch(output, targets)
loss_pytorch

tensor(0.6895, grad_fn=<NllLossBackward0>)

## Manual Cross Entropy Loss

In [5]:
output

tensor([[-0.7055,  2.2556,  0.4626, -0.1980,  1.0355],
        [-0.7134, -1.1858,  0.6833,  0.3093, -1.6893],
        [-0.8832,  0.5251,  1.1047,  1.4685, -0.9776]], requires_grad=True)

In [6]:
softmax_proba = torch.softmax(output, axis=1)
softmax_proba

tensor([[0.0324, 0.6252, 0.1041, 0.0538, 0.1846],
        [0.1133, 0.0707, 0.4581, 0.3152, 0.0427],
        [0.0420, 0.1718, 0.3067, 0.4413, 0.0382]], grad_fn=<SoftmaxBackward0>)

### Now cross-entropy loss is nothing but a combination of softmax and negative log likelihood loss. Hence, your loss can simply be computed by taking the average of the negative log of the probabilities of your true labels.

![](assets/2022-02-21-19-16-14.png)

In [7]:
loss_manual = (torch.log(1/softmax_proba[0,3]) +  torch.log(1/softmax_proba[1,3]) + torch.log(1/softmax_proba[2,2])) / 3

loss_manual

tensor(1.7533, grad_fn=<DivBackward0>)

In [8]:
max_from_tensor = torch.max(softmax_proba, dim=1)

print('max_from_tensor : ', max_from_tensor)

max_from_tensor :  torch.return_types.max(
values=tensor([0.6252, 0.4581, 0.4413], grad_fn=<MaxBackward0>),
indices=tensor([1, 2, 3]))


In [9]:
row_1_max =  max_from_tensor[0][0]
row_2_max =  max_from_tensor[0][1]
row_3_max =  max_from_tensor[0][2]

In [10]:
loss_manual = (torch.log(1/row_1_max) +  torch.log(1/row_2_max) + torch.log(1/row_3_max)) / 3
loss_manual

tensor(0.6895, grad_fn=<DivBackward0>)

### So overall, we can see that the calculation of Cross Entropy Loss using `nn.CrossEntropyLoss()` and using the manual step by step process with Softmax() gave the same result.

#### Hence DO NOT use `nn.Softmax()` Function in the output layer for a neural net when using `nn.CrossEntropyLoss` as a loss function as Pytorch's `nn.CrossEntropyLoss` already combines the softmax within it.