# Use weighted loss function to solve imbalanced data classification problems

Imbalanced datasets are a common problem in classification tasks, where number of instances in one class is significantly smaller than number of instances in another class. This will lead to biased models that perform poorly on minority class.

    - Weighted loss function is a modification of standard loss function used in training a model.
    - The weights are used to assign a higher penalty to mis classifications of minority class.
    - The idea is to make model more sensitive to minority class by increasing cost of mis classification of that class.
    - The most common way to implement a weighted loss function is to **assign higher weight to minority class** and lower weight to majority class
    

## How to add weights to pytorch’s common loss functions

#### **Binary Classification (torch.nn.BCEWithLogitsLoss)**

    **torch.nn.BCEWithLogitsLoss** function is a commonly used loss function for binary classification problems, where model output is a probability value between 0 and 1. It combines a sigmoid activation function with a binary cross-entropy loss.
    - absFor imbalanced datasets, where number of instances in one class is significantly smaller than other, torch.nn.BCEWithLogitsLoss function can be modified by adding a weight parameter to loss function. The weight parameter allows to assign different weights for the positive and negative classes.

**The weight parameter is a tensor of size [batch_size] that contains weight value for each sample in the batch**

In [71]:
import torch 
import torch.nn as nn

# Define the BCEWithLogitsLoss function with weight parameter
class_counts = torch.tensor([1000,600], dtype=torch.int)
print(f"class counts: {class_counts}")

# class weights
class_weights = 1.0 / class_counts
class_weights = class_weights/ class_weights.sum() #Higher weights to low count classes
print(f"class weights : {class_weights}", end="\n\n")

# Assign correct weights per target class
sample_weights = class_weights[target.view(-1).long()]  # Match target indices
sample_weights = sample_weights.view(-1, 1)  # Reshape to match (N, 1)
print(f"Sample weights: {sample_weights}")

# Generate some random data for the binary classification problem
input_ = torch.randn(3, 1)
target = torch.tensor([[0.], [1.], [1.]])
print(f"Inputs vector : {input_}, {input_.shape}, ndim : {input_.ndim}")
print(f"Target vector : {target}")

# loss function
criterion = nn.BCEWithLogitsLoss(weight=sample_weights)
loss = criterion(input_, target)
print(f"Loss : {loss}", end="\n\n")

class counts: tensor([1000,  600], dtype=torch.int32)
class weights : tensor([0.3750, 0.6250])

Sample weights: tensor([[0.3750],
        [0.6250],
        [0.6250]])
Inputs vector : tensor([[0.2851],
        [0.6095],
        [0.1164]]), torch.Size([3, 1]), ndim : 2
Target vector : tensor([[0.],
        [1.],
        [1.]])
Loss : 0.3287941515445709



#### weight_for_class_i = total_samples / (num_samples_in_class_i)

In [76]:
# Another ways
weight_for_class_0 = 1600 / (1000)
weight_for_class_1 = 1600 / (600)

print(f"weight_for_class_0 : {weight_for_class_0}")
print(f"weight_for_class_1 : {weight_for_class_1}") #assigning the higher weights to weaker classes

weight_for_class_0 : 1.6
weight_for_class_1 : 2.6666666666666665


```In addition to weight parameter, torch.nn.BCEWithLogitsLoss also has a pos_weight parameter, which is a simpler way to specify weight for positive class in a binary classification problem.```

In [77]:
import torch
import torch.nn as nn

# Define the BCEWithLogitsLoss function with pos_weight parameter
pos_weight = torch.tensor([3.0])  # higher weight for positive class
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

# Generate some random data for the binary classification problem
input = torch.randn(3, 1)
target = torch.tensor([[0.], [1.], [1.]])

# Compute the loss with the specified pos_weight
loss = criterion(input, target)

print(loss)

tensor(0.7315)


# Multiclass Classification

## Suppose we have a dataset with 1000 samples, and target variable has three classes: Class A, Class B, and Class C. The distribution of samples in dataset is as follows:

In [90]:
# class a: 100 samples
# class b: 800 samples
# class c: 100 samples

counts = [100,800,100]
class_counts= torch.tensor(counts)
print(f"class counts : {class_counts}")

#class weights 
class_weights = 1.0 / class_counts
print(f"Initial class weights : {class_weights}")

# normalized classweights 
class_weights= class_weights/class_weights.sum()
print(f"Normalized class weights : {class_weights}")
print(f"Normalized class weights in percentages : {class_weights*100}")

#loss functions 
loss_fn = torch.nn.CrossEntropyLoss(weight=class_weights)

class counts : tensor([100, 800, 100])
Initial class weights : tensor([0.0100, 0.0012, 0.0100])
Normalized class weights : tensor([0.4706, 0.0588, 0.4706])
Normalized class weights in percentages : tensor([47.0588,  5.8824, 47.0588])


In [106]:
# Another ways

counts = [100,800,100]
count_sum= sum(counts)
class_weightsUpdated= [sum(counts)/x for x in counts]

class_weightsUpdated = torch.tensor(class_weightsUpdated)
print("Updated class weights :",class_weightsUpdated)

## Loss function

loss_fn= torch.nn.CrossEntropyLoss(weight=class_weightsUpdated)
loss_fn.weight

Updated class weights : tensor([10.0000,  1.2500, 10.0000])


tensor([10.0000,  1.2500, 10.0000])