Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NotImplemented Error while running ImbalancedDatasetSampler #18

Closed
aryamansriram opened this issue Jun 18, 2020 · 8 comments
Closed

NotImplemented Error while running ImbalancedDatasetSampler #18

aryamansriram opened this issue Jun 18, 2020 · 8 comments

Comments

@aryamansriram
Copy link

aryamansriram commented Jun 18, 2020

I followed the steps exactly according to the readme file. Yet I am getting a notimplemented error. There's no explanation for the error as well.

Here's my code:
`from torchvision import transforms
from torchsampler import ImbalancedDatasetSampler

batch_size = 128
val_split = 0.2
shuffle_dataset=True
random_seed=42

dataset_size = len(melanoma_dataset)
indices = list(range(dataset_size))
split = int(np.floor(val_split * dataset_size))
if shuffle_dataset :
np.random.seed(random_seed)
np.random.shuffle(indices)
train_indices, test_indices = indices[split:], indices[:split]

train_loader = torch.utils.data.DataLoader(melanoma_dataset,batch_size=batch_size,sampler=ImbalancedDatasetSampler(melanoma_dataset))
test_loader = torch.utils.data.DataLoader(melanoma_dataset,batch_size=batch_size,sampler=test_sampler)`

@polm
Copy link

polm commented Jul 2, 2020

Having the same issue - no detailed error message, just a NotImplementedError.

It looks like the code only supports MNIST, Subsets, or an ImageFolder out of the box. If you have a custom dataset you need to implement callback_get_label.

@SorenJ89
Copy link

SorenJ89 commented Oct 28, 2020

If anyone reads this, this worked for me:

def callback_get_label(dataset, idx):
    #callback function used in imbalanced dataset loader.
    input, target = dataset[idx]
    return np.argwhere(target.numpy()).item()

Edit: I suspect it would be faster to not cast the tensor to numpy, so the following change should do the same within the tensor framework:

def callback_get_label(dataset, idx):
    #callback function used in imbalanced dataset loader.
    input, target = dataset[idx]
    return target.nonzero().item()

@PaleNeutron
Copy link

If you have an int label, try use this:

def callback_get_label(dataset, idx):
    #callback function used in imbalanced dataset loader.
    i, target = dataset[idx]
    return int(target)

@robinzhaorr
Copy link

If you have an int label, try use this:

def callback_get_label(dataset, idx):
    #callback function used in imbalanced dataset loader.
    i, target = dataset[idx]
    return int(target)

For those who are new to Python like me: define the 'callback_get_label' function before you initialize the train_loader, and make your function like:
train_loader = DataLoader(dataset, ImbalancedDatasetSampler(dataset,callback_get_label = callback_get_label),batch_size = batch_size)

@Borda
Copy link
Contributor

Borda commented Apr 9, 2021

@ufoym this is solved, can be closed 🐰

@ufoym
Copy link
Owner

ufoym commented Apr 10, 2021

@Borda Thanks pretty much for your contribution!

@ufoym ufoym closed this as completed Apr 10, 2021
@Andybrizt
Copy link

If you have an int label, try use this:

def callback_get_label(dataset, idx):
    #callback function used in imbalanced dataset loader.
    i, target = dataset[idx]
    return int(target)

For those who are new to Python like me: define the 'callback_get_label' function before you initialize the train_loader, and make your function like: train_loader = DataLoader(dataset, ImbalancedDatasetSampler(dataset,callback_get_label = callback_get_label),batch_size = batch_size)

I got this error:
TypeError: callback_get_label() missing 1 required positional argument: 'idx'

could you tell where to define the callback_get_label() function?

@robinzhaorr
Copy link

robinzhaorr commented May 20, 2022 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

8 participants