Skip to content

Commit

Permalink
reduce the time when calling sampler (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
hwany-j committed May 31, 2021
1 parent 6685679 commit ad50e22
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 14 deletions.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
torch>=1.3
torchvision>=0.5
torchvision>=0.5
pandas
29 changes: 16 additions & 13 deletions torchsampler/imbalanced.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Callable

import pandas as pd
import torch
import torch.utils.data
import torchvision
Expand All @@ -25,28 +26,30 @@ def __init__(self, dataset, indices: list = None, num_samples: int = None, callb
self.num_samples = len(self.indices) if num_samples is None else num_samples

# distribution of classes in the dataset
label_to_count = {}
for idx in self.indices:
label = self._get_label(dataset, idx)
label_to_count[label] = label_to_count.get(label, 0) + 1
df = pd.DataFrame()
df["label"] = self._get_labels(dataset)
df.index = self.indices
df = df.sort_index()

label_to_count = df["label"].value_counts()

weights = 1.0 / label_to_count(df["label"])

# weight for each sample
weights = [1.0 / label_to_count[self._get_label(dataset, idx)] for idx in self.indices]
self.weights = torch.DoubleTensor(weights)

def _get_label(self, dataset, idx):
def _get_labels(self, dataset):
if self.callback_get_label:
return self.callback_get_label(dataset, idx)
return self.callback_get_label(dataset)
elif isinstance(dataset, torchvision.datasets.MNIST):
return dataset.train_labels[idx].item()
return dataset.train_labels.tolist()
elif isinstance(dataset, torchvision.datasets.ImageFolder):
return dataset.imgs[idx][1]
return dataset.imgs[:][1]
elif isinstance(dataset, torchvision.datasets.DatasetFolder):
return dataset.samples[idx][1]
return dataset.samples[:][1]
elif isinstance(dataset, torch.utils.data.Subset):
return dataset.dataset.imgs[idx][1]
return dataset.dataset.imgs[:][1]
elif isinstance(dataset, torch.utils.data.Dataset):
return dataset.get_label(idx)
return dataset.get_labels()
else:
raise NotImplementedError

Expand Down

0 comments on commit ad50e22

Please sign in to comment.