# Example Weighted Sampler

In [1]:
import torch
from torch.utils.data import WeightedRandomSampler, DataLoader, TensorDataset
import numpy as np

In [2]:
# Settings
numDataPoints = 1000
data_dim = 5
bs = 100

In [3]:
# Create dummy data with class imbalance 9:1 (0s and 1s)
data = torch.randn(numDataPoints, data_dim)
target = np.hstack((
    np.zeros(int(numDataPoints * 0.9), dtype=np.int32),
    np.ones(int(numDataPoints * 0.1), dtype=np.int32)
))

In [10]:
print(data)

tensor([[-0.7313, -0.2483, -0.0854,  1.1607, -1.4849],
        [ 0.5510,  0.8307,  1.2464,  0.7996, -0.3997],
        [-0.4807,  0.3722, -0.9829, -0.8983, -0.5159],
        ...,
        [-0.3974,  0.7881, -0.9042,  1.5625,  0.6863],
        [ 1.0850, -0.0175,  0.3553, -0.4402, -0.0762],
        [-1.2026,  0.9135,  0.4817, -0.5554, -0.9916]])


In [11]:
print(target)

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 

In [4]:
print('Original class distribution (0/1): {}/{}'.format(
    len(np.where(target == 0)[0]),
    len(np.where(target == 1)[0])
))

Original class distribution (0/1): 900/100


In [5]:
# Calculate weights for each class (inverse frequency)
class_sample_count = np.array(
    [len(np.where(target == t)[0]) for t in np.unique(target)]
)
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target])
samples_weight = torch.from_numpy(samples_weight).double()

In [12]:
print(weight)

[0.00111111 0.01      ]


In [13]:
print(samples_weight)

tensor([0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011,
        0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011,
        0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011,
        0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011,
        0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011,
        0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011,
        0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011,
        0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011,
        0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011,
        0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011,
        0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011,
        0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011,
        0.0011, 0.0011, 0.0011, 0.0011, 

In [6]:
# Create sampler
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

In [7]:
# Create dataset
target_tensor = torch.from_numpy(target).long()
train_dataset = TensorDataset(data, target_tensor)

In [8]:
# Create DataLoader with the sampler
train_loader = DataLoader(
    train_dataset, batch_size=bs, num_workers=0, sampler=sampler
)

In [9]:
# Count class balance in each batch
for i, (data_batch, target_batch) in enumerate(train_loader):
    count_0 = (target_batch == 0).sum().item()
    count_1 = (target_batch == 1).sum().item()
    print(f"Batch {i}: class 0 = {count_0}, class 1 = {count_1}")

Batch 0: class 0 = 49, class 1 = 51
Batch 1: class 0 = 43, class 1 = 57
Batch 2: class 0 = 55, class 1 = 45
Batch 3: class 0 = 52, class 1 = 48
Batch 4: class 0 = 52, class 1 = 48
Batch 5: class 0 = 47, class 1 = 53
Batch 6: class 0 = 48, class 1 = 52
Batch 7: class 0 = 43, class 1 = 57
Batch 8: class 0 = 44, class 1 = 56
Batch 9: class 0 = 44, class 1 = 56
