# A Tutorial on How to Adress Class Imbalance using WeightedRandomSampler in PyTorch

Class imbalance is a very common problem in real world datasets. For example, a medical diagnosis dataset may have large number samples corresponding to the healthy class and very few samples belonging to the disease class. Class imbalance is detrimental to performance of the model, and it can lead to a very poor generalization. There are many ways to address this issue. In this article, we will focus on sampling strategies which can be implemented very easily in PyTorch.

# Imbalance Dataset

First, create a dummy dataset with imbalance classes. In this dataset, 90% of samples belong to *class 0* and 10% belong to *class 1*. Then, we create a PyTorch [TensorDataset](https://pytorch.org/docs/stable/data.html#torch.utils.data.TensorDataset) and [DataLoader](https://pytorch.org/docs/stable/data.html).

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler

# create dummy dataset
num_samples = 1024
data = torch.randn(num_samples,10) # dummy data
num_high_freq_samples = int(num_samples * 0.9)
targets = [0] * num_high_freq_samples + [1] * (num_samples - num_high_freq_samples)

## convert targets to LongTensor
targets = torch.LongTensor(targets)

# Create DataSet and DataLoader
dataset = TensorDataset(data,targets)
dataloader = DataLoader(dataset,batch_size=128, num_workers=1, shuffle=True)

Now, iterate through dataloader and plot the distribution of classes in each batch.

In [None]:
def plot_function(x1,x2):
    width =0.3
    fig = plt.figure()
    ax = fig.add_axes([0,0,1,1])
    plt.bar(np.arange(1, len(x1)+1), x1, width=width)
    plt.bar(np.arange(1, len(x2)+1)+ width, x2, width=width)
    ax.legend(labels=['Class 0', 'Class 1'])
    plt.show()

zeros = []
ones = []
for idx, (x,y) in enumerate(dataloader):
    unique, counts = np.unique(y.numpy(), return_counts=True)
    zeros.append(counts[0])
    ones.append(counts[1])

plot_function(zeros,ones)

We can observe high imbalance in Fig. 1., where all mini-batches are dominated by class 0. As only 10% of samples belong to *class 1*, the model is less likely to see learn useful features from this class.

# Oversampling with WeightedRandomSampler

We can address the issue of imbalance by [oversampling strategy](https://arxiv.org/pdf/1710.05381.pdf). In oversampling, the key idea is to maintain a balance between two classes by oversampling the minor class. With oversampling, each mini-batch will have nearly equal number of samples drawn from both classes. 

In PyTorch, Oversampling can be easily implemented using [WeightedRandomSampler](https://pytorch.org/docs/stable/data.html#torch.utils.data.WeightedRandomSampler). `WeightedRandomSampler` internally draws samples from *Multinomial Distribution* with controlled parameters. These parameters are `weights` and `num_samples`. Here,`weights` corresponds to weight assigned to each class sample. To draw nearly equal number samples from both classes (or have an equal probability of being sampled), the minor class should be assigned a higher weight. Below is an example for how to calculate weights for each class. 

In [None]:
# count occurance of each class
unique, counts = np.unique(targets, return_counts=True)

# calcuate weight of each class
class_weights = [1.0/c for c in counts]

# assign weight to each sample
sample_weights = [class_weights[i] for i in targets]

# Create WeightedRandomSampler
sampler = WeightedRandomSampler(sample_weights, len(sample_weights))

# assign sampler
dataloader = DataLoader(dataset,batch_size=128, num_workers=1,sampler=sampler)

# iterate through dataset and plot class distribution in each batch
zeros = []
ones = []
for idx, (x,y) in enumerate(dataloader):
    unique, counts = np.unique(y.numpy(), return_counts=True)
    zeros.append(counts[0])
    ones.append(counts[1])

plot_function(zeros,ones)

After calculating the weight for each sample, we initialize the `WeightedRandomSampler` with these weightes and pass it to the `sampler` argument in the `DataLoader`. After iterating through the dataloader, we plot the distribution of classes in each batch. It can be observed in Fig. 2. that after oversampling with `WeightedRandomSampler`, each mini-batch now has nearly equal number of samples from both classes. It should help to learn important features from the minor class and improve generalization of the model.