In [109]:
%matplotlib inline

import autograd.numpy.random as npr
import autograd.numpy as np
from autograd import grad
from autograd.extend import primitive, defvjp
from autograd.misc.optimizers import adam
import scipy.special.lambertw as lambertw_
import autograd.scipy.stats.norm as norm
import matplotlib.pyplot as plt

npr.seed(100)

In [137]:
lambertw = primitive(lambda x: lambertw_(x, 0).real)
defvjp(lambertw, 
        lambda ans, x: lambda g:  g * 1./ (x + np.exp(ans)),
        None 
      )
def slambertw_logpdf_(y, loc, scale, skew, tol = 0.8):
    u = (y - loc)/scale
    if skew != 0:
        cutoff = - tol /(np.e * skew)
        cond = u >= cutoff if skew < 0 else u <= cutoff
        utmp = np.where(cond, cutoff, u)
        wc = lambertw(skew * utmp)
        gc = 1./(np.exp(wc) + skew*utmp)
        z = np.where(cond, wc/skew + (u - cutoff) * gc, wc/skew)
        return norm.logpdf(z) + np.log(gc) - np.log(scale)
    else:
        return norm.logpdf(u) - np.log(scale)

In [138]:
x = np.linspace(-3,3,20)

In [139]:
# generate sample data
N = 1000
D = 50
K = 5

scale = 1.
skew = -0.05
theta_o = theta = npr.randn(N,K) 
beta_o = beta = npr.randn(D,K)
loc = np.matmul(theta, beta.T)
u = npr.randn(N,D)
y = u * np.exp(skew * u) * scale + loc

In [140]:
def make_lambertw_routines(y, fixed_skew):
    
    def skew_function(skew, limit = 0.2, scale = 0.5):
#         return 0.2 * np.tanh(scale * skew)
        return skew

    def unpack_params(params):
        theta = np.reshape(params[:(N*K)], [N, K])
        beta = np.reshape(params[(N*K):-1], [D, K])
#         skew = skew_function(params[-1])
        skew = fixed_skew
        return theta, beta, skew
  
    slambertw_logpdf = lambda loc, scale, skew: slambertw_logpdf_(y, loc, scale, skew, 0.9)

    def objective(params, t):
        theta, beta, skew = unpack_params(params)
        loc = np.matmul(theta, beta.T)
        return -np.sum(slambertw_logpdf(loc, 1., skew))
    return objective, slambertw_logpdf, unpack_params, skew_function

def callback(params, i, g):
    if not i%100: print(i, objective(params, 0), skew_function(params[-1]), np.sum(params[:-1]), end ='\r')

In [None]:
theta_n = npr.randn(N,K) 
beta_n = npr.randn(D,K) 
init_params = np.concatenate([theta_n.flatten(),beta_n.flatten(),np.array([0.0001])])
objective, slambertw_logpdf, unpack_params, skew_function = make_lambertw_routines(y, -0.1)
gradient = grad(objective)
final_params = adam(gradient, init_params, step_size=0.1, num_iters=1000, callback = callback)
# theta_f, beta_f, skew = unpack_params(final_params) 
# Inference ends here. 
# print("\n")
# print("Estimated skew:", skew, "; Original skew:", -0.05)

900 68090.9866712 0.0001 84.08747222771

In [38]:
gradient(init_params, 0)

TypeError: only integer scalar arrays can be converted to a scalar index

In [44]:
np.std(npr.randn(100000)) * 0.3

0.30012155672349367

In [106]:
# def slambertw_logpdf_(y, loc, scale, skew, tol = 0.8):
#     u = (y - loc)/scale
#     if skew != 0:
#         cutoff = - tol * 1/(np.e * skew)
#         cond = u >= cutoff if skew < 0 else u <= cutoff
#         wc = lambertw(skew*u)
#         gc = 1./(np.exp(wc) + skew*u)
#         z = wc/skew
        
#         wc[cond] = wtmp = lambertw(skew * cutoff)
#         gc[cond] = 1./(np.exp(wtmp) + skew*cutoff)
#         z[cond] = wc[cond]/skew + (u[cond] - cutoff) * gc[cond]
        
#         return norm.logpdf(z) + np.log(gc) - np.log(scale)
#     else:
#         return norm.logpdf(u) - np.log(scale)

In [108]:
# def slambertw_(x, skew):
#     cutoff = - 0.9 * 1/(np.e * skew)
#     cond = x >= cutoff if skew < 0 else x <= cutoff
#     wc = lambertw(skew * cutoff)
#     gc = 1. / (np.exp(wc) + skew*cutoff)
#     z = lambertw(skew*x)/skew
#     z[cond] = wc/skew + (x[cond] - cutoff) * gc
#     return z

# def lambertw_logpdf_(y, loc, scale, skew):
#     u = (y - loc)/scale
#     if skew != 0:
#         wc = lambertw(skew*u)
#         gc = 1./(np.exp(wc) + skew*u)
#         z = wc/skew
# #         return norm.logpdf(z) + np.log(gc) - np.log(scale)
#         return gc
#     else:
#         return norm.logpdf(u) - np.log(scale)