From 5f3ce11d1f34ec8efcfdfe2c29f21d5a670db1de Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 3 Jun 2020 10:51:54 -0400 Subject: [PATCH 1/4] basic --- tfjs-backend-wasm/src/kernels/Fill.ts | 35 +++++++++++++++++++++++++++ tfjs-core/src/kernel_names.ts | 9 ++++++- tfjs-core/src/ops/tensor_ops.ts | 9 +++++-- 3 files changed, 50 insertions(+), 3 deletions(-) create mode 100644 tfjs-backend-wasm/src/kernels/Fill.ts diff --git a/tfjs-backend-wasm/src/kernels/Fill.ts b/tfjs-backend-wasm/src/kernels/Fill.ts new file mode 100644 index 00000000000..68e17f56032 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/Fill.ts @@ -0,0 +1,35 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelFunc, registerKernel} from '@tensorflow/tfjs-core'; +import {Fill, FillAttrs} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; + +function fill(args: {attrs: FillAttrs, backend: BackendWasm}) { + const {attrs: {shape, value, dtype}, backend} = args; + const out = backend.makeOutput(shape, dtype); + const outVals = backend.typedArrayFromHeap(out); + outVals.fill(value); + return out; +} + +registerKernel({ + kernelName: Fill, + backendName: 'wasm', + kernelFunc: fill as {} as KernelFunc, +}); diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index 4d92b597b71..ac0805d4b59 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -19,7 +19,7 @@ // Unfortunately just enabling PascalCase per file (tslint:enable: // allow-pascal-case) doesn't work. import {NamedTensorInfoMap, TensorInfo} from './kernel_registry'; -import {PixelData} from './types'; +import {DataType, PixelData} from './types'; export const Add = 'Add'; export type AddInputs = BinaryInputs; @@ -44,6 +44,13 @@ export interface AvgPoolBackpropAttrs { pad: 'valid'|'same'|number; } +export const Fill = 'Fill'; +export interface FillAttrs { + shape: number[]; + value: number|string; + dtype: DataType; +} + export const AvgPool3D = 'AvgPool3D'; export type AvgPool3DInputs = Pick; export interface AvgPool3DAttrs { diff --git a/tfjs-core/src/ops/tensor_ops.ts b/tfjs-core/src/ops/tensor_ops.ts index b3b9ad7ea6b..b6171a49b07 100644 --- a/tfjs-core/src/ops/tensor_ops.ts +++ b/tfjs-core/src/ops/tensor_ops.ts @@ -17,7 +17,8 @@ import {ENGINE} from '../engine'; import {env} from '../environment'; - +import {Fill, FillAttrs} from '../kernel_names'; +import {NamedAttrMap} from '../kernel_registry'; import {Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, Tensor6D, Variable} from '../tensor'; import {convertToTensor, inferShape} from '../tensor_util_env'; import {TensorLike, TensorLike1D, TensorLike2D, TensorLike3D, TensorLike4D, TensorLike5D, TensorLike6D, TypedArray} from '../types'; @@ -469,7 +470,11 @@ function zeros( /** @doc {heading: 'Tensors', subheading: 'Creation'} */ function fill( shape: ShapeMap[R], value: number|string, dtype?: DataType): Tensor { - return ENGINE.runKernelFunc(backend => backend.fill(shape, value, dtype), {}); + const attrs: FillAttrs = {shape, value, dtype}; + + return ENGINE.runKernelFunc( + backend => backend.fill(shape, value, dtype), {}, null, Fill, + attrs as {} as NamedAttrMap); } /** From 57d0497c49d830615e5747648fccae27c3fdc07f Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 3 Jun 2020 10:57:29 -0400 Subject: [PATCH 2/4] lint --- tfjs-core/src/ops/fill.ts | 46 +++++++++++++++++++ tfjs-core/src/ops/ops.ts | 1 + tfjs-core/src/ops/signal_ops.ts | 3 +- tfjs-core/src/ops/tensor_ops.ts | 25 ---------- tfjs-core/src/optimizers/adagrad_optimizer.ts | 2 +- 5 files changed, 50 insertions(+), 27 deletions(-) create mode 100644 tfjs-core/src/ops/fill.ts diff --git a/tfjs-core/src/ops/fill.ts b/tfjs-core/src/ops/fill.ts new file mode 100644 index 00000000000..7edf56516e3 --- /dev/null +++ b/tfjs-core/src/ops/fill.ts @@ -0,0 +1,46 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {Fill, FillAttrs} from '../kernel_names'; +import {NamedAttrMap} from '../kernel_registry'; +import {Tensor} from '../tensor'; +import {DataType, Rank, ShapeMap} from '../types'; + +/** + * Creates a `tf.Tensor` filled with a scalar value. + * + * ```js + * tf.fill([2, 2], 4).print(); + * ``` + * + * @param shape An array of integers defining the output tensor shape. + * @param value The scalar value to fill the tensor with. + * @param dtype The type of an element in the resulting tensor. Defaults to + * 'float'. + */ +/** @doc {heading: 'Tensors', subheading: 'Creation'} */ +function fill( + shape: ShapeMap[R], value: number|string, dtype?: DataType): Tensor { + const attrs: FillAttrs = {shape, value, dtype}; + + return ENGINE.runKernelFunc( + backend => backend.fill(shape, value, dtype), {}, null, Fill, + attrs as {} as NamedAttrMap); +} + +export {fill}; diff --git a/tfjs-core/src/ops/ops.ts b/tfjs-core/src/ops/ops.ts index 37f3e61dd4e..c8ee81a74bd 100644 --- a/tfjs-core/src/ops/ops.ts +++ b/tfjs-core/src/ops/ops.ts @@ -46,6 +46,7 @@ export {divNoNan} from './div_no_nan'; export {dot} from './dot'; export {equal} from './equal'; export {eye} from './eye'; +export {fill} from './fill'; export {greater} from './greater'; export {greaterEqual} from './greater_equal'; export {less} from './less'; diff --git a/tfjs-core/src/ops/signal_ops.ts b/tfjs-core/src/ops/signal_ops.ts index 899bf8ecb59..d6f25da0404 100644 --- a/tfjs-core/src/ops/signal_ops.ts +++ b/tfjs-core/src/ops/signal_ops.ts @@ -20,9 +20,10 @@ import {Tensor, Tensor1D} from '../tensor'; import {mul} from './binary_ops'; import {concat} from './concat'; +import {fill} from './fill'; import {slice} from './slice'; import {rfft} from './spectral_ops'; -import {fill, tensor1d, tensor2d} from './tensor_ops'; +import {tensor1d, tensor2d} from './tensor_ops'; /** * Generate a Hann window. diff --git a/tfjs-core/src/ops/tensor_ops.ts b/tfjs-core/src/ops/tensor_ops.ts index b6171a49b07..d614f51cebb 100644 --- a/tfjs-core/src/ops/tensor_ops.ts +++ b/tfjs-core/src/ops/tensor_ops.ts @@ -17,8 +17,6 @@ import {ENGINE} from '../engine'; import {env} from '../environment'; -import {Fill, FillAttrs} from '../kernel_names'; -import {NamedAttrMap} from '../kernel_registry'; import {Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, Tensor6D, Variable} from '../tensor'; import {convertToTensor, inferShape} from '../tensor_util_env'; import {TensorLike, TensorLike1D, TensorLike2D, TensorLike3D, TensorLike4D, TensorLike5D, TensorLike6D, TypedArray} from '../types'; @@ -455,28 +453,6 @@ function zeros( return ENGINE.makeTensor(values, shape, dtype) as Tensor; } -/** - * Creates a `tf.Tensor` filled with a scalar value. - * - * ```js - * tf.fill([2, 2], 4).print(); - * ``` - * - * @param shape An array of integers defining the output tensor shape. - * @param value The scalar value to fill the tensor with. - * @param dtype The type of an element in the resulting tensor. Defaults to - * 'float'. - */ -/** @doc {heading: 'Tensors', subheading: 'Creation'} */ -function fill( - shape: ShapeMap[R], value: number|string, dtype?: DataType): Tensor { - const attrs: FillAttrs = {shape, value, dtype}; - - return ENGINE.runKernelFunc( - backend => backend.fill(shape, value, dtype), {}, null, Fill, - attrs as {} as NamedAttrMap); -} - /** * Creates a `tf.Tensor` with all elements set to 1 with the same shape as the * given tensor. @@ -589,7 +565,6 @@ function range( } export { - fill, linspace, ones, range, diff --git a/tfjs-core/src/optimizers/adagrad_optimizer.ts b/tfjs-core/src/optimizers/adagrad_optimizer.ts index e2f561e2d2a..f7e79a7a432 100644 --- a/tfjs-core/src/optimizers/adagrad_optimizer.ts +++ b/tfjs-core/src/optimizers/adagrad_optimizer.ts @@ -17,7 +17,7 @@ import {ENGINE} from '../engine'; import {dispose, tidy} from '../globals'; -import {fill} from '../ops/ops'; +import {fill} from '../ops/fill'; import {ConfigDict, registerClass, Serializable, SerializableConstructor} from '../serialization'; import {NamedTensor, NamedVariableMap} from '../tensor_types'; From 9592e6077f0ee69debd9fd03bc8a80495bb40177 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 3 Jun 2020 11:07:10 -0400 Subject: [PATCH 3/4] add cast --- tfjs-backend-wasm/src/kernels/Fill.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tfjs-backend-wasm/src/kernels/Fill.ts b/tfjs-backend-wasm/src/kernels/Fill.ts index 68e17f56032..5a84bbb46d4 100644 --- a/tfjs-backend-wasm/src/kernels/Fill.ts +++ b/tfjs-backend-wasm/src/kernels/Fill.ts @@ -24,7 +24,7 @@ function fill(args: {attrs: FillAttrs, backend: BackendWasm}) { const {attrs: {shape, value, dtype}, backend} = args; const out = backend.makeOutput(shape, dtype); const outVals = backend.typedArrayFromHeap(out); - outVals.fill(value); + outVals.fill(value as number); return out; } From 7551fdc09d343fa0b79c46e1470365b2506ec3f0 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 3 Jun 2020 11:07:48 -0400 Subject: [PATCH 4/4] oops --- tfjs-backend-wasm/src/kernels/all_kernels.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/tfjs-backend-wasm/src/kernels/all_kernels.ts b/tfjs-backend-wasm/src/kernels/all_kernels.ts index 0bf887355d4..a12e73194eb 100644 --- a/tfjs-backend-wasm/src/kernels/all_kernels.ts +++ b/tfjs-backend-wasm/src/kernels/all_kernels.ts @@ -34,6 +34,7 @@ import './CropAndResize'; import './DepthwiseConv2dNative'; import './Div'; import './Exp'; +import './Fill'; import './FloorDiv'; import './FusedBatchNorm'; import './FusedConv2D';