Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

new enlarge label shuffling sampler #4153

Closed
wants to merge 4 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
47 changes: 47 additions & 0 deletions torch/utils/data/sampler.py
@@ -1,4 +1,5 @@
import torch
import random


class Sampler(object):
Expand Down Expand Up @@ -93,6 +94,52 @@ def __len__(self):
return self.num_samples


class EnlargeLabelShufflingSampler(Sampler):
"""
label shuffling technique aimed to deal with imbalanced class problem
without replacement, manipulated by indices.
All classes are enlarged to the same amount, so classes can be trained equally.
argument:
indices: indices of labels of the whole dataset
"""

def __init__(self, indices):
# mapping between label index and sorted label index
sorted_labels = sorted(enumerate(indices), key=lambda x: x[1])
count = 1
count_of_each_label = []
tmp = -1
# get count of each label
for (x, y) in sorted_labels:
if y == tmp:
count += 1
else:
if tmp != -1:
count_of_each_label.append(count)
count = 1
tmp = y
count_of_each_label.append(count)
# get the largest count among all classes. used to enlarge every class to the same amount
largest = max(count_of_each_label)
self.count_of_each_label = count_of_each_label
self.enlarged_index = []

# preidx used for find the mapping beginning of arg "sorted_labels"
preidx = 0
for x in range(len(self.count_of_each_label)):
idxes = torch.remainder(torch.randperm(largest).numpy(), self.count_of_each_label[x]) + preidx
for y in idxes:
self.enlarged_index.append(sorted_labels[y][0])
preidx += int(self.count_of_each_label[x])

def __iter__(self):
random.shuffle(self.enlarged_index)
return iter(self.enlarged_index)

def __len__(self):
return max(self.count_of_each_label) * len(self.count_of_each_label)


class BatchSampler(object):
"""Wraps another sampler to yield a mini-batch of indices.

Expand Down