In [11]:
import torch
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

In [43]:
ds = torch.randn(100, 1) # our dataset will consist of 1000 random numbers
lbls = torch.tensor([0]*70+[1]*26+[2]*4) # we have 3 very unbalanced labels

dataset = torch.utils.data.TensorDataset(ds, lbls)

In [82]:
# let's first try undersampling the labels, so let's create a tensor with the label weights
weights = [70, 26, 4]
weights = torch.tensor([1.0/weights[label] for label in lbls])
# print(weights)

# now we just create the dataloader with the WeightedRandomSampler but use the undersampling method
batch_size = 2
element_count = 4 * 3 # 4 is our smalles class count, 3 is the number of classes
dataloader = DataLoader(dataset, batch_size=4, sampler=WeightedRandomSampler(weights, element_count))

In [100]:
print(len(dataloader)) # by undersampling we end up with only 3 batches of size 4, so 6 elements, , of which 
for s, l in dataloader:
    print(s.tolist(), l.tolist())

3
[[-2.122056245803833], [-0.8459286093711853], [-0.46079280972480774], [1.240039348602295]] [1, 1, 1, 2]
[[-0.40949711203575134], [-1.2838244438171387], [1.240039348602295], [1.1774694919586182]] [0, 2, 2, 0]
[[1.146072268486023], [0.7432038187980652], [0.555371105670929], [0.7432038187980652]] [1, 2, 0, 2]


In [102]:
# now the same, but oversampling
batch_size = 2
element_count = 70 * 3 # 70 is our largest class count, 3 is the number of classes
dataloader = DataLoader(dataset, batch_size=4, sampler=WeightedRandomSampler(weights, element_count))

In [103]:
print(len(dataloader)) # by oversampling we end up with only 53 batchesof size 4, so 6 elements, , of which 
for i in range(10): # print only 10 of the 53 batches
    s, l = next(iter(dataloader))
    print(s.tolist(), l.tolist())

53
[[0.6577771902084351], [1.240039348602295], [0.7432038187980652], [-0.3209373354911804]] [0, 2, 2, 1]
[[0.7432038187980652], [1.0059230327606201], [-0.40949711203575134], [-0.8459286093711853]] [2, 1, 0, 1]
[[-0.5922260284423828], [-0.9726153016090393], [-2.0352909564971924], [0.41047900915145874]] [1, 1, 1, 2]
[[1.0059230327606201], [-0.46079280972480774], [-0.6088237166404724], [1.1872705221176147]] [1, 1, 0, 0]
[[-0.8459286093711853], [1.8363431692123413], [-1.2838244438171387], [1.8363431692123413]] [1, 0, 2, 0]
[[1.240039348602295], [0.8751641511917114], [1.240039348602295], [-1.2838244438171387]] [2, 1, 2, 2]
[[-0.5922260284423828], [0.41047900915145874], [0.7432038187980652], [0.8751641511917114]] [1, 2, 2, 1]
[[0.7432038187980652], [1.3181135654449463], [-0.31453022360801697], [-0.10773799568414688]] [2, 0, 0, 0]
[[1.240039348602295], [1.240039348602295], [0.8386877179145813], [0.41047900915145874]] [2, 2, 1, 2]
[[-1.2838244438171387], [1.1774694919586182], [0.83868771791458