Skip to content
This repository was archived by the owner on Oct 17, 2021. It is now read-only.

adds truncatedNormal as a valid distribution identifier #510

Merged
merged 2 commits into from
Mar 28, 2019

Conversation

bileschi
Copy link
Contributor

@bileschi bileschi commented Mar 27, 2019

Expands the spec to allow for an additional string identifier here.

fixes tensorflow/tfjs#1460
"Newly added truncated_normal initializer is not supported in tfjs-layers"


This change is Reviewable

@bileschi bileschi requested review from davidsoergel and caisq March 27, 2019 22:30
Copy link
Member

@davidsoergel davidsoergel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: :shipit: complete! 1 of 1 approvals obtained (waiting on @bileschi, @caisq, and @davidsoergel)


src/initializers_test.ts, line 425 at r1 (raw file):

  });

  ['uniform', 'normal', 'truncatedNormal'].forEach(distribution => {

Worth thinking about whether we should purposely duplicate the string constants here, or just use VALID_DISTRIBUTION_VALUES from keras_format (which per my other suggestion might end up renamed to distributionOptions, but anyway).


src/keras_format/initializer_config.ts, line 27 at r1 (raw file):

    ['normal', 'uniform', 'truncatedNormal'];
// These constants have a snake vs. camel distinction.
export type DistributionSerialization = 'normal'|'uniform'|'truncated_normal';

See activation_config.ts for a different mechanism for this. Sorry for the historical inconsistency, but I think the stringLiteralArray thing is better going forward. You could update both FanMode and Distribution to that approach right away, or submit this as is and then I or anyone can do that cleanup in a later PR.

Copy link
Contributor

@caisq caisq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR, @bileschi !

Can you add a unit test and see if the following model JSON can be loaded, with the fix here?

const obj = JSON.parse('{"modelTopology":{"class_name":"Sequential","config":[{"class_name":"Dense","config":{"units":1,"activation":"linear","use_bias":true,"kernel_initializer":{"class_name":"VarianceScaling","config":{"scale":1,"mode":"fan_avg","distribution":"truncated_normal","seed":null}},"bias_initializer":{"class_name":"Zeros","config":{}},"kernel_regularizer":null,"bias_regularizer":null,"activity_regularizer":null,"kernel_constraint":null,"bias_constraint":null,"name":"dense_Dense1","trainable":true,"batch_input_shape":[null,1],"dtype":"float32"}}],"keras_version":"tfjs-layers 1.0.2","backend":"tensor_flow.js"},"weightsManifest":[{"paths":["weights.bin"],"weights":[{"name":"dense_Dense1/kernel","shape":[1,1],"dtype":"float32"},{"name":"dense_Dense1/bias","shape":[1],"dtype":"float32"}]}]}');
  const model = tf.models.modelFromJSON(obj);

Reviewable status: :shipit: complete! 1 of 1 approvals obtained (waiting on @bileschi and @caisq)

Copy link
Contributor Author

@bileschi bileschi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: :shipit: complete! 1 of 1 approvals obtained (waiting on @davidsoergel)


src/initializers_test.ts, line 425 at r1 (raw file):

Previously, davidsoergel (David Soergel) wrote…

Worth thinking about whether we should purposely duplicate the string constants here, or just use VALID_DISTRIBUTION_VALUES from keras_format (which per my other suggestion might end up renamed to distributionOptions, but anyway).

Done.


src/keras_format/initializer_config.ts, line 27 at r1 (raw file):

Previously, davidsoergel (David Soergel) wrote…

See activation_config.ts for a different mechanism for this. Sorry for the historical inconsistency, but I think the stringLiteralArray thing is better going forward. You could update both FanMode and Distribution to that approach right away, or submit this as is and then I or anyone can do that cleanup in a later PR.

Let me submit this as it is now, since I know it works. Then I'll try to straighten this out in a followon.

@bileschi bileschi merged commit fdc82d4 into master Mar 28, 2019
@bileschi bileschi deleted the truncatedNormal branch March 28, 2019 22:52
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Newly added truncated_normal initializer is not supported in tfjs-layers
3 participants