In [19]:
import numpy as np
from scipy.optimize import minimize
from scipy.stats import norm
import tqdm 

def fi(mu_i, i, mu_minus_i, sigma, gamma, X, Y, lamb):
    f = mu_i * (X.T @ X)[i, np.arange(X.shape[1])!=i] @ (gamma[np.arange(len(gamma))!=i] * mu_minus_i)
    f += 0.5*(X.T @ X)[i, i] * mu_i**2
    f -= (Y.T @ X)[i] * mu_i
    f += lamb * sigma[i] * np.sqrt(2/np.pi) * np.exp(-mu_i**2/(2*sigma[i]**2))
    f += lamb * mu_i * (1 - 2*norm.cdf(-mu_i/sigma[i]))

    return f


def gi(sigma_i, i, mu, X, lamb):

    g = 0.5 * (X.T @ X)[i, i] * sigma_i**2
    g += lamb * mu[i] * sigma_i * np.sqrt(2/np.pi) * np.exp(-mu[i]**2/(2*sigma_i**2))
    g += lamb * mu[i] * (1 - norm.cdf(mu[i]/sigma_i))
    g += - np.log(sigma_i)
    return g

def Gamma_function(i, mu, sigma, gamma, X, Y, a0, b0, lamb):
    Gamma = np.log(a0/b0)
    Gamma += np.log(np.sqrt(np.pi/2)*sigma[i]*lamb)
    Gamma -= mu[i] * (X.T @ X)[i, np.arange(X.shape[1])!=i] @ (gamma[np.arange(len(gamma))!=i] * mu[np.arange(len(mu))!=i])
    Gamma -= 0.5*(X.T @ X)[i, i] * (mu[i]**2 + sigma[i]**2)
    Gamma += (Y.T @ X)[i] * mu[i]
    Gamma -= lamb * sigma[i] * np.sqrt(2/np.pi) * np.exp(-mu[i]**2/(2*sigma[i]**2))
    Gamma -= lamb * mu[i] * (1 - 2*norm.cdf(-mu[i]/sigma[i]))

    return Gamma + 0.5


def H(p, ent):
    for j in range(len(p)):
        if p[j] < 1 - 1e-10 and p[j] > 1e-10:
            ent[j] = -p[j]*np.log(p[j]) - (1-p[j])*np.log(1-p[j])
    return ent


def mu_0(X, Y):
    _, p = X.shape
    return np.linalg.inv(X.T @ X + np.eye(p)) @ X.T @ Y


def inv_logit(p):
    if p > 0:
        return 1. / (1. + np.exp(-p))
    elif p <= 0:
        return np.exp(p) / (1 + np.exp(p))
    else:
        print("AWPER")
        raise ValueError


def variational_bayes(X, Y, sigma, gamma, mu, a0, b0, lamb, eps=1e-5, max_it=1):
    deltaH = 10
    p = len(mu)
    it = 0
    a = np.argsort(np.abs(mu))
    pbar = tqdm.tqdm(total=max_it)
    ent = H(gamma, np.random.random(p))
    while it < max_it and deltaH >= eps:
        pbar.update(1)
        for j in range(p):
            i = a[j]
            ## update mu_i
            mu_minus_i = mu[np.arange(len(mu))!=i]
            res = minimize(fi, mu[i], args=(i, mu_minus_i, sigma, gamma, X, Y, lamb))
            mu[i] = res.x

            ## update gamma_i
            cons = ({'type': 'ineq', 'fun': lambda x:  x})
            res = minimize(gi, sigma[i], args=(i, mu, X, lamb), constraints=cons)
            sigma[i] = res.x

            ## update gamma
            Gamma = Gamma_function(i, mu, sigma, gamma, X, Y, a0, b0, lamb)
            gamma[i] = inv_logit(Gamma)

        it += 1
        ent_old = np.copy(ent)
        ent = H(gamma, ent)
        deltaH = np.max(np.abs(ent) - np.abs(ent_old))
        print(deltaH)
    pbar.close( )
    return mu, sigma, gamma








In [20]:
import warnings

def fxn():
    warnings.warn("deprecated", DeprecationWarning)

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    fxn()


In [21]:
n, p, s = 100, 200, 20

theta = 10*np.ones(p)
theta[:s] = 0
X = np.random.normal(0, 1, size=(n, p))
Y = X @ theta
mu = mu_0(X, Y)
sigma = 10*np.random.random(p)
gamma = np.random.random(p)
mu, sigma, gamma = variational_bayes(X, Y, sigma, gamma, mu, a0=1, b0=p, lamb=1, eps=1e-5, max_it=1000)

  g += - np.log(sigma_i)
  g += - np.log(sigma_i)
  g += - np.log(sigma_i)
  g += lamb * mu[i] * sigma_i * np.sqrt(2/np.pi) * np.exp(-mu[i]**2/(2*sigma_i**2))
  g += lamb * mu[i] * (1 - norm.cdf(mu[i]/sigma_i))
  g += - np.log(sigma_i)
  g += - np.log(sigma_i)
  g += lamb * mu[i] * sigma_i * np.sqrt(2/np.pi) * np.exp(-mu[i]**2/(2*sigma_i**2))
  g += lamb * mu[i] * (1 - norm.cdf(mu[i]/sigma_i))
  g += - np.log(sigma_i)
  g += - np.log(sigma_i)
  g += lamb * mu[i] * sigma_i * np.sqrt(2/np.pi) * np.exp(-mu[i]**2/(2*sigma_i**2))
  g += lamb * mu[i] * (1 - norm.cdf(mu[i]/sigma_i))
  g += - np.log(sigma_i)
  g += - np.log(sigma_i)
  g += - np.log(sigma_i)
  g += lamb * mu[i] * sigma_i * np.sqrt(2/np.pi) * np.exp(-mu[i]**2/(2*sigma_i**2))
  g += lamb * mu[i] * (1 - norm.cdf(mu[i]/sigma_i))
  g += - np.log(sigma_i)
  g += - np.log(sigma_i)
  g += lamb * mu[i] * sigma_i * np.sqrt(2/np.pi) * np.exp(-mu[i]**2/(2*sigma_i**2))
  g += lamb * mu[i] * (1 - norm.cdf(mu[i]/sigma_i))
  g += - np.log(sigm

0.5108917659449239


  0%|          | 3/1000 [00:12<1:07:02,  4.03s/it]

0.05474043554878347


  g += - np.log(sigma_i)
  g += - np.log(sigma_i)
  0%|          | 4/1000 [00:17<1:14:00,  4.46s/it]

0.34366371843237714


  g += lamb * mu[i] * sigma_i * np.sqrt(2/np.pi) * np.exp(-mu[i]**2/(2*sigma_i**2))
  g += lamb * mu[i] * (1 - norm.cdf(mu[i]/sigma_i))
  g += - np.log(sigma_i)
  0%|          | 5/1000 [00:23<1:19:15,  4.78s/it]

0.10804762785229861


  1%|          | 6/1000 [00:28<1:22:10,  4.96s/it]

0.6356028418726719


  g += - np.log(sigma_i)
  g += lamb * mu[i] * sigma_i * np.sqrt(2/np.pi) * np.exp(-mu[i]**2/(2*sigma_i**2))
  g += lamb * mu[i] * (1 - norm.cdf(mu[i]/sigma_i))
  g += - np.log(sigma_i)
  g += lamb * mu[i] * sigma_i * np.sqrt(2/np.pi) * np.exp(-mu[i]**2/(2*sigma_i**2))
  g += lamb * mu[i] * (1 - norm.cdf(mu[i]/sigma_i))
  g += - np.log(sigma_i)
  g += lamb * mu[i] * sigma_i * np.sqrt(2/np.pi) * np.exp(-mu[i]**2/(2*sigma_i**2))
  g += lamb * mu[i] * (1 - norm.cdf(mu[i]/sigma_i))
  g += - np.log(sigma_i)
  1%|          | 7/1000 [00:34<1:24:54,  5.13s/it]

0.14175820298852831


  1%|          | 8/1000 [00:39<1:27:11,  5.27s/it]

0.6926243114419135


  g += - np.log(sigma_i)
  g += lamb * mu[i] * sigma_i * np.sqrt(2/np.pi) * np.exp(-mu[i]**2/(2*sigma_i**2))
  g += lamb * mu[i] * (1 - norm.cdf(mu[i]/sigma_i))
  g += - np.log(sigma_i)
  1%|          | 9/1000 [00:45<1:26:32,  5.24s/it]

0.0010456744918219346


  1%|          | 10/1000 [00:49<1:24:12,  5.10s/it]

0.00401599607339804


  1%|          | 11/1000 [00:55<1:24:29,  5.13s/it]

0.0002535740593255431


  1%|          | 12/1000 [01:00<1:24:30,  5.13s/it]

0.00011360491037955365


  1%|▏         | 13/1000 [01:04<1:21:56,  4.98s/it]

0.000400072899100707


  1%|▏         | 14/1000 [01:09<1:20:33,  4.90s/it]

0.0012371167764277098


  2%|▏         | 15/1000 [01:14<1:20:09,  4.88s/it]

0.0033373898502422304


  2%|▏         | 16/1000 [01:19<1:20:18,  4.90s/it]

0.007983504466992246


  2%|▏         | 17/1000 [01:23<1:18:51,  4.81s/it]

0.017553858924931325


  2%|▏         | 18/1000 [01:28<1:19:13,  4.84s/it]

0.037490122864236905


  2%|▏         | 19/1000 [01:33<1:20:25,  4.92s/it]

0.08465333899250792


  2%|▏         | 20/1000 [01:38<1:20:33,  4.93s/it]

0.22310247257240626


  2%|▏         | 21/1000 [01:43<1:19:20,  4.86s/it]

0.2779236654201525


  2%|▏         | 22/1000 [01:48<1:17:06,  4.73s/it]

0.00024307077938602378


  2%|▏         | 23/1000 [01:53<1:19:37,  4.89s/it]

0.001073658938856398


  2%|▏         | 24/1000 [01:59<1:24:06,  5.17s/it]

0.0023990473774281058


  2%|▎         | 25/1000 [02:08<1:44:41,  6.44s/it]

0.004264090296309442


  3%|▎         | 26/1000 [02:15<1:47:27,  6.62s/it]

0.007392569820392233


  3%|▎         | 27/1000 [02:20<1:40:19,  6.19s/it]

0.012776525571767654


  3%|▎         | 28/1000 [02:26<1:39:28,  6.14s/it]

0.022467075329698893


  3%|▎         | 29/1000 [02:32<1:37:30,  6.03s/it]

0.04161480463498149


  3%|▎         | 30/1000 [02:38<1:37:26,  6.03s/it]

0.08574603731653374


  3%|▎         | 31/1000 [02:44<1:35:13,  5.90s/it]

0.20826691792760682


  3%|▎         | 32/1000 [02:51<1:41:11,  6.27s/it]

0.29359822433369354


  g += lamb * mu[i] * sigma_i * np.sqrt(2/np.pi) * np.exp(-mu[i]**2/(2*sigma_i**2))
  g += lamb * mu[i] * (1 - norm.cdf(mu[i]/sigma_i))
  g += - np.log(sigma_i)
  3%|▎         | 33/1000 [02:57<1:42:08,  6.34s/it]

0.0005876579279427063


  3%|▎         | 34/1000 [03:04<1:41:38,  6.31s/it]

0.0015706711321482906


  4%|▎         | 35/1000 [03:10<1:41:43,  6.33s/it]

0.010150117003277184


  4%|▎         | 36/1000 [03:16<1:38:59,  6.16s/it]

0.025590466973717356


  4%|▎         | 37/1000 [03:21<1:36:32,  6.02s/it]

0.05867972787004116


  4%|▍         | 38/1000 [03:28<1:38:16,  6.13s/it]

0.14557172750314465


  4%|▍         | 39/1000 [03:34<1:37:16,  6.07s/it]

0.3710232676870468


  g += lamb * mu[i] * sigma_i * np.sqrt(2/np.pi) * np.exp(-mu[i]**2/(2*sigma_i**2))
  g += lamb * mu[i] * (1 - norm.cdf(mu[i]/sigma_i))
  g += - np.log(sigma_i)
  4%|▍         | 40/1000 [03:39<1:32:57,  5.81s/it]

0.06321554888309755


  4%|▍         | 41/1000 [03:44<1:30:32,  5.66s/it]

0.1779805408895323


  4%|▍         | 42/1000 [03:49<1:28:02,  5.51s/it]

0.41756985439047384


  4%|▍         | 43/1000 [03:55<1:26:20,  5.41s/it]

0.00013394482750565557


  4%|▍         | 44/1000 [04:00<1:28:29,  5.55s/it]

0.0002306176862396309


  4%|▍         | 45/1000 [04:06<1:29:34,  5.63s/it]

7.119473306484419e-05


  5%|▍         | 46/1000 [04:11<1:26:05,  5.41s/it]

3.225310881347851e-05


  5%|▍         | 47/1000 [04:17<1:25:36,  5.39s/it]

1.8836530044166917e-05


  5%|▍         | 48/1000 [04:21<1:22:10,  5.18s/it]

1.0393247140976988e-05


  5%|▍         | 48/1000 [04:27<1:28:19,  5.57s/it]

5.8650242976748465e-06





In [22]:
np.unique(gamma == 1, return_counts=True)

(array([False,  True]), array([ 20, 180]))

In [25]:
theta = gamma * np.random.normal(mu, sigma)

In [31]:
(X @ theta)[:10]

array([ 142.79439414,   74.7456955 ,   91.51712976,  244.39949731,
       -162.30024156,  -54.54658436,  -21.93495754,  -99.01473075,
        238.943847  ,  120.06397042])

In [32]:
Y[:10]

array([ 142.84459278,   76.49698088,   91.57275278,  243.28727912,
       -163.16451502,  -52.60007252,  -22.5311845 ,  -97.93486912,
        239.14541256,  119.37570615])