Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions tfjs-core/src/gradients/GreaterEqual_grad.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {GreaterEqual} from '../kernel_names';
import {GradConfig} from '../kernel_registry';
import {zerosLike} from '../ops/tensor_ops';
import {Tensor} from '../tensor';

export const greaterEqualGradConfig: GradConfig = {
kernelName: GreaterEqual,
inputsToSave: ['a', 'b'],
gradFunc: (dy: Tensor, saved: Tensor[]) => {
const [a, b] = saved;
return {a: () => zerosLike(a), b: () => zerosLike(b)};
}
};
3 changes: 3 additions & 0 deletions tfjs-core/src/kernel_names.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ export interface FusedBatchNormAttrs {
export const Greater = 'Greater';
export type GreaterInputs = BinaryInputs;

export const GreaterEqual = 'GreaterEqual';
export type GreaterEqualInputs = BinaryInputs;

export const Identity = 'Identity';
export type IdentityInputs = Pick<NamedTensorInfoMap, 'x'>;

Expand Down
41 changes: 0 additions & 41 deletions tfjs-core/src/ops/compare.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,11 @@
* limitations under the License.
* =============================================================================
*/

import {ENGINE} from '../engine';
import {Tensor} from '../tensor';
import {makeTypesMatch} from '../tensor_util';
import {convertToTensor} from '../tensor_util_env';
import {TensorLike} from '../types';
import {assertShapesMatch} from '../util';
import {assertAndGetBroadcastShape} from './broadcast_util';
import {op} from './operation';
import {zerosLike} from './tensor_ops';

/**
* Strict version of `tf.notEqual` that forces `a` and `b` to be of the same
Expand Down Expand Up @@ -78,41 +73,6 @@ function greaterStrict_<T extends Tensor>(a: T|TensorLike, b: T|TensorLike): T {
return $a.greater($b);
}

/**
* Returns the truth value of (a >= b) element-wise. Supports broadcasting.
*
* We also expose `tf.greaterEqualStrict` which has the same signature as this
* op and asserts that `a` and `b` are the same shape (does not broadcast).
*
* ```js
* const a = tf.tensor1d([1, 2, 3]);
* const b = tf.tensor1d([2, 2, 2]);
*
* a.greaterEqual(b).print();
* ```
*
* @param a The first input tensor.
* @param b The second input tensor. Must have the same dtype as `a`.
*/
/** @doc {heading: 'Operations', subheading: 'Logical'} */
function greaterEqual_<T extends Tensor>(
a: Tensor|TensorLike, b: Tensor|TensorLike): T {
let $a = convertToTensor(a, 'a', 'greaterEqual');
let $b = convertToTensor(b, 'b', 'greaterEqual');
[$a, $b] = makeTypesMatch($a, $b);
assertAndGetBroadcastShape($a.shape, $b.shape);

const grad = (dy: T, saved: Tensor[]) => {
const [$a, $b] = saved;
return {a: () => zerosLike($a), b: () => zerosLike($b)};
};
return ENGINE.runKernelFunc((backend, save) => {
const res = backend.greaterEqual($a, $b);
save([$a, $b]);
return res;
}, {a: $a, b: $b}, grad, 'GreaterEqual') as T;
}

function greaterEqualStrict_<T extends Tensor>(
a: T|TensorLike, b: T|TensorLike): T {
const $a = convertToTensor(a, 'a', 'greaterEqualStrict');
Expand All @@ -122,7 +82,6 @@ function greaterEqualStrict_<T extends Tensor>(
}

export const equalStrict = op({equalStrict_});
export const greaterEqual = op({greaterEqual_});
export const greaterEqualStrict = op({greaterEqualStrict_});
export const greaterStrict = op({greaterStrict_});
export const lessEqualStrict = op({lessEqualStrict_});
Expand Down
66 changes: 66 additions & 0 deletions tfjs-core/src/ops/greater_equal.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {ENGINE, ForwardFunc} from '../engine';
import {GreaterEqual, GreaterEqualInputs} from '../kernel_names';
import {Tensor} from '../tensor';
import {NamedTensorMap} from '../tensor_types';
import {makeTypesMatch} from '../tensor_util';
import {convertToTensor} from '../tensor_util_env';
import {TensorLike} from '../types';

import {assertAndGetBroadcastShape} from './broadcast_util';
import {op} from './operation';

/**
* Returns the truth value of (a >= b) element-wise. Supports broadcasting.
*
* We also expose `tf.greaterEqualStrict` which has the same signature as this
* op and asserts that `a` and `b` are the same shape (does not broadcast).
*
* ```js
* const a = tf.tensor1d([1, 2, 3]);
* const b = tf.tensor1d([2, 2, 2]);
*
* a.greaterEqual(b).print();
* ```
*
* @param a The first input tensor.
* @param b The second input tensor. Must have the same dtype as `a`.
*/
/** @doc {heading: 'Operations', subheading: 'Logical'} */
function greaterEqual_<T extends Tensor>(
a: Tensor|TensorLike, b: Tensor|TensorLike): T {
let $a = convertToTensor(a, 'a', 'greaterEqual');
let $b = convertToTensor(b, 'b', 'greaterEqual');
[$a, $b] = makeTypesMatch($a, $b);

assertAndGetBroadcastShape($a.shape, $b.shape);

const forward: ForwardFunc<Tensor> = (backend, save) => {
const res = backend.greaterEqual($a, $b);
save([$a, $b]);
return res;
};

const inputs: GreaterEqualInputs = {a: $a, b: $b};

return ENGINE.runKernelFunc(
forward, inputs as {} as NamedTensorMap, null /* grad */,
GreaterEqual) as T;
}

export const greaterEqual = op({greaterEqual_});
Loading