In [1]:
import numpy as np
import torch 
import torch.nn as nn
import torch.nn.functional as F

In [2]:
x_cuda = torch.rand((100, 3, 256, 256), dtype=torch.float32, requires_grad=True).to('cuda')
x_cpu = torch.rand((100, 3, 256, 256), dtype=torch.float32)

# staticmethod

In [3]:
class Mish(nn.Module):
    '''
    Applies the mish function element-wise:
    mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
    Shape:
        - Input: (N, *) where * means, any number of additional
          dimensions
        - Output: (N, *), same shape as the input
    Examples:
        >>> m = Mish()
        >>> input = torch.randn(2)
        >>> output = m(input)
    '''
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return Mish.mish(x)

    @staticmethod
    def mish(x):
        '''
        mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
        '''
        return x * torch.tanh(F.softplus(x))

In [4]:
%timeit Mish.mish(x_cpu)
mish = Mish()
%timeit mish(x_cpu)

86.5 ms ± 930 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
86.4 ms ± 907 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [5]:
%timeit Mish.mish(x_cuda)

mish = Mish()
%timeit mish(x_cuda)

1.01 ms ± 8.34 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
1.02 ms ± 15 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


# normal class

In [6]:
class Mish(nn.Module):
    '''
    Applies the mish function element-wise:
    mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
    Shape:
        - Input: (N, *) where * means, any number of additional
          dimensions
        - Output: (N, *), same shape as the input
    Examples:
        >>> m = Mish()
        >>> input = torch.randn(2)
        >>> output = m(input)
    '''
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return self.mish(x)

    def mish(self, x):
        '''
        mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
        '''
        return x * torch.tanh(F.softplus(x))

In [7]:
mish = Mish()
%timeit mish(x_cpu)

86.5 ms ± 604 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [8]:
mish = Mish()
%timeit mish(x_cuda)

1.01 ms ± 7.41 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


# function

In [9]:
def mish(x):
    '''
    mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
    '''
    return x * torch.tanh(F.softplus(x))

In [10]:
%timeit mish(x_cpu)

86 ms ± 723 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [11]:
%timeit mish(x_cuda)

1 ms ± 694 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


# global function

In [12]:
def mish(x):
    '''
    mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
    '''
    return x * torch.tanh(F.softplus(x))


class Mish(nn.Module):
    '''
    Applies the mish function element-wise:
    mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
    Shape:
        - Input: (N, *) where * means, any number of additional
          dimensions
        - Output: (N, *), same shape as the input
    Examples:
        >>> m = Mish()
        >>> input = torch.randn(2)
        >>> output = m(input)
    '''
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return mish(x)

In [13]:
mish_act = Mish()
%timeit mish_act(x_cpu)

85.4 ms ± 1.3 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [14]:
mish_act = Mish()
%timeit mish_act(x_cuda)

1.01 ms ± 9.76 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


# memory-efficient variant v1

In [15]:
class MishAutoFn(torch.autograd.Function):
    """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
    Experimental memory-efficient variant
    """

    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        y = x.mul(torch.tanh(F.softplus(x)))  # x * tanh(ln(1 + exp(x)))
        return y

    @staticmethod
    def backward(ctx, grad_output):
        x = ctx.saved_tensors[0]
        x_sigmoid = torch.sigmoid(x)
        x_tanh_sp = F.softplus(x).tanh()
        return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
    

class Mish(nn.Module):
    def __init__(self, inplace: bool = False):
        super(Mish, self).__init__()
        self.inplace = inplace

    def forward(self, x):
        return MishAutoFn.apply(x)

In [16]:
mish = Mish()
%timeit mish(x_cpu)

85.8 ms ± 643 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [17]:
mish = Mish()
%timeit mish(x_cuda)

1.01 ms ± 4.15 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


# memory-efficient variant v2

In [18]:
class Mish(nn.Module):
    class MishAutoFn(torch.autograd.Function):
        """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
        Experimental memory-efficient variant
        """
        @staticmethod
        def forward(ctx, x):
            ctx.save_for_backward(x)
            y = x.mul(torch.tanh(F.softplus(x)))  # x * tanh(ln(1 + exp(x)))
            return y

        @staticmethod
        def backward(ctx, grad_output):
            x = ctx.saved_tensors[0]
            x_sigmoid = torch.sigmoid(x)
            x_tanh_sp = F.softplus(x).tanh()
            return grad_output.mul(x_tanh_sp + x * x_sigmoid *
                                   (1 - x_tanh_sp * x_tanh_sp))

    def __init__(self, inplace: bool = False):
        super(Mish, self).__init__()
        self.inplace = inplace

    def forward(self, x):
        return self.MishAutoFn.apply(x)

In [19]:
mish = Mish()
%timeit mish(x_cpu)

87.2 ms ± 564 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [20]:
mish = Mish()
%timeit mish(x_cuda)

1 ms ± 869 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


# test the speed of batch size in 2^n

In [23]:
x_cuda1 = torch.rand((120, 3, 256, 256), dtype=torch.float32).to('cuda')
x_cuda2 = torch.rand((125, 3, 256, 256), dtype=torch.float32).to('cuda')
x_cuda3 = torch.rand((128, 3, 256, 256), dtype=torch.float32).to('cuda')

In [24]:
%timeit mish(x_cuda1)
%timeit mish(x_cuda2)
%timeit mish(x_cuda3)

1.2 ms ± 1.41 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
1.25 ms ± 5.74 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.28 ms ± 3.73 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


# relu

In [25]:
%timeit F.relu(x_cpu)
%timeit F.relu(x_cpu, inplace=True)

14.5 ms ± 78 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
5.87 ms ± 70.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [26]:
%timeit F.relu(x_cuda)
%timeit F.relu(x_cuda, inplace=True)

290 µs ± 1.12 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
289 µs ± 4.43 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


# sigmoid

In [27]:
%timeit torch.sigmoid(x_cpu)

29 ms ± 214 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [28]:
%timeit torch.sigmoid(x_cuda)

291 µs ± 1.45 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
