Skip to content
This repository has been archived by the owner on Oct 17, 2021. It is now read-only.

Fix a bug in the handling of array of step outputs by TimeDistributed layer #315

Merged
merged 10 commits into from
Sep 6, 2018

Conversation

caisq
Copy link
Contributor

@caisq caisq commented Sep 6, 2018

Fixes: tensorflow/tfjs#681

BUG


This change is Reviewable

@caisq caisq changed the title Time dist fix [WIP; DO NOT REVIEW YET] Fix a bug in the handling of array of step outputs by TimeDistributed layer Sep 6, 2018
@caisq caisq changed the title [WIP; DO NOT REVIEW YET] Fix a bug in the handling of array of step outputs by TimeDistributed layer Fix a bug in the handling of array of step outputs by TimeDistributed layer Sep 6, 2018
Copy link
Member

@davidsoergel davidsoergel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 0 of 1 approvals obtained (waiting on @caisq, @davidsoergel, @ericdnielsen, and @bileschi)


src/layers/wrappers.ts, line 230 at r1 (raw file):

        // TODO(cais): Add useLearningPhase.
        const output =
            getExactlyOneTensor(this.layer.call(inputs, kwargs) as Tensor);

Maybe worth adding a comment about what is going on here (i.e., under what circumstances does call() return a length-1 array?)


src/layers/wrappers.ts, line 230 at r1 (raw file):

        // TODO(cais): Add useLearningPhase.
        const output =
            getExactlyOneTensor(this.layer.call(inputs, kwargs) as Tensor);

The cast should be unnecessary now

Copy link
Contributor Author

@caisq caisq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review!

Reviewable status: :shipit: complete! 1 of 1 approvals obtained (waiting on @davidsoergel, @ericdnielsen, and @bileschi)


src/layers/wrappers.ts, line 230 at r1 (raw file):

Previously, davidsoergel (David Soergel) wrote…

Maybe worth adding a comment about what is going on here (i.e., under what circumstances does call() return a length-1 array?)

Done.


src/layers/wrappers.ts, line 230 at r1 (raw file):

Previously, davidsoergel (David Soergel) wrote…

The cast should be unnecessary now

Done.

@caisq caisq merged commit e0b73c1 into tensorflow:master Sep 6, 2018
@rodrigopivi
Copy link

rodrigopivi commented Sep 12, 2018

hi @caisq

Thanks for this fix, after manually building the tfjs-layers and testing this use case of passing a model as a layer. This fix works when training with a validationSplit of 0.5. If i use another validation split value, then the training and validation tensors will have different shapes and tfjs will throw an error. Here is an example code of this problem:

const inputs = tf.input({ dtype: 'float32', shape: [1, 2] });
const lstm = tf.layers.lstm({ units: 2, returnSequences: true }).apply(inputs) as tf.SymbolicTensor;
const timeAttention = new TimeSeriesAttention({}).apply(lstm) as tf.SymbolicTensor;
const model = tf.model({ inputs, outputs: timeAttention });
const optimize = tf.train.adam(0.0066, 0.0025, 0.1);
model.compile({ loss: 'categoricalCrossentropy', metrics: ['accuracy'], optimizer: optimize });

const inp = tf.tensor3d([[[1,1]],[[2,2]],[[3,3]],[[4,4]],[[5,5]],[[6,6]]],[6,1,2]);
const out = tf.tensor3d([[[1,0]],[[2,0]],[[3,0]],[[4,0]],[[5,0]],[[6,0]]], [6, 1, 2]);

(async () => {
    // NOTE: if validation split is 0.5 this works, else it fails because the tensors shape for
    //             train and validation will have different shapes
    await model.fit(inp, out, { validationSplit: 0.2 });
})();

NOTE: TimeSeriesAttention is just a model

@caisq
Copy link
Contributor Author

caisq commented Sep 13, 2018

@rodrigopivi Thanks for the report and the code for reproducing the error. It's on my TODO list to look at this issue. But my schedule is a little tight curerntly, so expect a delay of 1-2 weeks.

@rodrigopivi
Copy link

thank you

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
3 participants