diff --git a/tfjs-backend-webgl/src/backend_webgl.ts b/tfjs-backend-webgl/src/backend_webgl.ts index 30c2168e9c6..cd0583c1743 100644 --- a/tfjs-backend-webgl/src/backend_webgl.ts +++ b/tfjs-backend-webgl/src/backend_webgl.ts @@ -1743,11 +1743,6 @@ export class MathBackendWebGL extends KernelBackend { return this.compileAndRun(program, [x]); } - cos(x: T): T { - const program = new UnaryOpProgram(x.shape, unary_op.COS); - return this.compileAndRun(program, [x]); - } - tan(x: T): T { const program = new UnaryOpProgram(x.shape, unary_op.TAN); return this.compileAndRun(program, [x]); diff --git a/tfjs-backend-webgl/src/kernels/Cos.ts b/tfjs-backend-webgl/src/kernels/Cos.ts new file mode 100644 index 00000000000..39bd53fee7d --- /dev/null +++ b/tfjs-backend-webgl/src/kernels/Cos.ts @@ -0,0 +1,32 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Cos, CosInputs, KernelConfig} from '@tensorflow/tfjs-core'; + +import {MathBackendWebGL} from '../backend_webgl'; +import {COS, UnaryOpProgram} from '../unaryop_gpu'; + +export const cosConfig: KernelConfig = { + kernelName: Cos, + backendName: 'webgl', + kernelFunc: ({inputs, backend}) => { + const {x} = inputs as CosInputs; + const webglBackend = backend as MathBackendWebGL; + const program = new UnaryOpProgram(x.shape, COS); + return webglBackend.runWebGLProgram(program, [x], x.dtype); + } +}; diff --git a/tfjs-backend-webgl/src/register_all_kernels.ts b/tfjs-backend-webgl/src/register_all_kernels.ts index 9c26002d7d4..786284c97c3 100644 --- a/tfjs-backend-webgl/src/register_all_kernels.ts +++ b/tfjs-backend-webgl/src/register_all_kernels.ts @@ -16,6 +16,7 @@ */ import {KernelConfig, registerKernel} from '@tensorflow/tfjs-core'; +import {cosConfig} from './kernels/Cos'; import {divConfig} from './kernels/Div'; import {flipLeftRightConfig} from './kernels/FlipLeftRight'; import {fromPixelsConfig} from './kernels/FromPixels'; @@ -31,7 +32,7 @@ import {transposeConfig} from './kernels/Transpose'; // List all kernel configs here const kernelConfigs: KernelConfig[] = [ - maxConfig, flipLeftRightConfig, fromPixelsConfig, divConfig, + cosConfig, maxConfig, flipLeftRightConfig, fromPixelsConfig, divConfig, maxPoolWithArgmaxConfig, nonMaxSuppressionV3Config, nonMaxSuppressionV4Config, nonMaxSuppressionV5Config, rotateWithOffsetConfig, squareConfig, squaredDifferenceConfig, transposeConfig