Skip to content
This repository has been archived by the owner on Aug 31, 2021. It is now read-only.

Class weight support #57

Closed
vinhqdang opened this issue Dec 28, 2015 · 13 comments
Closed

Class weight support #57

vinhqdang opened this issue Dec 28, 2015 · 13 comments
Assignees
Milestone

Comments

@vinhqdang
Copy link

Hi,

I am using skflow.ops.dnn to classify two - classes dataset (True and False). The percentage of True example is very small, so I have an imbalanced dataset.

It seems to me that one way to resolve the issue is to use weighted classes. However, when I look to the implementation of skflow.ops.dnn, I do not know how could I do weighted classes with DNN.

Is it possible to do that with skflow, or is there another technique to deal with imbalanced dataset problem in skflow?

Thanks

@ilblackdragon
Copy link
Contributor

Usually, there are two ways to handle imbalanced data:

  1. Oversample under-represented class.
    e.g. You can just copy each record by false_rate/positive_rate
  2. Class weights
    This really just needs to be implemented in the loss function - if you are interested, you need to change skflow.ops.losses_ops.softmax_classifier to take a loss_weight tensor as argument and do something like this in line 37:
 xent = tf.mul(xent, loss_weight)

and then pass it via the models.logistic_regression.
If you want to do it and try in your case and send PR - will be greatly appreciated :) Otherwise, I'll try adding this later this week.

@ilblackdragon ilblackdragon changed the title Weighted classes with DNN? Class weight support Dec 29, 2015
@ilblackdragon ilblackdragon added this to the 0.1 milestone Dec 29, 2015
@lopuhin
Copy link
Contributor

lopuhin commented Feb 2, 2016

@vinhqdang
Copy link
Author

Thanks @lopuhin , so what is the correct way to use it? (in a case of unbalanced dataset, 90% of class A and 10% of class B)?

@lopuhin
Copy link
Contributor

lopuhin commented Feb 2, 2016

@vinhqdang sorry, I was wrong - I don't think that existing implementation is correct, because softmax_cross_entropy_with_logits already returns losses for each example in a mini-batch. For now you can try to over-sample class B in training data.

@lopuhin
Copy link
Contributor

lopuhin commented Feb 2, 2016

And I am not sure if is is possible to implement it in terms of existing tensorflow loss functions? It seems one will need to defined a loss function similar to softmax_cross_entropy_with_logits from scratch.

@lopuhin
Copy link
Contributor

lopuhin commented Feb 2, 2016

Ah, no, it should be possible - we just need to multiply xent with weights that depend on labels in minibatch. Sorry for the noise :)

@ilblackdragon
Copy link
Contributor

@lopuhin It's possible, and as you mentioned https://github.com/tensorflow/skflow/blob/master/skflow/ops/losses_ops.py#L52 partially implements this bug (it does multiple xent for each class by weight of the class). The only missing piece is passing it from estimator (e.g. TFLinearClassifier(..., class_weights={1: 0.9, 0: 0.1}) to the models and losses.

I didn't think of a good interface yet to do this (right now it would need to be an argument for every model function).

@ilblackdragon
Copy link
Contributor

What's currently there can be used by creating an explicit TF constant and initializing it with your weights:

def my_model(X, y):
    class_weight = tf.constant([0.9, 0.1]))
    return skflow.models.logistic_regression(X, y, class_weight=class_weight)

estimator = skflow.TensorFlowEstimator(model_fn=my_model, n_classes=2, ...other args...)

@lopuhin
Copy link
Contributor

lopuhin commented Feb 2, 2016

Thats what I thought @ilblackdragon , but for me it fails with tensorflow.python.framework.errors.InvalidArgumentError: Incompatible shapes: [32] vs. [2] here https://github.com/tensorflow/skflow/blob/master/skflow/ops/losses_ops.py#L53 because xent is a tensor of shape [32], which is the number of examples in a minibatch. Maybe I'm using it wrong though.

@ilblackdragon
Copy link
Contributor

@lopuhin, You are right, softmax_cross_entropy_with_logits returns just [batch_size] of values... So it's already too late to add class weights. I'll take a look how to bypass that.

ilblackdragon added a commit that referenced this issue Feb 3, 2016
…he math should work as -weight[class]*x[class] + log( sum ( exp weighted x))
@ilblackdragon
Copy link
Contributor

Ok, so instead, I moved it up to multiple logits.
I think it still works with math, just need to try out on a some imbalanced dataset and then add initialization from a constructor. @lopuhin Let me know what do you think.

@lopuhin
Copy link
Contributor

lopuhin commented Feb 3, 2016

@ilblackdragon for me a more natural solution would be something like this lopuhin@5c97849 - here I apply weight to xent of each label depending of what labels it is. I think this is different mathematically from scaling logits. But I am still learning, so take this with a grin of salt please :)

@ilblackdragon
Copy link
Contributor

@lopuhin Change importance in cross-entropy is to adjust relative importance of all the classes to each other (skewing distribution) when in your option it will only adjusting the weight of one class. But your option may work in practice. I'll double check with few people what is the best way.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

3 participants