Skip to content
This repository has been archived by the owner on Oct 17, 2021. It is now read-only.

Add option yieldEvery to ModelFitConfig; add heuristics for auto yielding #274

Merged
merged 11 commits into from
Jul 30, 2018

Conversation

caisq
Copy link
Contributor

@caisq caisq commented Jul 30, 2018

Previously, users have to call await tf.nextFrame() in the callbacks of Model.fit() calls in order
to ensure page responsiveness during long-running Model.fit() calls.

This PR adds logic to automatically call await tf.nextFrame(). Specifically, a new field is added
to the config object for Model.fit(): yieldEvery.

  • The default value of 'auto' lets TensorFlow.js measure the batch duration during the first
    several batches and calculate how many batches should pass before calling await
    tf.nextFrame() again.
  • The value 'batch' and 'epoch' lets TensorFlow.js call await tf.nextFrame every batch or
    every epoch, respectively.
  • The value 'never' disables await tf.nextFrame() calls. It is the legacy behavior. The user
    can still call await tf.nextFrame() in custom callbacks.

The 'auto' approach is a tradeoff between the need to ensure page responsiveness and
the need to prevent short-running Model.fit() calls from slowing down too much.

FEATURE


This change is Reviewable

@caisq caisq requested review from nsthorat and dsmilkov July 30, 2018 14:46
Copy link
Contributor

@dsmilkov dsmilkov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 4 of 4 files at r1.
Reviewable status: 0 of 1 approvals obtained (waiting on @caisq, @nsthorat, and @dsmilkov)


src/engine/training.ts, line 589 at r1 (raw file):

   * manner.
   *
   * - The value can also be a string from the closed set of:

remove "can also" and just say is a string of the following (i'm also not a big fan of closed set, very mathy :))


src/engine/training_test.ts, line 16 at r1 (raw file):

// tslint:disable:max-line-length
import * as tfc from '@tensorflow/tfjs-core';

unused import?


src/engine/training_test.ts, line 1690 at r1 (raw file):

    const history = await model.fit(xs, ys, {epochs, batchSize: numExamples});
    expect(history.history.loss.length).toEqual(epochs);
    // There are 20 batches in total. The first 3 batch are for

optional: I wonder if you can make this test a bit less dependent on internal details - otherwise changing some of the internal constants/heuristics will make this test fail. On the other hand, I do like that it gives us confidence in the logic.

Copy link
Contributor Author

@caisq caisq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: :shipit: complete! 1 of 1 approvals obtained (waiting on @caisq and @nsthorat)


src/engine/training.ts, line 589 at r1 (raw file):

Previously, dsmilkov (Daniel Smilkov) wrote…

remove "can also" and just say is a string of the following (i'm also not a big fan of closed set, very mathy :))

Removed "also". Thanks for catching this.


src/engine/training_test.ts, line 16 at r1 (raw file):

Previously, dsmilkov (Daniel Smilkov) wrote…

unused import?

It is used in the line spyOn(tfc, 'nextFrame').and.callFake(async () => {....


src/engine/training_test.ts, line 1690 at r1 (raw file):

Previously, dsmilkov (Daniel Smilkov) wrote…

optional: I wonder if you can make this test a bit less dependent on internal details - otherwise changing some of the internal constants/heuristics will make this test fail. On the other hand, I do like that it gives us confidence in the logic.

OK. I exposed some of the properties of the ModelTrainingYielder as public static properties and use them in the tests here. This reduces the likelihood of future changes in the internals of that class causing test failures.

Copy link

@nsthorat nsthorat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 1 of 4 files at r1, 3 of 3 files at r2.
Reviewable status: :shipit: complete! 1 of 1 approvals obtained (waiting on @dsmilkov, @caisq, and @nsthorat)


src/base_callbacks.ts, line 13 at r2 (raw file):

/* Original source: keras/callbacks.py */

// tslint:disable:max-line-length

for a follow up CL, you can add what daniel just added to automatically disable max line length for imports with this:

https://github.com/tensorflow/tfjs-core/blob/master/tslint.json#L29


src/base_callbacks.ts, line 246 at r2 (raw file):

    this.batchDurationsMillis = [];
    this.autoYieldEveryBatches = null;
    this.batchStartMillis = Date.now();

util.now()


src/base_callbacks.ts, line 253 at r2 (raw file):

   *
   * This causes a data download (e.g., from GPU) and therefore clears the
   * queued operations.

what do you mean by "clears the queued operations"?


src/base_callbacks.ts, line 290 at r2 (raw file):

        // the measurement phase.
        await this.resolveOneTensorInLogs(logs);
        const t = Date.now();

use util.now() not Date.now() here and elsewhere


src/engine/training.ts, line 589 at r1 (raw file):

Previously, caisq (Shanqing Cai) wrote…

Removed "also". Thanks for catching this.

I still see "closed set", I agree that that terminology is confusing.


src/engine/training_test.ts, line 1676 at r2 (raw file):

    let counter = 0;
    spyOn(Date, 'now').and.callFake(() => presetBatchTimestamps[counter++]);
    let nextFrameCallCount = 0;

instead of keeping your own, you can simply spyOn(tfc, 'nextFrame').and.callThrough();

then later

tfc.nextFrame.calls.count()

Copy link
Contributor Author

@caisq caisq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: :shipit: complete! 2 of 1 approvals obtained (waiting on @dsmilkov and @caisq)


src/base_callbacks.ts, line 13 at r2 (raw file):

Previously, nsthorat (Nikhil Thorat) wrote…

for a follow up CL, you can add what daniel just added to automatically disable max line length for imports with this:

https://github.com/tensorflow/tfjs-core/blob/master/tslint.json#L29

Ack. Will do. Tracking it with GitHub issue: tensorflow/tfjs#561


src/base_callbacks.ts, line 246 at r2 (raw file):

Previously, nsthorat (Nikhil Thorat) wrote…

util.now()

Done.


src/base_callbacks.ts, line 253 at r2 (raw file):

Previously, nsthorat (Nikhil Thorat) wrote…

what do you mean by "clears the queued operations"?

By that I mean the queued operations on the GPU. Added some words to clarify that.


src/base_callbacks.ts, line 290 at r2 (raw file):

Previously, nsthorat (Nikhil Thorat) wrote…

use util.now() not Date.now() here and elsewhere

Done.


src/engine/training.ts, line 589 at r1 (raw file):

Previously, nsthorat (Nikhil Thorat) wrote…

I still see "closed set", I agree that that terminology is confusing.

Done.


src/engine/training_test.ts, line 1676 at r2 (raw file):

Previously, nsthorat (Nikhil Thorat) wrote…

instead of keeping your own, you can simply spyOn(tfc, 'nextFrame').and.callThrough();

then later

tfc.nextFrame.calls.count()

The current approach is a little faster, because it doesn't actually call nextFrame(), which saves 16 ms for each call.

@caisq
Copy link
Contributor Author

caisq commented Jul 30, 2018

Previously, users have to call await tf.nextFrame() in the callbacks of Model.fit() calls in order
to ensure page responsiveness during long-running Model.fit() calls.

This PR adds logic to automatically call await tf.nextFrame(). Specifically, a new field is added
to the config object for Model.fit(): yieldEvery.

The default value of 'auto' lets TensorFlow.js measure the batch duration during the first
several batches and calculate how many batches should pass before calling await
tf.nextFrame() again.
The value 'batch' and 'epoch' lets TensorFlow.js call await tf.nextFrame every batch or
every epoch, respectively.
The value 'never' disables await tf.nextFrame() calls. It is the legacy behavior. The user
can still call await tf.nextFrame() in custom callbacks.
The 'auto' approach is a tradeoff between the need to ensure page responsiveness and
the need to prevent short-running Model.fit() calls from slowing down too much.

FEATURE

@caisq caisq closed this Jul 30, 2018
@caisq caisq reopened this Jul 30, 2018
@caisq caisq merged commit 1a9bd2e into tensorflow:master Jul 30, 2018
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
3 participants