Skip to content

Commit b05cfd2

Browse files
authored
Remove cpu backend in tfjs-node testing, use fake instead. (#3082)
BUG Remove cpu backend in tfjs-node testing, use fake instead.
1 parent 53fed37 commit b05cfd2

File tree

3 files changed

+30
-13
lines changed

3 files changed

+30
-13
lines changed

tfjs-node/src/image_test.ts

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/**
22
* @license
3-
* Copyright 2019 Google Inc. All Rights Reserved.
3+
* Copyright 2020 Google Inc. All Rights Reserved.
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
66
* You may obtain a copy of the License at
@@ -14,10 +14,12 @@
1414
* limitations under the License.
1515
* =============================================================================
1616
*/
17-
18-
import {memory, setBackend, test_util} from '@tensorflow/tfjs';
17+
import {memory, registerBackend, setBackend, test_util} from '@tensorflow/tfjs';
18+
// tslint:disable-next-line: no-imports-from-dist
19+
import {TestKernelBackend} from '@tensorflow/tfjs-core/dist/jasmine_util';
1920
import * as fs from 'fs';
2021
import {promisify} from 'util';
22+
2123
import {getImageType, ImageType} from './image';
2224
import * as tf from './index';
2325

@@ -221,14 +223,18 @@ describe('decode images', () => {
221223

222224
it('throw error if backend is not tensorflow', async done => {
223225
try {
224-
setBackend('cpu');
226+
const testBackend = new TestKernelBackend();
227+
registerBackend('fake', () => testBackend);
228+
setBackend('fake');
229+
225230
const uint8array = await getUint8ArrayFromImage(
226231
'test_objects/images/image_png_test.png');
227232
tf.node.decodeImage(uint8array);
228233
done.fail();
229234
} catch (err) {
230235
expect(err.message)
231-
.toBe('Expect the current backend to be "tensorflow", but got "cpu"');
236+
.toBe(
237+
'Expect the current backend to be "tensorflow", but got "fake"');
232238
setBackend('tensorflow');
233239
done();
234240
}

tfjs-node/src/nodejs_kernel_backend_test.ts

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
*/
1717

1818
import * as tf from '@tensorflow/tfjs';
19+
// tslint:disable-next-line: no-imports-from-dist
20+
import {TestKernelBackend} from '@tensorflow/tfjs-core/dist/jasmine_util';
21+
1922
import {createTensorsTypeOpAttr, createTypeOpAttr, ensureTensorflowBackend, getTFDType, nodeBackend, NodeJSKernelBackend} from './nodejs_kernel_backend';
2023

2124
describe('delayed upload', () => {
@@ -74,12 +77,16 @@ describe('Exposes Backend for internal Op execution.', () => {
7477

7578
it('throw error if backend is not tensorflow', async done => {
7679
try {
77-
tf.setBackend('cpu');
80+
const testBackend = new TestKernelBackend();
81+
tf.registerBackend('fake', () => testBackend);
82+
tf.setBackend('fake');
83+
7884
ensureTensorflowBackend();
7985
done.fail();
8086
} catch (err) {
8187
expect(err.message)
82-
.toBe('Expect the current backend to be "tensorflow", but got "cpu"');
88+
.toBe(
89+
'Expect the current backend to be "tensorflow", but got "fake"');
8390
tf.setBackend('tensorflow');
8491
done();
8592
}

tfjs-node/src/run_tests.ts

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,12 @@ process.on('unhandledRejection', e => {
3535
throw e;
3636
});
3737

38-
jasmine_util.setTestEnvs(
39-
[{name: 'test-tensorflow', backendName: 'tensorflow', flags: {}}]);
38+
jasmine_util.setTestEnvs([{
39+
name: 'test-tensorflow',
40+
backendName: 'tensorflow',
41+
flags: {},
42+
isDataSync: true
43+
}]);
4044

4145
const IGNORE_LIST: string[] = [
4246
// Always ignore version tests:
@@ -90,11 +94,11 @@ if (process.platform === 'win32') {
9094
'maxPool test-tensorflow {} [x=[3,3,1] f=[2,2] s=1 ignores NaNs');
9195
}
9296

93-
const coreTests = 'node_modules/@tensorflow/tfjs-core/dist/**/*_test.js';
94-
const nodeTests = 'src/**/*_test.ts';
95-
9697
const runner = new jasmineCtor();
97-
runner.loadConfig({spec_files: [coreTests, nodeTests], random: false});
98+
runner.loadConfig({spec_files: ['src/**/*_test.ts'], random: false});
99+
// Also import tests from core.
100+
// tslint:disable-next-line: no-imports-from-dist
101+
import '@tensorflow/tfjs-core/dist/tests';
98102

99103
if (process.env.JASMINE_SEED) {
100104
runner.seed(process.env.JASMINE_SEED);

0 commit comments

Comments
 (0)