diff --git a/tfjs-backend-cpu/src/backend_cpu.ts b/tfjs-backend-cpu/src/backend_cpu.ts index e0736c021bb..5c4cddec857 100644 --- a/tfjs-backend-cpu/src/backend_cpu.ts +++ b/tfjs-backend-cpu/src/backend_cpu.ts @@ -37,7 +37,7 @@ function mapActivation( } else if (activation === 'relu') { return backend.relu(x); } else if (activation === 'elu') { - return backend.elu(x); + return tf.elu(x); } else if (activation === 'relu6') { return backend.relu6(x); } else if (activation === 'prelu') { @@ -164,7 +164,7 @@ export class MathBackendCPU extends KernelBackend { return tf.buffer(t.shape, t.dtype, decodedData) as TensorBuffer; } - private makeOutput( + makeOutput( values: backend_util.BackendValues, shape: number[], dtype: DataType): T { const dataId = this.write(values, shape, dtype); return engine().makeTensorFromDataId(dataId, shape, dtype, this) as T; @@ -318,7 +318,7 @@ export class MathBackendCPU extends KernelBackend { // TODO(lina128): Use sub directly once softmax is modularized. const a = tf.sub(logits, maxLogit.reshape(expandedShape)); - const b = this.exp(a); + const b = tf.exp(a); const sumExp = this.sum(b, axes).reshape(expandedShape); // TODO(annxingyuan): Call divImpl rather than op as part of softmax @@ -615,17 +615,6 @@ export class MathBackendCPU extends KernelBackend { }); } - logicalNot(x: T): T { - assertNotComplex(x, 'logicalNot'); - - const values = this.readSync(x.dataId) as TypedArray; - const newValues = new Uint8Array(values.length); - for (let i = 0; i < values.length; ++i) { - newValues[i] = values[i] ? 0 : 1; - } - return this.makeOutput(newValues, x.shape, 'bool'); - } - logicalAnd(a: Tensor, b: Tensor): Tensor { assertNotComplex([a, b], 'logicalAnd'); @@ -789,188 +778,6 @@ export class MathBackendCPU extends KernelBackend { }); } - ceil(x: T): T { - assertNotComplex(x, 'ceil'); - - const values = this.readSync(x.dataId) as TypedArray; - const newValues = new Float32Array(values.length); - for (let i = 0; i < values.length; ++i) { - newValues[i] = Math.ceil(values[i]); - } - return this.makeOutput(newValues, x.shape, 'float32'); - } - - floor(x: T): T { - assertNotComplex(x, 'floor'); - - const values = this.readSync(x.dataId) as TypedArray; - const newValues = new Float32Array(values.length); - for (let i = 0; i < values.length; ++i) { - newValues[i] = Math.floor(values[i]); - } - return this.makeOutput(newValues, x.shape, 'float32'); - } - - sign(x: T): T { - assertNotComplex(x, 'x'); - - const values = this.readSync(x.dataId) as TypedArray; - const newValues = new Float32Array(values.length); - for (let i = 0; i < values.length; ++i) { - if (values[i] < 0) { - newValues[i] = -1; - } else if (values[i] > 0) { - newValues[i] = 1; - } else { - newValues[i] = 0; - } - } - return this.makeOutput(newValues, x.shape, 'float32'); - } - - isNaN(x: T): T { - assertNotComplex(x, 'x'); - - const values = this.readSync(x.dataId) as TypedArray; - const newValues = new Uint8Array(values.length); - for (let i = 0; i < values.length; ++i) { - if (Number.isNaN(values[i])) { - newValues[i] = 1; - } - } - return this.makeOutput(newValues, x.shape, 'bool'); - } - - isInf(x: T): T { - assertNotComplex(x, 'x'); - - const values = this.readSync(x.dataId) as TypedArray; - const newValues = new Uint8Array(values.length); - for (let i = 0; i < values.length; ++i) { - if (Math.abs(values[i]) === Infinity) { - newValues[i] = 1; - } - } - return this.makeOutput(newValues, x.shape, 'bool'); - } - - isFinite(x: T): T { - assertNotComplex(x, 'x'); - - const values = this.readSync(x.dataId) as TypedArray; - const newValues = new Uint8Array(values.length); - for (let i = 0; i < values.length; ++i) { - if (Number.isFinite(values[i])) { - newValues[i] = 1; - } - } - return this.makeOutput(newValues, x.shape, 'bool'); - } - - round(x: T): T { - assertNotComplex(x, 'round'); - - const values = this.readSync(x.dataId) as TypedArray; - const newValues = new Float32Array(values.length); - for (let i = 0; i < values.length; ++i) { - // The algorithm is based on banker's rounding. - const base = Math.floor(values[i]); - if (values[i] - base < 0.5) { - newValues[i] = Math.floor(values[i]); - } else if (values[i] - base > 0.5) { - newValues[i] = Math.ceil(values[i]); - } else { - if (base % 2.0 === 0.0) { - newValues[i] = base; - } else { - newValues[i] = base + 1.0; - } - } - } - return this.makeOutput(newValues, x.shape, 'float32'); - } - - exp(x: T): T { - assertNotComplex(x, 'exp'); - - const values = this.readSync(x.dataId) as TypedArray; - const newValues = new Float32Array(values.length); - for (let i = 0; i < values.length; ++i) { - newValues[i] = Math.exp(values[i]); - } - return this.makeOutput(newValues, x.shape, 'float32'); - } - - expm1(x: T): T { - assertNotComplex(x, 'expm1'); - - const values = this.readSync(x.dataId) as TypedArray; - const newValues = new Float32Array(values.length); - for (let i = 0; i < values.length; ++i) { - newValues[i] = Math.expm1(values[i]); - } - return this.makeOutput(newValues, x.shape, 'float32'); - } - - log(x: T): T { - assertNotComplex(x, 'log'); - - const values = this.readSync(x.dataId) as TypedArray; - const newValues = new Float32Array(values.length); - for (let i = 0; i < values.length; ++i) { - const value = values[i]; - newValues[i] = Math.log(value); - } - return this.makeOutput(newValues, x.shape, 'float32'); - } - - log1p(x: T): T { - assertNotComplex(x, 'log1p'); - - const values = this.readSync(x.dataId) as TypedArray; - const newValues = new Float32Array(values.length); - for (let i = 0; i < values.length; ++i) { - const value = values[i]; - newValues[i] = Math.log1p(value); - } - return this.makeOutput(newValues, x.shape, 'float32'); - } - - sqrt(x: T): T { - assertNotComplex(x, 'sqrt'); - - const values = this.readSync(x.dataId) as TypedArray; - const newValues = new Float32Array(values.length); - for (let i = 0; i < values.length; ++i) { - const value = values[i]; - newValues[i] = Math.sqrt(value); - } - return this.makeOutput(newValues, x.shape, 'float32'); - } - - rsqrt(x: T): T { - assertNotComplex(x, 'rsqrt'); - - const values = this.readSync(x.dataId) as TypedArray; - const newValues = new Float32Array(values.length); - for (let i = 0; i < values.length; ++i) { - const value = values[i]; - newValues[i] = 1 / Math.sqrt(value); - } - return this.makeOutput(newValues, x.shape, 'float32'); - } - - reciprocal(x: T): T { - assertNotComplex(x, 'reciprocal'); - - const values = this.readSync(x.dataId) as TypedArray; - const newValues = new Float32Array(values.length); - for (let i = 0; i < values.length; ++i) { - newValues[i] = 1 / values[i]; - } - return this.makeOutput(newValues, x.shape, 'float32'); - } - linear(x: T): T { return x; } @@ -1007,22 +814,6 @@ export class MathBackendCPU extends KernelBackend { (xValue, aValue) => xValue < 0 ? aValue * xValue : xValue) as T; } - elu(x: T): T { - assertNotComplex(x, 'elu'); - - const resultValues = new Float32Array(x.size); - const values = this.readSync(x.dataId) as TypedArray; - for (let i = 0; i < values.length; ++i) { - const v = values[i]; - if (v >= 0) { - resultValues[i] = v; - } else { - resultValues[i] = (Math.exp(v) - 1); - } - } - return this.makeOutput(resultValues, x.shape, 'float32'); - } - eluDer(dy: T, y: T): T { assertNotComplex([dy, y], 'eluDer'); @@ -1040,165 +831,6 @@ export class MathBackendCPU extends KernelBackend { return this.makeOutput(resultValues, y.shape, 'float32'); } - selu(x: T): T { - assertNotComplex(x, 'selu'); - - // Stable and Attracting Fixed Point (0, 1) for Normalized Weights. - // see: https://arxiv.org/abs/1706.02515 - const scaleAlpha = backend_util.SELU_SCALEALPHA; - const scale = backend_util.SELU_SCALE; - - const resultValues = new Float32Array(x.size); - const values = this.readSync(x.dataId) as TypedArray; - for (let i = 0; i < values.length; ++i) { - const v = values[i]; - if (v >= 0) { - resultValues[i] = scale * v; - } else { - resultValues[i] = scaleAlpha * (Math.exp(v) - 1); - } - } - return this.makeOutput(resultValues, x.shape, 'float32'); - } - - clip(x: T, min: number, max: number): T { - assertNotComplex(x, 'clip'); - - const resultValues = new Float32Array(x.size); - const values = this.readSync(x.dataId) as TypedArray; - for (let i = 0; i < values.length; ++i) { - const v = values[i]; - resultValues[i] = v > max ? max : (v < min ? min : v); - } - return this.makeOutput(resultValues, x.shape, x.dtype); - } - - abs(x: T): T { - const resultValues = new Float32Array(x.size); - const values = this.readSync(x.dataId) as TypedArray; - for (let i = 0; i < values.length; ++i) { - resultValues[i] = Math.abs(values[i]); - } - - return this.makeOutput(resultValues, x.shape, 'float32'); - } - - complexAbs(x: T): T { - const resultValues = new Float32Array(x.size); - const values = this.readSync(x.dataId) as TypedArray; - - for (let i = 0; i < x.size; ++i) { - const real = values[i * 2]; - const imag = values[i * 2 + 1]; - resultValues[i] = Math.hypot(real, imag); - } - return this.makeOutput(resultValues, x.shape, 'float32'); - } - - sigmoid(x: T): T { - assertNotComplex(x, 'sigmoid'); - - const resultValues = new Float32Array(x.size); - const values = this.readSync(x.dataId) as TypedArray; - for (let i = 0; i < values.length; ++i) { - resultValues[i] = 1 / (1 + Math.exp(-values[i])); - } - return this.makeOutput(resultValues, x.shape, 'float32'); - } - - softplus(x: T): T { - assertNotComplex(x, 'softplus'); - - // mirrors the implementation of tf.nn.softplus: https://goo.gl/vkcvwX - - // epsilon is the difference between 1.0 and the next representable float. - // For a single precision 32 bit float this should be 2^-23, see: - // https://math.byu.edu/~schow/work/IEEEFloatingPoint.htm - const epsilon = 1.1920928955078125e-7; - const threshold = Math.log(epsilon) + 2.0; - - const resultValues = new Float32Array(x.size); - const values = this.readSync(x.dataId) as TypedArray; - - for (let i = 0; i < values.length; ++i) { - // Value above which exp(x) may overflow, but softplus(x) == x - // is within machine epsilon. - const tooLarge = values[i] > -threshold; - - // Value below which exp(x) may underflow, but softplus(x) == exp(x) - // is within machine epsilon. - const tooSmall = values[i] < threshold; - - const expX = Math.exp(values[i]); - let result; - - if (tooSmall) { - result = expX; - } else if (tooLarge) { - result = values[i]; - } else { - result = Math.log(1.0 + expX); - } - resultValues[i] = result; - } - return this.makeOutput(resultValues, x.shape, 'float32'); - } - - sin(x: T): T { - assertNotComplex(x, 'sin'); - - const resultValues = new Float32Array(x.size); - const values = this.readSync(x.dataId) as TypedArray; - for (let i = 0; i < values.length; ++i) { - resultValues[i] = Math.sin(values[i]); - } - return this.makeOutput(resultValues, x.shape, 'float32'); - } - - tan(x: T): T { - assertNotComplex(x, 'tan'); - - const resultValues = new Float32Array(x.size); - const values = this.readSync(x.dataId) as TypedArray; - for (let i = 0; i < values.length; ++i) { - resultValues[i] = Math.tan(values[i]); - } - return this.makeOutput(resultValues, x.shape, 'float32'); - } - - asin(x: T): T { - assertNotComplex(x, 'asin'); - - const resultValues = new Float32Array(x.size); - const values = this.readSync(x.dataId) as TypedArray; - for (let i = 0; i < values.length; ++i) { - resultValues[i] = Math.asin(values[i]); - } - return this.makeOutput(resultValues, x.shape, 'float32'); - } - - acos(x: T): T { - assertNotComplex(x, 'acos'); - - const resultValues = new Float32Array(x.size); - const values = this.readSync(x.dataId) as TypedArray; - for (let i = 0; i < values.length; ++i) { - resultValues[i] = Math.acos(values[i]); - } - return this.makeOutput(resultValues, x.shape, 'float32'); - } - - atan(x: T): T { - assertNotComplex(x, 'atan'); - - const resultValues = new Float32Array(x.size); - const values = this.readSync(x.dataId) as TypedArray; - for (let i = 0; i < values.length; ++i) { - resultValues[i] = Math.atan(values[i]); - } - return this.makeOutput(resultValues, x.shape, 'float32'); - } - atan2(a: T, b: T): T { assertNotComplex([a, b], 'atan2'); @@ -1207,111 +839,6 @@ export class MathBackendCPU extends KernelBackend { T; } - sinh(x: T): T { - assertNotComplex(x, 'sinh'); - - const resultValues = new Float32Array(x.size); - const values = this.readSync(x.dataId) as TypedArray; - for (let i = 0; i < values.length; ++i) { - resultValues[i] = Math.sinh(values[i]); - } - return this.makeOutput(resultValues, x.shape, 'float32'); - } - - cosh(x: T): T { - assertNotComplex(x, 'cosh'); - - const resultValues = new Float32Array(x.size); - const values = this.readSync(x.dataId) as TypedArray; - for (let i = 0; i < values.length; ++i) { - resultValues[i] = Math.cosh(values[i]); - } - return this.makeOutput(resultValues, x.shape, 'float32'); - } - - tanh(x: T): T { - assertNotComplex(x, 'tanh'); - - const resultValues = new Float32Array(x.size); - const values = this.readSync(x.dataId) as TypedArray; - for (let i = 0; i < values.length; ++i) { - resultValues[i] = util.tanh(values[i]); - } - return this.makeOutput(resultValues, x.shape, 'float32'); - } - - asinh(x: T): T { - assertNotComplex(x, 'asinh'); - - const resultValues = new Float32Array(x.size); - const values = this.readSync(x.dataId) as TypedArray; - for (let i = 0; i < values.length; ++i) { - resultValues[i] = Math.asinh(values[i]); - } - return this.makeOutput(resultValues, x.shape, 'float32'); - } - - acosh(x: T): T { - assertNotComplex(x, 'acosh'); - - const resultValues = new Float32Array(x.size); - const values = this.readSync(x.dataId) as TypedArray; - for (let i = 0; i < values.length; ++i) { - resultValues[i] = Math.acosh(values[i]); - } - return this.makeOutput(resultValues, x.shape, 'float32'); - } - - atanh(x: T): T { - assertNotComplex(x, 'atanh'); - - const resultValues = new Float32Array(x.size); - const values = this.readSync(x.dataId) as TypedArray; - for (let i = 0; i < values.length; ++i) { - resultValues[i] = Math.atanh(values[i]); - } - return this.makeOutput(resultValues, x.shape, 'float32'); - } - - erf(x: T): T { - assertNotComplex(x, 'erf'); - - const resultValues = new Float32Array(x.size); - const values = this.readSync(x.dataId) as TypedArray; - const p = backend_util.ERF_P; - const a1 = backend_util.ERF_A1; - const a2 = backend_util.ERF_A2; - const a3 = backend_util.ERF_A3; - const a4 = backend_util.ERF_A4; - const a5 = backend_util.ERF_A5; - for (let i = 0; i < values.length; ++i) { - const sign = Math.sign(values[i]); - const v = Math.abs(values[i]); - const t = 1.0 / (1.0 + p * v); - resultValues[i] = sign * - (1.0 - - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * - Math.exp(-v * v)); - } - return this.makeOutput(resultValues, x.shape, 'float32'); - } - - step(x: T, alpha = 0): T { - assertNotComplex(x, 'step'); - - const resultValues = new Float32Array(x.size); - const values = this.readSync(x.dataId) as TypedArray; - for (let i = 0; i < values.length; ++i) { - const value = values[i]; - if (isNaN(value)) { - resultValues[i] = NaN; - } else { - resultValues[i] = value > 0 ? 1 : alpha; - } - } - return this.makeOutput(resultValues, x.shape, 'float32'); - } - fusedConv2d( {input, filter, convInfo, bias, activation, preluActivationWeights}: backend_util.FusedConv2DConfig): Tensor4D { diff --git a/tfjs-backend-cpu/src/kernels/Abs.ts b/tfjs-backend-cpu/src/kernels/Abs.ts new file mode 100644 index 00000000000..c8cc06ecb93 --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Abs.ts @@ -0,0 +1,53 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Abs, AbsInputs, KernelConfig, KernelFunc, TypedArray, util} from '@tensorflow/tfjs-core'; + +import {MathBackendCPU} from '../backend_cpu'; + +export const absKernelFunc = + (args: {inputs: AbsInputs, backend: MathBackendCPU}) => { + const {x} = args.inputs; + const cpuBackend = args.backend; + const resultValues = new Float32Array(util.sizeFromShape(x.shape)); + if (x.dtype !== 'complex64') { + const values = cpuBackend.data.get(x.dataId).values as TypedArray; + for (let i = 0; i < values.length; ++i) { + resultValues[i] = Math.abs(values[i]); + } + } else { + const complexVals = cpuBackend.data.get(x.dataId); + const real = complexVals.complexTensorInfos.real; + const imag = complexVals.complexTensorInfos.imag; + const realVals = + cpuBackend.data.get(real.dataId).values as Float32Array; + const imagVals = + cpuBackend.data.get(imag.dataId).values as Float32Array; + for (let i = 0; i < realVals.length; i++) { + const real = realVals[i]; + const imag = imagVals[i]; + resultValues[i] = Math.hypot(real, imag); + } + } + return cpuBackend.makeOutput(resultValues, x.shape, 'float32'); + }; + +export const absConfig: KernelConfig = { + kernelName: Abs, + backendName: 'cpu', + kernelFunc: absKernelFunc as {} as KernelFunc, +}; diff --git a/tfjs-backend-cpu/src/kernels/Acos.ts b/tfjs-backend-cpu/src/kernels/Acos.ts new file mode 100644 index 00000000000..8f64f517117 --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Acos.ts @@ -0,0 +1,28 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Acos, KernelConfig} from '@tensorflow/tfjs-core'; + +import {unaryKernelFunc} from '../utils/unary_utils'; + +export const acosKernelFunc = unaryKernelFunc(Acos, (xi) => Math.acos(xi)); + +export const acosConfig: KernelConfig = { + kernelName: Acos, + backendName: 'cpu', + kernelFunc: acosKernelFunc, +}; diff --git a/tfjs-backend-cpu/src/kernels/Acosh.ts b/tfjs-backend-cpu/src/kernels/Acosh.ts new file mode 100644 index 00000000000..e3371a18aa1 --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Acosh.ts @@ -0,0 +1,28 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Acosh, KernelConfig} from '@tensorflow/tfjs-core'; + +import {unaryKernelFunc} from '../utils/unary_utils'; + +export const acoshKernelFunc = unaryKernelFunc(Acosh, (xi) => Math.acosh(xi)); + +export const acoshConfig: KernelConfig = { + kernelName: Acosh, + backendName: 'cpu', + kernelFunc: acoshKernelFunc, +}; diff --git a/tfjs-backend-cpu/src/kernels/Asin.ts b/tfjs-backend-cpu/src/kernels/Asin.ts new file mode 100644 index 00000000000..6db150bf19f --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Asin.ts @@ -0,0 +1,28 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Asin, KernelConfig} from '@tensorflow/tfjs-core'; + +import {unaryKernelFunc} from '../utils/unary_utils'; + +export const asinKernelFunc = unaryKernelFunc(Asin, (xi) => Math.asin(xi)); + +export const asinConfig: KernelConfig = { + kernelName: Asin, + backendName: 'cpu', + kernelFunc: asinKernelFunc, +}; diff --git a/tfjs-backend-cpu/src/kernels/Asinh.ts b/tfjs-backend-cpu/src/kernels/Asinh.ts new file mode 100644 index 00000000000..253d1256bcf --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Asinh.ts @@ -0,0 +1,28 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Asinh, KernelConfig} from '@tensorflow/tfjs-core'; + +import {unaryKernelFunc} from '../utils/unary_utils'; + +export const asinhKernelFunc = unaryKernelFunc(Asinh, (xi) => Math.asinh(xi)); + +export const asinhConfig: KernelConfig = { + kernelName: Asinh, + backendName: 'cpu', + kernelFunc: asinhKernelFunc, +}; diff --git a/tfjs-backend-cpu/src/kernels/Atan.ts b/tfjs-backend-cpu/src/kernels/Atan.ts new file mode 100644 index 00000000000..37ee01a8fe7 --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Atan.ts @@ -0,0 +1,28 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Atan, KernelConfig} from '@tensorflow/tfjs-core'; + +import {unaryKernelFunc} from '../utils/unary_utils'; + +export const atanKernelFunc = unaryKernelFunc(Atan, (xi) => Math.atan(xi)); + +export const atanConfig: KernelConfig = { + kernelName: Atan, + backendName: 'cpu', + kernelFunc: atanKernelFunc, +}; diff --git a/tfjs-backend-cpu/src/kernels/Atanh.ts b/tfjs-backend-cpu/src/kernels/Atanh.ts new file mode 100644 index 00000000000..f0cd938852c --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Atanh.ts @@ -0,0 +1,28 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Atanh, KernelConfig} from '@tensorflow/tfjs-core'; + +import {unaryKernelFunc} from '../utils/unary_utils'; + +export const atanhKernelFunc = unaryKernelFunc(Atanh, (xi) => Math.atanh(xi)); + +export const atanhConfig: KernelConfig = { + kernelName: Atanh, + backendName: 'cpu', + kernelFunc: atanhKernelFunc, +}; diff --git a/tfjs-backend-cpu/src/kernels/Ceil.ts b/tfjs-backend-cpu/src/kernels/Ceil.ts new file mode 100644 index 00000000000..a714d6fc622 --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Ceil.ts @@ -0,0 +1,28 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Ceil, KernelConfig} from '@tensorflow/tfjs-core'; + +import {unaryKernelFunc} from '../utils/unary_utils'; + +export const ceilKernelFunc = unaryKernelFunc(Ceil, (xi) => Math.ceil(xi)); + +export const ceilConfig: KernelConfig = { + kernelName: Ceil, + backendName: 'cpu', + kernelFunc: ceilKernelFunc, +}; diff --git a/tfjs-backend-cpu/src/kernels/Clip.ts b/tfjs-backend-cpu/src/kernels/Clip.ts new file mode 100644 index 00000000000..ec97f84cff2 --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Clip.ts @@ -0,0 +1,34 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ClipByValue, ClipByValueAttrs, KernelConfig} from '@tensorflow/tfjs-core'; + +import {unaryKernelFunc} from '../utils/unary_utils'; + +export const clipKernelFunc = unaryKernelFunc(ClipByValue, (xi, attrs) => { + const clipAttrs = attrs as {} as ClipByValueAttrs; + if (xi > clipAttrs.clipValueMax) { + return clipAttrs.clipValueMax; + } + return xi < clipAttrs.clipValueMin ? clipAttrs.clipValueMin : xi; +}); + +export const clipConfig: KernelConfig = { + kernelName: ClipByValue, + backendName: 'cpu', + kernelFunc: clipKernelFunc, +}; diff --git a/tfjs-backend-cpu/src/kernels/Cos.ts b/tfjs-backend-cpu/src/kernels/Cos.ts index 3e57b5b5d46..2c12eef242c 100644 --- a/tfjs-backend-cpu/src/kernels/Cos.ts +++ b/tfjs-backend-cpu/src/kernels/Cos.ts @@ -15,26 +15,14 @@ * ============================================================================= */ -import {Cos, CosInputs, KernelConfig, TypedArray, util} from '@tensorflow/tfjs-core'; +import {Cos, KernelConfig} from '@tensorflow/tfjs-core'; -import {MathBackendCPU} from '../backend_cpu'; -import {assertNotComplex} from '../cpu_util'; +import {unaryKernelFunc} from '../utils/unary_utils'; + +export const cosKernelFunc = unaryKernelFunc(Cos, (xi) => Math.cos(xi)); export const cosConfig: KernelConfig = { kernelName: Cos, backendName: 'cpu', - kernelFunc: ({inputs, backend}) => { - const {x} = inputs as CosInputs; - const cpuBackend = backend as MathBackendCPU; - assertNotComplex(x, 'cos'); - - const values = cpuBackend.data.get(x.dataId).values as TypedArray; - const xSize = util.sizeFromShape(x.shape); - const newValues = new Float32Array(xSize); - for (let i = 0; i < xSize; ++i) { - newValues[i] = Math.cos(values[i]); - } - const dataId = cpuBackend.write(newValues, x.shape, x.dtype); - return {dataId, shape: x.shape, dtype: x.dtype}; - } + kernelFunc: cosKernelFunc, }; diff --git a/tfjs-backend-cpu/src/kernels/Cosh.ts b/tfjs-backend-cpu/src/kernels/Cosh.ts new file mode 100644 index 00000000000..572dcbb9ee1 --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Cosh.ts @@ -0,0 +1,28 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Cosh, KernelConfig} from '@tensorflow/tfjs-core'; + +import {unaryKernelFunc} from '../utils/unary_utils'; + +export const coshKernelFunc = unaryKernelFunc(Cosh, (xi) => Math.cosh(xi)); + +export const coshConfig: KernelConfig = { + kernelName: Cosh, + backendName: 'cpu', + kernelFunc: coshKernelFunc, +}; diff --git a/tfjs-backend-cpu/src/kernels/Elu.ts b/tfjs-backend-cpu/src/kernels/Elu.ts new file mode 100644 index 00000000000..f93fcd107f3 --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Elu.ts @@ -0,0 +1,29 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Elu, KernelConfig} from '@tensorflow/tfjs-core'; + +import {unaryKernelFunc} from '../utils/unary_utils'; + +export const eluKernelFunc = + unaryKernelFunc(Elu, (xi) => xi >= 0 ? xi : (Math.exp(xi) - 1)); + +export const eluConfig: KernelConfig = { + kernelName: Elu, + backendName: 'cpu', + kernelFunc: eluKernelFunc, +}; diff --git a/tfjs-backend-cpu/src/kernels/Erf.ts b/tfjs-backend-cpu/src/kernels/Erf.ts new file mode 100644 index 00000000000..c3bda037f77 --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Erf.ts @@ -0,0 +1,46 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {backend_util, Erf, KernelConfig} from '@tensorflow/tfjs-core'; + +import {unaryKernelFunc} from '../utils/unary_utils'; + +const p = backend_util.ERF_P; +const a1 = backend_util.ERF_A1; +const a2 = backend_util.ERF_A2; +const a3 = backend_util.ERF_A3; +const a4 = backend_util.ERF_A4; +const a5 = backend_util.ERF_A5; + +export const erfKernelFunc = unaryKernelFunc( + Erf, + (xi) => { + const sign = Math.sign(xi); + const v = Math.abs(xi); + const t = 1.0 / (1.0 + p * v); + return sign * + (1.0 - + (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * + Math.exp(-v * v)); + }, +); + +export const erfConfig: KernelConfig = { + kernelName: Erf, + backendName: 'cpu', + kernelFunc: erfKernelFunc, +}; diff --git a/tfjs-backend-cpu/src/kernels/Exp.ts b/tfjs-backend-cpu/src/kernels/Exp.ts new file mode 100644 index 00000000000..c8e0acfd327 --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Exp.ts @@ -0,0 +1,28 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Exp, KernelConfig} from '@tensorflow/tfjs-core'; + +import {unaryKernelFunc} from '../utils/unary_utils'; + +export const expKernelFunc = unaryKernelFunc(Exp, (xi) => Math.exp(xi)); + +export const expConfig: KernelConfig = { + kernelName: Exp, + backendName: 'cpu', + kernelFunc: expKernelFunc, +}; diff --git a/tfjs-backend-cpu/src/kernels/Expm1.ts b/tfjs-backend-cpu/src/kernels/Expm1.ts new file mode 100644 index 00000000000..0ee7ac706b7 --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Expm1.ts @@ -0,0 +1,28 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Expm1, KernelConfig} from '@tensorflow/tfjs-core'; + +import {unaryKernelFunc} from '../utils/unary_utils'; + +export const expm1KernelFunc = unaryKernelFunc(Expm1, (xi) => Math.expm1(xi)); + +export const expm1Config: KernelConfig = { + kernelName: Expm1, + backendName: 'cpu', + kernelFunc: expm1KernelFunc, +}; diff --git a/tfjs-backend-cpu/src/kernels/Floor.ts b/tfjs-backend-cpu/src/kernels/Floor.ts new file mode 100644 index 00000000000..6613239cbb5 --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Floor.ts @@ -0,0 +1,28 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Floor, KernelConfig} from '@tensorflow/tfjs-core'; + +import {unaryKernelFunc} from '../utils/unary_utils'; + +export const floorKernelFunc = unaryKernelFunc(Floor, (xi) => Math.floor(xi)); + +export const floorConfig: KernelConfig = { + kernelName: Floor, + backendName: 'cpu', + kernelFunc: floorKernelFunc, +}; diff --git a/tfjs-backend-cpu/src/kernels/IsFinite.ts b/tfjs-backend-cpu/src/kernels/IsFinite.ts new file mode 100644 index 00000000000..7fd5ef7ecfd --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/IsFinite.ts @@ -0,0 +1,29 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {IsFinite, KernelConfig} from '@tensorflow/tfjs-core'; + +import {unaryKernelFunc} from '../utils/unary_utils'; + +export const isFiniteKernelFunc = + unaryKernelFunc(IsFinite, (xi) => Number.isFinite(xi) ? 1 : 0, 'bool'); + +export const isFiniteConfig: KernelConfig = { + kernelName: IsFinite, + backendName: 'cpu', + kernelFunc: isFiniteKernelFunc, +}; diff --git a/tfjs-backend-cpu/src/kernels/IsInf.ts b/tfjs-backend-cpu/src/kernels/IsInf.ts new file mode 100644 index 00000000000..27d86afd9a2 --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/IsInf.ts @@ -0,0 +1,29 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {IsInf, KernelConfig} from '@tensorflow/tfjs-core'; + +import {unaryKernelFunc} from '../utils/unary_utils'; + +export const isInfKernelFunc = + unaryKernelFunc(IsInf, (xi) => Math.abs(xi) === Infinity ? 1 : 0, 'bool'); + +export const isInfConfig: KernelConfig = { + kernelName: IsInf, + backendName: 'cpu', + kernelFunc: isInfKernelFunc, +}; diff --git a/tfjs-backend-cpu/src/kernels/IsNaN.ts b/tfjs-backend-cpu/src/kernels/IsNaN.ts new file mode 100644 index 00000000000..44ef8d19c8c --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/IsNaN.ts @@ -0,0 +1,29 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {IsNan, KernelConfig} from '@tensorflow/tfjs-core'; + +import {unaryKernelFunc} from '../utils/unary_utils'; + +export const isNaNKernelFunc = + unaryKernelFunc(IsNan, (xi) => Number.isNaN(xi) ? 1 : 0, 'bool'); + +export const isNaNConfig: KernelConfig = { + kernelName: IsNan, + backendName: 'cpu', + kernelFunc: isNaNKernelFunc, +}; diff --git a/tfjs-backend-cpu/src/kernels/Log.ts b/tfjs-backend-cpu/src/kernels/Log.ts new file mode 100644 index 00000000000..a1ced04d4a4 --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Log.ts @@ -0,0 +1,28 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelConfig, Log} from '@tensorflow/tfjs-core'; + +import {unaryKernelFunc} from '../utils/unary_utils'; + +export const logKernelFunc = unaryKernelFunc(Log, (xi) => Math.log(xi)); + +export const logConfig: KernelConfig = { + kernelName: Log, + backendName: 'cpu', + kernelFunc: logKernelFunc, +}; diff --git a/tfjs-backend-cpu/src/kernels/Log1p.ts b/tfjs-backend-cpu/src/kernels/Log1p.ts new file mode 100644 index 00000000000..b5549642a26 --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Log1p.ts @@ -0,0 +1,28 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelConfig, Log1p} from '@tensorflow/tfjs-core'; + +import {unaryKernelFunc} from '../utils/unary_utils'; + +export const log1pKernelFunc = unaryKernelFunc(Log1p, (xi) => Math.log1p(xi)); + +export const log1pConfig: KernelConfig = { + kernelName: Log1p, + backendName: 'cpu', + kernelFunc: log1pKernelFunc, +}; diff --git a/tfjs-backend-cpu/src/kernels/LogicalNot.ts b/tfjs-backend-cpu/src/kernels/LogicalNot.ts new file mode 100644 index 00000000000..0ef72834323 --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/LogicalNot.ts @@ -0,0 +1,29 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelConfig, LogicalNot} from '@tensorflow/tfjs-core'; + +import {unaryKernelFunc} from '../utils/unary_utils'; + +export const logicalNotKernelFunc = + unaryKernelFunc(LogicalNot, (xi) => xi ? 0 : 1, 'bool'); + +export const logicalNotConfig: KernelConfig = { + kernelName: LogicalNot, + backendName: 'cpu', + kernelFunc: logicalNotKernelFunc, +}; diff --git a/tfjs-backend-cpu/src/kernels/Reciprocal.ts b/tfjs-backend-cpu/src/kernels/Reciprocal.ts new file mode 100644 index 00000000000..f33fd23d045 --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Reciprocal.ts @@ -0,0 +1,28 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelConfig, Reciprocal} from '@tensorflow/tfjs-core'; + +import {unaryKernelFunc} from '../utils/unary_utils'; + +export const reciprocalKernelFunc = unaryKernelFunc(Reciprocal, (xi) => 1 / xi); + +export const reciprocalConfig: KernelConfig = { + kernelName: Reciprocal, + backendName: 'cpu', + kernelFunc: reciprocalKernelFunc, +}; diff --git a/tfjs-backend-cpu/src/kernels/Round.ts b/tfjs-backend-cpu/src/kernels/Round.ts new file mode 100644 index 00000000000..ae93e9644ef --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Round.ts @@ -0,0 +1,42 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelConfig, Round} from '@tensorflow/tfjs-core'; + +import {unaryKernelFunc} from '../utils/unary_utils'; + +export const roundKernelFunc = unaryKernelFunc(Round, (xi) => { + // The algorithm is based on banker's rounding. + const base = Math.floor(xi); + if (xi - base < 0.5) { + return Math.floor(xi); + } else if (xi - base > 0.5) { + return Math.ceil(xi); + } else { + if (base % 2.0 === 0.0) { + return base; + } else { + return base + 1.0; + } + } +}); + +export const roundConfig: KernelConfig = { + kernelName: Round, + backendName: 'cpu', + kernelFunc: roundKernelFunc, +}; diff --git a/tfjs-backend-cpu/src/kernels/Rsqrt.ts b/tfjs-backend-cpu/src/kernels/Rsqrt.ts new file mode 100644 index 00000000000..6a54ff4fed5 --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Rsqrt.ts @@ -0,0 +1,29 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelConfig, Rsqrt} from '@tensorflow/tfjs-core'; + +import {unaryKernelFunc} from '../utils/unary_utils'; + +export const rsqrtKernelFunc = + unaryKernelFunc(Rsqrt, (xi) => 1 / Math.sqrt(xi)); + +export const rsqrtConfig: KernelConfig = { + kernelName: Rsqrt, + backendName: 'cpu', + kernelFunc: rsqrtKernelFunc, +}; diff --git a/tfjs-backend-cpu/src/kernels/Selu.ts b/tfjs-backend-cpu/src/kernels/Selu.ts new file mode 100644 index 00000000000..38133ca6510 --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Selu.ts @@ -0,0 +1,37 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {backend_util, KernelConfig, Selu} from '@tensorflow/tfjs-core'; + +import {unaryKernelFunc} from '../utils/unary_utils'; + +const scaleAlpha = backend_util.SELU_SCALEALPHA; +const scale = backend_util.SELU_SCALE; + +export const seluKernelFunc = unaryKernelFunc(Selu, (xi) => { + if (xi >= 0) { + return scale * xi; + } else { + return scaleAlpha * (Math.exp(xi) - 1); + } +}); + +export const seluConfig: KernelConfig = { + kernelName: Selu, + backendName: 'cpu', + kernelFunc: seluKernelFunc, +}; diff --git a/tfjs-backend-cpu/src/kernels/Sigmoid.ts b/tfjs-backend-cpu/src/kernels/Sigmoid.ts new file mode 100644 index 00000000000..1d8e0a03eed --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Sigmoid.ts @@ -0,0 +1,29 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelConfig, Sigmoid} from '@tensorflow/tfjs-core'; + +import {unaryKernelFunc} from '../utils/unary_utils'; + +export const sigmoidKernelFunc = + unaryKernelFunc(Sigmoid, (xi) => 1 / (1 + Math.exp(-xi))); + +export const sigmoidConfig: KernelConfig = { + kernelName: Sigmoid, + backendName: 'cpu', + kernelFunc: sigmoidKernelFunc, +}; diff --git a/tfjs-backend-cpu/src/kernels/Sign.ts b/tfjs-backend-cpu/src/kernels/Sign.ts new file mode 100644 index 00000000000..0eec5243291 --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Sign.ts @@ -0,0 +1,36 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelConfig, Sign} from '@tensorflow/tfjs-core'; + +import {unaryKernelFunc} from '../utils/unary_utils'; + +export const signKernelFunc = unaryKernelFunc(Sign, (xi) => { + if (xi < 0) { + return -1; + } else if (xi > 0) { + return 1; + } else { + return 0; + } +}); + +export const signConfig: KernelConfig = { + kernelName: Sign, + backendName: 'cpu', + kernelFunc: signKernelFunc, +}; diff --git a/tfjs-backend-cpu/src/kernels/Sin.ts b/tfjs-backend-cpu/src/kernels/Sin.ts new file mode 100644 index 00000000000..fd07bfaf109 --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Sin.ts @@ -0,0 +1,28 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelConfig, Sin} from '@tensorflow/tfjs-core'; + +import {unaryKernelFunc} from '../utils/unary_utils'; + +export const sinKernelFunc = unaryKernelFunc(Sin, (xi) => Math.sin(xi)); + +export const sinConfig: KernelConfig = { + kernelName: Sin, + backendName: 'cpu', + kernelFunc: sinKernelFunc, +}; diff --git a/tfjs-backend-cpu/src/kernels/Sinh.ts b/tfjs-backend-cpu/src/kernels/Sinh.ts new file mode 100644 index 00000000000..e70a85716ed --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Sinh.ts @@ -0,0 +1,28 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelConfig, Sinh} from '@tensorflow/tfjs-core'; + +import {unaryKernelFunc} from '../utils/unary_utils'; + +export const sinhKernelFunc = unaryKernelFunc(Sinh, (xi) => Math.sinh(xi)); + +export const sinhConfig: KernelConfig = { + kernelName: Sinh, + backendName: 'cpu', + kernelFunc: sinhKernelFunc, +}; diff --git a/tfjs-backend-cpu/src/kernels/Softplus.ts b/tfjs-backend-cpu/src/kernels/Softplus.ts new file mode 100644 index 00000000000..bcc94592192 --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Softplus.ts @@ -0,0 +1,56 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelConfig, Softplus} from '@tensorflow/tfjs-core'; + +import {unaryKernelFunc} from '../utils/unary_utils'; + +// mirrors the implementation of tf.nn.softplus: https://goo.gl/vkcvwX + +// epsilon is the difference between 1.0 and the next representable float. +// For a single precision 32 bit float this should be 2^-23, see: +// https://math.byu.edu/~schow/work/IEEEFloatingPoint.htm +const epsilon = 1.1920928955078125e-7; +const threshold = Math.log(epsilon) + 2.0; + +export const softplusKernelFunc = unaryKernelFunc(Softplus, (xi) => { + // Value above which exp(x) may overflow, but softplus(x) == x + // is within machine epsilon. + const tooLarge = xi > -threshold; + + // Value below which exp(x) may underflow, but softplus(x) == exp(x) + // is within machine epsilon. + const tooSmall = xi < threshold; + + const expX = Math.exp(xi); + let result; + + if (tooSmall) { + result = expX; + } else if (tooLarge) { + result = xi; + } else { + result = Math.log(1.0 + expX); + } + return result; +}); + +export const softplusConfig: KernelConfig = { + kernelName: Softplus, + backendName: 'cpu', + kernelFunc: softplusKernelFunc, +}; diff --git a/tfjs-backend-cpu/src/kernels/Sqrt.ts b/tfjs-backend-cpu/src/kernels/Sqrt.ts new file mode 100644 index 00000000000..f6d50df1bca --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Sqrt.ts @@ -0,0 +1,28 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelConfig, Sqrt} from '@tensorflow/tfjs-core'; + +import {unaryKernelFunc} from '../utils/unary_utils'; + +export const sqrtKernelFunc = unaryKernelFunc(Sqrt, (xi) => Math.sqrt(xi)); + +export const sqrtConfig: KernelConfig = { + kernelName: Sqrt, + backendName: 'cpu', + kernelFunc: sqrtKernelFunc, +}; diff --git a/tfjs-backend-cpu/src/kernels/Step.ts b/tfjs-backend-cpu/src/kernels/Step.ts new file mode 100644 index 00000000000..d4b1a79f23d --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Step.ts @@ -0,0 +1,35 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelConfig, Step, StepAttrs} from '@tensorflow/tfjs-core'; + +import {unaryKernelFunc} from '../utils/unary_utils'; + +export const stepKernelFunc = unaryKernelFunc(Step, (xi, attrs) => { + const stepAttrs = attrs as {} as StepAttrs; + if (isNaN(xi)) { + return NaN; + } else { + return xi > 0 ? 1 : stepAttrs.alpha; + } +}); + +export const stepConfig: KernelConfig = { + kernelName: Step, + backendName: 'cpu', + kernelFunc: stepKernelFunc, +}; diff --git a/tfjs-backend-cpu/src/kernels/Tan.ts b/tfjs-backend-cpu/src/kernels/Tan.ts new file mode 100644 index 00000000000..35a69ff4b25 --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Tan.ts @@ -0,0 +1,28 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelConfig, Tan} from '@tensorflow/tfjs-core'; + +import {unaryKernelFunc} from '../utils/unary_utils'; + +export const tanKernelFunc = unaryKernelFunc(Tan, (xi) => Math.tan(xi)); + +export const tanConfig: KernelConfig = { + kernelName: Tan, + backendName: 'cpu', + kernelFunc: tanKernelFunc, +}; diff --git a/tfjs-backend-cpu/src/kernels/Tanh.ts b/tfjs-backend-cpu/src/kernels/Tanh.ts new file mode 100644 index 00000000000..3311d27b90d --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Tanh.ts @@ -0,0 +1,28 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelConfig, Tanh} from '@tensorflow/tfjs-core'; + +import {unaryKernelFunc} from '../utils/unary_utils'; + +export const tanhKernelFunc = unaryKernelFunc(Tanh, (xi) => Math.tanh(xi)); + +export const tanhConfig: KernelConfig = { + kernelName: Tanh, + backendName: 'cpu', + kernelFunc: tanhKernelFunc, +}; diff --git a/tfjs-backend-cpu/src/register_all_kernels.ts b/tfjs-backend-cpu/src/register_all_kernels.ts index 5051806001c..9403e7754aa 100644 --- a/tfjs-backend-cpu/src/register_all_kernels.ts +++ b/tfjs-backend-cpu/src/register_all_kernels.ts @@ -19,20 +19,41 @@ // the contents of this file and import only the kernels that are needed. import {KernelConfig, registerKernel} from '@tensorflow/tfjs-core'; +import {absConfig} from './kernels/Abs'; +import {acosConfig} from './kernels/Acos'; +import {acoshConfig} from './kernels/Acosh'; import {addConfig} from './kernels/Add'; +import {asinConfig} from './kernels/Asin'; +import {asinhConfig} from './kernels/Asinh'; +import {atanConfig} from './kernels/Atan'; +import {atanhConfig} from './kernels/Atanh'; import {castConfig} from './kernels/Cast'; +import {ceilConfig} from './kernels/Ceil'; +import {clipConfig} from './kernels/Clip'; import {complexConfig} from './kernels/Complex'; import {concatConfig} from './kernels/Concat'; import {cosConfig} from './kernels/Cos'; +import {coshConfig} from './kernels/Cosh'; import {dilation2dConfig} from './kernels/Dilation2D'; import {dilation2dBackpropFilterConfig} from './kernels/Dilation2DBackpropFilter'; import {dilation2dBackpropInputConfig} from './kernels/Dilation2DBackpropInput'; import {divConfig} from './kernels/Div'; +import {eluConfig} from './kernels/Elu'; +import {erfConfig} from './kernels/Erf'; +import {expConfig} from './kernels/Exp'; +import {expm1Config} from './kernels/Expm1'; import {fftConfig} from './kernels/FFT'; import {flipLeftRightConfig} from './kernels/FlipLeftRight'; +import {floorConfig} from './kernels/Floor'; import {identityConfig} from './kernels/Identity'; import {ifftConfig} from './kernels/IFFT'; import {imagConfig} from './kernels/Imag'; +import {isFiniteConfig} from './kernels/IsFinite'; +import {isInfConfig} from './kernels/IsInf'; +import {isNaNConfig} from './kernels/IsNaN'; +import {logConfig} from './kernels/Log'; +import {log1pConfig} from './kernels/Log1p'; +import {logicalNotConfig} from './kernels/LogicalNot'; import {maxConfig} from './kernels/Max'; import {maxPoolWithArgmaxConfig} from './kernels/MaxPoolWithArgmax'; import {multiplyConfig} from './kernels/Multiply'; @@ -41,31 +62,65 @@ import {nonMaxSuppressionV5Config} from './kernels/NonMaxSuppressionV5'; import {notEqualConfig} from './kernels/NotEqual'; import {padV2Config} from './kernels/PadV2'; import {realConfig} from './kernels/Real'; +import {reciprocalConfig} from './kernels/Reciprocal'; import {reshapeConfig} from './kernels/Reshape'; import {rotateWithOffsetConfig} from './kernels/RotateWithOffset'; +import {roundConfig} from './kernels/Round'; +import {rsqrtConfig} from './kernels/Rsqrt'; +import {seluConfig} from './kernels/Selu'; +import {sigmoidConfig} from './kernels/Sigmoid'; +import {signConfig} from './kernels/Sign'; +import {sinConfig} from './kernels/Sin'; +import {sinhConfig} from './kernels/Sinh'; import {sliceConfig} from './kernels/Slice'; +import {softplusConfig} from './kernels/Softplus'; import {spaceToBatchNDConfig} from './kernels/SpaceToBatchND'; +import {sqrtConfig} from './kernels/Sqrt'; import {squareConfig} from './kernels/Square'; import {squaredDifferenceConfig} from './kernels/SquaredDifference'; +import {stepConfig} from './kernels/Step'; import {subConfig} from './kernels/Sub'; +import {tanConfig} from './kernels/Tan'; +import {tanhConfig} from './kernels/Tanh'; import {transposeConfig} from './kernels/Transpose'; // List all kernel configs here const kernelConfigs: KernelConfig[] = [ + absConfig, + acosConfig, + acoshConfig, addConfig, + asinConfig, + asinhConfig, + atanConfig, + atanhConfig, castConfig, + ceilConfig, + clipConfig, complexConfig, concatConfig, cosConfig, + coshConfig, dilation2dConfig, dilation2dBackpropInputConfig, dilation2dBackpropFilterConfig, divConfig, + eluConfig, + erfConfig, + expConfig, + expm1Config, fftConfig, flipLeftRightConfig, + floorConfig, identityConfig, ifftConfig, imagConfig, + isFiniteConfig, + isInfConfig, + isNaNConfig, + logConfig, + log1pConfig, + logicalNotConfig, maxPoolWithArgmaxConfig, maxConfig, multiplyConfig, @@ -74,13 +129,26 @@ const kernelConfigs: KernelConfig[] = [ notEqualConfig, padV2Config, realConfig, + reciprocalConfig, reshapeConfig, rotateWithOffsetConfig, + roundConfig, + rsqrtConfig, + seluConfig, + sigmoidConfig, + signConfig, + sinConfig, + sinhConfig, sliceConfig, + softplusConfig, spaceToBatchNDConfig, + sqrtConfig, squareConfig, squaredDifferenceConfig, + stepConfig, subConfig, + tanConfig, + tanhConfig, transposeConfig ]; diff --git a/tfjs-backend-cpu/src/utils/unary_types.ts b/tfjs-backend-cpu/src/utils/unary_types.ts new file mode 100644 index 00000000000..0c411da817e --- /dev/null +++ b/tfjs-backend-cpu/src/utils/unary_types.ts @@ -0,0 +1,20 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {NamedAttrMap} from '@tensorflow/tfjs-core'; + +export type SimpleUnaryOperation = (x: number, attrs?: NamedAttrMap) => number; diff --git a/tfjs-backend-cpu/src/utils/unary_utils.ts b/tfjs-backend-cpu/src/utils/unary_utils.ts new file mode 100644 index 00000000000..df0d76d1149 --- /dev/null +++ b/tfjs-backend-cpu/src/utils/unary_utils.ts @@ -0,0 +1,54 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {DataType, KernelFunc, TypedArray, UnaryInputs, util} from '@tensorflow/tfjs-core'; + +import {MathBackendCPU} from '../backend_cpu'; +import {assertNotComplex} from '../cpu_util'; + +import {SimpleUnaryOperation} from './unary_types'; + +/** + * Template that creates a `KernelFunc` for unary ops. + * @param name Kernel name. + * @param op A `SimpleUnaryOperation` for the kernel. + * @param dtype Optional. If set, the result has this dtype. Otherwise, the + * result has the same dtype as the input. This is mainly used in certain + * kernels that return bool type, such as isFinite, isInf, etc. + * @param opComplex A `ComplexUnaryOperation` for the kernel to handle complex64 + * inputs. + */ +export function unaryKernelFunc( + name: string, op: SimpleUnaryOperation, dtype?: DataType): KernelFunc { + return ({inputs, attrs, backend}) => { + const {x} = inputs as UnaryInputs; + assertNotComplex(x, name); + if (x.dtype === 'string' || dtype === 'string') { + throw new Error('unaryKernelFunc does not support string input/output'); + } + + const cpuBackend = backend as MathBackendCPU; + const values = cpuBackend.data.get(x.dataId).values as TypedArray; + const xSize = util.sizeFromShape(x.shape); + const $dtype = dtype || x.dtype; + const newValues = util.getArrayFromDType($dtype, xSize); + for (let i = 0; i < xSize; ++i) { + newValues[i] = op(values[i], attrs); + } + return cpuBackend.makeTensorInfo(x.shape, $dtype, newValues); + }; +}