In [1]:
%matplotlib inline

import autograd.numpy.random as npr
import autograd.numpy as np
from autograd import grad
from autograd.extend import primitive, defvjp
from autograd.misc.optimizers import adam
import scipy.special.lambertw as lambertw_
import autograd.scipy.stats.norm as norm
import matplotlib.pyplot as plt

npr.seed(100)

In [18]:
def lambertw(x):
#     if np.any(x < -1/np.e):
#         return np.NaN
#     else:
        return lambertw_(x,0).real
    
lambertw = primitive(lambertw)

defvjp(lambertw, 
        lambda ans, x: lambda g:  g * 1./ (x + np.exp(ans)),
        None 
      )

In [19]:
lambertw(-4)

0.6788119713209454

In [4]:
# generate sample data
N = 1000
D = 100
K = 5

scale = 1.
skew = -0.05
theta_o = theta = npr.randn(N,K) 
beta_o = beta = npr.randn(D,K)
loc = np.matmul(theta, beta.T)
u = npr.randn(N,D)
y = u * np.exp(skew * u) * scale + loc

In [12]:
def make_lambertw_routines(y, fixed_skew):
    
    def skew_function(skew, limit = 0.2, scale = 0.5):
#         return 0.2 * np.tanh(scale * skew)
        return skew

    def unpack_params(params):
        theta = np.reshape(params[:(N*K)], [N, K])
        beta = np.reshape(params[(N*K):-1], [D, K])
#         skew = skew_function(params[-1])
        skew = fixed_skew
        return theta, beta, skew
        
    def lambertw_logpdf(loc, log_scale, skew, t):
        scale = np.exp(log_scale)
        u = (y - loc)/scale
        if skew != 0: #and t > 2000:
            u_ = u*skew
            W = lambertw(u_)
            z = W/skew
            jacobian = 1./(u_+np.exp(W))
            return norm.logpdf(z) + np.log(np.abs(jacobian)) - log_scale
        else:
            return norm.logpdf(u) - log_scale
      
    def objective(params, t):
        theta, beta, skew = unpack_params(params)
        loc = np.matmul(theta, beta.T)
        return -np.sum(lambertw_logpdf(loc, np.log(1.), skew, t))
    return objective, lambertw_logpdf, unpack_params, skew_function

def callback(params, i, g):
    if not i%100: print(i, objective(params, 0), skew_function(params[-1]), np.sum(params[:-1]), end ='\r')

In [15]:
theta_n = npr.randn(N,K) 
beta_n = npr.randn(D,K) 
init_params = np.concatenate([theta_n.flatten(),beta_n.flatten(),np.array([0.0001])])
objective, lambertw_logpdf, unpack_params, skew_function = make_lambertw_routines(y, 0.01)
gradient = grad(objective)
final_params = adam(gradient, init_params, step_size=0.1, num_iters=1000, callback = callback)
theta_f, beta_f, skew = unpack_params(final_params) 
# Inference ends here. 
print("\n")
print("Estimated skew:", skew, "; Original skew:", -0.05)

700 140092.918587 0.0001 -88.7270686775

KeyboardInterrupt: 