# Understanding weighted Cross Entropy loss and its effects


__Problem description:__
- we have a multiclass classification problem
- the class imbalance is very large -> loss class weights are applied
- in inference time the occurrence of a class (debris) is large  -> we want to classify rare classes in "large noise"
- the rare classes have large amount of false positive predictions from the numerous debris class -> we should punish this during training

cross entropy loss:
$$ - \sum_{c=1}^{N}y_c \log(p_c)$$
where N is the number of classes and c iterates through each class; $y$ is a binary label $[0, 1]$ and $p$ is the predicted class probability. We can create a probability value for each class with the application of softmax function:

$$ p_i = \frac{e^{x_i}}{\sum_{c=1}^{N}e^{x_c}} $$
This function exponentially highlights the class with the largest assigned value and suppresses the other ones. At this point, information from the predictions for the other classes appear in the calculated probability value. If there was a high value assigned to another class that will suppress the current class probability.

If we have an imbalanced dataset we can apply class weights to correct the bias:
$$ - \sum_{c=1}^{N} w_c y_c \log(p_c)$$

Note that using LogSoftmax in the output layer and NLLLoss (negative log likelihood loss) is equivalent with  cross entropy loss.
The output layer of the NN is LogSoftmax:
$$ LogSoftmax(x) = log \frac{e^{x_c}}{\sum_{c=1}^{N} e^{x_c}} \in [- \inf, 0] $$

where $x \in \mathcal{R}^{N}$ is the raw output of the NN; N is the number of classes.

$$ NLLLoss(x) = -\sum_{c=1}^{N}y_c x_c $$

for the sake of simplicity we will stick to the cross entropy formulation (although our implementation use this scheme).

We have 4 classes with the following sample distribution and assigned weights:

- class1 (260) -> 0.8
- class2 (10) -> 26
- class3 (56) -> 4
- class4 (500)-> 0.4

The first 3 classes are rare objects and the fourth is the numerous debris class.

In [3]:
import numpy as np

In [5]:
class_weights = np.array([0.8, 26, 4, 0.4])

In [10]:
# softmax
def softmax(x: np.array):
    return np.exp(x)/np.sum(np.exp(x))

# CrossEntropyLoss
def cross_entropy_loss(p: np.array, y: int, w: np.array):
    """
    :param p: class probabilities
    :param y: ground truth class index
    :param w: class weights
    :return:
    """
    return -1 * np.log(p[y]) * w[y]

In [11]:
# case 1
# class 1 predicted correctly with large confidence
y = 0
x = [10, 0.1, 0.1, 0.1]
p = softmax(x)
l = cross_entropy_loss(p, y, class_weights)
print(p)
print(l)

[9.99849499e-01 5.01671307e-05 5.01671307e-05 5.01671307e-05]
0.00012041017484880806


In [12]:
np.sum(p)

1.0

In [13]:
# case 2
# class 2 predicted correctly with large confidence
y = 1
x = [0.10, 10, 0.1, 0.1]
p = softmax(x)
l = cross_entropy_loss(p, y, class_weights)
print(p)
print(l)

[5.01671307e-05 9.99849499e-01 5.01671307e-05 5.01671307e-05]
0.003913330682586261


In [14]:
# case 3
# class 3 predicted correctly with large confidence
y = 2
x = [0.10, 0.10, 10, 0.1]
p = softmax(x)
l = cross_entropy_loss(p, y, class_weights)
print(p)
print(l)

[5.01671307e-05 5.01671307e-05 9.99849499e-01 5.01671307e-05]
0.0006020508742440402


In [15]:
# case 4
# class 4 predicted correctly with large confidence
y = 3
x = [0.10, 0.10, 0.10, 10]
p = softmax(x)
l = cross_entropy_loss(p, y, class_weights)
print(p)
print(l)

[5.01671307e-05 5.01671307e-05 5.01671307e-05 9.99849499e-01]
6.020508742440403e-05


The trivial cases are clear: for a rare class a larger weight is assigned -> the network update for that class is scarce but its extent is large.

In [17]:
# class 1 predicted correctly with smaller confidence
y = 0   # GT
x = [10, 8, 0.1, 0.1]   # Network raw output
p = softmax(x)
l = cross_entropy_loss(p, y, class_weights)
print(p)
print(l)

[8.80719233e-01 1.19192387e-01 4.41898075e-05 4.41898075e-05]
0.10161311565097696


In [21]:
y = 0   # GT
x = [10, 0.1, 0.1, 0.8]   # Network raw output
p = softmax(x)
l = cross_entropy_loss(p, y, class_weights)
print(p)
print(l)

[9.99798652e-01 5.01645795e-05 5.01645795e-05 1.01019058e-04]
0.00016109479196323826


In [18]:
y = 0   # GT
x = [10, 8, 5, 0.1]   # Network raw output
p = softmax(x)
l = cross_entropy_loss(p, y, class_weights)
print(p)
print(l)

[8.75562129e-01 1.18494449e-01 5.89949122e-03 4.39310514e-05]
0.10631133259521273


In [20]:
x = [10, 8, 5, 5]   # Network raw output
p = softmax(x)
l = cross_entropy_loss(p, y, class_weights)
print(p)
print(l)

[0.87046507 0.11780464 0.00586515 0.00586515]
0.11098211900127541


In [22]:
x = [10, 9, 9, 9]   # Network raw output
p = softmax(x)
l = cross_entropy_loss(p, y, class_weights)
print(p)
print(l)

[0.47536689 0.1748777  0.1748777  0.1748777 ]
0.5949347045029433


The critical case for us is when a debris is classified to an another class.

In [25]:
y = 3
x = [10, 0.1, 0.1, 5]   # Network raw output
p = softmax(x)
l = cross_entropy_loss(p, y, class_weights)
print(p)
print(l)

[9.93208148e-01 4.98339031e-05 4.98339031e-05 6.69218386e-03]
2.0027260085049603


In [26]:
x = [10, 0.1, 0.1, 0.1]   # Network raw output
p = softmax(x)
l = cross_entropy_loss(p, y, class_weights)
print(p)
print(l)

[9.99849499e-01 5.01671307e-05 5.01671307e-05 5.01671307e-05]
3.9600602050874247


# Ideas

To increase the punishment when a debris is wrongly classified we could play with the assigned weights.
We could change the weights generated from the sample distribution to a custom value.

In [27]:
class_weights[3] = 1

x = [10, 0.1, 0.1, 0.1]   # Network raw output
p = softmax(x)
l = cross_entropy_loss(p, y, class_weights)
print(p)
print(l)

[9.99849499e-01 5.01671307e-05 5.01671307e-05 5.01671307e-05]
9.900150512718561


In [28]:
class_weights[3] = 0.8

x = [10, 0.1, 0.1, 0.1]   # Network raw output
p = softmax(x)
l = cross_entropy_loss(p, y, class_weights)
print(p)
print(l)

[9.99849499e-01 5.01671307e-05 5.01671307e-05 5.01671307e-05]
7.9201204101748495


The pitfalls of this is that the number of false negative predictions from the rare classes might increase, so it might tend to classify samples from rare class to the debris class.
Thus, the weights must be carefully chosen.