在这我们将考虑stan的一个简单的例子，展示它的点估计/优化，变分推断，HMC采样如何用pytorch简单的实现。

In [1]:
import pystan
import numpy as np
import torch
from torch.autograd import Variable

# MAP点估计

In [2]:
ocode = """
data {
    int<lower=1> N;
    real y[N];
}
parameters {
    real mu;
}
model {
    y ~ normal(mu, 1);
}
"""


In [3]:
%time sm = pystan.StanModel(model_code=ocode)


INFO:pystan:COMPILING THE C++ CODE FOR MODEL anon_model_3aaa1aff3be33470f8a5bfa56085d51c NOW.


Wall time: 1min 13s


In [4]:
np.random.seed(13)

In [5]:
y2 = np.random.normal(size=20)
np.mean(y2)


0.36640264498852165

In [6]:
%time op = sm.optimizing(data=dict(y=y2, N=len(y2)))


Wall time: 45.6 ms


  elif np.issubdtype(np.asarray(v).dtype, float):


In [7]:
op


OrderedDict([('mu', array(0.36640264))])

这里我们手动写出正态分布的似然函数，虽然pytorch也提供一些，但并不是很多，这也许是用它来实现相关贝叶斯统计框架所将遇到的主要困难。

$$
\mathrm{normal}(Y \mid \mu,1) = \prod_i^n C \exp\left(-\frac{(x_i-\mu)^2}{2}\right)  \\
\log \mathrm{normal} (Y \mid \mu,1) = C_2 - \sum_i^n  \frac{(x_i - \mu)^2}{2}
$$

In [8]:
def normal_lp(x,mu,sigma):
    #虽然在stan里这个函数是包含常数的，但这里把它弃掉
    return -((x - mu)**2/(2*sigma**2)).sum()

In [9]:
target = Variable(torch.zeros(1))

In [10]:
y = Variable(torch.from_numpy(y2).float())

In [11]:
mu = Variable(torch.zeros(1),requires_grad=True)
target += normal_lp(y,mu,1)

loss = -target

In [12]:
loss.backward()

In [13]:
mu

Variable containing:
 0
[torch.FloatTensor of size 1]

In [14]:
mu.grad

Variable containing:
-7.3281
[torch.FloatTensor of size 1]

In [15]:
def model(y):
    target = Variable(torch.zeros(1))
    target += normal_lp(y,mu,1)
    return target

optimizer = torch.optim.SGD([mu],lr=0.01)
for epoch in range(100):
    optimizer.zero_grad()
    y = Variable(torch.from_numpy(y2).float())
    target = model(y)
    loss = -target
    loss.backward()
    optimizer.step()

In [16]:
loss.grad

In [17]:
mu

Variable containing:
 0.3664
[torch.FloatTensor of size 1]

# 自动微分变分推断

## 平均场

In [89]:
fit = sm.vb(data = dict(y=y2, N = len(y2)))

  elif np.issubdtype(np.asarray(v).dtype, float):


In [90]:
fit.keys()

odict_keys(['args', 'inits', 'sampler_params', 'sampler_param_names', 'mean_pars'])

In [91]:
fit['mean_pars']

[0.3106542643767246]

In [92]:
fit['sampler_param_names']

['mu']

In [94]:
len(fit['sampler_params'][0])

1001

In [93]:
np.mean(fit['sampler_params'][0]),np.std(fit['sampler_params'][0]) #当然由于这个标准差不是直接从变分参数里拿出来而是按stan的设计莫名其妙又“估计”一次所以并不精确，在多维情况下会更明显

(0.30312868383719116, 0.2286080403322702)

In [96]:
ppost_stan = []
for i in range(1000):
    fit = sm.vb(data = dict(y=y2, N = len(y2)))
    ppost_stan.append([np.mean(fit['sampler_params'][0]),np.std(fit['sampler_params'][0])])

  elif np.issubdtype(np.asarray(v).dtype, float):
























































“随机”梯度下降依据：

期望变分参数之梯度
$$
\nabla_\mu \mathscr{L} = \mathbf{E}_{\mathscr{N}(\mathbf{\eta})} (\nabla_\theta \log p(\mathbf{x},\mathbf{\theta}) \nabla_\xi T^{-1}(\xi)+
\nabla_\xi \log \mid \det J_{T^{-1}}(\xi) \mid)
$$

$\omega$变分参数（$\omega = \log \sigma$,$\omega$为平均场设定下设定的向量，满秩设定采用chol分解因子$L$一般）参数之梯度
$$
\nabla_\omega \mathscr{L} = \mathbf{E}_{\mathscr{N}(\mathbf{\eta})} (\nabla_\theta \log p(\mathbf{x},\mathbf{\theta}) \nabla_\xi T^{-1}(\xi)+
\nabla_\xi \log \mid \det J_{T^{-1}}(\xi) \mid)\eta^T \mathrm{diag}(\exp(\omega)) + \mathbf{1} \\
\nabla_\mathbf{L} \mathscr{L} = \mathbf{E}_{\mathscr{N}(\mathbf{\eta})} (\nabla_\theta \log p(\mathbf{x},\mathbf{\theta}) \nabla_\xi T^{-1}(\xi)+
\nabla_\xi \log \mid \det J_{T^{-1}}(\xi) \mid)\eta^T ) + \mathbf{L}^{-1}
$$

In [23]:
#mu_q_mu = Variable(torch.zeros(1),requires_grad=True)
#mu_q_sigma = Variable(torch.ones(1),requires_grad=True)
mu_q_mu = 0.0
mu_q_omega = 1.0 # omega = log(sigma)

In [25]:
q_size = 10
#mu_q_samples = np.random.normal(loc=mu_q_mu.data.numpy(),scale=mu_q_sigma.data.numpy(),size=q_size)
mu_q_samples_eta = np.random.normal(size=q_size) # theta(constrained) -> xi(unconstrained) ~ normal -> eta(standarded) ~ Normal(0,1)
mu_q_samples = mu_q_samples_eta*np.exp(mu_q_omega) + mu_q_mu

In [33]:
mu_q_mu_grad = 0.0
mu_q_omega_grad = 0.0

for mu_q_samples_eta,mu_q_sample in zip(mu_q_samples_eta,mu_q_samples):
    mu_q = Variable(torch.ones(1)*mu_q_sample, requires_grad=True)
    target = normal_lp(y,mu_q,1)
    loss = -target
    loss.backward()
    mu_q_mu_grad += mu_q.grad.data.numpy()
    
    mu_q_omega_grad += mu_q.grad.data.numpy() * mu_q_samples_eta *np.exp(mu_q_omega)
    
mu_q_mu_grad /= q_size
mu_q_omega_grad /= q_size
mu_q_omega_grad += 1


In [34]:
mu_q_mu_grad

array([-27.935236], dtype=float32)

In [35]:
mu_q_omega_grad

array([226.88248], dtype=float32)

In [136]:
mu_q_samples_eta

-1.6151079632521659

In [184]:
def grad_q(mu_q_mu, mu_q_omega, q_size = 10):
    
    mu_q_samples_eta = np.random.normal(size=q_size) # theta(constrained) -> xi(unconstrained) ~ normal -> eta(standarded) ~ Normal(0,1)
    mu_q_samples = mu_q_samples_eta*np.exp(mu_q_omega) + mu_q_mu
    
    mu_q_mu_grad = 0.0
    mu_q_omega_grad = 0.0

    for mu_q_samples_eta,mu_q_sample in zip(mu_q_samples_eta,mu_q_samples):
        mu_q = Variable(torch.ones(1)*mu_q_sample, requires_grad=True)
        target = normal_lp(y,mu_q,1)
        target.backward()
        #loss = -target
        #loss.backward()
        mu_q_mu_grad += mu_q.grad.data.numpy()
        
        #print(target.data[0],mu_q.grad.data.numpy(), mu_q_samples_eta, np.exp(mu_q_omega))
        mu_q_omega_grad += mu_q.grad.data.numpy() * mu_q_samples_eta #*np.exp(mu_q_omega)

    mu_q_mu_grad /= q_size
    mu_q_omega_grad /= q_size
    mu_q_omega_grad *= np.exp(mu_q_omega)
    mu_q_omega_grad += 1.0
    return mu_q_mu_grad,mu_q_omega_grad

In [180]:
grad_q(mu_q_mu, mu_q_omega)

-7.533953666687012 [-4.2036176e-05] 0.2226337797113641 [9.871969e-06]
-7.533953666687012 [-4.298985e-05] 0.2302885023425095 [9.871969e-06]
-7.533953666687012 [0.00010817] -0.536077660940225 [9.871969e-06]
-7.533953666687012 [-0.00010343] 0.5335959700874251 [9.871969e-06]
-7.533953666687012 [0.00022499] -1.1281482983177094 [9.871969e-06]
-7.53395414352417 [-6.170571e-05] 0.32224425188157674 [9.871969e-06]
-7.533953666687012 [0.00019829] -0.9937985887459845 [9.871969e-06]
-7.533953666687012 [-8.657575e-06] 0.053218432497989285 [9.871969e-06]
-7.533953666687012 [-0.00021227] 1.0875599515138579 [9.871969e-06]
-7.53395414352417 [0.0002145] -1.0757209786479593 [9.871969e-06]


(array([2.7486682e-05], dtype=float32), array([1.], dtype=float32))

In [181]:
grad_q(0.36,0)

-7.534194469451904 [0.09814732] 0.0014952808221046854 1.0
-7.536289215087891 [0.3056382] -0.008879246810882868 1.0
-39.210514068603516 [35.59582] -1.7733883470758618 1.0
-13.829805374145508 [15.869281] -0.7870613821166278 1.0
-13.446975708007812 [-15.379235] 0.7753644691900808 1.0
-12.326577186584473 [-13.845756] 0.6986904397637416 1.0
-30.111600875854492 [-30.05172] 1.508988574791585 1.0
-17.526914596557617 [19.992958] -0.9932453141417751 1.0
-8.411796569824219 [5.9256825] -0.2898814874361849 1.0
-11.739463806152344 [-12.96998] 0.65490166189617 1.0


(array([0.5540838], dtype=float32), array([-16.26337], dtype=float32))

In [182]:
grad_q(0.36,-1)

-8.404897689819336 [5.902353] -0.7848087269967956 0.36787944117144233
-7.925912857055664 [3.9595919] -0.5207601299133467 0.36787944117144233
-7.746644973754883 [-2.9167876] 0.413836683781184 0.36787944117144233
-7.858697891235352 [-3.6041305] 0.5072562783682488 0.36787944117144233
-9.109000205993652 [-7.9373713] 1.0962048420341968 0.36787944117144233
-7.5340070724487305 [-0.0462857] 0.023695054489860024 0.36787944117144233
-12.00731372833252 [-13.3766365] 1.8354775102756988 0.36787944117144233
-7.543758392333984 [0.62625796] -0.06771309829082954 0.36787944117144233
-10.791436195373535 [11.414872] -1.5340378263753989 0.36787944117144233
-7.563039779663086 [-1.0786263] 0.1640046835282517 0.36787944117144233


(array([-0.7056763], dtype=float32), array([-1.2335527], dtype=float32))

In [183]:
grad_q(0.36,-2.5)

-7.543975353240967 [-0.6331364] 0.4636590738460825 0.0820849986238988
-7.55221700668335 [0.85470706] -0.44262310102131475 0.0820849986238988
-7.550801753997803 [0.82092613] -0.4220460816665051 0.0820849986238988
-7.5506486892700195 [-0.81719077] 0.5757711347821525 0.0820849986238988
-7.570662498474121 [1.2117562] -0.660110470888259 0.0820849986238988
-7.546915054321289 [0.7200317] -0.36058896964296 0.0820849986238988
-7.533953666687012 [0.00268261] 0.07636605866997537 0.0820849986238988
-7.546417236328125 [-0.70607436] 0.5080875985790488 0.0820849986238988
-7.561623573303223 [-1.052043] 0.7188257289694143 0.0820849986238988
-7.630311489105225 [-1.9632394] 1.2738578308957338 0.0820849986238988


(array([-0.156158], dtype=float32), array([0.9494024], dtype=float32))

In [158]:
grad_q(0.36,-5)

[-0.13578029] -0.05734153856805679 0.006737946999085467
[-0.26619285] -1.0250902609660595 0.006737946999085467
[-0.01381795] 0.8477007740016017 0.006737946999085467
[-0.22332112] -0.7069538564709593 0.006737946999085467
[-0.20960797] -0.6051907547991681 0.006737946999085467
[-0.06845085] 0.4422873962230695 0.006737946999085467
[-0.02473383] 0.7666948836202403 0.006737946999085467
[-0.1794671] -0.38152417060047134 0.006737946999085467
[0.08567245] 1.5859805823816426 0.006737946999085467
[-0.25378454] -0.9330128622048296 0.006737946999085467


(array([-0.1289484], dtype=float32), array([0.00063712], dtype=float32))

In [159]:
grad_q(0.36,-10)

[-0.12714414] 1.0005218139705558 4.5399929762484854e-05
[-0.12824206] -0.20869664062225887 4.5399929762484854e-05
[-0.12856404] -0.5630704022049372 4.5399929762484854e-05
[-0.12793343] 0.13149534999830156 4.5399929762484854e-05
[-0.12819092] -0.15197463658548493 4.5399929762484854e-05
[-0.12752335] 0.5833494949150893 4.5399929762484854e-05
[-0.1291594] -1.2183812923499213 4.5399929762484854e-05
[-0.12689342] 1.2773605145341744 4.5399929762484854e-05
[-0.12801398] 0.04287152818783283 4.5399929762484854e-05
[-0.12809198] -0.04327223739896249 4.5399929762484854e-05


(array([-0.12797567], dtype=float32), array([-4.7423106e-07], dtype=float32))

In [145]:
grad_q(0.36,-100)

[-0.12805264] -0.3739058200818378 3.720075976020836e-44
[-0.12805264] -1.1252546862048873 3.720075976020836e-44
[-0.12805264] 1.1743753581122953 3.720075976020836e-44
[-0.12805264] 0.2168079444574901 3.720075976020836e-44
[-0.12805264] -1.189568289129904 3.720075976020836e-44
[-0.12805264] 0.4786041197869711 3.720075976020836e-44
[-0.12805264] -0.2648865850212951 3.720075976020836e-44
[-0.12805264] -1.182598209022306 3.720075976020836e-44
[-0.12805264] -1.2357593785415266 3.720075976020836e-44
[-0.12805264] 1.9043439303917344 3.720075976020836e-44


(array([-0.12805262], dtype=float32), array([1.], dtype=float32))

In [55]:
lr = 0.001

In [80]:
for epoch in range(100):
    mu_q_mu_grad, mu_q_omega_grad = grad_q(mu_q_mu, mu_q_omega)
    mu_q_mu -= mu_q_mu_grad * lr
    mu_q_omega -= mu_q_omega_grad * lr

In [81]:
mu_q_mu

array([0.36268085], dtype=float32)

In [82]:
mu_q_omega

array([-2.0681887], dtype=float32)

In [83]:
np.exp(mu_q_omega)

array([0.12641455], dtype=float32)

In [188]:
mu_q_mu = 0.0
mu_q_omega = 0.0

In [189]:
ppost_torch = []
for i in range(100):
    for epoch in range(100):
        mu_q_mu_grad, mu_q_omega_grad = grad_q(mu_q_mu, mu_q_omega)
        mu_q_mu += mu_q_mu_grad * lr
        mu_q_omega += mu_q_omega_grad * lr
    ppost_torch.append([mu_q_mu,np.exp(mu_q_omega)])

In [190]:
mu_q_mu

array([0.37196043], dtype=float32)

In [191]:
np.mean(np.array(ppost_stan),axis=0)

array([-1.68484273e+76,  1.43333369e+78])

In [163]:
_ppost_stan = np.clip(ppost_stan,0.0,1.0)

In [164]:
np.mean(np.array(_ppost_stan),axis=0) # stan这种不靠谱的东西当然日常出一些溢出异常值

array([0.37015113, 0.22651429])

In [192]:
np.mean(np.array(ppost_torch),axis=0)

array([[0.37196007],
       [0.23125193]], dtype=float32)

# 采样

In [165]:
fit = sm.sampling(data = dict(y=y2, N = len(y2)))

  elif np.issubdtype(np.asarray(v).dtype, float):


In [166]:
print(fit)

Inference for Stan model: anon_model_3aaa1aff3be33470f8a5bfa56085d51c.
4 chains, each with iter=2000; warmup=1000; thin=1; 
post-warmup draws per chain=1000, total post-warmup draws=4000.

       mean se_mean     sd   2.5%    25%    50%    75%  97.5%  n_eff   Rhat
mu     0.37  5.7e-3   0.22  -0.05   0.21   0.37   0.52   0.81   1494    1.0
lp__  -8.03    0.02   0.67  -9.95   -8.2  -7.76  -7.59  -7.53   1664    1.0

Samples were drawn using NUTS at Mon Apr  9 09:01:32 2018.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).


虽然stan默认用的是NUTS采样器，但我们实现较为简单的哈密尔顿蒙特卡洛采样法（HMC）

<img src="HMC.png">

In [211]:
def hmc_grad(mu):
    #mu_q = Variable(torch.ones(1)*mu, requires_grad=True)
    mu_q = Variable(torch.from_numpy(mu).float(), requires_grad=True)
    target = normal_lp(y,mu_q,1)
    target.backward()
    return mu_q.data.numpy()

def hmc_L(mu):
    mu_q = Variable(torch.from_numpy(mu).float(), requires_grad=True)
    target = normal_lp(y,mu_q,1)
    return target.data.numpy()

def hmc(mu_q,epsilon=0.01,L=20,M=100):
    mu_q = np.array(mu_q)
    sample = [mu_q]
    for m in range(1,M):
        r = np.random.normal()
        #sample[m] = sample[m-1]
        sample.append(sample[m-1])
        mu_q_ = sample[m-1]
        r_ = r
        for i in range(L):
            r_ = r_ + epsilon/2.0 * hmc_grad(mu_q_)
            mu_q_ = mu_q_ + epsilon * r_
            r_ = r_ + epsilon/2.0 * hmc_grad(mu_q_)
        p = np.exp(hmc_L(mu_q_) -0.5*r_*r_)/ np.exp(hmc_L(sample[m-1])-0.5*r*r)
        alpha = np.min([1,p])
        if np.random.random() < alpha:
            sample[m] = mu_q_
    return sample

In [212]:
fit2 = hmc([0.36],M=1000)

In [213]:
np.mean(fit2),np.std(fit2)

(0.34214126220486607, 0.22397032206471484)