diff --git a/tfjs-backend-wasm/src/kernels/Fill.ts b/tfjs-backend-wasm/src/kernels/Fill.ts new file mode 100644 index 00000000000..5a84bbb46d4 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/Fill.ts @@ -0,0 +1,35 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelFunc, registerKernel} from '@tensorflow/tfjs-core'; +import {Fill, FillAttrs} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; + +function fill(args: {attrs: FillAttrs, backend: BackendWasm}) { + const {attrs: {shape, value, dtype}, backend} = args; + const out = backend.makeOutput(shape, dtype); + const outVals = backend.typedArrayFromHeap(out); + outVals.fill(value as number); + return out; +} + +registerKernel({ + kernelName: Fill, + backendName: 'wasm', + kernelFunc: fill as {} as KernelFunc, +}); diff --git a/tfjs-backend-wasm/src/kernels/all_kernels.ts b/tfjs-backend-wasm/src/kernels/all_kernels.ts index 0bf887355d4..a12e73194eb 100644 --- a/tfjs-backend-wasm/src/kernels/all_kernels.ts +++ b/tfjs-backend-wasm/src/kernels/all_kernels.ts @@ -34,6 +34,7 @@ import './CropAndResize'; import './DepthwiseConv2dNative'; import './Div'; import './Exp'; +import './Fill'; import './FloorDiv'; import './FusedBatchNorm'; import './FusedConv2D'; diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index c9ce90b2320..8a51369521c 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -19,7 +19,7 @@ // Unfortunately just enabling PascalCase per file (tslint:enable: // allow-pascal-case) doesn't work. import {NamedTensorInfoMap, TensorInfo} from './kernel_registry'; -import {PixelData} from './types'; +import {DataType, PixelData} from './types'; export const Add = 'Add'; export type AddInputs = BinaryInputs; @@ -44,6 +44,13 @@ export interface AvgPoolBackpropAttrs { pad: 'valid'|'same'|number; } +export const Fill = 'Fill'; +export interface FillAttrs { + shape: number[]; + value: number|string; + dtype: DataType; +} + export const AvgPool3D = 'AvgPool3D'; export type AvgPool3DInputs = Pick; export interface AvgPool3DAttrs { diff --git a/tfjs-core/src/ops/fill.ts b/tfjs-core/src/ops/fill.ts new file mode 100644 index 00000000000..7edf56516e3 --- /dev/null +++ b/tfjs-core/src/ops/fill.ts @@ -0,0 +1,46 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {Fill, FillAttrs} from '../kernel_names'; +import {NamedAttrMap} from '../kernel_registry'; +import {Tensor} from '../tensor'; +import {DataType, Rank, ShapeMap} from '../types'; + +/** + * Creates a `tf.Tensor` filled with a scalar value. + * + * ```js + * tf.fill([2, 2], 4).print(); + * ``` + * + * @param shape An array of integers defining the output tensor shape. + * @param value The scalar value to fill the tensor with. + * @param dtype The type of an element in the resulting tensor. Defaults to + * 'float'. + */ +/** @doc {heading: 'Tensors', subheading: 'Creation'} */ +function fill( + shape: ShapeMap[R], value: number|string, dtype?: DataType): Tensor { + const attrs: FillAttrs = {shape, value, dtype}; + + return ENGINE.runKernelFunc( + backend => backend.fill(shape, value, dtype), {}, null, Fill, + attrs as {} as NamedAttrMap); +} + +export {fill}; diff --git a/tfjs-core/src/ops/ops.ts b/tfjs-core/src/ops/ops.ts index 8164696dc8e..b46a88bd141 100644 --- a/tfjs-core/src/ops/ops.ts +++ b/tfjs-core/src/ops/ops.ts @@ -47,6 +47,7 @@ export {divNoNan} from './div_no_nan'; export {dot} from './dot'; export {equal} from './equal'; export {eye} from './eye'; +export {fill} from './fill'; export {greater} from './greater'; export {greaterEqual} from './greater_equal'; export {imag} from './imag'; diff --git a/tfjs-core/src/ops/signal_ops.ts b/tfjs-core/src/ops/signal_ops.ts index 899bf8ecb59..d6f25da0404 100644 --- a/tfjs-core/src/ops/signal_ops.ts +++ b/tfjs-core/src/ops/signal_ops.ts @@ -20,9 +20,10 @@ import {Tensor, Tensor1D} from '../tensor'; import {mul} from './binary_ops'; import {concat} from './concat'; +import {fill} from './fill'; import {slice} from './slice'; import {rfft} from './spectral_ops'; -import {fill, tensor1d, tensor2d} from './tensor_ops'; +import {tensor1d, tensor2d} from './tensor_ops'; /** * Generate a Hann window. diff --git a/tfjs-core/src/ops/tensor_ops.ts b/tfjs-core/src/ops/tensor_ops.ts index cb32bfe234c..1255a1ac7e8 100644 --- a/tfjs-core/src/ops/tensor_ops.ts +++ b/tfjs-core/src/ops/tensor_ops.ts @@ -456,24 +456,6 @@ function zeros( return ENGINE.makeTensor(values, shape, dtype) as Tensor; } -/** - * Creates a `tf.Tensor` filled with a scalar value. - * - * ```js - * tf.fill([2, 2], 4).print(); - * ``` - * - * @param shape An array of integers defining the output tensor shape. - * @param value The scalar value to fill the tensor with. - * @param dtype The type of an element in the resulting tensor. Defaults to - * 'float'. - */ -/** @doc {heading: 'Tensors', subheading: 'Creation'} */ -function fill( - shape: ShapeMap[R], value: number|string, dtype?: DataType): Tensor { - return ENGINE.runKernelFunc(backend => backend.fill(shape, value, dtype), {}); -} - /** * Creates a `tf.Tensor` with all elements set to 1 with the same shape as the * given tensor. @@ -586,7 +568,6 @@ function range( } export { - fill, linspace, ones, range, diff --git a/tfjs-core/src/optimizers/adagrad_optimizer.ts b/tfjs-core/src/optimizers/adagrad_optimizer.ts index e2f561e2d2a..f7e79a7a432 100644 --- a/tfjs-core/src/optimizers/adagrad_optimizer.ts +++ b/tfjs-core/src/optimizers/adagrad_optimizer.ts @@ -17,7 +17,7 @@ import {ENGINE} from '../engine'; import {dispose, tidy} from '../globals'; -import {fill} from '../ops/ops'; +import {fill} from '../ops/fill'; import {ConfigDict, registerClass, Serializable, SerializableConstructor} from '../serialization'; import {NamedTensor, NamedVariableMap} from '../tensor_types';