-
Notifications
You must be signed in to change notification settings - Fork 610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use tf.keras.backend.epsilon() as dtype #2008
Use tf.keras.backend.epsilon() as dtype #2008
Conversation
Add test Fix test
You are owner of some files modified in this pull request. |
Hi, Tzu-Wei, I know there is one way to serialize tf.dtypes.DType: In [21]: tf.float32.as_datatype_enum
Out[21]: 1
In [22]: tf.as_dtype(tf.float32.as_datatype_enum)
Out[22]: tf.float32 |
Awesome! Let me try it. |
|
||
|
||
def test_dtype_config(): | ||
wkl = WeightedKappaLoss(num_classes=4, dtype=tf.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this the only loss in the repo where we handle dtype
in the loss class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes it is :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a particular reason?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because it is very weird to pass dtype
into losses. In core TF, the dtype of losses is determined either from inputs or from backend.floatx()
. Consider the loss is going to pass into optimizers, which (mostly) creates the slots with the same dtype of variables, I think loss will be better to have either dtype floatx()
or the one same with y_pred/y_true
, so that the any computation in backward pass will not throw some errors like incompatible dtype.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant Is there any particular reason why we are using it in this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ohoh, sorry for misunderstanding :-(
Actually, I think no reason we do that. I would say backend.floatx()
is the better solution. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we don't break serialized users models I think It Is better
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar thoughts. That argument doesn't make sense as this should be inferred from the input dtype
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I revert the commit and use tf.keras.backend.floatx()
instead.
This reverts commit 6ace2cb.
* Use tf.keras.backend.epsilon() as dtype Add test Fix test * Fix test * Fix casting * Fix typo * Use as_datatype_enum * Revert "Use as_datatype_enum" This reverts commit 6ace2cb.
Fixes #2006.
There is no way to serialize
tf.dtypes.DType
and intf.keras.losses.*
, the computation dtype depends on either input (y_true
,y_pred
) ortf.keras.backend.floatx()
. It terms out the most reasonable way for me is to cast values totf.keras.backend.floatx()
for computation because it involves something like log and div.https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/keras/losses.py