In [12]:
import torch
from torch.utils.data import WeightedRandomSampler, DataLoader

class_counts = torch.tensor([104, 642, 784])
numDataPoints = class_counts.sum()
numDataPoints

tensor(1530)

In [13]:
data_dim = 5
bs = 170
data = torch.randn(numDataPoints, data_dim)
data

tensor([[-0.6557,  0.2413, -0.5738,  1.5683, -0.5675],
        [ 0.7457,  1.1364, -0.4761, -0.2889, -2.2014],
        [-0.1682, -0.8786, -0.3168, -2.1195, -0.7793],
        ...,
        [ 1.0239, -0.2812, -1.0666, -0.3779, -0.4048],
        [ 1.4318, -1.4207, -0.2470, -0.4998,  0.0420],
        [-1.1152,  3.3121,  2.2242,  1.1222, -0.1707]])

In [14]:
target = torch.cat((torch.zeros(class_counts[0], dtype=torch.long),
                    torch.ones(class_counts[1], dtype=torch.long),
                    torch.ones(class_counts[2], dtype=torch.long) * 2))

print('target train 0/1/2: {}/{}/{}'.format(
    (target == 0).sum(), (target == 1).sum(), (target == 2).sum()))
target

target train 0/1/2: 104/642/784


tensor([0, 0, 0,  ..., 2, 2, 2])

In [15]:
# Compute samples weight (each sample should get its own weight)
class_sample_count = torch.tensor(
    [(target == t).sum() for t in torch.unique(target, sorted=True)])
class_sample_count

tensor([104, 642, 784])

In [16]:
weight = 1. / class_sample_count.float()
samples_weight = torch.tensor([weight[t] for t in target])
samples_weight

tensor([0.0096, 0.0096, 0.0096,  ..., 0.0013, 0.0013, 0.0013])

In [17]:
# Create sampler, dataset, loader
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
train_dataset = torch.utils.data.TensorDataset(data, target)
#train_dataset = triaxial_dataset(data, target)
train_loader = DataLoader(
    train_dataset, batch_size=bs, num_workers=0, sampler=sampler)

# Iterate DataLoader and check class balance for each batch
for i, (x, y) in enumerate(train_loader):
    print("batch index {}, 0/1/2: {}/{}/{}".format(
        i, (y == 0).sum(), (y == 1).sum(), (y == 2).sum()))

batch index 0, 0/1/2: 60/59/51
batch index 1, 0/1/2: 41/66/63
batch index 2, 0/1/2: 49/59/62
batch index 3, 0/1/2: 55/58/57
batch index 4, 0/1/2: 47/67/56
batch index 5, 0/1/2: 54/55/61
batch index 6, 0/1/2: 45/64/61
batch index 7, 0/1/2: 48/54/68
batch index 8, 0/1/2: 58/53/59
