-
Notifications
You must be signed in to change notification settings - Fork 95
Add option yieldEvery to ModelFitConfig; add heuristics for auto yielding #274
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed 4 of 4 files at r1.
Reviewable status: 0 of 1 approvals obtained (waiting on @caisq, @nsthorat, and @dsmilkov)
src/engine/training.ts, line 589 at r1 (raw file):
* manner. * * - The value can also be a string from the closed set of:
remove "can also" and just say is a string of the following
(i'm also not a big fan of closed set
, very mathy :))
src/engine/training_test.ts, line 16 at r1 (raw file):
// tslint:disable:max-line-length import * as tfc from '@tensorflow/tfjs-core';
unused import?
src/engine/training_test.ts, line 1690 at r1 (raw file):
const history = await model.fit(xs, ys, {epochs, batchSize: numExamples}); expect(history.history.loss.length).toEqual(epochs); // There are 20 batches in total. The first 3 batch are for
optional: I wonder if you can make this test a bit less dependent on internal details - otherwise changing some of the internal constants/heuristics will make this test fail. On the other hand, I do like that it gives us confidence in the logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: complete! 1 of 1 approvals obtained (waiting on @caisq and @nsthorat)
src/engine/training.ts, line 589 at r1 (raw file):
Previously, dsmilkov (Daniel Smilkov) wrote…
remove "can also" and just say
is a string of the following
(i'm also not a big fan ofclosed set
, very mathy :))
Removed "also". Thanks for catching this.
src/engine/training_test.ts, line 16 at r1 (raw file):
Previously, dsmilkov (Daniel Smilkov) wrote…
unused import?
It is used in the line spyOn(tfc, 'nextFrame').and.callFake(async () => {...
.
src/engine/training_test.ts, line 1690 at r1 (raw file):
Previously, dsmilkov (Daniel Smilkov) wrote…
optional: I wonder if you can make this test a bit less dependent on internal details - otherwise changing some of the internal constants/heuristics will make this test fail. On the other hand, I do like that it gives us confidence in the logic.
OK. I exposed some of the properties of the ModelTrainingYielder
as public static properties and use them in the tests here. This reduces the likelihood of future changes in the internals of that class causing test failures.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed 1 of 4 files at r1, 3 of 3 files at r2.
Reviewable status: complete! 1 of 1 approvals obtained (waiting on @dsmilkov, @caisq, and @nsthorat)
src/base_callbacks.ts, line 13 at r2 (raw file):
/* Original source: keras/callbacks.py */ // tslint:disable:max-line-length
for a follow up CL, you can add what daniel just added to automatically disable max line length for imports with this:
https://github.com/tensorflow/tfjs-core/blob/master/tslint.json#L29
src/base_callbacks.ts, line 246 at r2 (raw file):
this.batchDurationsMillis = []; this.autoYieldEveryBatches = null; this.batchStartMillis = Date.now();
util.now()
src/base_callbacks.ts, line 253 at r2 (raw file):
* * This causes a data download (e.g., from GPU) and therefore clears the * queued operations.
what do you mean by "clears the queued operations"?
src/base_callbacks.ts, line 290 at r2 (raw file):
// the measurement phase. await this.resolveOneTensorInLogs(logs); const t = Date.now();
use util.now() not Date.now() here and elsewhere
src/engine/training.ts, line 589 at r1 (raw file):
Previously, caisq (Shanqing Cai) wrote…
Removed "also". Thanks for catching this.
I still see "closed set", I agree that that terminology is confusing.
src/engine/training_test.ts, line 1676 at r2 (raw file):
let counter = 0; spyOn(Date, 'now').and.callFake(() => presetBatchTimestamps[counter++]); let nextFrameCallCount = 0;
instead of keeping your own, you can simply spyOn(tfc, 'nextFrame').and.callThrough();
then later
tfc.nextFrame.calls.count()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: complete! 2 of 1 approvals obtained (waiting on @dsmilkov and @caisq)
src/base_callbacks.ts, line 13 at r2 (raw file):
Previously, nsthorat (Nikhil Thorat) wrote…
for a follow up CL, you can add what daniel just added to automatically disable max line length for imports with this:
https://github.com/tensorflow/tfjs-core/blob/master/tslint.json#L29
Ack. Will do. Tracking it with GitHub issue: tensorflow/tfjs#561
src/base_callbacks.ts, line 246 at r2 (raw file):
Previously, nsthorat (Nikhil Thorat) wrote…
util.now()
Done.
src/base_callbacks.ts, line 253 at r2 (raw file):
Previously, nsthorat (Nikhil Thorat) wrote…
what do you mean by "clears the queued operations"?
By that I mean the queued operations on the GPU. Added some words to clarify that.
src/base_callbacks.ts, line 290 at r2 (raw file):
Previously, nsthorat (Nikhil Thorat) wrote…
use util.now() not Date.now() here and elsewhere
Done.
src/engine/training.ts, line 589 at r1 (raw file):
Previously, nsthorat (Nikhil Thorat) wrote…
I still see "closed set", I agree that that terminology is confusing.
Done.
src/engine/training_test.ts, line 1676 at r2 (raw file):
Previously, nsthorat (Nikhil Thorat) wrote…
instead of keeping your own, you can simply spyOn(tfc, 'nextFrame').and.callThrough();
then later
tfc.nextFrame.calls.count()
The current approach is a little faster, because it doesn't actually call nextFrame()
, which saves 16 ms for each call.
Previously, users have to call await tf.nextFrame() in the callbacks of Model.fit() calls in order This PR adds logic to automatically call await tf.nextFrame(). Specifically, a new field is added The default value of 'auto' lets TensorFlow.js measure the batch duration during the first FEATURE |
Previously, users have to call await tf.nextFrame() in the callbacks of Model.fit() calls in order
to ensure page responsiveness during long-running Model.fit() calls.
This PR adds logic to automatically call await tf.nextFrame(). Specifically, a new field is added
to the config object for Model.fit():
yieldEvery
.several batches and calculate how many batches should pass before calling await
tf.nextFrame() again.
every epoch, respectively.
can still call await tf.nextFrame() in custom callbacks.
The 'auto' approach is a tradeoff between the need to ensure page responsiveness and
the need to prevent short-running Model.fit() calls from slowing down too much.
FEATURE
This change is