diff --git a/torch/utils/data/sampler.py b/torch/utils/data/sampler.py index 62a9a43bd1d4..8ca819cf2e84 100644 --- a/torch/utils/data/sampler.py +++ b/torch/utils/data/sampler.py @@ -1,4 +1,5 @@ import torch +import random class Sampler(object): @@ -93,6 +94,52 @@ def __len__(self): return self.num_samples +class EnlargeLabelShufflingSampler(Sampler): + """ + label shuffling technique aimed to deal with imbalanced class problem + without replacement, manipulated by indices. + All classes are enlarged to the same amount, so classes can be trained equally. + argument: + indices: indices of labels of the whole dataset + """ + + def __init__(self, indices): + # mapping between label index and sorted label index + sorted_labels = sorted(enumerate(indices), key=lambda x: x[1]) + count = 1 + count_of_each_label = [] + tmp = -1 + # get count of each label + for (x, y) in sorted_labels: + if y == tmp: + count += 1 + else: + if tmp != -1: + count_of_each_label.append(count) + count = 1 + tmp = y + count_of_each_label.append(count) + # get the largest count among all classes. used to enlarge every class to the same amount + largest = max(count_of_each_label) + self.count_of_each_label = count_of_each_label + self.enlarged_index = [] + + # preidx used for find the mapping beginning of arg "sorted_labels" + preidx = 0 + for x in range(len(self.count_of_each_label)): + idxes = torch.remainder(torch.randperm(largest).numpy(), self.count_of_each_label[x]) + preidx + for y in idxes: + self.enlarged_index.append(sorted_labels[y][0]) + preidx += int(self.count_of_each_label[x]) + + def __iter__(self): + random.shuffle(self.enlarged_index) + return iter(self.enlarged_index) + + def __len__(self): + return max(self.count_of_each_label) * len(self.count_of_each_label) + + class BatchSampler(object): """Wraps another sampler to yield a mini-batch of indices.