-
Notifications
You must be signed in to change notification settings - Fork 95
Fix for overly-agressive dtype checking for symbolic tensors. Rely instead on casting. #256
Conversation
… tests for desired type-casting behavior.
Review status: 0 of 1 LGTMs obtained src/engine/executor.ts, line 24 at r1 (raw file):
Thanks for sending the PR. I think this is an important issue to solve, especially considering that in the future, we will have more dtypes, like string, float16, unsigned ints, etc. I think we should do automatic casting of the input Tensor value, just like what Python keras and tensorflow do. I also think doing casting is better than removing this check here. With this change, if a model-internal operation doesn't support the dtype provided, the error message and stack will be confusing. It also goes against the principle of "fail fast", i.e., it'll sometimes waste time doing some computation in early parts of the model until hitting a part of the graph that can't handle the dtype, instead of telling the user that the dtype won't work at the beginning. Below I will describe an alternative that I think is better. Instead of removing this rigid type check, replace it with a best-of-effort casting for the user. This method should probably be renamed and its signature be changed, something like function checkAndMaybeCastFeed(key: SymbolicTensor, val: Tensor): Tensor {
// 1. Check shape compatibility. If shapes are not compatible, error out as is.
// ...
// 2. Check dtype compatibility.
// If dtypes match, return the Tensor as is without any change.
// Else,
// try {
// return tf.cast(val, key.dtype)
// } catch(err) {
// throw an Error with an informative error message.
// }
//
} Comments from Reviewable |
Updated with explicit cast
…On Fri, Jun 29, 2018 at 4:50 PM, Shanqing Cai ***@***.***> wrote:
Review status: 0 of 1 LGTMs obtained
------------------------------
*src/engine/executor.ts, line 24 at r1
<https://reviewable.io/reviews/tensorflow/tfjs-layers/256#-LGCHJdl8Gu2Aym_6jnZ:-LGCHJdl8Gu2Aym_6jn_:bfqpguv>
(raw file
<https://github.com/tensorflow/tfjs-layers/blob/04beb115ed8f5f66a032569300fa9b51c95b6466/src/engine/executor.ts#L24>):*
if (key.dtype != null && key.dtype !== val.dtype) {
Thanks for sending the PR. I think this is an important issue to solve,
especially considering that in the future, we will have more dtypes, like
string, float16, unsigned ints, etc. I think we should do automatic casting
of the input Tensor value, just like what Python keras and tensorflow do. I
also think doing casting is better than removing this check here. With this
change, if a model-internal operation doesn't support the dtype provided,
the error message and stack will be confusing. It also goes against the
principle of "fail fast", i.e., it'll sometimes waste time doing some
computation in early parts of the model until hitting a part of the graph
that can't handle the dtype, instead of telling the user that the dtype
won't work at the beginning.
Below I will describe an alternative that I think is better.
Instead of removing this rigid type check, replace it with a
best-of-effort casting for the user. This method should probably be renamed
and its signature be changed, something like
function checkAndMaybeCastFeed(key: SymbolicTensor, val: Tensor): Tensor {
// 1. Check shape compatibility. If shapes are not compatible, error out as is. // ...
// 2. Check dtype compatibility. // If dtypes match, return the Tensor as is without any change. // Else, // try { // return tf.cast(val, key.dtype) // } catch(err) { // throw an Error with an informative error message. // } // }
------------------------------
*Comments from Reviewable
<https://reviewable.io/reviews/tensorflow/tfjs-layers/256>*
—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
<#256 (comment)>,
or mute the thread
<https://github.com/notifications/unsubscribe-auth/AAhZTjf7Yn6Uq3_hhhzqqEbXIdD1UZSoks5uBpMhgaJpZM4U9TpZ>
.
--
Stan Bileschi Ph.D. | SWE | bileschi@google.com | 617-230-8081
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: 0 of 1 LGTMs obtained
src/models_test.ts, line 1108 at r2 (raw file):
const dtypes: DataType[] = ['int32', 'float32'];
These tests are nice
src/models_test.ts, line 1116 at r2 (raw file):
y.dtype === dtype
Can you help me understand why the output dtype is determined by the input dtype?
src/models_test.ts, line 1123 at r2 (raw file):
embModel.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
I don't think this test is really testing anything. For starters, dtype is not used.
src/engine/executor.ts, line 53 at r2 (raw file):
is incompatible with
Maybe be more explicit and say "cannot be cast to...".
src/engine/executor.ts, line 104 at r2 (raw file):
typeCompatibleValue;
This can be replaced with assertFeedCompatibility(key, value)
and Line 102 can be removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: 0 of 1 LGTMs obtained
src/models_test.ts, line 1108 at r2 (raw file):
Previously, caisq (Shanqing Cai) wrote…
const dtypes: DataType[] = ['int32', 'float32'];
These tests are nice
... but is it possible to conceive a test where the casting fails?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: 0 of 1 LGTMs obtained
src/models_test.ts, line 1108 at r2 (raw file):
Previously, caisq (Shanqing Cai) wrote…
... but is it possible to conceive a test where the casting fails?
The original motivation for this work was clearing the path for 'string' type tensors. We can extend this test to fail for 'string' to X conversions. The way it stands now, however, all conversions are allowed by tf.cast (bool <--> int32 <--> float32)
src/models_test.ts, line 1116 at r2 (raw file):
Previously, caisq (Shanqing Cai) wrote…
y.dtype === dtype
Can you help me understand why the output dtype is determined by the input dtype?
This is a double bug in the test. Two bugs in one! y.dtype isn't supposed to match dtype, but expect(whatever), with no matcher, does not fail.
The expected behavior, per python, is that the return type is 'float32', independent of the input type. The test checks this now.
In [12]: emb_model.predict(np.array([[0], [0], [1]]))
Out[12]:
array([[[ 0.04032261, -0.03752766]],
[[ 0.04032261, -0.03752766]],
[[-0.04086472, 0.03731059]]], dtype=float32)
In [13]: emb_model.predict(np.array([[0], [0], [1]], dtype='int32'))
Out[13]:
array([[[ 0.04032261, -0.03752766]],
[[ 0.04032261, -0.03752766]],
[[-0.04086472, 0.03731059]]], dtype=float32)
In [14]: np.array([[0], [0], [1]], dtype='int32')
Out[14]:
array([[0],
[0],
[1]], dtype=int32)
src/models_test.ts, line 1123 at r2 (raw file):
Previously, caisq (Shanqing Cai) wrote…
embModel.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
I don't think this test is really testing anything. For starters, dtype is not used.
agreed. removed.
src/engine/executor.ts, line 24 at r1 (raw file):
Previously, caisq (Shanqing Cai) wrote…
if (key.dtype != null && key.dtype !== val.dtype) {
Thanks for sending the PR. I think this is an important issue to solve, especially considering that in the future, we will have more dtypes, like string, float16, unsigned ints, etc. I think we should do automatic casting of the input Tensor value, just like what Python keras and tensorflow do. I also think doing casting is better than removing this check here. With this change, if a model-internal operation doesn't support the dtype provided, the error message and stack will be confusing. It also goes against the principle of "fail fast", i.e., it'll sometimes waste time doing some computation in early parts of the model until hitting a part of the graph that can't handle the dtype, instead of telling the user that the dtype won't work at the beginning.
Below I will describe an alternative that I think is better.
Instead of removing this rigid type check, replace it with a best-of-effort casting for the user. This method should probably be renamed and its signature be changed, something like
function checkAndMaybeCastFeed(key: SymbolicTensor, val: Tensor): Tensor { // 1. Check shape compatibility. If shapes are not compatible, error out as is. // ... // 2. Check dtype compatibility. // If dtypes match, return the Tensor as is without any change. // Else, // try { // return tf.cast(val, key.dtype) // } catch(err) { // throw an Error with an informative error message. // } // }
Thanks for your thoughtful feedback. I agree that it is right to catch these errors as early as possible, and to use early casting as mechanism to distinguish innocuous castings from troublesome ones, so as to deliver actionable error messages.
There is a problem with the proposed solution however; the cast is between the wrong two types. key.dtype is the type of the expected input tensor, and val.dtype is the type of the expected output tensor. However, we don't expect these to match. We want to cast between the expected input dtype and the actual input dtype. This wouldn't be an issue, but the actual input dtype isn't available at the time of this call. This call happens at model compile time, but we don't know the actual type of the input until the call to execute.
To summarize, there are four dtypes involved.
- Expected input dtype
- Expected output dtype
- Actual input dtype
- Actual output dtype
The current implementation is comparing 1 and 2
src/engine/executor.ts, line 53 at r2 (raw file):
Previously, caisq (Shanqing Cai) wrote…
is incompatible with
Maybe be more explicit and say "cannot be cast to...".
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: 0 of 1 LGTMs obtained
src/models_test.ts, line 1108 at r2 (raw file):
Previously, bileschi (Stanley Bileschi) wrote…
The original motivation for this work was clearing the path for 'string' type tensors. We can extend this test to fail for 'string' to X conversions. The way it stands now, however, all conversions are allowed by tf.cast (bool <--> int32 <--> float32)
OK. I see. In that case, can you add a TODO item here for adding a test that coves casting failure?
src/models_test.ts, line 1123 at r2 (raw file):
Previously, bileschi (Stanley Bileschi) wrote…
agreed. removed.
Did you forget to push your commit? I still see the test in the latest snapshot.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: 0 of 1 LGTMs obtained
src/models_test.ts, line 1123 at r2 (raw file):
Previously, caisq (Shanqing Cai) wrote…
Did you forget to push your commit? I still see the test in the latest snapshot.
Indeed. Pushed to my personal fork, but that didn't carry over to the main repo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: 0 of 1 LGTMs obtained
src/engine/executor.ts, line 104 at r2 (raw file):
Previously, caisq (Shanqing Cai) wrote…
typeCompatibleValue;
This can be replaced with
assertFeedCompatibility(key, value)
and Line 102 can be removed.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: 0 of 1 approvals obtained (waiting on @caisq)
Description
Fix for overly-agressive dtype checking for symbolic tensors. Rely instead on casting.
For repository owners only:
BUG: Fixes issue wherein a model with just an embedding layer could not accept int32 type input.
This change is