Skip to content
32 changes: 26 additions & 6 deletions tfjs-core/src/debug_mode_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,38 @@ describeWithFlags('debug on', SYNC_BACKEND_ENVS, () => {
expect(a).toThrowError();
});

it('debug mode errors when infinities in op output', () => {
it('debug mode errors when infinities in op output', async () => {
const a = tf.tensor1d([1, 2, 3, 4]);
const b = tf.tensor1d([2, -1, 0, 3]);
const c = () => a.div(b);
expect(c).toThrowError();

spyOn(console, 'error');

const c = async () => {
const result = a.div(b);
// Must await result so we know exception would have happened by the
// time we call `expect`.
await result.data();
};

await c();

expect(console.error).toHaveBeenCalled();
});

it('debug mode errors when nans in op output', () => {
it('debug mode errors when nans in op output', async () => {
const a = tf.tensor1d([-1, 2]);
const b = tf.tensor1d([0.5, 1]);
const c = () => a.pow(b);
expect(c).toThrowError();

spyOn(console, 'error');

const c = async () => {
const result = a.pow(b);
await result.data();
};

await c();

expect(console.error).toHaveBeenCalled();
});

it('debug mode errors when nans in oneHot op (tensorlike), int32', () => {
Expand Down
4 changes: 2 additions & 2 deletions tfjs-core/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ export {RMSPropOptimizer} from './optimizers/rmsprop_optimizer';
export {SGDOptimizer} from './optimizers/sgd_optimizer';
export {Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer, variable, Variable} from './tensor';
export {GradSaveFunc, NamedTensorMap, TensorContainer, TensorContainerArray, TensorContainerObject} from './tensor_types';
export {DataType, DataTypeMap, DataValues, Rank, ShapeMap, TensorLike} from './types';
export {DataType, DataTypeMap, DataValues, Rank, RecursiveArray, ShapeMap, TensorLike} from './types';

export * from './ops/ops';
export {LSTMCellFunc} from './ops/lstm';
Expand All @@ -65,7 +65,7 @@ export * from './train';
export * from './globals';
export {customGrad, grad, grads, valueAndGrad, valueAndGrads, variableGrads} from './gradients';

export {TimingInfo} from './engine';
export {TimingInfo, MemoryInfo} from './engine';
export {ENV, Environment} from './environment';
export {Platform} from './platforms/platform';

Expand Down
43 changes: 33 additions & 10 deletions tfjs-core/src/profiler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import {BackendTimer} from './backends/backend';
import {Tensor} from './tensor';
import {NamedTensorMap} from './tensor_types';
import {TypedArray} from './types';
import {DataType, DataTypeMap, TypedArray} from './types';
import * as util from './util';

export class Profiler {
Expand All @@ -39,24 +39,47 @@ export class Profiler {
const results: Tensor[] =
Array.isArray(result) ? result : [result] as Tensor[];
results.forEach(r => {
const vals = r.dataSync();
util.checkComputationForErrors(vals, r.dtype, name);
// Dangling promise here because we don't want to propagate up
// asynchronicity.
r.data().then(vals => {
checkComputationForErrors(vals, r.dtype, name);

timer.then(timing => {
let extraInfo = '';
if (timing.getExtraProfileInfo != null) {
extraInfo = timing.getExtraProfileInfo();
}
timer.then(timing => {
let extraInfo = '';
if (timing.getExtraProfileInfo != null) {
extraInfo = timing.getExtraProfileInfo();
}

this.logger.logKernelProfile(
name, r, vals, timing.kernelMs, inputs, extraInfo);
this.logger.logKernelProfile(
name, r, vals, timing.kernelMs, inputs, extraInfo);
});
});
});

return result as T;
}
}

// Create a custom exception class so it can be stubbed in tests.
export function ProfilerException(msg: string) {
console.error(msg);
}

export function checkComputationForErrors<D extends DataType>(
vals: DataTypeMap[D], dtype: D, name: string): void {
if (dtype !== 'float32') {
// Only floating point computations will generate NaN values
return;
}
for (let i = 0; i < vals.length; i++) {
const num = vals[i] as number;
if (isNaN(num) || !isFinite(num)) {
// Throwing custom exception so behavior is testable.
throw ProfilerException(`The result of the '${name}' is ${num}.`);
}
}
}

export class Logger {
logKernelProfile(
name: string, result: Tensor, vals: TypedArray, timeMs: number,
Expand Down
26 changes: 25 additions & 1 deletion tfjs-core/src/profiler_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import {BackendTimer, BackendTimingInfo} from './backends/backend';
import * as tf from './index';
import {describeWithFlags, SYNC_BACKEND_ENVS} from './jasmine_util';
import {Logger, Profiler} from './profiler';
import {checkComputationForErrors, Logger, Profiler} from './profiler';
import {Tensor} from './tensor';
import {TypedArray} from './types';

Expand Down Expand Up @@ -129,3 +129,27 @@ describeWithFlags('profiler.Profiler', SYNC_BACKEND_ENVS, () => {
}, delayMs * 2);
});
});

describe('profiler.checkComputationForErrors', () => {
it('Float32Array has NaN', () => {
expect(
() => checkComputationForErrors(
new Float32Array([1, 2, 3, NaN, 4, 255]), 'float32', ''))
.toThrow();
});

it('Float32Array has Infinity', () => {
expect(
() => checkComputationForErrors(
new Float32Array([1, 2, 3, Infinity, 4, 255]), 'float32', ''))
.toThrow();
});

it('Float32Array no NaN', () => {
// Int32 and Bool NaNs should not trigger an error.
expect(
() => checkComputationForErrors(
new Float32Array([1, 2, 3, -1, 4, 255]), 'float32', ''))
.not.toThrow();
});
});
14 changes: 0 additions & 14 deletions tfjs-core/src/util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -405,20 +405,6 @@ export function getArrayFromDType<D extends DataType>(
return values as DataTypeMap[D];
}

export function checkComputationForErrors<D extends DataType>(
vals: DataTypeMap[D], dtype: D, name: string): void {
if (dtype !== 'float32') {
// Only floating point computations will generate NaN values
return;
}
for (let i = 0; i < vals.length; i++) {
const num = vals[i] as number;
if (isNaN(num) || !isFinite(num)) {
throw Error(`The result of the '${name}' is ${num}.`);
}
}
}

export function checkConversionForErrors<D extends DataType>(
vals: DataTypeMap[D]|number[], dtype: D): void {
for (let i = 0; i < vals.length; i++) {
Expand Down
24 changes: 0 additions & 24 deletions tfjs-core/src/util_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -407,30 +407,6 @@ describe('util.squeezeShape', () => {
});
});

describe('util.checkComputationForErrors', () => {
it('Float32Array has NaN', () => {
expect(
() => util.checkComputationForErrors(
new Float32Array([1, 2, 3, NaN, 4, 255]), 'float32', ''))
.toThrowError();
});

it('Float32Array has Infinity', () => {
expect(
() => util.checkComputationForErrors(
new Float32Array([1, 2, 3, Infinity, 4, 255]), 'float32', ''))
.toThrowError();
});

it('Float32Array no NaN', () => {
// Int32 and Bool NaNs should not trigger an error.
expect(
() => util.checkComputationForErrors(
new Float32Array([1, 2, 3, 4, -1, 255]), 'float32', ''))
.not.toThrowError();
});
});

describe('util.checkConversionForErrors', () => {
it('Float32Array has NaN', () => {
expect(
Expand Down
5 changes: 3 additions & 2 deletions tfjs-webgpu/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
},
"license": "Apache-2.0",
"devDependencies": {
"@tensorflow/tfjs-core": "1.2.1",
"@tensorflow/tfjs-core": "1.2.7",
"@types/jasmine": "~2.5.53",
"clang-format": "~1.2.2",
"http-server": "~0.10.0",
Expand All @@ -34,6 +34,7 @@
"rollup-plugin-commonjs": "9.1.3",
"rollup-plugin-node-resolve": "3.3.0",
"rollup-plugin-typescript2": "0.13.0",
"rollup-plugin-terser": "^5.1.1",
"rollup-plugin-uglify": "~3.0.0",
"tslint": "~5.11.0",
"tslint-no-circular-imports": "^0.5.0",
Expand All @@ -47,4 +48,4 @@
"peerDependencies": {
"@tensorflow/tfjs-core": "1.2.1"
}
}
}
4 changes: 1 addition & 3 deletions tfjs-webgpu/rollup.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ function config({plugins = [], output = {}, external = []}) {
return {
input: 'src/index.ts',
plugins: [
typescript({
tsconfigOverride: {compilerOptions: {module: 'ES2015'}}
}),
typescript({tsconfigOverride: {compilerOptions: {module: 'ES2015'}}}),
node(),
// Polyfill require() from dependencies.
commonjs({
Expand Down
Loading