In [1]:
import torch
import numpy as np

from imbalanced_sampler_3 import MultilabelBalancedRandomSampler
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable


In [2]:
class RandomDataset(Dataset):
    def __init__(self, n_examples, n_features, n_classes, mean_labels_per_example):
        self.n_examples = n_examples
        self.n_features = n_features
        self.n_classes = n_classes
        self.X = np.random.random([self.n_examples, self.n_features])

        class_probabilities = np.random.random([self.n_classes])
        class_probabilities = class_probabilities / sum(class_probabilities)
        class_probabilities *= mean_labels_per_example
        self.y = (
            np.random.random([self.n_examples, self.n_classes]) < class_probabilities
        ).astype(int)

    def __len__(self):
        return self.n_examples

    def __getitem__(self, index):
        example = Variable(torch.tensor(self.X[index]), requires_grad=False)
        labels = Variable(torch.tensor(self.y[index]), requires_grad=False)
        return {"example": example, "labels": labels}


In [3]:
dataset = RandomDataset(20000, 100, 20, 2)


batch_size=512
val_size=0.2

# Split into training and validation
indices = list(range(len(dataset)))
np.random.shuffle(indices)
split = int(np.floor(val_size * len(dataset)))
train_idx, validate_idx = indices[split:], indices[:split]

In [4]:
indices

[1441,
 18710,
 19466,
 449,
 19892,
 8283,
 12847,
 12155,
 18307,
 12512,
 11040,
 14212,
 12639,
 5086,
 5842,
 4413,
 11206,
 11331,
 13243,
 6196,
 8297,
 5717,
 16388,
 12592,
 2282,
 12412,
 14005,
 9012,
 931,
 8938,
 9343,
 16256,
 12532,
 10676,
 6204,
 11627,
 10114,
 2632,
 8167,
 19325,
 17563,
 182,
 645,
 17783,
 7760,
 17781,
 474,
 8043,
 18332,
 7351,
 1574,
 2599,
 11776,
 19215,
 1826,
 3645,
 5112,
 520,
 14437,
 11881,
 16938,
 6813,
 13354,
 19618,
 7917,
 9292,
 463,
 13551,
 901,
 4496,
 18039,
 10095,
 16993,
 859,
 19536,
 13518,
 5355,
 15227,
 6468,
 1839,
 1717,
 17486,
 11409,
 6183,
 1111,
 12578,
 6737,
 11761,
 3109,
 8823,
 15459,
 915,
 3665,
 8798,
 781,
 7627,
 3089,
 10856,
 18674,
 2316,
 9724,
 4645,
 6052,
 19704,
 8602,
 9695,
 4406,
 3854,
 892,
 10531,
 834,
 8478,
 5351,
 15183,
 12355,
 9601,
 7398,
 19791,
 3633,
 12726,
 18889,
 12599,
 18756,
 9753,
 18877,
 1523,
 12823,
 4849,
 8850,
 8810,
 7691,
 14528,
 2959,
 19603,
 10955,
 16238

In [5]:
dataset[0]

{'example': tensor([0.1777, 0.1627, 0.5648, 0.7032, 0.3972, 0.7236, 0.9172, 0.6639, 0.1889,
         0.9120, 0.2158, 0.8815, 0.1764, 0.1974, 0.4394, 0.1648, 0.5718, 0.8816,
         0.6830, 0.4217, 0.2708, 0.6863, 0.8665, 0.6342, 0.2555, 0.1367, 0.9010,
         0.3420, 0.4653, 0.0789, 0.2746, 0.3843, 0.9441, 0.8581, 0.4386, 0.7385,
         0.7001, 0.1534, 0.7710, 0.4055, 0.9824, 0.2047, 0.4890, 0.4962, 0.2533,
         0.1967, 0.4250, 0.3629, 0.9061, 0.4114, 0.0651, 0.7133, 0.5172, 0.4118,
         0.9512, 0.3664, 0.1353, 0.4587, 0.6645, 0.8712, 0.8250, 0.7423, 0.5225,
         0.8813, 0.0606, 0.5417, 0.5538, 0.2889, 0.9137, 0.4619, 0.7860, 0.3375,
         0.1629, 0.8661, 0.7950, 0.9295, 0.7734, 0.0956, 0.5210, 0.8411, 0.0701,
         0.7121, 0.7066, 0.7022, 0.1633, 0.3544, 0.1583, 0.7403, 0.4577, 0.7071,
         0.1144, 0.4646, 0.8724, 0.3635, 0.3081, 0.2149, 0.4857, 0.9745, 0.9619,
         0.3888], dtype=torch.float64),
 'labels': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [6]:
train_sampler = MultilabelBalancedRandomSampler(
    dataset.y, train_idx, class_choice="least_sampled"
)

In [11]:
len(dataset.y)

20000

In [13]:
dataset.y

array([[0, 0, 0, ..., 0, 0, 0],
       [1, 0, 0, ..., 0, 0, 0],
       [1, 0, 0, ..., 0, 0, 0],
       ...,
       [1, 0, 0, ..., 0, 0, 0],
       [0, 1, 0, ..., 0, 0, 0],
       [0, 1, 0, ..., 1, 0, 0]])

In [12]:
len(train_idx)

16000

In [14]:
train_idx

[9136,
 12967,
 12919,
 18491,
 16901,
 17784,
 223,
 15566,
 7542,
 1357,
 3930,
 3904,
 18289,
 8765,
 10369,
 15612,
 18777,
 5496,
 16546,
 9345,
 5760,
 12419,
 2829,
 11301,
 17410,
 9873,
 12183,
 6206,
 13952,
 8571,
 15787,
 11257,
 7763,
 5301,
 148,
 4499,
 4991,
 17292,
 7030,
 16106,
 10612,
 15536,
 14589,
 1972,
 8448,
 5172,
 2078,
 9211,
 9173,
 7431,
 19779,
 8403,
 14226,
 10634,
 6294,
 8580,
 3552,
 12332,
 17174,
 17890,
 17338,
 7723,
 16850,
 5953,
 11136,
 10841,
 12563,
 11928,
 6227,
 17288,
 14509,
 19529,
 10448,
 252,
 12054,
 5371,
 14606,
 11065,
 13621,
 18165,
 1040,
 15272,
 17748,
 2247,
 15846,
 19633,
 16811,
 1197,
 8489,
 3150,
 1277,
 17614,
 5833,
 5168,
 9358,
 8652,
 336,
 7607,
 11555,
 12623,
 15720,
 8949,
 12809,
 7796,
 6361,
 1108,
 3249,
 7458,
 12525,
 10495,
 4462,
 12646,
 1800,
 5108,
 3353,
 9221,
 2644,
 6874,
 9214,
 6419,
 5551,
 3410,
 2120,
 5193,
 8520,
 4520,
 7337,
 16651,
 8032,
 4741,
 6824,
 7498,
 6433,
 14229,
 6706,


In [7]:
validate_sampler = SubsetRandomSampler(validate_idx)

# Create data loaders
train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler,)
validate_loader = DataLoader(
    dataset, batch_size=batch_size, sampler=validate_sampler,
)

In [8]:
epochs = 2
for epoch in range(epochs):
        print("================ Training phase ===============")
        for batch in train_loader:
            examples = batch["example"]
            labels = batch["labels"]
            print("Label counts per class:")
            sum_ = labels.sum(axis=0)
            print(sum_)
            print("Difference between min and max")
            print(max(sum_) - min(sum_))
            print("")
        print("")

        print("=============== Validation phase ==============")
        for batch in validate_loader:
            examples = batch["example"]
            labels = batch["labels"]
            print("Label counts per class:")
            sum_ = labels.sum(axis=0)
            print(sum_)
            print("Difference between min and max")
            print(max(sum_) - min(sum_))
            print("")
        print("")


Label counts per class:
tensor([105, 126,  66,  67,  64,  64,  64,  79,  79,  74, 154,  64,  64,  66,
         64,  66,  64,  85,  64,  64])
Difference between min and max
tensor(90)

Label counts per class:
tensor([112, 126,  60,  60,  62,  62,  62,  80,  55,  72, 142,  62,  62,  62,
         63,  66,  62,  79,  62,  62])
Difference between min and max
tensor(87)

Label counts per class:
tensor([113, 145,  64,  64,  65,  64,  64,  69,  72,  76, 156,  63,  63,  61,
         62,  59,  64,  78,  63,  63])
Difference between min and max
tensor(97)

Label counts per class:
tensor([103, 146,  63,  65,  62,  63,  64,  77,  72,  72, 137,  64,  64,  66,
         64,  62,  63,  80,  64,  64])
Difference between min and max
tensor(84)

Label counts per class:
tensor([100, 150,  61,  62,  62,  61,  60,  82,  53,  71, 150,  61,  61,  62,
         61,  61,  61,  82,  60,  60])
Difference between min and max
tensor(97)

Label counts per class:
tensor([125, 126,  64,  68,  64,  64,  66,  63,  70,  79