diff --git a/tfjs-backend-wasm/src/kernels/Concat.ts b/tfjs-backend-wasm/src/kernels/Concat.ts index 5cf36b35593..73e2e9efc06 100644 --- a/tfjs-backend-wasm/src/kernels/Concat.ts +++ b/tfjs-backend-wasm/src/kernels/Concat.ts @@ -25,7 +25,10 @@ interface ConcatAttrs extends NamedAttrMap { function concat( args: {inputs: TensorInfo[], backend: BackendWasm, attrs: ConcatAttrs}) { - const {inputs, backend, attrs: {axis}} = args; + const {inputs, backend} = args; + + const axis = util.parseAxisParam(args.attrs.axis, inputs[0].shape)[0]; + const outShape = backend_util.computeOutShape(inputs.map(t => t.shape), axis); const out = backend.makeOutput(outShape, inputs[0].dtype); diff --git a/tfjs-core/src/ops/concat.ts b/tfjs-core/src/ops/concat.ts index 98b9056bfe5..1651dd37a4f 100644 --- a/tfjs-core/src/ops/concat.ts +++ b/tfjs-core/src/ops/concat.ts @@ -94,7 +94,6 @@ function concat_(tensors: Array, axis = 0): T { assertParamsConsistent(shapes, $axis); const forward: ForwardFunc = (backend, save) => { - const $axis = parseAxisParam(axis, $tensors[0].shape)[0]; const res = backend.concat($tensors, $axis); save($tensors); return res; diff --git a/tfjs-core/src/ops/concat_test.ts b/tfjs-core/src/ops/concat_test.ts index d0dc79f4560..811308981c2 100644 --- a/tfjs-core/src/ops/concat_test.ts +++ b/tfjs-core/src/ops/concat_test.ts @@ -262,6 +262,14 @@ describeWithFlags('concat2d', ALL_ENVS, () => { }); describeWithFlags('concat3d', ALL_ENVS, () => { + it('shapes correct concat axis=-1', async () => { + const tensor1 = tf.tensor3d([1, 2, 3], [1, 1, 3]); + const tensor2 = tf.tensor3d([4, 5, 6], [1, 1, 3]); + const values = tf.concat3d([tensor1, tensor2], -1); + expect(values.shape).toEqual([1, 1, 6]); + expectArraysClose(await values.data(), [1, 2, 3, 4, 5, 6]); + }); + it('shapes correct concat axis=0', async () => { const tensor1 = tf.tensor3d([1, 2, 3], [1, 1, 3]); const tensor2 = tf.tensor3d([4, 5, 6], [1, 1, 3]);