diff --git a/tfjs-layers/src/activations.ts b/tfjs-layers/src/activations.ts index 3e9a0d38a28..97c2c02f668 100644 --- a/tfjs-layers/src/activations.ts +++ b/tfjs-layers/src/activations.ts @@ -208,6 +208,25 @@ export class LogSoftmax extends Activation { } serialization.registerClass(LogSoftmax); +/** + * Swish activation function + */ +export class Swish extends Activation { + /** @nocollapse */ + static readonly className = 'swish'; + /** + * Calculate the activation function. + * + * @param x Tensor. + * @param alpha Scaling factor for the sigmoid function. + * @returns a Tensor of the same shape as x + */ + apply(x: Tensor, alpha = 1): Tensor { + return tidy(() => tfc.sigmoid(x.mul(alpha)).mul(x)); + } +} +serialization.registerClass(Swish); + export function serializeActivation(activation: Activation): string { return activation.getClassName(); } diff --git a/tfjs-layers/src/activations_test.ts b/tfjs-layers/src/activations_test.ts index 0b9b4271b95..69dd91c27eb 100644 --- a/tfjs-layers/src/activations_test.ts +++ b/tfjs-layers/src/activations_test.ts @@ -13,7 +13,7 @@ */ import {scalar, tensor1d, tensor2d, tensor3d} from '@tensorflow/tfjs-core'; -import {Elu, HardSigmoid, Linear, LogSoftmax, Relu, Relu6, Selu, Sigmoid, Softmax, Softplus, Softsign, Tanh} from './activations'; +import {Elu, HardSigmoid, Linear, LogSoftmax, Relu, Relu6, Selu, Sigmoid, Softmax, Softplus, Softsign, Tanh, Swish} from './activations'; import {describeMathCPUAndGPU, expectNoLeakedTensors, expectTensorsClose} from './utils/test_utils'; describeMathCPUAndGPU('linear activation', () => { @@ -300,3 +300,36 @@ describeMathCPUAndGPU('logsoftmax activation', () => { expectNoLeakedTensors(() => logsoftmax(initX), 1); }); }); + +describeMathCPUAndGPU('swish activation', () => { + const swish = new Swish().apply; + // Setup: Array with initial values. + // Execute: Swish on the last dimension. + // Expect: Output array matches size and approximate expected values. + it('1D', () => { + const initX = tensor1d([0, 1, 3, 9]); + const expectedVals = tensor1d([0, .731, 2.857, 8.998]); + expectTensorsClose(swish(initX), expectedVals); + }); + it('1D all equal', () => { + const initX = tensor1d([-1, -1, -1, -1]); + const expectedVals = tensor1d([-.268, -.268, -.268, -.268]); + expectTensorsClose(swish(initX), expectedVals); + }); + it('2D', () => { + const initX = tensor2d([[0, 1, 3, 9], [0, 1, 3, 9]]); + const expectedVals = tensor2d( + [[0, .731, 2.857, 8.998], [0, .731, 2.857, 8.998]]); + expectTensorsClose(swish(initX), expectedVals); + }); + it('3D', () => { + const initX = tensor3d([[[0, 1, 3, 9], [0, 1, 3, 9]]]); + const expectedVals = tensor3d( + [[[0, .731, 2.857, 8.998], [0, .731, 2.857, 8.998]]]); + expectTensorsClose(swish(initX), expectedVals); + }); + it('Does not leak', () => { + const initX = tensor1d([0, 1, 3, 9]); + expectNoLeakedTensors(() => swish(initX), 1); + }); +});