Skip to content
This repository was archived by the owner on Oct 17, 2021. It is now read-only.

Conversation

caisq
Copy link
Contributor

@caisq caisq commented Nov 30, 2018

  • The implementation of the training algorithm for BatchNormalization is different between keras-team/keras and tensorflow.keras.
  • This PR aligns the implementation with the latter. (Previously it aligned with the former.) Rationale:
    • Alignment with TensorFlow is a higher priority for TensorFlow.js than alignment with keras-team/keras
    • The convergence on the MNIST-ACGAN example is better with the tf.keras implementation.
  • While this change has no implication for the public API of tf.layers.batchNormalization, it may be breaking for users who rely on the exact numeric values of the batchNormalization layer during training.

BREAKING

This change is Reviewable

@caisq caisq requested a review from bileschi November 30, 2018 19:29
Copy link
Contributor

@bileschi bileschi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 0 of 1 approvals obtained (waiting on @caisq and @bileschi)


src/layers/normalization.ts, line 20 at r1 (raw file):

import {Constraint, ConstraintIdentifier, getConstraint, serializeConstraint} from '../constraints';
import {InputSpec, Layer, LayerConfig} from '../engine/topology';
import {getScalar} from '../backend/state';

Why did this move?

Copy link

@nsthorat nsthorat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: :shipit: complete! 1 of 1 approvals obtained (waiting on @caisq)


src/layers/normalization.ts, line 398 at r1 (raw file):

          (variable: LayerVariable, value: Tensor, momentum: number): void => {
            tfc.tidy(() => {
              const decay = getScalar(1.0).sub(getScalar(momentum));

FYI you should no longer need this scalar cache. You can also just pass momentum directly to sub (no need for wrapping in a scalar tensor).

@bileschi
Copy link
Contributor

bileschi commented Nov 30, 2018 via email

Copy link
Contributor Author

@caisq caisq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: :shipit: complete! 1 of 1 approvals obtained (waiting on @bileschi)


src/layers/normalization.ts, line 20 at r1 (raw file):

Previously, bileschi (Stanley Bileschi) wrote…

Why did this move?

So that the imports are sorted alphabetically by the source file path.


src/layers/normalization.ts, line 398 at r1 (raw file):

Previously, nsthorat (Nikhil Thorat) wrote…

FYI you should no longer need this scalar cache. You can also just pass momentum directly to sub (no need for wrapping in a scalar tensor).

Ack. Thanks. There are a lot of places like this that should be fixed in a batch. Filed an issue to track this.
tensorflow/tfjs#957

@caisq caisq merged commit 78bcd9a into tensorflow:master Nov 30, 2018
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants