This repository has been archived by the owner on Aug 15, 2019. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 952
/
reverse.ts
133 lines (125 loc) · 4.09 KB
/
reverse.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
/**
* @license
* Copyright 2018 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {ENV} from '../environment';
import {Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D} from '../tensor';
import {convertToTensor} from '../tensor_util_env';
import {TensorLike} from '../types';
import * as util from '../util';
import {op} from './operation';
/**
* Reverses a `tf.Tensor1D`.
*
* @param x The input tensor.
*/
function reverse1d_(x: Tensor1D|TensorLike): Tensor1D {
const $x = convertToTensor(x, 'x', 'reverse');
util.assert(
$x.rank === 1,
() => `Error in reverse1D: x must be rank 1 but got rank ${$x.rank}.`);
return reverse($x, 0);
}
/**
* Reverses a `tf.Tensor2D` along a specified axis.
*
* @param x The input tensor.
* @param axis The set of dimensions to reverse. Must be in the
* range [-rank(x), rank(x)). Defaults to all axes.
*/
function reverse2d_(x: Tensor2D|TensorLike, axis?: number|number[]): Tensor2D {
const $x = convertToTensor(x, 'x', 'reverse');
util.assert(
$x.rank === 2,
() => `Error in reverse2D: x must be rank 2 but got rank ${$x.rank}.`);
return reverse($x, axis);
}
/**
* Reverses a `tf.Tensor3D` along a specified axis.
*
* @param x The input tensor.
* @param axis The set of dimensions to reverse. Must be in the
* range [-rank(x), rank(x)). Defaults to all axes.
*/
function reverse3d_(x: Tensor3D|TensorLike, axis?: number|number[]): Tensor3D {
const $x = convertToTensor(x, 'x', 'reverse');
util.assert(
$x.rank === 3,
() => `Error in reverse3D: x must be rank 3 but got rank ${$x.rank}.`);
return reverse($x, axis);
}
/**
* Reverses a `tf.Tensor4D` along a specified axis.
*
* @param x The input tensor.
* @param axis The set of dimensions to reverse. Must be in the
* range [-rank(x), rank(x)). Defaults to all axes.
*/
function reverse4d_(x: Tensor4D|TensorLike, axis?: number|number[]): Tensor4D {
const $x = convertToTensor(x, 'x', 'reverse');
util.assert(
$x.rank === 4,
() => `Error in reverse4D: x must be rank 4 but got rank ${$x.rank}.`);
return reverse($x, axis);
}
/**
* Reverses a `tf.Tensor` along a specified axis.
*
* Also available are stricter rank-specific methods that assert that `x` is
* of the given rank:
* - `tf.reverse1d`
* - `tf.reverse2d`
* - `tf.reverse3d`
* - `tf.reverse4d`
*
* Except `tf.reverse1d` (which does not have axis param), all methods have
* same signature as this method.
*
* ```js
* const x = tf.tensor1d([1, 2, 3, 4]);
*
* x.reverse().print();
* ```
*
* ```js
* const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
*
* const axis = 1;
* x.reverse(axis).print();
* ```
* @param x The input tensor to be reversed.
* @param axis The set of dimensions to reverse. Must be in the
* range [-rank(x), rank(x)). Defaults to all axes.
*/
/** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */
function reverse_<T extends Tensor>(
x: T|TensorLike, axis?: number|number[]): T {
const $x = convertToTensor(x, 'x', 'reverse');
if ($x.rank === 0) {
return $x.clone();
}
const axes = util.parseAxisParam(axis, $x.shape);
const grad = (dy: T) => {
return {$x: () => dy.reverse(axes)};
};
const res =
ENV.engine.runKernel(backend => backend.reverse($x, axes), {$x}, grad);
return res.reshapeAs($x);
}
export const reverse = op({reverse_});
export const reverse1d = op({reverse1d_});
export const reverse2d = op({reverse2d_});
export const reverse3d = op({reverse3d_});
export const reverse4d = op({reverse4d_});