diff --git a/tfjs-core/src/gradients/Reverse_grad.ts b/tfjs-core/src/gradients/Reverse_grad.ts new file mode 100644 index 00000000000..028ebbb8755 --- /dev/null +++ b/tfjs-core/src/gradients/Reverse_grad.ts @@ -0,0 +1,30 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {Reverse, ReverseAttrs} from '../kernel_names'; +import {GradConfig, NamedAttrMap} from '../kernel_registry'; +import {reverse} from '../ops/reverse'; +import {Tensor} from '../tensor'; +import {parseAxisParam} from '../util'; + +export const reverseGradConfig: GradConfig = { + kernelName: Reverse, + gradFunc: (dy: Tensor, saved: Tensor[], attrs: NamedAttrMap) => { + const {dims} = attrs as {} as ReverseAttrs; + const axes = parseAxisParam(dims, dy.shape); + return {x: () => reverse(dy, axes)}; + } +}; diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index 1a7e4992b94..64c790ddd78 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -440,6 +440,12 @@ export type ResizeBilinearGradInputs = Pick; export const Relu6 = 'Relu6'; export type Relu6Inputs = Pick; +export const Reverse = 'Reverse'; +export type ReverseInputs = Pick; +export interface ReverseAttrs { + dims: number|number[]; +} + export const SelectV2 = 'SelectV2'; export type SelectV2Inputs = Pick; diff --git a/tfjs-core/src/ops/ops.ts b/tfjs-core/src/ops/ops.ts index f4a236ade57..5ba687ee716 100644 --- a/tfjs-core/src/ops/ops.ts +++ b/tfjs-core/src/ops/ops.ts @@ -93,6 +93,11 @@ export {randomUniform} from './random_uniform'; export {real} from './real'; export {relu} from './relu'; export {relu6} from './relu6'; +export {reverse} from './reverse'; +export {reverse1d} from './reverse_1d'; +export {reverse2d} from './reverse_2d'; +export {reverse3d} from './reverse_3d'; +export {reverse4d} from './reverse_4d'; export {selu} from './selu'; export {separableConv2d} from './separable_conv2d'; export {spaceToBatchND} from './space_to_batch_nd'; @@ -106,7 +111,6 @@ export {where} from './where'; export {whereAsync} from './where_async'; export * from './boolean_mask'; -export * from './reverse'; export * from './slice'; export * from './unary_ops'; export * from './reduction_ops'; diff --git a/tfjs-core/src/ops/reverse.ts b/tfjs-core/src/ops/reverse.ts index 1e72f4af8da..875d8978e88 100644 --- a/tfjs-core/src/ops/reverse.ts +++ b/tfjs-core/src/ops/reverse.ts @@ -15,70 +15,18 @@ * ============================================================================= */ -import {ENGINE} from '../engine'; -import {Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D} from '../tensor'; +import {ENGINE, ForwardFunc} from '../engine'; +import {Reverse, ReverseAttrs, ReverseInputs} from '../kernel_names'; +import {NamedAttrMap} from '../kernel_registry'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; -import * as util from '../util'; -import {op} from './operation'; +import {parseAxisParam} from '../util'; -/** - * Reverses a `tf.Tensor1D`. - * - * @param x The input tensor. - */ -function reverse1d_(x: Tensor1D|TensorLike): Tensor1D { - const $x = convertToTensor(x, 'x', 'reverse'); - util.assert( - $x.rank === 1, - () => `Error in reverse1D: x must be rank 1 but got rank ${$x.rank}.`); - return reverse($x, 0); -} - -/** - * Reverses a `tf.Tensor2D` along a specified axis. - * - * @param x The input tensor. - * @param axis The set of dimensions to reverse. Must be in the - * range [-rank(x), rank(x)). Defaults to all axes. - */ -function reverse2d_(x: Tensor2D|TensorLike, axis?: number|number[]): Tensor2D { - const $x = convertToTensor(x, 'x', 'reverse'); - util.assert( - $x.rank === 2, - () => `Error in reverse2D: x must be rank 2 but got rank ${$x.rank}.`); - return reverse($x, axis); -} - -/** - * Reverses a `tf.Tensor3D` along a specified axis. - * - * @param x The input tensor. - * @param axis The set of dimensions to reverse. Must be in the - * range [-rank(x), rank(x)). Defaults to all axes. - */ -function reverse3d_(x: Tensor3D|TensorLike, axis?: number|number[]): Tensor3D { - const $x = convertToTensor(x, 'x', 'reverse'); - util.assert( - $x.rank === 3, - () => `Error in reverse3D: x must be rank 3 but got rank ${$x.rank}.`); - return reverse($x, axis); -} - -/** - * Reverses a `tf.Tensor4D` along a specified axis. - * - * @param x The input tensor. - * @param axis The set of dimensions to reverse. Must be in the - * range [-rank(x), rank(x)). Defaults to all axes. - */ -function reverse4d_(x: Tensor4D|TensorLike, axis?: number|number[]): Tensor4D { - const $x = convertToTensor(x, 'x', 'reverse'); - util.assert( - $x.rank === 4, - () => `Error in reverse4D: x must be rank 4 but got rank ${$x.rank}.`); - return reverse($x, axis); -} +import {reshape} from './array_ops'; +import {clone} from './clone'; +import {op} from './operation'; /** * Reverses a `tf.Tensor` along a specified axis. @@ -114,20 +62,21 @@ function reverse_( x: T|TensorLike, axis?: number|number[]): T { const $x = convertToTensor(x, 'x', 'reverse'); - if ($x.rank === 0) { - return $x.clone(); - } - const axes = util.parseAxisParam(axis, $x.shape); - const grad = (dy: T) => { - return {$x: () => dy.reverse(axes)}; + const forward: ForwardFunc = (backend) => { + const axes = parseAxisParam(axis, $x.shape); + if ($x.rank === 0) { + return clone($x); + } + const res = backend.reverse($x, axes); + return reshape(res, $x.shape); }; - const res = - ENGINE.runKernelFunc(backend => backend.reverse($x, axes), {$x}, grad); - return res.reshapeAs($x); + + const inputs: ReverseInputs = {x: $x}; + const attrs: ReverseAttrs = {dims: axis}; + + return ENGINE.runKernelFunc( + forward, inputs as {} as NamedTensorMap, null /* gradient */, + Reverse, attrs as {} as NamedAttrMap) as T; } export const reverse = op({reverse_}); -export const reverse1d = op({reverse1d_}); -export const reverse2d = op({reverse2d_}); -export const reverse3d = op({reverse3d_}); -export const reverse4d = op({reverse4d_}); diff --git a/tfjs-core/src/ops/reverse_1d.ts b/tfjs-core/src/ops/reverse_1d.ts new file mode 100644 index 00000000000..2c980242e0c --- /dev/null +++ b/tfjs-core/src/ops/reverse_1d.ts @@ -0,0 +1,38 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Tensor1D} from '../tensor'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; +import * as util from '../util'; +import {op} from './operation'; +import {reverse} from './reverse'; + +/** + * Reverses a `tf.Tensor1D`. + * + * @param x The input tensor. + */ +function reverse1d_(x: Tensor1D|TensorLike): Tensor1D { + const $x = convertToTensor(x, 'x', 'reverse'); + util.assert( + $x.rank === 1, + () => `Error in reverse1D: x must be rank 1 but got rank ${$x.rank}.`); + return reverse($x, 0); +} + +export const reverse1d = op({reverse1d_}); diff --git a/tfjs-core/src/ops/reverse_1d_test.ts b/tfjs-core/src/ops/reverse_1d_test.ts new file mode 100644 index 00000000000..4e11f70be7c --- /dev/null +++ b/tfjs-core/src/ops/reverse_1d_test.ts @@ -0,0 +1,60 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('reverse1d', ALL_ENVS, () => { + it('reverse a 1D array', async () => { + const input = tf.tensor1d([1, 2, 3, 4, 5]); + const result = tf.reverse1d(input); + expect(result.shape).toEqual(input.shape); + expectArraysClose(await result.data(), [5, 4, 3, 2, 1]); + }); + + it('reverse a 1D array, even length', async () => { + const input = tf.tensor1d([1, 2, 3, 4]); + const result = tf.reverse1d(input); + expect(result.shape).toEqual(input.shape); + expectArraysClose(await result.data(), [4, 3, 2, 1]); + }); + + it('grad', async () => { + const a = tf.tensor1d([1, 2, 3]); + const dy = tf.tensor1d([10, 20, 30]); + const da = tf.grad((a: tf.Tensor1D) => tf.reverse1d(a))(a, dy); + expect(da.shape).toEqual([3]); + expectArraysClose(await da.data(), [30, 20, 10]); + }); + + it('gradient with clones', async () => { + const a = tf.tensor1d([1, 2, 3]); + const dy = tf.tensor1d([10, 20, 30]); + const da = + tf.grad((a: tf.Tensor1D) => tf.reverse1d(a.clone()).clone())(a, dy); + expect(da.shape).toEqual([3]); + expectArraysClose(await da.data(), [30, 20, 10]); + }); + + it('accepts a tensor-like object', async () => { + const input = [1, 2, 3, 4, 5]; + const result = tf.reverse1d(input); + expect(result.shape).toEqual([5]); + expectArraysClose(await result.data(), [5, 4, 3, 2, 1]); + }); +}); diff --git a/tfjs-core/src/ops/reverse_2d.ts b/tfjs-core/src/ops/reverse_2d.ts new file mode 100644 index 00000000000..96fe4d7e05f --- /dev/null +++ b/tfjs-core/src/ops/reverse_2d.ts @@ -0,0 +1,40 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Tensor2D} from '../tensor'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; +import * as util from '../util'; +import {op} from './operation'; +import {reverse} from './reverse'; + +/** + * Reverses a `tf.Tensor2D` along a specified axis. + * + * @param x The input tensor. + * @param axis The set of dimensions to reverse. Must be in the + * range [-rank(x), rank(x)). Defaults to all axes. + */ +function reverse2d_(x: Tensor2D|TensorLike, axis?: number|number[]): Tensor2D { + const $x = convertToTensor(x, 'x', 'reverse'); + util.assert( + $x.rank === 2, + () => `Error in reverse2D: x must be rank 2 but got rank ${$x.rank}.`); + return reverse($x, axis); +} + +export const reverse2d = op({reverse2d_}); diff --git a/tfjs-core/src/ops/reverse_2d_test.ts b/tfjs-core/src/ops/reverse_2d_test.ts new file mode 100644 index 00000000000..30b69a2c164 --- /dev/null +++ b/tfjs-core/src/ops/reverse_2d_test.ts @@ -0,0 +1,102 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('reverse2d', ALL_ENVS, () => { + it('reverse a 2D array at axis [0]', async () => { + const axis = [0]; + const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); + const result = tf.reverse2d(a, axis); + + expect(result.shape).toEqual(a.shape); + expectArraysClose(await result.data(), [4, 5, 6, 1, 2, 3]); + }); + + it('reverse a 2D array at axis [1]', async () => { + const axis = [1]; + const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); + const result = tf.reverse2d(a, axis); + + expect(result.shape).toEqual(a.shape); + expectArraysClose(await result.data(), [3, 2, 1, 6, 5, 4]); + }); + + it('reverse a 2D array odd rows and columns at axis [0, 1]', async () => { + const axis = [0, 1]; + const a = tf.tensor2d( + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], [3, 5]); + const result = tf.reverse2d(a, axis); + + expect(result.shape).toEqual(a.shape); + expectArraysClose( + await result.data(), + [15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]); + }); + + it('throws error with invalid input', () => { + // tslint:disable-next-line:no-any + const x: any = tf.tensor1d([1, 20, 300, 4]); + expect(() => tf.reverse2d(x, [0])).toThrowError(); + }); + + it('throws error with invalid axis param', () => { + const x = tf.tensor2d([1, 20, 300, 4], [1, 4]); + expect(() => tf.reverse2d(x, [2])).toThrowError(); + expect(() => tf.reverse2d(x, [-3])).toThrowError(); + }); + + it('throws error with non integer axis param', () => { + const x = tf.tensor2d([1, 20, 300, 4], [1, 4]); + expect(() => tf.reverse2d(x, [0.5])).toThrowError(); + }); + + it('grad', async () => { + const a = tf.tensor2d([[1, 2, 3], [4, 5, 6]]); + const dy = tf.tensor2d([[10, 20, 30], [40, 50, 60]]); + const da = tf.grad((a: tf.Tensor2D) => tf.reverse2d(a))(a, dy); + expect(da.shape).toEqual([2, 3]); + expectArraysClose(await da.data(), [60, 50, 40, 30, 20, 10]); + }); + + it('grad with reverse(axis=0)', async () => { + const a = tf.tensor2d([[1, 2, 3], [4, 5, 6]]); + const dy = tf.tensor2d([[10, 20, 30], [40, 50, 60]]); + const da = tf.grad((a: tf.Tensor2D) => tf.reverse2d(a, 0))(a, dy); + expect(da.shape).toEqual([2, 3]); + expectArraysClose(await da.data(), [40, 50, 60, 10, 20, 30]); + }); + + it('grad with reverse(axis=1)', async () => { + const a = tf.tensor2d([[1, 2, 3], [4, 5, 6]]); + const dy = tf.tensor2d([[10, 20, 30], [40, 50, 60]]); + const da = tf.grad((a: tf.Tensor2D) => tf.reverse2d(a, 1))(a, dy); + expect(da.shape).toEqual([2, 3]); + expectArraysClose(await da.data(), [30, 20, 10, 60, 50, 40]); + }); + + it('accepts a tensor-like object', async () => { + const axis = [0]; + const a = [[1, 2, 3], [4, 5, 6]]; // 2x3 + const result = tf.reverse2d(a, axis); + + expect(result.shape).toEqual([2, 3]); + expectArraysClose(await result.data(), [4, 5, 6, 1, 2, 3]); + }); +}); diff --git a/tfjs-core/src/ops/reverse_3d.ts b/tfjs-core/src/ops/reverse_3d.ts new file mode 100644 index 00000000000..9b6e4615513 --- /dev/null +++ b/tfjs-core/src/ops/reverse_3d.ts @@ -0,0 +1,40 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Tensor3D} from '../tensor'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; +import * as util from '../util'; +import {op} from './operation'; +import {reverse} from './reverse'; + +/** + * Reverses a `tf.Tensor3D` along a specified axis. + * + * @param x The input tensor. + * @param axis The set of dimensions to reverse. Must be in the + * range [-rank(x), rank(x)). Defaults to all axes. + */ +function reverse3d_(x: Tensor3D|TensorLike, axis?: number|number[]): Tensor3D { + const $x = convertToTensor(x, 'x', 'reverse'); + util.assert( + $x.rank === 3, + () => `Error in reverse3D: x must be rank 3 but got rank ${$x.rank}.`); + return reverse($x, axis); +} + +export const reverse3d = op({reverse3d_}); diff --git a/tfjs-core/src/ops/reverse_3d_test.ts b/tfjs-core/src/ops/reverse_3d_test.ts new file mode 100644 index 00000000000..0b6216347a4 --- /dev/null +++ b/tfjs-core/src/ops/reverse_3d_test.ts @@ -0,0 +1,124 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('reverse3d', ALL_ENVS, () => { + // [ + // [ + // [0, 1, 2, 3], + // [4, 5, 6, 7], + // [8, 9, 10, 11] + // ], + // [ + // [12, 13, 14, 15], + // [16, 17, 18, 19], + // [20, 21, 22, 23] + // ] + // ] + const shape: [number, number, number] = [2, 3, 4]; + const data = [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23 + ]; + + it('reverse a 3D array at axis [0]', async () => { + const input = tf.tensor3d(data, shape); + const result = tf.reverse3d(input, [0]); + expect(result.shape).toEqual(input.shape); + expectArraysClose(await result.data(), [ + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11 + ]); + }); + + it('reverse a 3D array at axis [1]', async () => { + const input = tf.tensor3d(data, shape); + const result = tf.reverse3d(input, [1]); + expect(result.shape).toEqual(input.shape); + expectArraysClose(await result.data(), [ + 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3, + 20, 21, 22, 23, 16, 17, 18, 19, 12, 13, 14, 15 + ]); + }); + + it('reverse a 3D array at axis [2]', async () => { + const input = tf.tensor3d(data, shape); + const result = tf.reverse3d(input, [2]); + expect(result.shape).toEqual(input.shape); + expectArraysClose(await result.data(), [ + 3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8, + 15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20 + ]); + }); + + it('reverse a 3D array at axis [0, 1]', async () => { + const input = tf.tensor3d(data, shape); + const result = tf.reverse3d(input, [0, 1]); + expect(result.shape).toEqual(input.shape); + expectArraysClose(await result.data(), [ + 20, 21, 22, 23, 16, 17, 18, 19, 12, 13, 14, 15, + 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3 + ]); + }); + + it('reverse a 3D array at axis [0, 2]', async () => { + const input = tf.tensor3d(data, shape); + const result = tf.reverse3d(input, [0, 2]); + expect(result.shape).toEqual(input.shape); + expectArraysClose(await result.data(), [ + 15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20, + 3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8 + ]); + }); + + it('reverse a 3D array at axis [1, 2]', async () => { + const input = tf.tensor3d(data, shape); + const result = tf.reverse3d(input, [1, 2]); + expect(result.shape).toEqual(input.shape); + expectArraysClose(await result.data(), [ + 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, + 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12 + ]); + }); + + it('throws error with invalid input', () => { + // tslint:disable-next-line:no-any + const x: any = tf.tensor2d([1, 20, 300, 4], [1, 4]); + expect(() => tf.reverse3d(x, [1])).toThrowError(); + }); + + it('throws error with invalid axis param', () => { + const x = tf.tensor3d([1, 20, 300, 4], [1, 1, 4]); + expect(() => tf.reverse3d(x, [3])).toThrowError(); + expect(() => tf.reverse3d(x, [-4])).toThrowError(); + }); + + it('throws error with non integer axis param', () => { + const x = tf.tensor3d([1, 20, 300, 4], [1, 1, 4]); + expect(() => tf.reverse3d(x, [0.5])).toThrowError(); + }); + + it('accepts a tensor-like object', async () => { + const input = [[[1], [2], [3]], [[4], [5], [6]]]; // 2x3x1 + const result = tf.reverse3d(input, [0]); + expect(result.shape).toEqual([2, 3, 1]); + expectArraysClose(await result.data(), [4, 5, 6, 1, 2, 3]); + }); +}); diff --git a/tfjs-core/src/ops/reverse_4d.ts b/tfjs-core/src/ops/reverse_4d.ts new file mode 100644 index 00000000000..72554eb5327 --- /dev/null +++ b/tfjs-core/src/ops/reverse_4d.ts @@ -0,0 +1,40 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Tensor4D} from '../tensor'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; +import * as util from '../util'; +import {op} from './operation'; +import {reverse} from './reverse'; + +/** + * Reverses a `tf.Tensor4D` along a specified axis. + * + * @param x The input tensor. + * @param axis The set of dimensions to reverse. Must be in the + * range [-rank(x), rank(x)). Defaults to all axes. + */ +function reverse4d_(x: Tensor4D|TensorLike, axis?: number|number[]): Tensor4D { + const $x = convertToTensor(x, 'x', 'reverse'); + util.assert( + $x.rank === 4, + () => `Error in reverse4D: x must be rank 4 but got rank ${$x.rank}.`); + return reverse($x, axis); +} + +export const reverse4d = op({reverse4d_}); diff --git a/tfjs-core/src/ops/reverse_4d_test.ts b/tfjs-core/src/ops/reverse_4d_test.ts new file mode 100644 index 00000000000..7e46cc4767e --- /dev/null +++ b/tfjs-core/src/ops/reverse_4d_test.ts @@ -0,0 +1,164 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('reverse4d', ALL_ENVS, () => { + // [ + // [ + // [ + // [0, 1, 2, 3], + // [4, 5, 6, 7], + // [8, 9, 10, 11] + // ], + // [ + // [12, 13, 14, 15], + // [16, 17, 18, 19], + // [20, 21, 22, 23] + // ] + // ], + // [ + // [ + // [24, 25, 26, 27], + // [28, 29, 30, 31], + // [32, 33, 34, 35] + // ], + // [ + // [36, 37, 38, 39], + // [40, 41, 42, 43], + // [44, 45, 46, 47] + // ] + // ], + // [ + // [ + // [48, 49, 50, 51], + // [52, 53, 54, 55], + // [56, 57, 58, 59] + // ], + // [ + // [60, 61, 62, 63], + // [64, 65, 66, 67], + // [68, 69, 70, 71] + // ] + // ] + // ] + const shape: [number, number, number, number] = [3, 2, 3, 4]; + const data = [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, + 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, + 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71 + ]; + + it('reverse a 4D array at axis [0]', async () => { + const input = tf.tensor4d(data, shape); + const result = tf.reverse4d(input, [0]); + expect(result.shape).toEqual(input.shape); + expectArraysClose(await result.data(), [ + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, + 66, 67, 68, 69, 70, 71, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, + 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 0, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23 + ]); + }); + + it('reverse a 4D array at axis [1]', async () => { + const input = tf.tensor4d(data, shape); + const result = tf.reverse4d(input, [1]); + expect(result.shape).toEqual(input.shape); + expectArraysClose(await result.data(), [ + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 0, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, 11, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 60, 61, 62, 63, 64, 65, + 66, 67, 68, 69, 70, 71, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59 + ]); + }); + + it('reverse a 4D array at axis [2]', async () => { + const input = tf.tensor4d(data, shape); + const result = tf.reverse4d(input, [2]); + expect(result.shape).toEqual(input.shape); + expectArraysClose(await result.data(), [ + 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3, 20, 21, 22, 23, 16, 17, + 18, 19, 12, 13, 14, 15, 32, 33, 34, 35, 28, 29, 30, 31, 24, 25, 26, 27, + 44, 45, 46, 47, 40, 41, 42, 43, 36, 37, 38, 39, 56, 57, 58, 59, 52, 53, + 54, 55, 48, 49, 50, 51, 68, 69, 70, 71, 64, 65, 66, 67, 60, 61, 62, 63 + ]); + }); + + it('reverse a 4D array at axis [3]', async () => { + const input = tf.tensor4d(data, shape); + const result = tf.reverse4d(input, [3]); + expect(result.shape).toEqual(input.shape); + expectArraysClose(await result.data(), [ + 3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8, 15, 14, 13, 12, 19, 18, + 17, 16, 23, 22, 21, 20, 27, 26, 25, 24, 31, 30, 29, 28, 35, 34, 33, 32, + 39, 38, 37, 36, 43, 42, 41, 40, 47, 46, 45, 44, 51, 50, 49, 48, 55, 54, + 53, 52, 59, 58, 57, 56, 63, 62, 61, 60, 67, 66, 65, 64, 71, 70, 69, 68 + ]); + }); + + it('reverse a 4D array at axis [0, 2]', async () => { + const input = tf.tensor4d(data, shape); + const result = tf.reverse4d(input, [0, 2]); + expect(result.shape).toEqual(input.shape); + expectArraysClose(await result.data(), [ + 56, 57, 58, 59, 52, 53, 54, 55, 48, 49, 50, 51, 68, 69, 70, 71, 64, 65, + 66, 67, 60, 61, 62, 63, 32, 33, 34, 35, 28, 29, 30, 31, 24, 25, 26, 27, + 44, 45, 46, 47, 40, 41, 42, 43, 36, 37, 38, 39, 8, 9, 10, 11, 4, 5, + 6, 7, 0, 1, 2, 3, 20, 21, 22, 23, 16, 17, 18, 19, 12, 13, 14, 15 + ]); + }); + + it('reverse a 4D array at axis [1, 3]', async () => { + const input = tf.tensor4d(data, shape); + const result = tf.reverse4d(input, [1, 3]); + expect(result.shape).toEqual(input.shape); + expectArraysClose(await result.data(), [ + 15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20, 3, 2, 1, 0, 7, 6, + 5, 4, 11, 10, 9, 8, 39, 38, 37, 36, 43, 42, 41, 40, 47, 46, 45, 44, + 27, 26, 25, 24, 31, 30, 29, 28, 35, 34, 33, 32, 63, 62, 61, 60, 67, 66, + 65, 64, 71, 70, 69, 68, 51, 50, 49, 48, 55, 54, 53, 52, 59, 58, 57, 56 + ]); + }); + + it('throws error with invalid input', () => { + // tslint:disable-next-line:no-any + const x: any = tf.tensor3d([1, 20, 300, 4], [1, 1, 4]); + expect(() => tf.reverse4d(x, [1])).toThrowError(); + }); + + it('throws error with invalid axis param', () => { + const x = tf.tensor4d([1, 20, 300, 4], [1, 1, 1, 4]); + expect(() => tf.reverse4d(x, [4])).toThrowError(); + expect(() => tf.reverse4d(x, [-5])).toThrowError(); + }); + + it('throws error with non integer axis param', () => { + const x = tf.tensor4d([1, 20, 300, 4], [1, 1, 1, 4]); + expect(() => tf.reverse4d(x, [0.5])).toThrowError(); + }); + + it('accepts a tensor-like object', async () => { + const input = [[[[1]], [[2]], [[3]]], [[[4]], [[5]], [[6]]]]; // 2x3x1x1 + const result = tf.reverse4d(input, [0]); + expect(result.shape).toEqual([2, 3, 1, 1]); + expectArraysClose(await result.data(), [4, 5, 6, 1, 2, 3]); + }); +}); diff --git a/tfjs-core/src/ops/reverse_test.ts b/tfjs-core/src/ops/reverse_test.ts index 3ebf3351f49..ab377449a22 100644 --- a/tfjs-core/src/ops/reverse_test.ts +++ b/tfjs-core/src/ops/reverse_test.ts @@ -19,376 +19,6 @@ import * as tf from '../index'; import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; import {expectArraysClose} from '../test_util'; -describeWithFlags('reverse1d', ALL_ENVS, () => { - it('reverse a 1D array', async () => { - const input = tf.tensor1d([1, 2, 3, 4, 5]); - const result = tf.reverse1d(input); - expect(result.shape).toEqual(input.shape); - expectArraysClose(await result.data(), [5, 4, 3, 2, 1]); - }); - - it('reverse a 1D array, even length', async () => { - const input = tf.tensor1d([1, 2, 3, 4]); - const result = tf.reverse1d(input); - expect(result.shape).toEqual(input.shape); - expectArraysClose(await result.data(), [4, 3, 2, 1]); - }); - - it('grad', async () => { - const a = tf.tensor1d([1, 2, 3]); - const dy = tf.tensor1d([10, 20, 30]); - const da = tf.grad((a: tf.Tensor1D) => tf.reverse1d(a))(a, dy); - expect(da.shape).toEqual([3]); - expectArraysClose(await da.data(), [30, 20, 10]); - }); - - it('gradient with clones', async () => { - const a = tf.tensor1d([1, 2, 3]); - const dy = tf.tensor1d([10, 20, 30]); - const da = - tf.grad((a: tf.Tensor1D) => tf.reverse1d(a.clone()).clone())(a, dy); - expect(da.shape).toEqual([3]); - expectArraysClose(await da.data(), [30, 20, 10]); - }); - - it('accepts a tensor-like object', async () => { - const input = [1, 2, 3, 4, 5]; - const result = tf.reverse1d(input); - expect(result.shape).toEqual([5]); - expectArraysClose(await result.data(), [5, 4, 3, 2, 1]); - }); -}); - -describeWithFlags('reverse2d', ALL_ENVS, () => { - it('reverse a 2D array at axis [0]', async () => { - const axis = [0]; - const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); - const result = tf.reverse2d(a, axis); - - expect(result.shape).toEqual(a.shape); - expectArraysClose(await result.data(), [4, 5, 6, 1, 2, 3]); - }); - - it('reverse a 2D array at axis [1]', async () => { - const axis = [1]; - const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); - const result = tf.reverse2d(a, axis); - - expect(result.shape).toEqual(a.shape); - expectArraysClose(await result.data(), [3, 2, 1, 6, 5, 4]); - }); - - it('reverse a 2D array odd rows and columns at axis [0, 1]', async () => { - const axis = [0, 1]; - const a = tf.tensor2d( - [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], [3, 5]); - const result = tf.reverse2d(a, axis); - - expect(result.shape).toEqual(a.shape); - expectArraysClose( - await result.data(), - [15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]); - }); - - it('throws error with invalid input', () => { - // tslint:disable-next-line:no-any - const x: any = tf.tensor1d([1, 20, 300, 4]); - expect(() => tf.reverse2d(x, [0])).toThrowError(); - }); - - it('throws error with invalid axis param', () => { - const x = tf.tensor2d([1, 20, 300, 4], [1, 4]); - expect(() => tf.reverse2d(x, [2])).toThrowError(); - expect(() => tf.reverse2d(x, [-3])).toThrowError(); - }); - - it('throws error with non integer axis param', () => { - const x = tf.tensor2d([1, 20, 300, 4], [1, 4]); - expect(() => tf.reverse2d(x, [0.5])).toThrowError(); - }); - - it('grad', async () => { - const a = tf.tensor2d([[1, 2, 3], [4, 5, 6]]); - const dy = tf.tensor2d([[10, 20, 30], [40, 50, 60]]); - const da = tf.grad((a: tf.Tensor2D) => tf.reverse2d(a))(a, dy); - expect(da.shape).toEqual([2, 3]); - expectArraysClose(await da.data(), [60, 50, 40, 30, 20, 10]); - }); - - it('grad with reverse(axis=0)', async () => { - const a = tf.tensor2d([[1, 2, 3], [4, 5, 6]]); - const dy = tf.tensor2d([[10, 20, 30], [40, 50, 60]]); - const da = tf.grad((a: tf.Tensor2D) => tf.reverse2d(a, 0))(a, dy); - expect(da.shape).toEqual([2, 3]); - expectArraysClose(await da.data(), [40, 50, 60, 10, 20, 30]); - }); - - it('grad with reverse(axis=1)', async () => { - const a = tf.tensor2d([[1, 2, 3], [4, 5, 6]]); - const dy = tf.tensor2d([[10, 20, 30], [40, 50, 60]]); - const da = tf.grad((a: tf.Tensor2D) => tf.reverse2d(a, 1))(a, dy); - expect(da.shape).toEqual([2, 3]); - expectArraysClose(await da.data(), [30, 20, 10, 60, 50, 40]); - }); - - it('accepts a tensor-like object', async () => { - const axis = [0]; - const a = [[1, 2, 3], [4, 5, 6]]; // 2x3 - const result = tf.reverse2d(a, axis); - - expect(result.shape).toEqual([2, 3]); - expectArraysClose(await result.data(), [4, 5, 6, 1, 2, 3]); - }); -}); - -describeWithFlags('reverse3d', ALL_ENVS, () => { - // [ - // [ - // [0, 1, 2, 3], - // [4, 5, 6, 7], - // [8, 9, 10, 11] - // ], - // [ - // [12, 13, 14, 15], - // [16, 17, 18, 19], - // [20, 21, 22, 23] - // ] - // ] - const shape: [number, number, number] = [2, 3, 4]; - const data = [ - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, - 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23 - ]; - - it('reverse a 3D array at axis [0]', async () => { - const input = tf.tensor3d(data, shape); - const result = tf.reverse3d(input, [0]); - expect(result.shape).toEqual(input.shape); - expectArraysClose(await result.data(), [ - 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11 - ]); - }); - - it('reverse a 3D array at axis [1]', async () => { - const input = tf.tensor3d(data, shape); - const result = tf.reverse3d(input, [1]); - expect(result.shape).toEqual(input.shape); - expectArraysClose(await result.data(), [ - 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3, - 20, 21, 22, 23, 16, 17, 18, 19, 12, 13, 14, 15 - ]); - }); - - it('reverse a 3D array at axis [2]', async () => { - const input = tf.tensor3d(data, shape); - const result = tf.reverse3d(input, [2]); - expect(result.shape).toEqual(input.shape); - expectArraysClose(await result.data(), [ - 3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8, - 15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20 - ]); - }); - - it('reverse a 3D array at axis [0, 1]', async () => { - const input = tf.tensor3d(data, shape); - const result = tf.reverse3d(input, [0, 1]); - expect(result.shape).toEqual(input.shape); - expectArraysClose(await result.data(), [ - 20, 21, 22, 23, 16, 17, 18, 19, 12, 13, 14, 15, - 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3 - ]); - }); - - it('reverse a 3D array at axis [0, 2]', async () => { - const input = tf.tensor3d(data, shape); - const result = tf.reverse3d(input, [0, 2]); - expect(result.shape).toEqual(input.shape); - expectArraysClose(await result.data(), [ - 15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20, - 3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8 - ]); - }); - - it('reverse a 3D array at axis [1, 2]', async () => { - const input = tf.tensor3d(data, shape); - const result = tf.reverse3d(input, [1, 2]); - expect(result.shape).toEqual(input.shape); - expectArraysClose(await result.data(), [ - 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, - 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12 - ]); - }); - - it('throws error with invalid input', () => { - // tslint:disable-next-line:no-any - const x: any = tf.tensor2d([1, 20, 300, 4], [1, 4]); - expect(() => tf.reverse3d(x, [1])).toThrowError(); - }); - - it('throws error with invalid axis param', () => { - const x = tf.tensor3d([1, 20, 300, 4], [1, 1, 4]); - expect(() => tf.reverse3d(x, [3])).toThrowError(); - expect(() => tf.reverse3d(x, [-4])).toThrowError(); - }); - - it('throws error with non integer axis param', () => { - const x = tf.tensor3d([1, 20, 300, 4], [1, 1, 4]); - expect(() => tf.reverse3d(x, [0.5])).toThrowError(); - }); - - it('accepts a tensor-like object', async () => { - const input = [[[1], [2], [3]], [[4], [5], [6]]]; // 2x3x1 - const result = tf.reverse3d(input, [0]); - expect(result.shape).toEqual([2, 3, 1]); - expectArraysClose(await result.data(), [4, 5, 6, 1, 2, 3]); - }); -}); - -describeWithFlags('reverse4d', ALL_ENVS, () => { - // [ - // [ - // [ - // [0, 1, 2, 3], - // [4, 5, 6, 7], - // [8, 9, 10, 11] - // ], - // [ - // [12, 13, 14, 15], - // [16, 17, 18, 19], - // [20, 21, 22, 23] - // ] - // ], - // [ - // [ - // [24, 25, 26, 27], - // [28, 29, 30, 31], - // [32, 33, 34, 35] - // ], - // [ - // [36, 37, 38, 39], - // [40, 41, 42, 43], - // [44, 45, 46, 47] - // ] - // ], - // [ - // [ - // [48, 49, 50, 51], - // [52, 53, 54, 55], - // [56, 57, 58, 59] - // ], - // [ - // [60, 61, 62, 63], - // [64, 65, 66, 67], - // [68, 69, 70, 71] - // ] - // ] - // ] - const shape: [number, number, number, number] = [3, 2, 3, 4]; - const data = [ - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, - 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, - 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, - 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71 - ]; - - it('reverse a 4D array at axis [0]', async () => { - const input = tf.tensor4d(data, shape); - const result = tf.reverse4d(input, [0]); - expect(result.shape).toEqual(input.shape); - expectArraysClose(await result.data(), [ - 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, - 66, 67, 68, 69, 70, 71, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, - 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 0, 1, 2, 3, 4, 5, - 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23 - ]); - }); - - it('reverse a 4D array at axis [1]', async () => { - const input = tf.tensor4d(data, shape); - const result = tf.reverse4d(input, [1]); - expect(result.shape).toEqual(input.shape); - expectArraysClose(await result.data(), [ - 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 0, 1, 2, 3, 4, 5, - 6, 7, 8, 9, 10, 11, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, - 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 60, 61, 62, 63, 64, 65, - 66, 67, 68, 69, 70, 71, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59 - ]); - }); - - it('reverse a 4D array at axis [2]', async () => { - const input = tf.tensor4d(data, shape); - const result = tf.reverse4d(input, [2]); - expect(result.shape).toEqual(input.shape); - expectArraysClose(await result.data(), [ - 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3, 20, 21, 22, 23, 16, 17, - 18, 19, 12, 13, 14, 15, 32, 33, 34, 35, 28, 29, 30, 31, 24, 25, 26, 27, - 44, 45, 46, 47, 40, 41, 42, 43, 36, 37, 38, 39, 56, 57, 58, 59, 52, 53, - 54, 55, 48, 49, 50, 51, 68, 69, 70, 71, 64, 65, 66, 67, 60, 61, 62, 63 - ]); - }); - - it('reverse a 4D array at axis [3]', async () => { - const input = tf.tensor4d(data, shape); - const result = tf.reverse4d(input, [3]); - expect(result.shape).toEqual(input.shape); - expectArraysClose(await result.data(), [ - 3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8, 15, 14, 13, 12, 19, 18, - 17, 16, 23, 22, 21, 20, 27, 26, 25, 24, 31, 30, 29, 28, 35, 34, 33, 32, - 39, 38, 37, 36, 43, 42, 41, 40, 47, 46, 45, 44, 51, 50, 49, 48, 55, 54, - 53, 52, 59, 58, 57, 56, 63, 62, 61, 60, 67, 66, 65, 64, 71, 70, 69, 68 - ]); - }); - - it('reverse a 4D array at axis [0, 2]', async () => { - const input = tf.tensor4d(data, shape); - const result = tf.reverse4d(input, [0, 2]); - expect(result.shape).toEqual(input.shape); - expectArraysClose(await result.data(), [ - 56, 57, 58, 59, 52, 53, 54, 55, 48, 49, 50, 51, 68, 69, 70, 71, 64, 65, - 66, 67, 60, 61, 62, 63, 32, 33, 34, 35, 28, 29, 30, 31, 24, 25, 26, 27, - 44, 45, 46, 47, 40, 41, 42, 43, 36, 37, 38, 39, 8, 9, 10, 11, 4, 5, - 6, 7, 0, 1, 2, 3, 20, 21, 22, 23, 16, 17, 18, 19, 12, 13, 14, 15 - ]); - }); - - it('reverse a 4D array at axis [1, 3]', async () => { - const input = tf.tensor4d(data, shape); - const result = tf.reverse4d(input, [1, 3]); - expect(result.shape).toEqual(input.shape); - expectArraysClose(await result.data(), [ - 15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20, 3, 2, 1, 0, 7, 6, - 5, 4, 11, 10, 9, 8, 39, 38, 37, 36, 43, 42, 41, 40, 47, 46, 45, 44, - 27, 26, 25, 24, 31, 30, 29, 28, 35, 34, 33, 32, 63, 62, 61, 60, 67, 66, - 65, 64, 71, 70, 69, 68, 51, 50, 49, 48, 55, 54, 53, 52, 59, 58, 57, 56 - ]); - }); - - it('throws error with invalid input', () => { - // tslint:disable-next-line:no-any - const x: any = tf.tensor3d([1, 20, 300, 4], [1, 1, 4]); - expect(() => tf.reverse4d(x, [1])).toThrowError(); - }); - - it('throws error with invalid axis param', () => { - const x = tf.tensor4d([1, 20, 300, 4], [1, 1, 1, 4]); - expect(() => tf.reverse4d(x, [4])).toThrowError(); - expect(() => tf.reverse4d(x, [-5])).toThrowError(); - }); - - it('throws error with non integer axis param', () => { - const x = tf.tensor4d([1, 20, 300, 4], [1, 1, 1, 4]); - expect(() => tf.reverse4d(x, [0.5])).toThrowError(); - }); - - it('accepts a tensor-like object', async () => { - const input = [[[[1]], [[2]], [[3]]], [[[4]], [[5]], [[6]]]]; // 2x3x1x1 - const result = tf.reverse4d(input, [0]); - expect(result.shape).toEqual([2, 3, 1, 1]); - expectArraysClose(await result.data(), [4, 5, 6, 1, 2, 3]); - }); -}); - describeWithFlags('reverse', ALL_ENVS, () => { it('throws when passed a non-tensor', () => { expect(() => tf.reverse({} as tf.Tensor)) diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts index 5c71468725f..3267fbc4f58 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts @@ -62,6 +62,7 @@ import './relu'; import './resize_bilinear'; import './resize_nearest_neighbor'; import './relu6'; +import './reverse'; import './selu'; import './separable_conv2d'; import './split'; diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts index a35227e811c..33a984536ff 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts @@ -72,6 +72,7 @@ const CHAINED_OPS = [ 'resizeBilinear', 'resizeNearestNeighbor', 'relu6', + 'reverse', 'selu', 'separableConv2d', 'spaceToBatchND', diff --git a/tfjs-core/src/public/chained_ops/reverse.ts b/tfjs-core/src/public/chained_ops/reverse.ts new file mode 100644 index 00000000000..d81fa1acd68 --- /dev/null +++ b/tfjs-core/src/public/chained_ops/reverse.ts @@ -0,0 +1,31 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {reverse} from '../../ops/reverse'; +import {Tensor} from '../../tensor'; +import {Rank} from '../../types'; + +declare module '../../tensor' { + interface Tensor { + reverse(this: T, axis?: number|number[]): T; + } +} + +Tensor.prototype.reverse = function( + this: T, axis?: number|number[]): T { + this.throwIfDisposed(); + return reverse(this, axis); +}; diff --git a/tfjs-core/src/register_all_gradients.ts b/tfjs-core/src/register_all_gradients.ts index 563fa8d81d7..0bbf7c0d7ee 100644 --- a/tfjs-core/src/register_all_gradients.ts +++ b/tfjs-core/src/register_all_gradients.ts @@ -51,6 +51,7 @@ import {relu6GradConfig} from './gradients/Relu6_grad'; import {reluGradConfig} from './gradients/Relu_grad'; import {resizeBilinearGradConfig} from './gradients/ResizeBilinear_grad'; import {resizeNearestNeighborGradConfig} from './gradients/ResizeNearestNeighbor_grad'; +import {reverseGradConfig} from './gradients/Reverse_grad'; import {selectV2PoolGradConfig} from './gradients/SelectV2_grad'; import {seluGradConfig} from './gradients/Selu_grad'; import {spaceToBatchNDGradConfig} from './gradients/SpaceToBatchND_grad'; @@ -107,6 +108,7 @@ const gradConfigs: GradConfig[] = [ resizeBilinearGradConfig, resizeNearestNeighborGradConfig, relu6GradConfig, + reverseGradConfig, seluGradConfig, selectV2PoolGradConfig, spaceToBatchNDGradConfig, diff --git a/tfjs-core/src/tensor.ts b/tfjs-core/src/tensor.ts index d0231ac3bf9..5dfe6e719d8 100644 --- a/tfjs-core/src/tensor.ts +++ b/tfjs-core/src/tensor.ts @@ -184,7 +184,6 @@ export interface OpHandler { keepDims: boolean): Tensor; slice>( x: T, begin: number|number[], size?: number|number[]): T; - reverse(x: T, axis?: number|number[]): T; stack(tensors: Array, axis: number): Tensor; unstack(value: T, axis: number): Tensor[]; all(x: Tensor, axis: number|number[], keepDims: boolean): T; @@ -678,10 +677,6 @@ export class Tensor { this.throwIfDisposed(); return opHandler.slice(this, begin, size); } - reverse(this: T, axis?: number|number[]): T { - this.throwIfDisposed(); - return opHandler.reverse(this, axis); - } stack(x: Tensor, axis = 0): Tensor { return opHandler.stack([this, x], axis); } diff --git a/tfjs-core/src/tests.ts b/tfjs-core/src/tests.ts index 3cc1ebe433a..e9a68fe0ac4 100644 --- a/tfjs-core/src/tests.ts +++ b/tfjs-core/src/tests.ts @@ -122,6 +122,10 @@ import './ops/relu6_test'; import './ops/relu_test'; import './ops/resize_bilinear_test'; import './ops/resize_nearest_neighbor_test'; +import './ops/reverse_1d_test'; +import './ops/reverse_2d_test'; +import './ops/reverse_3d_test'; +import './ops/reverse_4d_test'; import './ops/reverse_test'; import './ops/scatter_nd_test'; import './ops/segment_ops_test';