diff --git a/tfjs-core/src/gradients/LogSoftmax_grad.ts b/tfjs-core/src/gradients/LogSoftmax_grad.ts new file mode 100644 index 00000000000..acc8e61000a --- /dev/null +++ b/tfjs-core/src/gradients/LogSoftmax_grad.ts @@ -0,0 +1,37 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {LogSoftmax, LogSoftmaxAttrs} from '../kernel_names'; +import {GradConfig, NamedAttrMap} from '../kernel_registry'; +import {Tensor} from '../tensor'; + +export const logSoftmaxGradConfig: GradConfig = { + kernelName: LogSoftmax, + inputsToSave: [], + outputsToSave: [true], + gradFunc: (dy: Tensor, saved: Tensor[], attrs: NamedAttrMap) => { + const [value] = saved; + const {axis} = attrs as {} as LogSoftmaxAttrs; + return { + logits: () => { + const keepDims = true; + const softmax = value.exp(); + return dy.sub(dy.sum(axis, keepDims).mul(softmax)); + } + }; + } +}; diff --git a/tfjs-core/src/gradients/Softmax_grad.ts b/tfjs-core/src/gradients/Softmax_grad.ts new file mode 100644 index 00000000000..c8bbedf6dbc --- /dev/null +++ b/tfjs-core/src/gradients/Softmax_grad.ts @@ -0,0 +1,38 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Softmax, SoftmaxAttrs} from '../kernel_names'; +import {GradConfig, NamedAttrMap} from '../kernel_registry'; +import {mul} from '../ops/mul'; +import {sub} from '../ops/sub'; +import {sum} from '../ops/sum'; +import {Tensor} from '../tensor'; + +export const softmaxGradConfig: GradConfig = { + kernelName: Softmax, + outputsToSave: [true], + gradFunc: (dy: Tensor, saved: Tensor[], attrs: NamedAttrMap) => { + const [y] = saved; + const {dim} = attrs as {} as SoftmaxAttrs; + const keepDims = true; + + const dyTimesY = mul(dy, y); + return { + logits: () => sub(dyTimesY, mul(sum(dyTimesY, [dim], keepDims), y)) + }; + } +}; diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index d9150a040b2..2020d63ee8a 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -376,6 +376,12 @@ export type LogicalNotInputs = Pick; export const LogicalOr = 'LogicalOr'; export type LogicalOrInputs = BinaryInputs; +export const LogSoftmax = 'LogSoftmax'; +export type LogSoftmaxInputs = Pick; +export interface LogSoftmaxAttrs { + axis: number; +} + export const LRN = 'LRN'; export type LRNInputs = Pick; export interface LRNAttrs { @@ -637,6 +643,12 @@ export interface SplitVAttrs { axis: number; } +export const Softmax = 'Softmax'; +export type SoftmaxInputs = Pick; +export interface SoftmaxAttrs { + dim: number; +} + export const SquaredDifference = 'SquaredDifference'; export type SquaredDifferenceInputs = BinaryInputs; @@ -646,6 +658,26 @@ export type SquareInputs = Pick; export const Sub = 'Sub'; export type SubInputs = BinaryInputs; +export const SparseToDense = 'SparseToDense'; +export type SparseToDenseInputs = + Pick; +export interface SparseToDenseAttrs { + outputShape: number[]; +} + +export const StridedSlice = 'StridedSlice'; +export type StridedSliceInputs = Pick; +export interface StridedSliceAttrs { + begin: number[]; + end: number[]; + strides: number[]; + beginMask: number; + endMask: number; + ellipsisMask: number; + newAxisMask: number; + shrinkAxisMask: number; +} + export const Tan = 'Tan'; export type TanInputs = UnaryInputs; @@ -658,6 +690,13 @@ export interface TileAttrs { reps: number[]; } +export const TopK = 'TopK'; +export type TopKInputs = Pick; +export interface TopKAttrs { + k: number; + sorted: boolean; +} + export const Transpose = 'Transpose'; export type TransposeInputs = Pick; export interface TransposeAttrs { diff --git a/tfjs-core/src/ops/log_softmax.ts b/tfjs-core/src/ops/log_softmax.ts new file mode 100644 index 00000000000..e39a9d15eb9 --- /dev/null +++ b/tfjs-core/src/ops/log_softmax.ts @@ -0,0 +1,80 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE, ForwardFunc} from '../engine'; +import {LogSoftmax, LogSoftmaxAttrs, LogSoftmaxInputs} from '../kernel_names'; +import {NamedAttrMap} from '../kernel_registry'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {max} from './max'; +import {op} from './operation'; +import {sub} from './sub'; + +/** + * Computes the log softmax. + * + * ```js + * const a = tf.tensor1d([1, 2, 3]); + * + * a.logSoftmax().print(); // or tf.logSoftmax(a) + * ``` + * + * ```js + * const a = tf.tensor2d([2, 4, 6, 1, 2, 3], [2, 3]); + * + * a.logSoftmax().print(); // or tf.logSoftmax(a) + * ``` + * + * @param logits The logits array. + * @param axis The dimension softmax would be performed on. Defaults to `-1` + * which indicates the last dimension. + */ +/** @doc {heading: 'Operations', subheading: 'Normalization'} */ +function logSoftmax_(logits: T|TensorLike, axis = -1): T { + const $logits = convertToTensor(logits, 'logits', 'logSoftmax'); + + if (axis === -1) { + axis = $logits.rank - 1; + } + if (axis !== $logits.rank - 1) { + throw Error( + 'Log Softmax along a non-last dimension is not yet supported. ' + + `Logits was rank ${$logits.rank} and axis was ${axis}`); + } + + const forward: ForwardFunc = (backend, save) => { + const keepDims = true; + const xMax = max(logits, axis, true); + const shifted = sub(logits, xMax); + const value = + shifted.toFloat().sub(shifted.exp().sum(axis, keepDims).log()); + save([value]); + return value; + }; + + const inputs: LogSoftmaxInputs = {logits: $logits}; + const attrs: LogSoftmaxAttrs = {axis}; + + return ENGINE.runKernelFunc( + forward, inputs as {} as NamedTensorMap, null /* grad */, + LogSoftmax, attrs as {} as NamedAttrMap) as T; +} + +export const logSoftmax = op({logSoftmax_}); diff --git a/tfjs-core/src/ops/log_softmax_test.ts b/tfjs-core/src/ops/log_softmax_test.ts new file mode 100644 index 00000000000..56c9b442e07 --- /dev/null +++ b/tfjs-core/src/ops/log_softmax_test.ts @@ -0,0 +1,84 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('logSoftmax', ALL_ENVS, () => { + it('regular test', async () => { + const y = tf.logSoftmax(tf.tensor1d([2, 1, 3])); + + expectArraysClose(await y.data(), [-1.407606, -2.4076061, -0.407606]); + }); + + it('Huge difference', async () => { + const y = tf.logSoftmax(tf.tensor1d([-1000, +1000])); + + expectArraysClose(await y.data(), [-2000, 0]); + }); + + it('Propagates NaNs', async () => { + const a = tf.tensor1d([2, 1, NaN]); + const y = tf.logSoftmax(a); + expectArraysClose(await y.data(), [NaN, NaN, NaN]); + }); + + it('2D, axis=1', async () => { + const y = tf.logSoftmax(tf.tensor2d([[2, 1, 3], [1, 3, 2]], [2, 3]), 1); + const expected = + [-1.407606, -2.4076061, -0.407606, -2.4076061, -0.4076061, -1.4076061]; + expect(y.rank).toBe(2); + expectArraysClose(await y.data(), expected); + }); + + it('2D, implicit axis=1', async () => { + const y = tf.logSoftmax(tf.tensor2d([[2, 1, 3], [1, 3, 2]], [2, 3])); + const expected = + [-1.407606, -2.4076061, -0.407606, -2.4076061, -0.4076061, -1.4076061]; + expect(y.rank).toBe(2); + expectArraysClose(await y.data(), expected); + }); + + it('1D gradient', async () => { + const x = tf.tensor1d([1, 2, 10]); + const dy = tf.tensor1d([1, 2, 3]); + const dx = tf.grad((x) => x.logSoftmax())(x, dy); + + expect(dx.shape).toEqual(x.shape); + expectArraysClose(await dx.data(), [0.9992599, 1.9979881, -2.9972477]); + }); + + it('2D, axis=0 throws error', () => { + const f = () => { + tf.logSoftmax(tf.tensor2d([[2, 1, 3], [1, 3, 2]], [2, 3]), 0); + }; + expect(f).toThrowError(); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.logSoftmax({} as tf.Tensor)) + .toThrowError( + /Argument 'logits' passed to 'logSoftmax' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const y = tf.logSoftmax([2, 1, 3]); + + expectArraysClose(await y.data(), [-1.407606, -2.4076061, -0.407606]); + }); +}); diff --git a/tfjs-core/src/ops/ops.ts b/tfjs-core/src/ops/ops.ts index ba0663e9683..91c36a2fd63 100644 --- a/tfjs-core/src/ops/ops.ts +++ b/tfjs-core/src/ops/ops.ts @@ -87,6 +87,7 @@ export {localResponseNormalization} from './local_response_normalization'; export {log} from './log'; export {log1p} from './log1p'; export {logSigmoid} from './log_sigmoid'; +export {logSoftmax} from './log_softmax'; export {logSumExp} from './log_sum_exp'; export {logicalAnd} from './logical_and'; export {logicalNot} from './logical_not'; @@ -149,12 +150,14 @@ export {slice1d} from './slice1d'; export {slice2d} from './slice2d'; export {slice3d} from './slice3d'; export {slice4d} from './slice4d'; +export {softmax} from './softmax'; export {spaceToBatchND} from './space_to_batch_nd'; export {split} from './split'; export {square} from './square'; export {squaredDifference} from './squared_difference'; export {squeeze} from './squeeze'; export {stack} from './stack'; +export {stridedSlice} from './strided_slice'; export {sub} from './sub'; export {sum} from './sum'; export {tan} from './tan'; @@ -167,6 +170,7 @@ export {tensor4d} from './tensor4d'; export {tensor5d} from './tensor5d'; export {tensor6d} from './tensor6d'; export {tile} from './tile'; +export {topk} from './topk'; export {truncatedNormal} from './truncated_normal'; export {unsortedSegmentSum} from './unsorted_segment_sum'; export {unstack} from './unstack'; @@ -181,11 +185,8 @@ export * from './unary_ops'; export * from './compare'; export * from './binary_ops'; export * from './transpose'; -export * from './softmax'; export * from './norm'; export * from './moving_average'; -export * from './strided_slice'; -export * from './topk'; export * from './scatter_nd'; export * from './spectral_ops'; export * from './sparse_to_dense'; diff --git a/tfjs-core/src/ops/softmax.ts b/tfjs-core/src/ops/softmax.ts index 9e8fc5ec294..568f23cdec1 100644 --- a/tfjs-core/src/ops/softmax.ts +++ b/tfjs-core/src/ops/softmax.ts @@ -16,9 +16,10 @@ */ import {ENGINE} from '../engine'; -import {customGrad} from '../gradients'; +import {Softmax, SoftmaxAttrs, SoftmaxInputs} from '../kernel_names'; +import {NamedAttrMap} from '../kernel_registry'; import {Tensor} from '../tensor'; -import {GradSaveFunc} from '../tensor_types'; +import {NamedTensorMap} from '../tensor_types'; import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; @@ -56,8 +57,8 @@ function softmax_(logits: T|TensorLike, dim = -1): T { `Logits was rank ${$logits.rank} and dim was ${dim}`); } - const inputsToSave: Tensor[] = []; - const outputsToSave = [true]; + const inputs: SoftmaxInputs = {logits: $logits}; + const attrs: SoftmaxAttrs = {dim}; return ENGINE.runKernelFunc( (backend, save) => { @@ -65,69 +66,8 @@ function softmax_(logits: T|TensorLike, dim = -1): T { save([y]); return y; }, - {logits: $logits}, - (dy: T, saved: Tensor[]) => { - const [y] = saved; - const dyTimesY = dy.mul(y); - const keepDims = true; - - return { - logits: () => dyTimesY.sub(dyTimesY.sum([dim], keepDims).mul(y)) - }; - }, - 'Softmax', {dim}, inputsToSave, outputsToSave); -} - -/** - * Computes the log softmax. - * - * ```js - * const a = tf.tensor1d([1, 2, 3]); - * - * a.logSoftmax().print(); // or tf.logSoftmax(a) - * ``` - * - * ```js - * const a = tf.tensor2d([2, 4, 6, 1, 2, 3], [2, 3]); - * - * a.logSoftmax().print(); // or tf.logSoftmax(a) - * ``` - * - * @param logits The logits array. - * @param axis The dimension softmax would be performed on. Defaults to `-1` - * which indicates the last dimension. - */ -/** @doc {heading: 'Operations', subheading: 'Normalization'} */ -function logSoftmax_(logits: T|TensorLike, axis = -1): T { - const $logits = convertToTensor(logits, 'logits', 'logSoftmax'); - - if (axis === -1) { - axis = $logits.rank - 1; - } - if (axis !== $logits.rank - 1) { - throw Error( - 'Log Softmax along a non-last dimension is not yet supported. ' + - `Logits was rank ${$logits.rank} and axis was ${axis}`); - } - - const customOp = customGrad((logits: Tensor, save: GradSaveFunc) => { - const keepDims = true; - const xMax = logits.max(axis, true); - const shifted = logits.sub(xMax); - const value = - shifted.toFloat().sub(shifted.exp().sum(axis, keepDims).log()); - save([value]); - const gradFunc = (dy: T, saved: Tensor[]) => { - const [value] = saved; - const softmax = value.exp(); - return dy.sub(dy.sum(axis, keepDims).mul(softmax)); - }; - - return {value, gradFunc}; - }); - - return customOp($logits) as T; + inputs as {} as NamedTensorMap, null /* grad */, Softmax, + attrs as {} as NamedAttrMap); } export const softmax = op({softmax_}); -export const logSoftmax = op({logSoftmax_}); diff --git a/tfjs-core/src/ops/softmax_test.ts b/tfjs-core/src/ops/softmax_test.ts index 66f0cd5a804..d916ba76f2c 100644 --- a/tfjs-core/src/ops/softmax_test.ts +++ b/tfjs-core/src/ops/softmax_test.ts @@ -138,67 +138,3 @@ describeWithFlags('softmax', ALL_ENVS, () => { expectArraysClose(await y.sum().data(), 1); }); }); - -describeWithFlags('logSoftmax', ALL_ENVS, () => { - it('regular test', async () => { - const y = tf.logSoftmax(tf.tensor1d([2, 1, 3])); - - expectArraysClose(await y.data(), [-1.407606, -2.4076061, -0.407606]); - }); - - it('Huge difference', async () => { - const y = tf.logSoftmax(tf.tensor1d([-1000, +1000])); - - expectArraysClose(await y.data(), [-2000, 0]); - }); - - it('Propagates NaNs', async () => { - const a = tf.tensor1d([2, 1, NaN]); - const y = tf.logSoftmax(a); - expectArraysClose(await y.data(), [NaN, NaN, NaN]); - }); - - it('2D, axis=1', async () => { - const y = tf.logSoftmax(tf.tensor2d([[2, 1, 3], [1, 3, 2]], [2, 3]), 1); - const expected = - [-1.407606, -2.4076061, -0.407606, -2.4076061, -0.4076061, -1.4076061]; - expect(y.rank).toBe(2); - expectArraysClose(await y.data(), expected); - }); - - it('2D, implicit axis=1', async () => { - const y = tf.logSoftmax(tf.tensor2d([[2, 1, 3], [1, 3, 2]], [2, 3])); - const expected = - [-1.407606, -2.4076061, -0.407606, -2.4076061, -0.4076061, -1.4076061]; - expect(y.rank).toBe(2); - expectArraysClose(await y.data(), expected); - }); - - it('1D gradient', async () => { - const x = tf.tensor1d([1, 2, 10]); - const dy = tf.tensor1d([1, 2, 3]); - const dx = tf.grad((x) => x.logSoftmax())(x, dy); - - expect(dx.shape).toEqual(x.shape); - expectArraysClose(await dx.data(), [0.9992599, 1.9979881, -2.9972477]); - }); - - it('2D, axis=0 throws error', () => { - const f = () => { - tf.logSoftmax(tf.tensor2d([[2, 1, 3], [1, 3, 2]], [2, 3]), 0); - }; - expect(f).toThrowError(); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.logSoftmax({} as tf.Tensor)) - .toThrowError( - /Argument 'logits' passed to 'logSoftmax' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const y = tf.logSoftmax([2, 1, 3]); - - expectArraysClose(await y.data(), [-1.407606, -2.4076061, -0.407606]); - }); -}); diff --git a/tfjs-core/src/ops/sparse_to_dense.ts b/tfjs-core/src/ops/sparse_to_dense.ts index 6184845007d..a2022e31caa 100644 --- a/tfjs-core/src/ops/sparse_to_dense.ts +++ b/tfjs-core/src/ops/sparse_to_dense.ts @@ -16,10 +16,14 @@ */ import {ENGINE} from '../engine'; +import {SparseToDense, SparseToDenseAttrs, SparseToDenseInputs} from '../kernel_names'; +import {NamedAttrMap} from '../kernel_registry'; import * as sparse_to_dense from '../ops/sparse_to_dense_util'; import {Scalar, Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; import {convertToTensor} from '../tensor_util_env'; import {Rank, ScalarLike, ShapeMap, TensorLike} from '../types'; + import {op} from './operation'; /** @@ -72,10 +76,19 @@ function sparseToDense_( sparse_to_dense.validateInput( $sparseIndices, $sparseValues, outputShape, $defaultValue); + const inputs: SparseToDenseInputs = { + sparseIndices: $sparseIndices, + sparseValues: $sparseValues, + defaultValue: $defaultValue + }; + + const attrs: SparseToDenseAttrs = {outputShape}; + return ENGINE.runKernelFunc( backend => backend.sparseToDense( $sparseIndices, $sparseValues, outputShape, $defaultValue), - {$sparseIndices, $sparseValues, $defaultValue}); + inputs as {} as NamedTensorMap, null /* grad */, SparseToDense, + attrs as {} as NamedAttrMap); } export const sparseToDense = op({sparseToDense_}); diff --git a/tfjs-core/src/ops/strided_slice.ts b/tfjs-core/src/ops/strided_slice.ts index 9ad1b855e60..cfba1f383ef 100644 --- a/tfjs-core/src/ops/strided_slice.ts +++ b/tfjs-core/src/ops/strided_slice.ts @@ -15,12 +15,16 @@ * ============================================================================= */ -import {ENGINE} from '../engine'; +import {ENGINE, ForwardFunc} from '../engine'; +import {StridedSlice, StridedSliceAttrs, StridedSliceInputs} from '../kernel_names'; +import {NamedAttrMap} from '../kernel_registry'; import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import {op} from './operation'; +import {reshape} from './reshape'; import {slice} from './slice'; import {computeOutShape, maskToAxes, startForAxis, startIndicesWithElidedDims, stopForAxis, stopIndicesWithElidedDims, stridesForAxis, stridesWithElidedDims} from './slice_util'; @@ -64,76 +68,95 @@ function stridedSlice_( strides = new Array(begin.length); } - const ellipsisAxes = maskToAxes(ellipsisMask); - if (ellipsisAxes.length > 1) { - throw new Error('Multiple ellipses in slice is not allowed.'); - } + let $x = convertToTensor(x, 'x', 'stridedSlice'); - if (ellipsisMask !== 0 && newAxisMask !== 0) { - throw new Error( - 'Using both ellipsisMask and newAxisMask is not yet supported.'); - } + const forward: ForwardFunc = (backend) => { + const ellipsisAxes = maskToAxes(ellipsisMask); + if (ellipsisAxes.length > 1) { + throw new Error('Multiple ellipses in slice is not allowed.'); + } - if (ellipsisMask !== 0 && shrinkAxisMask !== 0) { - throw new Error( - 'Using both ellipsisMask and shrinkAxisMask is not yet supported.'); - } + if (ellipsisMask !== 0 && newAxisMask !== 0) { + throw new Error( + 'Using both ellipsisMask and newAxisMask is not yet supported.'); + } - let $x = convertToTensor(x, 'x', 'stridedSlice'); - const numInterpolatedAxes = $x.rank - begin.length; - - // Expand the dims of x based on the newAxisMask. - const expandAxes = maskToAxes(newAxisMask); - const newShape = $x.shape.slice(); - expandAxes.forEach(axis => { - begin[axis] = 0; - end[axis] = 1; - newShape.splice(axis, 0, 1); - }); - $x = $x.reshape(newShape); - - // Normalize the start, end and strides. - if (ellipsisAxes.length && numInterpolatedAxes > 0) { - const fullIndex = ellipsisAxes[0]; - - // The ellipsis applies to the masked index as well as any dimensions - // that are interpolated. - const numElidedAxes = numInterpolatedAxes + 1; - begin = startIndicesWithElidedDims( - beginMask, fullIndex, numElidedAxes, begin, $x.shape); - end = stopIndicesWithElidedDims( - endMask, fullIndex, numElidedAxes, end, $x.shape); - strides = - stridesWithElidedDims(strides, fullIndex, numElidedAxes, $x.shape); - } else { - for (let axis = 0; axis < $x.rank; axis++) { - begin[axis] = - startForAxis(beginMask, begin, strides, $x.shape, axis, ellipsisMask); - end[axis] = - stopForAxis(endMask, end, strides, $x.shape, axis, ellipsisMask); - strides[axis] = stridesForAxis(strides, axis, ellipsisMask); + if (ellipsisMask !== 0 && shrinkAxisMask !== 0) { + throw new Error( + 'Using both ellipsisMask and shrinkAxisMask is not yet supported.'); } - } - const shrinkAxes = maskToAxes(shrinkAxisMask); - // Adjust the ends based on the shrink mask. - shrinkAxes.forEach(axis => { - end[axis] = begin[axis] + 1; - strides[axis] = 1; - }); - - // Figure out the output shape. - const size = computeOutShape(begin, end, strides); - // Remove the axes based on shrinkMask. - const outShape = size.filter((_, axis) => shrinkAxes.indexOf(axis) === -1); - - const nonStrided = strides.every(v => v === 1); - if (nonStrided) { - return slice($x, begin, size).reshape(outShape); - } - const res = ENGINE.runKernelFunc( - backend => backend.stridedSlice($x, begin, end, strides), {$x}); - return res.reshape(outShape); + const numInterpolatedAxes = $x.rank - begin.length; + + // Expand the dims of x based on the newAxisMask. + const expandAxes = maskToAxes(newAxisMask); + const newShape = $x.shape.slice(); + expandAxes.forEach(axis => { + begin[axis] = 0; + end[axis] = 1; + newShape.splice(axis, 0, 1); + }); + $x = reshape($x, newShape); + + // Normalize the start, end and strides. + if (ellipsisAxes.length && numInterpolatedAxes > 0) { + const fullIndex = ellipsisAxes[0]; + + // The ellipsis applies to the masked index as well as any dimensions + // that are interpolated. + const numElidedAxes = numInterpolatedAxes + 1; + begin = startIndicesWithElidedDims( + beginMask, fullIndex, numElidedAxes, begin, $x.shape); + end = stopIndicesWithElidedDims( + endMask, fullIndex, numElidedAxes, end, $x.shape); + strides = + stridesWithElidedDims(strides, fullIndex, numElidedAxes, $x.shape); + } else { + for (let axis = 0; axis < $x.rank; axis++) { + begin[axis] = startForAxis( + beginMask, begin, strides, $x.shape, axis, ellipsisMask); + end[axis] = + stopForAxis(endMask, end, strides, $x.shape, axis, ellipsisMask); + strides[axis] = stridesForAxis(strides, axis, ellipsisMask); + } + } + + const shrinkAxes = maskToAxes(shrinkAxisMask); + // Adjust the ends based on the shrink mask. + shrinkAxes.forEach(axis => { + end[axis] = begin[axis] + 1; + strides[axis] = 1; + }); + + // Figure out the output shape. + const size = computeOutShape(begin, end, strides); + // Remove the axes based on shrinkMask. + const outShape = size.filter((_, axis) => shrinkAxes.indexOf(axis) === -1); + + const nonStrided = strides.every(v => v === 1); + if (nonStrided) { + return reshape(slice($x, begin, size), outShape); + } + + const res = backend.stridedSlice($x, begin, end, strides); + return res.reshape(outShape); + }; + + const inputs: StridedSliceInputs = {x: $x}; + const attrs: StridedSliceAttrs = { + begin, + end, + strides, + beginMask, + endMask, + ellipsisMask, + newAxisMask, + shrinkAxisMask + }; + + return ENGINE.runKernelFunc( + forward, inputs as {} as NamedTensorMap, null /* grad */, StridedSlice, + attrs as {} as NamedAttrMap); } export const stridedSlice = op({stridedSlice_}); diff --git a/tfjs-core/src/ops/topk.ts b/tfjs-core/src/ops/topk.ts index 6f8ba88fb0b..50dc9d87b39 100644 --- a/tfjs-core/src/ops/topk.ts +++ b/tfjs-core/src/ops/topk.ts @@ -16,9 +16,13 @@ */ import {ENGINE} from '../engine'; +import {TopK, TopKAttrs, TopKInputs} from '../kernel_names'; +import {NamedAttrMap} from '../kernel_registry'; import {NumericTensor, Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; + import {op} from './operation'; /** @@ -57,8 +61,14 @@ function topk_( `but got ${k}`); } - const [values, indices] = - ENGINE.runKernelFunc(b => b.topk($x as NumericTensor, k, sorted), {$x}); + const inputs: TopKInputs = {x: $x}; + const attrs: TopKAttrs = {k, sorted}; + + const [values, indices] = ENGINE.runKernelFunc( + b => b.topk($x as NumericTensor, k, sorted), + inputs as {} as NamedTensorMap, null /* grad */, TopK, + attrs as {} as NamedAttrMap); + return {values, indices} as {values: T, indices: T}; } diff --git a/tfjs-core/src/ops/transpose.ts b/tfjs-core/src/ops/transpose.ts index c44d39485a5..6f4dee3d661 100644 --- a/tfjs-core/src/ops/transpose.ts +++ b/tfjs-core/src/ops/transpose.ts @@ -16,10 +16,14 @@ */ import {ENGINE} from '../engine'; +import {Transpose, TransposeAttrs, TransposeInputs} from '../kernel_names'; +import {NamedAttrMap} from '../kernel_registry'; import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; + import {op} from './operation'; /** @@ -61,10 +65,12 @@ function transpose_(x: T|TensorLike, perm?: number[]): T { return $x.clone(); } - const attrs = {perm}; + const inputs: TransposeInputs = {x: $x}; + const attrs: TransposeAttrs = {perm}; + return ENGINE.runKernelFunc( - backend => backend.transpose($x, perm), {x: $x}, null /* gradient */, - 'Transpose', attrs); + backend => backend.transpose($x, perm), inputs as {} as NamedTensorMap, + null /* gradient */, Transpose, attrs as {} as NamedAttrMap); } export const transpose = op({transpose_}); diff --git a/tfjs-core/src/register_all_gradients.ts b/tfjs-core/src/register_all_gradients.ts index 1fde3d29c1f..9caa7a6fbb9 100644 --- a/tfjs-core/src/register_all_gradients.ts +++ b/tfjs-core/src/register_all_gradients.ts @@ -56,6 +56,7 @@ import {greaterEqualGradConfig} from './gradients/GreaterEqual_grad'; import {identityGradConfig} from './gradients/Identity_grad'; import {log1pGradConfig} from './gradients/Log1p_grad'; import {logGradConfig} from './gradients/Log_grad'; +import {logSoftmaxGradConfig} from './gradients/LogSoftmax_grad'; import {lrnGradConfig} from './gradients/LRN_grad'; import {maxGradConfig} from './gradients/Max_grad'; import {maximumGradConfig} from './gradients/Maximum_grad'; @@ -84,6 +85,7 @@ import {signGradConfig} from './gradients/Sign_grad'; import {sinGradConfig} from './gradients/Sin_grad'; import {sinhGradConfig} from './gradients/Sinh_grad'; import {sliceGradConfig} from './gradients/Slice_grad'; +import {softmaxGradConfig} from './gradients/Softmax_grad'; import {spaceToBatchNDGradConfig} from './gradients/SpaceToBatchND_grad'; import {splitVGradConfig} from './gradients/SplitV_grad'; import {squareGradConfig} from './gradients/Square_grad'; @@ -145,6 +147,7 @@ const gradConfigs: GradConfig[] = [ log1pGradConfig, logGradConfig, lrnGradConfig, + logSoftmaxGradConfig, maxGradConfig, maxGradConfig, maximumGradConfig, @@ -182,6 +185,7 @@ const gradConfigs: GradConfig[] = [ squareGradConfig, subGradConfig, sumGradConfig, + softmaxGradConfig, tanGradConfig, tanhGradConfig, tileGradConfig, diff --git a/tfjs-core/src/tests.ts b/tfjs-core/src/tests.ts index a69dec5f7a5..5b2cbfcbaff 100644 --- a/tfjs-core/src/tests.ts +++ b/tfjs-core/src/tests.ts @@ -121,6 +121,7 @@ import './ops/local_response_normalization_test'; import './ops/log1p_test'; import './ops/log_loss_test'; import './ops/log_sigmoid_test'; +import './ops/log_softmax_test'; import './ops/log_sum_exp_test'; import './ops/log_test'; import './ops/logical_and_test';