Skip to content
This repository has been archived by the owner on Oct 17, 2021. It is now read-only.

added masking layer #525

Merged
merged 6 commits into from
Apr 16, 2019
Merged

Conversation

Gurpreetsingh9465
Copy link
Contributor

@Gurpreetsingh9465 Gurpreetsingh9465 commented Apr 13, 2019

this is a PR regarding issue tensorflow/tfjs#1502 usage
var a = tf.tensor([[[1, 2, 3], [1, 2, 3], [1, 2, 3]], [[8, 8, 8], [1, 5, 7], [8, 8, 9]], [[8, 8, 8], [8, 8, 8], [8, 8, 8]]]);
var test2 = tf.layers.masking({maskValue: 8}).apply(a);
test2.print();


This change is Reviewable

Copy link
Contributor

@caisq caisq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 0 of 1 approvals obtained (waiting on @Gurpreetsingh9465)


src/layers/masking.ts, line 12 at r1 (raw file):

 /**
  * TensorFlow.js Layers: Masking Layer.

Re code location: to be consistent with Python, this code should live in core.ts.


src/layers/masking.ts, line 24 at r1 (raw file):

Masking Value.

Explicitly say what the default value is (0).


src/layers/masking.ts, line 25 at r1 (raw file):

 mas

style: indent should be 2 spaces, not 1.


src/layers/masking.ts, line 36 at r1 (raw file):

exception

In JavaScript/TypeScript, this should say "Error".


src/layers/masking.ts, line 57 at r1 (raw file):

args.maskValue;

This should be args.maskValue == null ? 0 : args.maskValue


src/layers/masking.ts, line 62 at r1 (raw file):

 }

Like I commented above, check all your indentations.


src/layers/masking.ts, line 66 at r1 (raw file):

mask_value

This should be maskValue (lowerCamelCase).


src/layers/masking.ts, line 75 at r1 (raw file):

-1,true)

To make the meaning of these parameters more explicit, do:

axis = -1
keepDims = true

src/layers/masking.ts, line 81 at r1 (raw file):

serialization.registerClass(Masking);

This PR lacks unit tests for the new feature. Please add tests. See examples in core_test.ts for how to add unit tests.

In general, we test both computeOutputShape() and call(). The latter is tested through creating a model that contains
this type of Layer and invoking the predict() method of the model.

For tests with concrete tensor values, we compare our results with reference values from Python. In this case, you can do
something like:

import numpy as np
from tensorflow import keras

model = keras.Sequential()
model.add(keras.layers.Masking(input_shape=[3, 3]))
model.add(keras.layers.LSTM())

xs = np.array([
    [[1, 0, 0], [1, 1, 0], [1, 1, 1]],
    [[0, 0, 0], [1, 0, 0], [1, 1, 0]]], dtype=np.float32)
ys = model.predict(xs)
print(ys)

@Gurpreetsingh9465
Copy link
Contributor Author

is this a bug?
import * as tf from '@tensorflow/tfjs';
const model = tf.sequential();
let t = tf.tensor([
[[1, 0, 0], [1, 1, 0], [1, 1, 1]],
[[0, 0, 0], [1, 0, 0], [1, 1, 0]]]);
model.add(tf.layers.reLU({batchInputShape: t.shape}));
model.add(tf.layers.lstm({units: 4}));
let ans = model.predict(t);
ans.print();

Error:
Property 'print' does not exist on type 'Tensor<Rank>[]'.

@caisq
Copy link
Contributor

caisq commented Apr 15, 2019

@Gurpreetsingh9465 This seems to be due to the union type of the return value of predict(). You should do.

const ans = model.predict(t) as Tensor;
ans.print();

@caisq
Copy link
Contributor

caisq commented Apr 15, 2019

@Gurpreetsingh9465 I see you've been making changes to address my comments. Please let me know when the PR is ready for me to review again. Also don't hesitate to ask if you need additional help.

@Gurpreetsingh9465
Copy link
Contributor Author

Error by using
let ans =model.predict(t) as Tensor;
Argument of type 'Tensor<Rank>' is not assignable to parameter of type 'Tensor<Rank> | Tensor<Rank>[]'.

also tried
let ans =model.predict(t as Tensor) as Tensor;
Argument of type 'Tensor<Rank>' is not assignable to parameter of type 'Tensor<Rank> | Tensor<Rank>[]'.

Copy link
Contributor

@caisq caisq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Gurpreetsingh9465 I'm going to merge this PR now. I'll add some unit tests myself and cc you on the PR.

Thank you for your contribution!

Reviewable status: :shipit: complete! 1 of 1 approvals obtained (waiting on @Gurpreetsingh9465)

@caisq caisq merged commit db9e23a into tensorflow:master Apr 16, 2019
@Gurpreetsingh9465
Copy link
Contributor Author

thank you @caisq 😃

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants