Skip to content

Commit

Permalink
Feat/add mish (#4950)
Browse files Browse the repository at this point in the history
FEATURE
* add impl

* add tests

* Update activations_test.ts

Co-authored-by: Ping Yu <4018+pyu10055@users.noreply.github.com>
  • Loading branch information
WenheLI and pyu10055 committed Apr 19, 2021
1 parent cfe4e1d commit ac9d67d
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 1 deletion.
18 changes: 18 additions & 0 deletions tfjs-layers/src/activations.ts
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,24 @@ export class Swish extends Activation {
}
serialization.registerClass(Swish);

/**
* Mish activation function
*/
export class Mish extends Activation {
/** @nocollapse */
static readonly className = 'mish';
/**
* Calculate the activation function.
*
* @param x Tensor.
* @returns a Tensor of the same shape as x
*/
apply(x: Tensor): Tensor {
return tidy(() => tfc.mul(x, tfc.tanh(tfc.softplus(x))));
}
}
serialization.registerClass(Mish);

export function serializeActivation(activation: Activation): string {
return activation.getClassName();
}
Expand Down
35 changes: 34 additions & 1 deletion tfjs-layers/src/activations_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
*/
import {scalar, tensor1d, tensor2d, tensor3d} from '@tensorflow/tfjs-core';

import {Elu, HardSigmoid, Linear, LogSoftmax, Relu, Relu6, Selu, Sigmoid, Softmax, Softplus, Softsign, Tanh, Swish} from './activations';
import {Elu, HardSigmoid, Linear, LogSoftmax, Relu, Relu6, Selu, Sigmoid, Softmax, Softplus, Softsign, Tanh, Swish, Mish} from './activations';
import {describeMathCPUAndGPU, expectNoLeakedTensors, expectTensorsClose} from './utils/test_utils';

describeMathCPUAndGPU('linear activation', () => {
Expand Down Expand Up @@ -333,3 +333,36 @@ describeMathCPUAndGPU('swish activation', () => {
expectNoLeakedTensors(() => swish(initX), 1);
});
});

describeMathCPUAndGPU('mish activation', () => {
const mish = new Mish().apply;
// Setup: Array with initial values.
// Execute: Mish on the last dimension.
// Expect: Output array matches size and approximate expected values.
it('1D', () => {
const initX = tensor1d([0, 1, 3, 9]);
const expectedVals = tensor1d([0., .865, 2.987, 9.]);
expectTensorsClose(mish(initX), expectedVals);
});
it('1D all equal', () => {
const initX = tensor1d([-1, -1, -1, -1]);
const expectedVals = tensor1d([-0.303, -0.303, -0.303, -0.303]);
expectTensorsClose(mish(initX), expectedVals);
});
it('2D', () => {
const initX = tensor2d([[0, 1, 3, 9], [0, 1, 3, 9]]);
const expectedVals = tensor2d(
[[0., .865, 2.987, 9.], [0., .865, 2.987, 9.]]);
expectTensorsClose(mish(initX), expectedVals);
});
it('3D', () => {
const initX = tensor3d([[[0, 1, 3, 9], [0, 1, 3, 9]]]);
const expectedVals = tensor3d(
[[[0., .865, 2.987, 9.], [0., .865, 2.987, 9.]]]);
expectTensorsClose(mish(initX), expectedVals);
});
it('Does not leak', () => {
const initX = tensor1d([0, 1, 3, 9]);
expectNoLeakedTensors(() => mish(initX), 1);
});
});

0 comments on commit ac9d67d

Please sign in to comment.