Skip to content

Conversation

ptrblck
Copy link
Contributor

@ptrblck ptrblck commented May 6, 2018

This tutorial deals with the problem of an imbalanced dataset and how to train a classifier on it.

After training a CNN on the original CIFAR10 dataset, we resample it to create an artificially imbalanced dataset. Since the CNN performs quite poorly on this new dataset, we use the WeightedRandomSampler in the first step and a weighted criterion afterwards to tackle the problem.

I've created the tutorial in the intermediate section, but I'm not sure if it's the right place.

Feedback regarding the text and code is very welcome!

@chsasank chsasank self-assigned this May 8, 2018
###############################################################################
# Let's have a look at the class distribution in the datasets.

# Get all training targets and count the number of class instances

This comment was marked as off-topic.

This comment was marked as off-topic.

# The last 5 classes will keep their samples.

# Create class proportions
imbal_class_prop = imbal_class_prop = np.hstack(([0.1] * 5, [1.0] * 5))

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Contributor

@chsasank chsasank left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have roughly gone through half of the tutorial.

I also have some minor comments with formatting etc. But I'll create a PR to your branch later with those changes rather than commenting.

# The last 5 classes will keep their samples.

# Create class proportions
imbal_class_prop = imbal_class_prop = np.hstack(([0.1] * 5, [1.0] * 5))

This comment was marked as off-topic.

@ptrblck
Copy link
Contributor Author

ptrblck commented May 14, 2018

@chsasank Thank you for the review! I've added your suggestions. Let me know, what you think about the changes.

# Let's have a look at the class distribution in the datasets.


def get_labels_and_class_counts(labels_list):

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

f, (ax1, ax2) = plt.subplots(1, 2, sharey=True, figsize=(15, 6))
ax1.bar(class_names, train_class_counts)
ax1.set_title('Training dataset distribution')
ax1.set_xlabel('Classes')

This comment was marked as off-topic.

optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)


def train(epoch):

This comment was marked as off-topic.

@chsasank
Copy link
Contributor

Thanks, I think this looks good. Let me add some small formatting changes to your branch.

# Let's have a look at the class distribution in the datasets.

# Get all training targets and count the number of class instances
train_targets = np.array(train_dataset.train_labels)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Getting the following error on this line:
AttributeError: 'CIFAR10' object has no attribute 'train_labels'
Could you please help with this

@facebook-github-bot
Copy link
Contributor

Hi @ptrblck!

Thank you for your pull request. We require contributors to sign our Contributor License Agreement, and yours needs attention.

You currently have a record in our system, but we do not have a signature on file.

In order for us to review and merge your code, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

If you have received this in error or have any questions, please contact us at cla@fb.com. Thanks!

Base automatically changed from master to main February 16, 2021 19:32
Base automatically changed from main to master February 16, 2021 19:37
@facebook-github-bot
Copy link
Contributor

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Facebook open source project. Thanks!

Copy link

@kyscg kyscg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated variables because in recent versions, the variable for the labels for the training samples in torchvision.datasets.CIFAR10 has been changed to targets


# Get all training targets and count the number of class instances
train_targets, train_class_counts = get_labels_and_class_counts(
train_dataset.train_labels)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
train_dataset.train_labels)
train_dataset.targets)

'''
if self.train:
targets, class_counts = get_labels_and_class_counts(
self.dataset.train_labels)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.dataset.train_labels)
self.dataset.targets)

@svekars
Copy link
Contributor

svekars commented Mar 24, 2023

Closing this as it's been quite some time since it was created and no longer relevant.

@svekars svekars closed this Mar 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants