diff --git a/tfjs-node/src/kernels/Softmax.ts b/tfjs-node/src/kernels/Softmax.ts new file mode 100644 index 00000000000..af08a9f15d0 --- /dev/null +++ b/tfjs-node/src/kernels/Softmax.ts @@ -0,0 +1,37 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {NamedTensorInfoMap, registerKernel, TensorInfo} from '@tensorflow/tfjs-core'; + +import {createTypeOpAttr, NodeJSKernelBackend} from '../nodejs_kernel_backend'; + +interface SoftmaxInputs extends NamedTensorInfoMap { + logits: TensorInfo; +} + +registerKernel({ + kernelName: 'Softmax', + backendName: 'tensorflow', + kernelFunc: ({inputs, backend}) => { + const {logits} = inputs as SoftmaxInputs; + const opAttrs = [createTypeOpAttr('T', logits.dtype)]; + + const nodeBackend = backend as NodeJSKernelBackend; + + return nodeBackend.executeSingleOutput('Softmax', opAttrs, [logits]); + } +}); diff --git a/tfjs-node/src/kernels/all_kernels.ts b/tfjs-node/src/kernels/all_kernels.ts index c6eea842964..6b5f6bf8245 100644 --- a/tfjs-node/src/kernels/all_kernels.ts +++ b/tfjs-node/src/kernels/all_kernels.ts @@ -19,3 +19,4 @@ // global registry when we compile the library. A modular build would replace // the contents of this file and import only the kernels that are needed. import './non_max_suppression_v5'; +import './Softmax';