# Mish Derivatves

In [1]:
import torch
from torch.nn import functional as F

In [2]:
inp = torch.randn(100) + (torch.arange(0, 1000, 10, dtype=torch.float)-500.)
inp

tensor([-500.3069, -490.6361, -480.3858, -471.2755, -459.0872, -451.1570,
        -440.2400, -429.6230, -419.9467, -408.3055, -402.3395, -389.1660,
        -380.1614, -369.7649, -359.7261, -348.8759, -338.7170, -329.2680,
        -319.7470, -309.6079, -301.3083, -290.9236, -279.3832, -267.8622,
        -259.3479, -249.7400, -240.8742, -229.2343, -219.3999, -210.0166,
        -199.8259, -191.5603, -178.9595, -171.4488, -160.3362, -150.1327,
        -139.2230, -130.8046, -121.8909, -108.4913, -100.5724,  -88.9087,
         -79.9365,  -70.3478,  -60.1005,  -49.9595,  -37.6322,  -29.9353,
         -18.9407,  -11.9213,   -2.5633,   10.6869,   18.9005,   29.4622,
          41.7188,   49.6080,   59.3583,   71.3071,   80.2604,   91.4908,
         100.2913,  108.7626,  118.8391,  129.8859,  139.5593,  150.6612,
         161.5152,  170.5409,  179.0472,  187.5896,  199.0938,  210.3955,
         221.2551,  229.1151,  240.5497,  250.7286,  260.3474,  268.6524,
         280.6704,  291.0199,  302.052

In [48]:
import sympy
from sympy import Symbol, Function, Expr, diff, simplify, exp, log, tanh
x = Symbol('x')
f =  Function('f')

## Overall Derivative

In [20]:
diff(x*tanh(log(exp(x)+1)))

x*(1 - tanh(log(exp(x) + 1))**2)*exp(x)/(exp(x) + 1) + tanh(log(exp(x) + 1))

In [22]:
simplify(diff(x*tanh(log(exp(x)+1))))

-(x*(tanh(log(exp(x) + 1))**2 - 1)*exp(x) - (exp(x) + 1)*tanh(log(exp(x) + 1)))/(exp(x) + 1)

## Softplus

$ \Large \frac{\partial}{\partial x} Softplus(x) = 1 - \frac{1}{e^{x} + 1} $

Or, from PyTorch:

$ \Large \frac{\partial}{\partial x} Softplus(x) = 1 - e^{-Y} $

Where $Y$ is saved output

In [4]:
class SoftPlusTest(torch.autograd.Function):
    @staticmethod
    def forward(ctx, inp, threshold=20):
        y = torch.where(inp < threshold, torch.log1p(torch.exp(inp)), inp)
        ctx.save_for_backward(y)
        return y
    
    @staticmethod
    def backward(ctx, grad_out):
        y, = ctx.saved_tensors
        res = 1 - (-y).exp_()
        return grad_out * res


In [5]:
torch.allclose(F.softplus(inp), SoftPlusTest.apply(inp))

True

In [6]:
torch.autograd.gradcheck(SoftPlusTest.apply, inp.to(torch.float64).requires_grad_())

True

## $tanh(Softplus(x))$

In [7]:
diff(tanh(f(x)))

(1 - tanh(f(x))**2)*Derivative(f(x), x)

In [8]:
class TanhSPTest(torch.autograd.Function):
    @staticmethod
    def forward(ctx, inp, threshold=20):
        ctx.save_for_backward(inp)
        sp = torch.where(inp < threshold, torch.log1p(torch.exp(inp)), inp)
        y = torch.tanh(sp)
        return y
    
    @staticmethod
    def backward(ctx, grad_out, threshold=20):
        inp, = ctx.saved_tensors
        sp = torch.where(inp < threshold, torch.log1p(torch.exp(inp)), inp)
        grad_sp = 1 - torch.exp(-sp)
        tanhsp = torch.tanh(sp)
        grad = (1 - tanhsp*tanhsp) * grad_sp
        return grad_out * grad


In [9]:
torch.allclose(TanhSPTest.apply(inp), torch.tanh(F.softplus(inp)))

True

In [10]:
torch.autograd.gradcheck(TanhSPTest.apply, inp.to(torch.float64).requires_grad_())

True

## Mish

In [11]:
diff(x * f(x))

x*Derivative(f(x), x) + f(x)

In [12]:
diff(x*tanh(f(x)))

x*(1 - tanh(f(x))**2)*Derivative(f(x), x) + tanh(f(x))

In [13]:
simplify(diff(x*tanh(f(x))))

x*Derivative(f(x), x)/cosh(f(x))**2 + tanh(f(x))

In [14]:
diff(tanh(f(x)))

(1 - tanh(f(x))**2)*Derivative(f(x), x)

In [15]:
class MishTest(torch.autograd.Function):
    @staticmethod
    def forward(ctx, inp, threshold=20):
        ctx.save_for_backward(inp)
        sp = torch.where(inp < threshold, torch.log1p(torch.exp(inp)), inp)
        tsp = torch.tanh(sp)
        y = inp.mul(tsp)
        return y
    
    @staticmethod
    def backward(ctx, grad_out, threshold=20):
        inp, = ctx.saved_tensors
        sp = torch.where(inp < threshold, torch.log1p(torch.exp(inp)), inp)
        grad_sp = 1 - torch.exp(-sp)
        tsp = torch.tanh(sp)
        grad_tsp = (1 - tsp*tsp) * grad_sp
        grad = inp * grad_tsp + tsp
        return grad_out * grad


In [16]:
torch.allclose(MishTest.apply(inp), inp.mul(torch.tanh(F.softplus(inp))))

True

In [17]:
torch.autograd.gradcheck(TanhSPTest.apply, inp.to(torch.float64).requires_grad_())

True