Skip to content

Conversation

@caisq
Copy link
Contributor

@caisq caisq commented Aug 2, 2020

Fixes #3705

Previously, the tensorBoard callback in tfjs-node does not honor
the initialEpoch arg passed to the fit() call that uses the
callback. It always incorrectly starts from 0.
This CL fixes this bug by using the epoch arg passed to onEpochEnd()
instead of an epochsSeen counter maintained by the callback object
itself.


This change is Reviewable

Fixes tensorflow#3705

Previously, the `tensorBoard` callback in tfjs-node does not honor
the `initialEpoch` arg passed to the `fit()` call that uses the
callback. It always incorrectly starts from 0.
This CL fixes this bug by using the `epoch` arg passed to `onEpochEnd()`
instead of an `epochsSeen` counter maintained by the callback object
itself.
@caisq caisq requested a review from lina128 August 2, 2020 21:23
Copy link
Collaborator

@pyu10055 pyu10055 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 0 of 1 approvals obtained (waiting on @caisq and @lina128)


tfjs-node/src/callbacks.ts, line 215 at r1 (raw file):

        this.batchesSeen++;
        if (this.args.updateFreq !== 'epoch') {
          this.logMetrics(logs, 'batch_', this.batchesSeen);

should the batch number be updated the same way as epoch?

Copy link
Contributor Author

@caisq caisq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 0 of 1 approvals obtained (waiting on @lina128 and @pyu10055)


tfjs-node/src/callbacks.ts, line 215 at r1 (raw file):

Previously, pyu10055 (Ping Yu) wrote…

should the batch number be updated the same way as epoch?

Good question. The behavior of tf.keras in Python is that batch numbers logged by TensorBoard callback with update_freq='batch' doesn't reflect the initial_epoch arg, even though the epoch numbers do. For instance, I tested the following code:

import numpy as np
import tensorflow as tf

model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(units=1, input_shape=(4,)))
model.compile(loss="mse", optimizer="sgd")

xs = np.ones([8000, 4])
ys = np.zeros([8000, 1])

model.fit(xs, ys, epochs=3)

callback = tf.keras.callbacks.TensorBoard(
    "/tmp/initial_epochs_logdir", update_freq="batch")
model.fit(xs,
          ys,
          batch_size=40,
          epochs=6,
          initial_epoch=3,
          callbacks=[callback])

Here the tensorboard scalar log "batch_loss" starts at step 1, instead of a larger number that reflects the batches that have already happened in the previous (first) model.fit() call. Therefore the behavior in tfjs-node code here is correct: it keeps track of the batch number by itself.

Copy link
Collaborator

@pyu10055 pyu10055 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 2 of 2 files at r1.
Reviewable status: :shipit: complete! 1 of 1 approvals obtained (waiting on @lina128 and @pyu10055)

@caisq
Copy link
Contributor Author

caisq commented Aug 3, 2020

@pyu10055 @lina128 Any idea why the CI of tfjs-node is failing? It seems to be unrelated to this PR.

image

Copy link
Collaborator

@pyu10055 pyu10055 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error seems to related to the lint error on the test file:

ERROR: (no-any) /workspace/tfjs-node/src/tensorboard_test.ts[309, 18]: Type declaration of 'any' loses type-safety. Consider replacing it with a more precise type.
ERROR: (no-any) /workspace/tfjs-node/src/tensorboard_test.ts[310, 18]: Type declaration of 'any' loses type-safety. Consider replacing it with a more precise type.
ERROR: (no-any) /workspace/tfjs-node/src/tensorboard_test.ts[311, 53]: Type declaration of 'any' loses type-safety. Consider replacing it with a more precise type.
ERROR: (no-any) /workspace/tfjs-node/src/tensorboard_test.ts[312, 51]: Type declaration of 'any' loses type-safety. Consider replacing it with a more precise type.

Reviewable status: :shipit: complete! 1 of 1 approvals obtained (waiting on @lina128 and @pyu10055)

@lina128
Copy link
Collaborator

lina128 commented Aug 3, 2020

Hi @caisq , it complains the newly added test: maybe skip lint?
ERROR: (no-any) /workspace/tfjs-node/src/tensorboard_test.ts[309, 18]: Type declaration of 'any' loses type-safety. Consider replacing it with a more precise type.
ERROR: (no-any) /workspace/tfjs-node/src/tensorboard_test.ts[310, 18]: Type declaration of 'any' loses type-safety. Consider replacing it with a more precise type.
ERROR: (no-any) /workspace/tfjs-node/src/tensorboard_test.ts[311, 53]: Type declaration of 'any' loses type-safety. Consider replacing it with a more precise type.
ERROR: (no-any) /workspace/tfjs-node/src/tensorboard_test.ts[312, 51]: Type declaration of 'any' loses type-safety. Consider replacing it with a more precise type.

@caisq
Copy link
Contributor Author

caisq commented Aug 3, 2020

@lina128 @pyu10055 Ah - good find. Sorry I didn't notice that.

@caisq caisq self-assigned this Aug 3, 2020
@caisq caisq merged commit 9b8c963 into tensorflow:master Aug 3, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Tensorboard logging does not honor initialEpoch

4 participants