In [None]:
import numpy as np 
import pandas as pd 

import torch 
import seaborn as sns

In [None]:
data = pd.read_csv('../input/seti-breakthrough-listen/train_labels.csv')
data.head()

# Compute positive and negative frequencies

In [None]:
def compute_class_freqs(labels):
    
    labels = np.array(labels)
    
    N = labels.shape[0]
    
    positive_frequencies = np.sum(labels,axis = 0) / N
    negative_frequencies = 1 - positive_frequencies
    
    return positive_frequencies, negative_frequencies

In [None]:
freq_pos, freq_neg = compute_class_freqs(data['target'])

In [None]:
df = pd.DataFrame({"Targets": ['0'], "Label": ["Negative"], "Value": freq_neg})
df = df.append({"Targets": '1', "Label": "Positive", "Value": freq_pos}, ignore_index=True)
sns.barplot(x="Targets", y="Value" ,data=df)

- If we calculate BCE loss of individual examples, total loss coming from positive examples contributes less than the total loss coming from negative examples 

<center><img src="https://i.pinimg.com/originals/a6/62/32/a66232f444f7eb9452c2868a37c3be0c.png" width="440" height="440" ></center>

Contributions of positive cases is significantly lower than that of the negative ones. However, we want the contributions to be equal. One way of doing this is by multiplying each example from each class by a class-specific weight factor, $pos_{weights}$ and $neg_{weights}$, so that the overall contribution of each class is the same.

To have this, we want

$$pos_{weights} \times freq_{p} = neg_{weights} \times freq_{n},$$
which we can do simply by taking

$$pos_{weights} = freq_{neg}$$$$neg_{weights} = freq_{pos}$$
This way, we will be balancing the contribution of positive and negative labels.

In [None]:
pos_weights = freq_neg
neg_weights = freq_pos
pos_contribution = freq_pos * pos_weights 
neg_contribution = freq_neg * neg_weights

In [None]:
df = pd.DataFrame({"Targets": ['0'], "Label": ["Negative"], "Value": neg_contribution})
df = df.append({"Targets": '1', "Label": "Positive", "Value": pos_contribution}, ignore_index=True)
f = sns.barplot(x="Targets", y="Value" ,data=df)

# Weighted Binary Cross Entropy 

- If we multiply pos weights to positive loss and neg weights to negative loss, we see an equal contribution of loss.

<center><img src="https://i.pinimg.com/originals/ba/e0/5d/bae05dffdcb9efdd4430d877febe6dbf.png" width="440" height="440" ></center>

In [None]:
class W_BCEWithLogitsLoss(torch.nn.Module):
    
    def __init__(self, w_p = None, w_n = None):
        super(W_BCEWithLogitsLoss, self).__init__()
        
        self.w_p = w_p
        self.w_n = w_n
        
    def forward(self, logits, labels, epsilon = 1e-7):
        
        ps = torch.sigmoid(logits.squeeze()) 
        
        loss_pos = -1 * torch.mean(self.w_p * labels * torch.log(ps + epsilon))
        loss_neg = -1 * torch.mean(self.w_n * (1-labels) * torch.log((1-ps) + epsilon))
        
        loss = loss_pos + loss_neg
        
        return loss

In [None]:
targets = torch.tensor([0, 0, 1, 0, 1, 0, 0, 0 ,1, 0 ,0, 0]).float()
logits = torch.zeros_like(targets)  # probs will be 0.5

w_p = len(torch.where(targets == 0)[0]) / len(targets)
w_n = 1 - w_p 

criterion1 = W_BCEWithLogitsLoss(w_p, w_n)
criterion2 = torch.nn.BCEWithLogitsLoss()

In [None]:
print("Loss from Weighted BCE : {}".format(criterion1(logits, targets)))
print("Loss from BCE : {}".format(criterion2(logits, targets)))