Skip to content
This repository has been archived by the owner on Oct 17, 2021. It is now read-only.

Fix for overly-agressive dtype checking for symbolic tensors. Rely instead on casting. #256

Merged
merged 11 commits into from
Jul 17, 2018

Conversation

bileschi
Copy link
Contributor

@bileschi bileschi commented Jun 29, 2018

Description

Fix for overly-agressive dtype checking for symbolic tensors. Rely instead on casting.

For repository owners only:

BUG: Fixes issue wherein a model with just an embedding layer could not accept int32 type input.


This change is Reviewable

@bileschi bileschi requested a review from caisq June 29, 2018 16:36
@caisq
Copy link
Contributor

caisq commented Jun 29, 2018

Review status: 0 of 1 LGTMs obtained


src/engine/executor.ts, line 24 at r1 (raw file):

if (key.dtype != null && key.dtype !== val.dtype) {

Thanks for sending the PR. I think this is an important issue to solve, especially considering that in the future, we will have more dtypes, like string, float16, unsigned ints, etc. I think we should do automatic casting of the input Tensor value, just like what Python keras and tensorflow do. I also think doing casting is better than removing this check here. With this change, if a model-internal operation doesn't support the dtype provided, the error message and stack will be confusing. It also goes against the principle of "fail fast", i.e., it'll sometimes waste time doing some computation in early parts of the model until hitting a part of the graph that can't handle the dtype, instead of telling the user that the dtype won't work at the beginning.

Below I will describe an alternative that I think is better.

Instead of removing this rigid type check, replace it with a best-of-effort casting for the user. This method should probably be renamed and its signature be changed, something like

function checkAndMaybeCastFeed(key: SymbolicTensor, val: Tensor): Tensor {
   // 1. Check shape compatibility. If shapes are not compatible, error out as is.
   // ...

  // 2. Check dtype compatibility. 
  // If dtypes match, return the Tensor as is without any change.
  // Else,
  // try {
  //    return tf.cast(val, key.dtype)
  // } catch(err) {
  //    throw an Error with an informative error message.
  // }
  // 
}

Comments from Reviewable

@bileschi
Copy link
Contributor Author

bileschi commented Jul 6, 2018 via email

Copy link
Contributor

@caisq caisq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 0 of 1 LGTMs obtained


src/models_test.ts, line 1108 at r2 (raw file):

const dtypes: DataType[] = ['int32', 'float32'];

These tests are nice


src/models_test.ts, line 1116 at r2 (raw file):

y.dtype === dtype

Can you help me understand why the output dtype is determined by the input dtype?


src/models_test.ts, line 1123 at r2 (raw file):

embModel.compile({optimizer: 'sgd', loss: 'meanSquaredError'});

I don't think this test is really testing anything. For starters, dtype is not used.


src/engine/executor.ts, line 53 at r2 (raw file):

is incompatible with

Maybe be more explicit and say "cannot be cast to...".


src/engine/executor.ts, line 104 at r2 (raw file):

typeCompatibleValue;

This can be replaced with assertFeedCompatibility(key, value) and Line 102 can be removed.

Copy link
Contributor

@caisq caisq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 0 of 1 LGTMs obtained


src/models_test.ts, line 1108 at r2 (raw file):

Previously, caisq (Shanqing Cai) wrote…
const dtypes: DataType[] = ['int32', 'float32'];

These tests are nice

... but is it possible to conceive a test where the casting fails?

Copy link
Contributor Author

@bileschi bileschi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 0 of 1 LGTMs obtained


src/models_test.ts, line 1108 at r2 (raw file):

Previously, caisq (Shanqing Cai) wrote…

... but is it possible to conceive a test where the casting fails?

The original motivation for this work was clearing the path for 'string' type tensors. We can extend this test to fail for 'string' to X conversions. The way it stands now, however, all conversions are allowed by tf.cast (bool <--> int32 <--> float32)

✂-1


src/models_test.ts, line 1116 at r2 (raw file):

Previously, caisq (Shanqing Cai) wrote…
y.dtype === dtype

Can you help me understand why the output dtype is determined by the input dtype?

This is a double bug in the test. Two bugs in one! y.dtype isn't supposed to match dtype, but expect(whatever), with no matcher, does not fail.

The expected behavior, per python, is that the return type is 'float32', independent of the input type. The test checks this now.

In [12]: emb_model.predict(np.array([[0], [0], [1]]))
Out[12]: 
array([[[ 0.04032261, -0.03752766]],

       [[ 0.04032261, -0.03752766]],

       [[-0.04086472,  0.03731059]]], dtype=float32)

In [13]: emb_model.predict(np.array([[0], [0], [1]], dtype='int32'))
Out[13]: 
array([[[ 0.04032261, -0.03752766]],

       [[ 0.04032261, -0.03752766]],

       [[-0.04086472,  0.03731059]]], dtype=float32)

In [14]: np.array([[0], [0], [1]], dtype='int32')
Out[14]: 
array([[0],
       [0],
       [1]], dtype=int32)

src/models_test.ts, line 1123 at r2 (raw file):

Previously, caisq (Shanqing Cai) wrote…
embModel.compile({optimizer: 'sgd', loss: 'meanSquaredError'});

I don't think this test is really testing anything. For starters, dtype is not used.

agreed. removed.


src/engine/executor.ts, line 24 at r1 (raw file):

Previously, caisq (Shanqing Cai) wrote…
if (key.dtype != null && key.dtype !== val.dtype) {

Thanks for sending the PR. I think this is an important issue to solve, especially considering that in the future, we will have more dtypes, like string, float16, unsigned ints, etc. I think we should do automatic casting of the input Tensor value, just like what Python keras and tensorflow do. I also think doing casting is better than removing this check here. With this change, if a model-internal operation doesn't support the dtype provided, the error message and stack will be confusing. It also goes against the principle of "fail fast", i.e., it'll sometimes waste time doing some computation in early parts of the model until hitting a part of the graph that can't handle the dtype, instead of telling the user that the dtype won't work at the beginning.

Below I will describe an alternative that I think is better.

Instead of removing this rigid type check, replace it with a best-of-effort casting for the user. This method should probably be renamed and its signature be changed, something like

function checkAndMaybeCastFeed(key: SymbolicTensor, val: Tensor): Tensor {
   // 1. Check shape compatibility. If shapes are not compatible, error out as is.
   // ...

  // 2. Check dtype compatibility. 
  // If dtypes match, return the Tensor as is without any change.
  // Else,
  // try {
  //    return tf.cast(val, key.dtype)
  // } catch(err) {
  //    throw an Error with an informative error message.
  // }
  // 
}

Thanks for your thoughtful feedback. I agree that it is right to catch these errors as early as possible, and to use early casting as mechanism to distinguish innocuous castings from troublesome ones, so as to deliver actionable error messages.

There is a problem with the proposed solution however; the cast is between the wrong two types. key.dtype is the type of the expected input tensor, and val.dtype is the type of the expected output tensor. However, we don't expect these to match. We want to cast between the expected input dtype and the actual input dtype. This wouldn't be an issue, but the actual input dtype isn't available at the time of this call. This call happens at model compile time, but we don't know the actual type of the input until the call to execute.

To summarize, there are four dtypes involved.

  1. Expected input dtype
  2. Expected output dtype
  3. Actual input dtype
  4. Actual output dtype

The current implementation is comparing 1 and 2


src/engine/executor.ts, line 53 at r2 (raw file):

Previously, caisq (Shanqing Cai) wrote…
is incompatible with

Maybe be more explicit and say "cannot be cast to...".

Done.

Copy link
Contributor

@caisq caisq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 0 of 1 LGTMs obtained


src/models_test.ts, line 1108 at r2 (raw file):

Previously, bileschi (Stanley Bileschi) wrote…

The original motivation for this work was clearing the path for 'string' type tensors. We can extend this test to fail for 'string' to X conversions. The way it stands now, however, all conversions are allowed by tf.cast (bool <--> int32 <--> float32)

✂-1

OK. I see. In that case, can you add a TODO item here for adding a test that coves casting failure?


src/models_test.ts, line 1123 at r2 (raw file):

Previously, bileschi (Stanley Bileschi) wrote…

agreed. removed.

Did you forget to push your commit? I still see the test in the latest snapshot.

Copy link
Contributor Author

@bileschi bileschi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 0 of 1 LGTMs obtained


src/models_test.ts, line 1123 at r2 (raw file):

Previously, caisq (Shanqing Cai) wrote…

Did you forget to push your commit? I still see the test in the latest snapshot.

Indeed. Pushed to my personal fork, but that didn't carry over to the main repo.

Copy link
Contributor Author

@bileschi bileschi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 0 of 1 LGTMs obtained


src/engine/executor.ts, line 104 at r2 (raw file):

Previously, caisq (Shanqing Cai) wrote…
typeCompatibleValue;

This can be replaced with assertFeedCompatibility(key, value) and Line 102 can be removed.

Done.

Copy link
Contributor

@caisq caisq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:lgtm_strong:

Reviewable status: 0 of 1 approvals obtained (waiting on @caisq)

@bileschi bileschi merged commit 96a242c into master Jul 17, 2018
@dsmilkov dsmilkov deleted the embedding-dtype branch February 21, 2019 15:38
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
2 participants