This repository has been archived by the owner on Aug 15, 2019. It is now read-only.
/
tracking.ts
176 lines (171 loc) · 6.2 KB
/
tracking.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
/**
* @license
* Copyright 2018 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {doc} from './doc';
import {ScopeFn, TimingInfo} from './engine';
import {ENV} from './environment';
import {Tensor} from './tensor';
import {TensorContainer} from './types';
import {extractTensorsFromAny} from './util';
export class Tracking {
/**
* Executes the provided function `f` and after it is executed, cleans up all
* intermediate tensors allocated by `f` except those returned by `f`.
* `f` must not return a Promise (async functions not allowed).
* The returned result can be a complex object, however tidy only walks the
* top-level properties (depth 1) of that object to search for tensors, or
* lists of tensors that need to be tracked in the parent scope.
*
* Using this method helps avoid memory leaks. In general, wrap calls to
* operations in `tidy` for automatic memory cleanup.
*
* When in safe mode, you must enclose all `Tensor` creation and ops
* inside a `tidy` to prevent memory leaks.
*
* ```js
* // y = 2 ^ 2 + 1
* const y = tf.tidy(() => {
* // a, b, and one will be cleaned up when the tidy ends.
* const one = tf.scalar(1);
* const a = tf.scalar(2);
* const b = a.square();
*
* console.log('numTensors (in tidy): ' + tf.memory().numTensors);
*
* // The value returned inside the tidy function will return
* // through the tidy, in this case to the variable y.
* return b.add(one);
* });
*
* console.log('numTensors (outside tidy): ' + tf.memory().numTensors);
* y.print();
* ```
*
* @param nameOrFn The name of the closure, or the function to execute.
* If a name is provided, the 2nd argument should be the function.
* If debug mode is on, the timing and the memory usage of the function
* will be tracked and displayed on the console using the provided name.
* @param fn The function to execute.
* @param gradMode If true, starts a tape and doesn't dispose tensors.
*/
@doc({heading: 'Performance', subheading: 'Memory'})
static tidy<T extends TensorContainer>(
nameOrFn: string|ScopeFn<T>, fn?: ScopeFn<T>, gradMode = false): T {
let name = null;
if (fn == null) {
// Called with only 1 argument.
if (typeof nameOrFn !== 'function') {
throw new Error('Please provide a function to tidy()');
}
fn = nameOrFn;
} else {
// Called with 2 arguments.
if (typeof nameOrFn !== 'string' && !(nameOrFn instanceof String)) {
throw new Error(
'When calling with two arguments, the first argument ' +
'to tidy() must be a string');
}
if (typeof fn !== 'function') {
throw new Error(
'When calling with two arguments, the 2nd argument ' +
'to tidy() must be a function');
}
name = nameOrFn as string;
// TODO(nsthorat,smilkov): Do operation logging and performance
// profiling.
}
ENV.engine.startScope(name, gradMode);
const result = fn();
if (result instanceof Promise) {
console.error('Cannot return a Promise inside of tidy.');
}
ENV.engine.endScope(result, gradMode);
return result;
}
/**
* Disposes any `Tensor`s found within the provided object up to depth 1.
*
* @param container an object that may be a `Tensor` or may directly contain
* `Tensor`s, such as a `Tensor[]` or `{key: Tensor, ...}`. If the
* object is not a `Tensor` or does not contain `Tensors`, nothing
* happens. In general it is safe to pass any object here, except that
* `Promise`s are not supported.
*/
// tslint:disable-next-line:no-any
static dispose(container: any) {
const tensors = extractTensorsFromAny(container);
tensors.forEach(tensor => tensor.dispose());
}
/**
* Keeps a `Tensor` generated inside a `tidy` from being disposed
* automatically.
*
* ```js
* let b;
* const y = tf.tidy(() => {
* const one = tf.scalar(1);
* const a = tf.scalar(2);
*
* // b will not be cleaned up by the tidy. a and one will be cleaned up
* // when the tidy ends.
* b = tf.keep(a.square());
*
* console.log('numTensors (in tidy): ' + tf.memory().numTensors);
*
* // The value returned inside the tidy function will return
* // through the tidy, in this case to the variable y.
* return b.add(one);
* });
*
* console.log('numTensors (outside tidy): ' + tf.memory().numTensors);
* console.log('y:');
* y.print();
* console.log('b:');
* b.print();
* ```
*
* @param result The tensor to keep from being disposed.
*/
@doc({heading: 'Performance', subheading: 'Memory'})
static keep<T extends Tensor>(result: T): T {
return ENV.engine.keep(result);
}
/**
* Executes `f()` and returns a promise that resolves with timing
* information.
*
* The result is an object with the following properties:
*
* - `wallMs`: Wall execution time.
* - `kernelMs`: Kernel execution time, ignoring data transfer.
* - On `WebGL` The following additional properties exist:
* - `uploadWaitMs`: CPU blocking time on texture uploads.
* - `downloadWaitMs`: CPU blocking time on texture downloads (readPixels).
*
* ```js
* const x = tf.randomNormal([20, 20]);
* const time = await tf.time(() => x.matMul(x));
*
* console.log(`kernelMs: ${time.kernelMs}, wallTimeMs: ${time.wallMs}`);
* ```
*
* @param f The function to execute and time.
*/
@doc({heading: 'Performance', subheading: 'Timing'})
static time(f: () => void): Promise<TimingInfo> {
return ENV.engine.time(f);
}
}