From 3bcf7d3b0b5bcba337e0bc2d9b4b0dd6c9361f77 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Thu, 30 Apr 2020 09:05:13 -0400 Subject: [PATCH 01/10] register kernel --- tfjs-backend-wasm/src/index_test.ts | 30 ++++++++++++++- tfjs-backend-wasm/src/kernels/Split.ts | 40 ++++++++++++++++++++ tfjs-backend-wasm/src/kernels/all_kernels.ts | 1 + 3 files changed, 69 insertions(+), 2 deletions(-) create mode 100644 tfjs-backend-wasm/src/kernels/Split.ts diff --git a/tfjs-backend-wasm/src/index_test.ts b/tfjs-backend-wasm/src/index_test.ts index fb1404f8416..280064c776d 100644 --- a/tfjs-backend-wasm/src/index_test.ts +++ b/tfjs-backend-wasm/src/index_test.ts @@ -58,8 +58,8 @@ describeWithFlags('wasm init', BROWSER_ENVS, () => { }, 100); // Silences backend registration warnings. - spyOn(console, 'warn'); - spyOn(console, 'log'); + // spyOn(console, 'warn'); + // spyOn(console, 'log'); }); afterEach(() => { @@ -121,4 +121,30 @@ describeWithFlags('wasm init', BROWSER_ENVS, () => { expect(() => setWasmPath('too/late')) .toThrowError(/The WASM backend was already initialized. Make sure/); }); + + fit('split by number', async () => { + const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]); + const res = tf.split(x, 2, 1); + expect(res.length).toEqual(2); + expect(res[0].shape).toEqual([2, 2]); + const res0data = await res[0].data(); + console.log(Array.from(res0data)); + // expectArraysClose(await res[0].data(), [1, 2, 5, 6]); + expect(res[1].shape).toEqual([2, 2]); + const res1data = await res[1].data(); + console.log(Array.from(res1data)); + // expectArraysClose(await res[1].data(), [3, 4, 7, 8]); + }); + + it('split by sizes', async () => { + const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]); + const res = tf.split(x, [1, 2, 1], 1); + expect(res.length).toEqual(3); + expect(res[0].shape).toEqual([2, 1]); + // expectArraysClose(await res[0].data(), [1, 5]); + expect(res[1].shape).toEqual([2, 2]); + // expectArraysClose(await res[1].data(), [2, 3, 6, 7]); + expect(res[2].shape).toEqual([2, 1]); + // expectArraysClose(await res[2].data(), [4, 8]); + }); }); diff --git a/tfjs-backend-wasm/src/kernels/Split.ts b/tfjs-backend-wasm/src/kernels/Split.ts new file mode 100644 index 00000000000..eb4790c0ff5 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/Split.ts @@ -0,0 +1,40 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {NamedAttrMap, NamedTensorInfoMap, registerKernel} from '@tensorflow/tfjs-core'; +import {TensorInfo} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; + +interface SplitInputs extends NamedTensorInfoMap { + x: TensorInfo; +} + +interface SplitAttrs extends NamedAttrMap { + numOrSizeSplits: number[]; + axis: number; +} + +export function split( + args: {inputs: SplitInputs, attrs: SplitAttrs, backend: BackendWasm}) { + const {inputs: {x}, attrs: {numOrSizeSplits, axis}, backend} = args; + const out = backend.makeOutput(x.shape, x.dtype); + console.log(numOrSizeSplits, axis); + return out; +} + +registerKernel({kernelName: 'SplitV', backendName: 'wasm', kernelFunc: split}); diff --git a/tfjs-backend-wasm/src/kernels/all_kernels.ts b/tfjs-backend-wasm/src/kernels/all_kernels.ts index 454fc89d85b..50ea4ab4bd0 100644 --- a/tfjs-backend-wasm/src/kernels/all_kernels.ts +++ b/tfjs-backend-wasm/src/kernels/all_kernels.ts @@ -70,6 +70,7 @@ import './Sigmoid'; import './Sin'; import './Slice'; import './Softmax'; +import './Split'; import './Square'; import './Sub'; import './Sum'; From 164a1aa5ac8f3b3740fbb04d678bcdac77ed4479 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Thu, 30 Apr 2020 09:37:22 -0400 Subject: [PATCH 02/10] add --- tfjs-backend-wasm/src/index_test.ts | 10 +++++++++- tfjs-backend-wasm/src/kernels/Split.ts | 26 ++++++++++++++++++++++---- tfjs-core/src/ops/split.ts | 1 - 3 files changed, 31 insertions(+), 6 deletions(-) diff --git a/tfjs-backend-wasm/src/index_test.ts b/tfjs-backend-wasm/src/index_test.ts index 280064c776d..a593792bee7 100644 --- a/tfjs-backend-wasm/src/index_test.ts +++ b/tfjs-backend-wasm/src/index_test.ts @@ -125,6 +125,8 @@ describeWithFlags('wasm init', BROWSER_ENVS, () => { fit('split by number', async () => { const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]); const res = tf.split(x, 2, 1); + console.log('OUTPUT'); + console.log(res); expect(res.length).toEqual(2); expect(res[0].shape).toEqual([2, 2]); const res0data = await res[0].data(); @@ -136,15 +138,21 @@ describeWithFlags('wasm init', BROWSER_ENVS, () => { // expectArraysClose(await res[1].data(), [3, 4, 7, 8]); }); - it('split by sizes', async () => { + fit('split by sizes', async () => { const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]); const res = tf.split(x, [1, 2, 1], 1); expect(res.length).toEqual(3); expect(res[0].shape).toEqual([2, 1]); + const res0data = await res[0].data(); + console.log(Array.from(res0data)); // expectArraysClose(await res[0].data(), [1, 5]); expect(res[1].shape).toEqual([2, 2]); + const res1data = await res[0].data(); + console.log(Array.from(res1data)); // expectArraysClose(await res[1].data(), [2, 3, 6, 7]); expect(res[2].shape).toEqual([2, 1]); + const res2data = await res[2].data(); + console.log(Array.from(res2data)); // expectArraysClose(await res[2].data(), [4, 8]); }); }); diff --git a/tfjs-backend-wasm/src/kernels/Split.ts b/tfjs-backend-wasm/src/kernels/Split.ts index eb4790c0ff5..1fd8c3be6ba 100644 --- a/tfjs-backend-wasm/src/kernels/Split.ts +++ b/tfjs-backend-wasm/src/kernels/Split.ts @@ -15,11 +15,13 @@ * ============================================================================= */ -import {NamedAttrMap, NamedTensorInfoMap, registerKernel} from '@tensorflow/tfjs-core'; +import {NamedAttrMap, NamedTensorInfoMap, registerKernel, util} from '@tensorflow/tfjs-core'; import {TensorInfo} from '@tensorflow/tfjs-core'; import {BackendWasm} from '../backend_wasm'; +import {slice} from './Slice'; + interface SplitInputs extends NamedTensorInfoMap { x: TensorInfo; } @@ -32,9 +34,25 @@ interface SplitAttrs extends NamedAttrMap { export function split( args: {inputs: SplitInputs, attrs: SplitAttrs, backend: BackendWasm}) { const {inputs: {x}, attrs: {numOrSizeSplits, axis}, backend} = args; - const out = backend.makeOutput(x.shape, x.dtype); - console.log(numOrSizeSplits, axis); - return out; + + const $axis = util.parseAxisParam(axis, x.shape)[0]; + + let splitSizes: number[]; + if (typeof (numOrSizeSplits) === 'number') { + splitSizes = + new Array(numOrSizeSplits).fill(x.shape[$axis] / numOrSizeSplits); + } else { + splitSizes = numOrSizeSplits; + } + + const begin = new Array(x.shape.length).fill(0); + const size = x.shape.slice(); + return splitSizes.map(s => { + size[$axis] = s; + const xSlice = slice({inputs: {x}, attrs: {begin, size}, backend}); + begin[$axis] += s; + return xSlice; + }); } registerKernel({kernelName: 'SplitV', backendName: 'wasm', kernelFunc: split}); diff --git a/tfjs-core/src/ops/split.ts b/tfjs-core/src/ops/split.ts index 5751f9499f7..38d2e03567b 100644 --- a/tfjs-core/src/ops/split.ts +++ b/tfjs-core/src/ops/split.ts @@ -80,7 +80,6 @@ function split_( } const forward: ForwardFunc = (backend, _) => { - const $axis = parseAxisParam(axis, $x.shape)[0]; return backend.split($x, splitSizes, $axis) as {} as T; }; From e174848c3277f1817347d891cda3d16b7299428a Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Thu, 30 Apr 2020 09:58:29 -0400 Subject: [PATCH 03/10] add sqrt --- tfjs-backend-wasm/src/cc/kernels/Sqrt.cc | 36 ++++++++++++++++++++ tfjs-backend-wasm/src/kernels/Sqrt.ts | 19 +++++++++++ tfjs-backend-wasm/src/kernels/all_kernels.ts | 1 + tfjs-core/src/ops/unary_ops.ts | 4 +-- 4 files changed, 58 insertions(+), 2 deletions(-) create mode 100644 tfjs-backend-wasm/src/cc/kernels/Sqrt.cc create mode 100644 tfjs-backend-wasm/src/kernels/Sqrt.ts diff --git a/tfjs-backend-wasm/src/cc/kernels/Sqrt.cc b/tfjs-backend-wasm/src/cc/kernels/Sqrt.cc new file mode 100644 index 00000000000..a920d6f3527 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/Sqrt.cc @@ -0,0 +1,36 @@ +/* Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include + +#include "src/cc/backend.h" +#include "src/cc/unary.h" + +namespace tfjs { +namespace wasm { +// We use C-style API to interface with Javascript. +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif +void Sqrt(const int x_id, const int out_id) { unary(x_id, out_id, sqrt); } + +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/kernels/Sqrt.ts b/tfjs-backend-wasm/src/kernels/Sqrt.ts new file mode 100644 index 00000000000..906cfe1d467 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/Sqrt.ts @@ -0,0 +1,19 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {registerUnaryKernel} from './unary_kernel'; +registerUnaryKernel('Sqrt'); diff --git a/tfjs-backend-wasm/src/kernels/all_kernels.ts b/tfjs-backend-wasm/src/kernels/all_kernels.ts index 50ea4ab4bd0..0bf887355d4 100644 --- a/tfjs-backend-wasm/src/kernels/all_kernels.ts +++ b/tfjs-backend-wasm/src/kernels/all_kernels.ts @@ -71,6 +71,7 @@ import './Sin'; import './Slice'; import './Softmax'; import './Split'; +import './Sqrt'; import './Square'; import './Sub'; import './Sum'; diff --git a/tfjs-core/src/ops/unary_ops.ts b/tfjs-core/src/ops/unary_ops.ts index 7a33f1e0b42..1c50420490e 100644 --- a/tfjs-core/src/ops/unary_ops.ts +++ b/tfjs-core/src/ops/unary_ops.ts @@ -325,13 +325,13 @@ function sqrt_(x: T|TensorLike): T { const grad = (dy: T, saved: Tensor[]) => { const [$x] = saved; - return {$x: () => dy.div($x.toFloat().sqrt().mul(2))} as {$x: () => T}; + return {x: () => dy.div($x.toFloat().sqrt().mul(2))} as {x: () => T}; }; return ENGINE.runKernelFunc((backend, save) => { const res = backend.sqrt($x); save([$x]); return res; - }, {$x}, grad); + }, {x: $x}, grad, 'Sqrt', {}); } /** From ac5c447a58d49d8aec543707f8973e742e3f4188 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Fri, 1 May 2020 16:30:58 -0400 Subject: [PATCH 04/10] fix --- tfjs-backend-wasm/src/index_test.ts | 34 -------------------------- tfjs-backend-wasm/src/kernels/Split.ts | 3 ++- tfjs-backend-wasm/src/setup_test.ts | 1 + 3 files changed, 3 insertions(+), 35 deletions(-) diff --git a/tfjs-backend-wasm/src/index_test.ts b/tfjs-backend-wasm/src/index_test.ts index a593792bee7..6c498788956 100644 --- a/tfjs-backend-wasm/src/index_test.ts +++ b/tfjs-backend-wasm/src/index_test.ts @@ -121,38 +121,4 @@ describeWithFlags('wasm init', BROWSER_ENVS, () => { expect(() => setWasmPath('too/late')) .toThrowError(/The WASM backend was already initialized. Make sure/); }); - - fit('split by number', async () => { - const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]); - const res = tf.split(x, 2, 1); - console.log('OUTPUT'); - console.log(res); - expect(res.length).toEqual(2); - expect(res[0].shape).toEqual([2, 2]); - const res0data = await res[0].data(); - console.log(Array.from(res0data)); - // expectArraysClose(await res[0].data(), [1, 2, 5, 6]); - expect(res[1].shape).toEqual([2, 2]); - const res1data = await res[1].data(); - console.log(Array.from(res1data)); - // expectArraysClose(await res[1].data(), [3, 4, 7, 8]); - }); - - fit('split by sizes', async () => { - const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]); - const res = tf.split(x, [1, 2, 1], 1); - expect(res.length).toEqual(3); - expect(res[0].shape).toEqual([2, 1]); - const res0data = await res[0].data(); - console.log(Array.from(res0data)); - // expectArraysClose(await res[0].data(), [1, 5]); - expect(res[1].shape).toEqual([2, 2]); - const res1data = await res[0].data(); - console.log(Array.from(res1data)); - // expectArraysClose(await res[1].data(), [2, 3, 6, 7]); - expect(res[2].shape).toEqual([2, 1]); - const res2data = await res[2].data(); - console.log(Array.from(res2data)); - // expectArraysClose(await res[2].data(), [4, 8]); - }); }); diff --git a/tfjs-backend-wasm/src/kernels/Split.ts b/tfjs-backend-wasm/src/kernels/Split.ts index 1fd8c3be6ba..57b1a3f08a8 100644 --- a/tfjs-backend-wasm/src/kernels/Split.ts +++ b/tfjs-backend-wasm/src/kernels/Split.ts @@ -49,7 +49,8 @@ export function split( const size = x.shape.slice(); return splitSizes.map(s => { size[$axis] = s; - const xSlice = slice({inputs: {x}, attrs: {begin, size}, backend}); + const xSlice = + slice({inputs: {x}, attrs: {begin, size: [...size]}, backend}); begin[$axis] += s; return xSlice; }); diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index 56e212e7f49..8a1d62ade37 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -224,6 +224,7 @@ const TEST_FILTERS: TestFilter[] = [ include: 'transpose', excludes: ['oneHot'] // oneHot not yet implemented. }, + {include: 'split'}, {include: 'pad ', excludes: ['complex', 'zerosLike']}, {include: 'clip', excludes: ['gradient']}, {include: 'addN'}, From 641a2fa65f8de53dc868393e2d6c1c5382096aa4 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 4 May 2020 07:07:37 -0400 Subject: [PATCH 05/10] revive --- tfjs-backend-wasm/src/index_test.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tfjs-backend-wasm/src/index_test.ts b/tfjs-backend-wasm/src/index_test.ts index 6c498788956..fb1404f8416 100644 --- a/tfjs-backend-wasm/src/index_test.ts +++ b/tfjs-backend-wasm/src/index_test.ts @@ -58,8 +58,8 @@ describeWithFlags('wasm init', BROWSER_ENVS, () => { }, 100); // Silences backend registration warnings. - // spyOn(console, 'warn'); - // spyOn(console, 'log'); + spyOn(console, 'warn'); + spyOn(console, 'log'); }); afterEach(() => { From 6ca0544cfbca4e200d673068ea0104e2e6605248 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 4 May 2020 07:16:32 -0400 Subject: [PATCH 06/10] fix --- tfjs-backend-wasm/src/kernels/Split.ts | 6 ++++-- tfjs-core/src/backends/split_shared.ts | 2 ++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tfjs-backend-wasm/src/kernels/Split.ts b/tfjs-backend-wasm/src/kernels/Split.ts index 57b1a3f08a8..6f6957b5429 100644 --- a/tfjs-backend-wasm/src/kernels/Split.ts +++ b/tfjs-backend-wasm/src/kernels/Split.ts @@ -16,6 +16,7 @@ */ import {NamedAttrMap, NamedTensorInfoMap, registerKernel, util} from '@tensorflow/tfjs-core'; + import {TensorInfo} from '@tensorflow/tfjs-core'; import {BackendWasm} from '../backend_wasm'; @@ -48,9 +49,10 @@ export function split( const begin = new Array(x.shape.length).fill(0); const size = x.shape.slice(); return splitSizes.map(s => { - size[$axis] = s; + const xSliceSize = [...size]; + xSliceSize[$axis] = s; const xSlice = - slice({inputs: {x}, attrs: {begin, size: [...size]}, backend}); + slice({inputs: {x}, attrs: {begin, size: xSliceSize}, backend}); begin[$axis] += s; return xSlice; }); diff --git a/tfjs-core/src/backends/split_shared.ts b/tfjs-core/src/backends/split_shared.ts index 487a5d0d288..29789b8b7cb 100644 --- a/tfjs-core/src/backends/split_shared.ts +++ b/tfjs-core/src/backends/split_shared.ts @@ -17,6 +17,8 @@ import {Tensor} from '../tensor'; +// TODO(annxingyuan): Use this helper in WASM Split kernel once intermediate +// kernels have been modularized in WebGL and CPU. /** Shared implementation of the split kernel across WebGL and CPU. */ export function split( x: T, sizeSplits: number[], axis: number): T[] { From a160974ccb72055082284c718a9e7117790b643e Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 4 May 2020 07:20:27 -0400 Subject: [PATCH 07/10] add sqrt --- tfjs-backend-wasm/src/setup_test.ts | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index 8a1d62ade37..8cbf2950147 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -329,6 +329,10 @@ const TEST_FILTERS: TestFilter[] = [ startsWith: 'rsqrt ', excludes: ['gradient'] // Gradient not yet implemented. }, + { + startsWith: 'sqrt ', + excludes: ['gradient'] // Gradient not yet implemented. + }, { startsWith: 'zerosLike', // Complex numbers not supported yet. From 87254adcfd835eae60b56f325370ab680a55bb5f Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 4 May 2020 07:23:06 -0400 Subject: [PATCH 08/10] fix --- tfjs-core/src/backends/split_shared.ts | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tfjs-core/src/backends/split_shared.ts b/tfjs-core/src/backends/split_shared.ts index 29789b8b7cb..fd45862eb88 100644 --- a/tfjs-core/src/backends/split_shared.ts +++ b/tfjs-core/src/backends/split_shared.ts @@ -25,8 +25,9 @@ export function split( const begin = new Array(x.rank).fill(0); const size = x.shape.slice(); return sizeSplits.map(s => { - size[axis] = s; - const slice = x.slice(begin, size); + const sliceSize = [...size]; + sliceSize[axis] = s; + const slice = x.slice(begin, sliceSize); begin[axis] += s; return slice; }); From 23160ce932a0f452c6318831f75b2f9a71f02f23 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Thu, 7 May 2020 07:34:37 -0400 Subject: [PATCH 09/10] use kernelnames --- tfjs-backend-wasm/src/kernels/Split.ts | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/tfjs-backend-wasm/src/kernels/Split.ts b/tfjs-backend-wasm/src/kernels/Split.ts index 6f6957b5429..7925a6dda28 100644 --- a/tfjs-backend-wasm/src/kernels/Split.ts +++ b/tfjs-backend-wasm/src/kernels/Split.ts @@ -15,26 +15,20 @@ * ============================================================================= */ -import {NamedAttrMap, NamedTensorInfoMap, registerKernel, util} from '@tensorflow/tfjs-core'; - -import {TensorInfo} from '@tensorflow/tfjs-core'; +import {NamedAttrMap, NamedTensorInfoMap, registerKernel, SplitV, SplitVAttrs, SplitVInputs, util} from '@tensorflow/tfjs-core'; import {BackendWasm} from '../backend_wasm'; import {slice} from './Slice'; -interface SplitInputs extends NamedTensorInfoMap { - x: TensorInfo; -} - -interface SplitAttrs extends NamedAttrMap { - numOrSizeSplits: number[]; - axis: number; -} - -export function split( - args: {inputs: SplitInputs, attrs: SplitAttrs, backend: BackendWasm}) { - const {inputs: {x}, attrs: {numOrSizeSplits, axis}, backend} = args; +export function split(args: { + inputs: NamedTensorInfoMap, + attrs: NamedAttrMap, + backend: BackendWasm +}) { + const {inputs, attrs, backend} = args; + const {x} = inputs as {} as SplitVInputs; + const {numOrSizeSplits, axis} = attrs as {} as SplitVAttrs; const $axis = util.parseAxisParam(axis, x.shape)[0]; @@ -58,4 +52,4 @@ export function split( }); } -registerKernel({kernelName: 'SplitV', backendName: 'wasm', kernelFunc: split}); +registerKernel({kernelName: SplitV, backendName: 'wasm', kernelFunc: split}); From 725219fa2f51800954c71c13e135b9eabeedba42 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Thu, 7 May 2020 07:36:49 -0400 Subject: [PATCH 10/10] add issue ref --- tfjs-core/src/backends/split_shared.ts | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tfjs-core/src/backends/split_shared.ts b/tfjs-core/src/backends/split_shared.ts index fd45862eb88..ca6e23a1aa9 100644 --- a/tfjs-core/src/backends/split_shared.ts +++ b/tfjs-core/src/backends/split_shared.ts @@ -18,7 +18,8 @@ import {Tensor} from '../tensor'; // TODO(annxingyuan): Use this helper in WASM Split kernel once intermediate -// kernels have been modularized in WebGL and CPU. +// kernels have been modularized in WebGL and CPU +// https://github.com/tensorflow/tfjs/issues/2822. /** Shared implementation of the split kernel across WebGL and CPU. */ export function split( x: T, sizeSplits: number[], axis: number): T[] {