<a href="https://colab.research.google.com/github/shainedl/Papers-Colab/blob/master/Variational_Principal_Components.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import numpy as np
import scipy.special as sp
from scipy.stats import multivariate_normal 
from scipy.stats import gamma 

In [0]:
class BayesianPCA():
  
  def __init__(self, a_alpha=10e-3, b_alpha=10e-3, a_tau=10e-3, b_tau=10e-3, beta=10e-3):
    
    # hyperparameters
    self.a_alpha = a_alpha
    self.b_alpha = b_alpha
    self.a_tau = a_tau
    self.b_tau = b_tau
    self.beta = beta 
    
    # variational parameters
    self.mean_x = np.random.randn(self.q, self.N)
    self.sigma_x = np.identity(self.q)
    self.mean_mu = np.random.randn(self.d, 1)
    self.sigma_mu = np.identity(self.d)
    self.mean_w = np.random.randn(self.d, self.q)
    self.sigma_w = np.identity(self.q)
    self.a_alpha_tilde = self.a_alpha + self.d / 2
    self.b_alpha_tilde = np.abs(np.random.randn(self.q))
    self.a_tau_tilde = self.a_tau + self.N * self.d / 2
    self.b_tau_tilde = np.abs(np.random.randn(1))
     
  def __reestimate(self):
    
    # observation parameter
    self.tau = self.a_tau_tilde / self.b_tau_tilde

    # latent variables
    self.sigma_x = np.linalg.inv(np.identity(self.q) + self.tau *
                   (np.trace(self.sigma_w) + np.dot(self.mean_w.T, self.mean_w)))
    self.mean_x = self.tau * np.dot(np.dot(self.sigma_x, self.mean_w.T),(self.t_n - self.mean_mu)
    
    # observation parameter                                
    self.sigma_mu = np.identity(self.d) / (self.beta + self.N * self.tau)
    self.mean_mu = self.tau * np.dot(self.sigma_mu, np.sum(self.t_n - np.dot(self.mean_w, self.mean_x))
    
    # hyperparameter controlling the columns of W
    self.alpha = self.a_alpha_tilde / self.b_alpha_tilde
                                     
    # weight                                 
    self.sigma_w = np.linalg.inv(np.diag(self.alpha) + self.tau * 
                   (self.N * self.sigma_x + np.dot(self.mean_x, mean_x.T))
    self.mean_w = self.tau * np.dot(self.sigma_w, np.dot(self.mean_x, (self.t_n - self.mean_mu)))
    
    # alpha's gamma distribution parameter                            
    self.b_alpha_tilde = self.b_alpha + 0.5 * (np.trace(self.sigma_w) + 
                         np.dot(self.mean_w, self.mean_w))
                                                            
    # tau's gamma distribution parameter     
    self.b_tau_tilde = self.b_tau + 0.5 * np.sum(np.dot(self.t_n, self.t_n)) + 
                       0.5 * self.N * (np.trace(self.sigma_mu) + np.dot(self.mean_mu, self.mean_mu)) + 
                       0.5 * np.trace(np.dot(np.trace(self.sigma_w) + 
                             np.dot(self.mean_w.T, self.mean_w), self.N * self.sigma_x + 
                             np.dot(self.mean_x, self.mean_x.T))) +
                       np.sum(np.dot(np.dot(self.mean_mu.T, self.mean_w), self.mean_x)) -
                       np.sum(np.dot(np.dot(t_n.T, self.mean_w), self.mean_x)) -
                       np.sum(np.dot(t_n.T, self.mean_mu))          
  
  def __get_elbo(self):
                                 
    # random sample
    x = np.random.multivariate_normal(self.mean_x, self.cov_x)
    mu = np.random.multivariate_normal(self.mean_mu, self.cov_mu)
    w = np.random.multivariate_normal(self.mean_w, self.cov_w) 
    alpha = np.random.gamma(self.a_alpha_tilde, 1 / self.b_alpha_tilde)     
    tau = np.random.gamma(self.a_tau_tilde, 1 / self.b_tau_tilde)                             
                                 
    # priors
    # p(x) = N(x|0,I_q)
    prior = np.sum(multivariate_normal.logpdf(np.asarray(x).flatten(), np.zeros(self.q), np.identity(self.q)))
      
    # p(w|alpha) = conditional distribution                   
    prior += np.sum(np.asarray((self.d / 2) * (np.log(alpha[i] / (2 * np.pi)) - 0.5 * alpha[i] * np.sum(np.power(w[:,i],2)) for i in range(self.q))                  
                                 
    # p(alpha) = Gamma(a, b)                             
    prior += np.sum(gamma.logpdf(alpha, self.a_alpha, scale=1/self.b_alpha)                             
                                 
    # p(mu) = N(mu|0,Beta^-1I)        
    prior += np.sum(multivariate_normal.logpdf(np.asarray(mu).flatten(), np.zeros(self.d), np.identity(self.d)/self.beta)) 
                    
    # p(tau) = Gamma(c, d)      
    prior += np.sum(gamma.logpdf(tau, self.a_tau, scale=1/self.b_tau) 
        
                    
    # log likelihood of the conditional distribution 
    # p(t_n | x_n, W, mu, tau)
    likelihood = np.sum(multivariate_normal.logpdf(np.asarray(self.t_n).flatten(), np.asarray(np.dot(w, z) + mu).flatten(), np.identity(self.d) / tau))                 
             
                                   
    # entropy
    # q(x) 
    entropy = self.N * (0.5 * np.log(np.linalg.det(self.sigma_x) + (self.d / 2) * (1 + np.log(2 * np.pi)))     
                       
    # q(mu)
    entropy += 0.5 * np.log(np.linalg.det(self.sigma_mu) + (self.d / 2) * (1 + np.log(2 * np.pi)) 
                            
    # q(W)          
    entropy += self.d * (0.5 * np.log(np.linalg.det(self.sigma_w) + (self.d / 2) * (1 + np.log(2 * np.pi)))  
                         
    # q(alpha)
    entropy += self.q * (np.log(sp.gamma(self.a_alpha_tilde)) - (self.a_alpha_tilde - 1)
                        * sp.digamma(self.a_alpha_tilde) - np.log(self.b_alpha_tilde) 
                        + self.a_alpha_tilde)  
                         
    # q(tau)   
    entropy += np.log(sp.gamma(self.a_tau_tilde)) - (self.a_tau_tilde - 1)
                        * sp.digamma(self.a_tau_tilde) - np.log(self.b_tau_tilde) 
                        + self.a_tau_tilde     
                         
    return prior + likelihood - entropy   
                         
  def fit(t_n, iterations = 10):
    self.t_n = t_n
    self.d = self.t_n.shape[0]                     
    self.q = self.d - 1
    self.N = self.t_n.shape[1]   
                         
    for i in range(iterations):
      self.__reestimate()
      elbo = self.__get_elbo()
      print("Iterations: %d", i)                   
      print("ELBO: %d", elbo)                   