<a href="https://colab.research.google.com/github/shainedl/Papers-Colab/blob/master/Variational_Principal_Components_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torch
import torch.distributions as tdist

In [0]:
class BayesianPCA():
  
  def __init__(self, a_alpha=10e-3, b_alpha=10e-3, a_tau=10e-3, b_tau=10e-3, beta=10e-3):
    
    # hyperparameters
    self.a_alpha = a_alpha
    self.b_alpha = b_alpha
    self.a_tau = a_tau
    self.b_tau = b_tau
    self.beta = beta 
     
  def __reestimate(self):
    """
    Cycle through the groups of variables in turn to re-estimate each distribution 
    """
    
    # observation parameter
    self.tau = self.a_tau_tilde / self.b_tau_tilde

    # latent variables
    self.sigma_x = torch.inverse(torch.eye(self.q) + self.tau *
                   (torch.trace(self.sigma_w) + torch.mm(self.mean_w.T, self.mean_w)))
    self.mean_x = self.tau * torch.mm(torch.mm(self.sigma_x, self.mean_w.T),(self.t_n - self.mean_mu))
    
    # observation parameter                                
    self.sigma_mu = torch.eye(self.d) / (self.beta + self.N * self.tau)
    w_x = torch.mm(self.mean_w, self.mean_x)
    sum = 0
    for n in range(self.N):
      sum += torch.sub(self.t_n[:,n], w_x[:,n])
    self.mean_mu = self.tau * torch.mm(self.sigma_mu, torch.reshape(sum, (-1,1)))
        
    # hyperparameter controlling the columns of W
    self.alpha = self.a_alpha_tilde / self.b_alpha_tilde
                                     
    # weight                                 
    self.sigma_w = torch.inverse(torch.diag(self.alpha) + self.tau * 
                   (self.N * self.sigma_x + torch.mm(self.mean_x, self.mean_x.t())))
    self.mean_w = (self.tau * torch.mm(self.mean_x, (torch.sub(self.t_n.t(), self.mean_mu.t())))).t()
   
    # alpha's gamma distribution parameter                            
    self.b_alpha_tilde = self.b_alpha + 0.5 * (torch.trace(self.sigma_w) + torch.diag(torch.mm(self.mean_w.t(), self.mean_w)))                                                     
    # tau's gamma distribution parameter     
    self.b_tau_tilde = torch.tensor([self.b_tau + 0.5 * torch.sum(torch.mm(self.t_n.t(), self.t_n)) + \
                       0.5 * self.N * (torch.trace(self.sigma_mu) + torch.dot(self.mean_mu.flatten(), self.mean_mu.flatten()))+ \
                       0.5 * torch.trace(torch.mm(torch.trace(self.sigma_w) + \
                       torch.mm(self.mean_w.t(), self.mean_w), self.N * self.sigma_x + \
                       torch.mm(self.mean_x, self.mean_x.t()))) + \
                       torch.sum(torch.mm(torch.mm(self.mean_mu.t(), self.mean_w), self.mean_x)) - \
                       torch.sum(torch.mm(torch.mm(self.t_n.t(), self.mean_w), self.mean_x)) - \
                       torch.sum(torch.mm(self.t_n.t(), self.mean_mu))])      
    
  def fit(self, t_n, iterations = 1000, threshold = 1.0):
    """
    Fits the data
    
    Parameters 
    ----------
    t_n : d x N matrix
      observed data to be fit
      
    iterations: int
      number of iterations to re-estimate the lower bound
    
    threshold: float
      determines convergence
      
    """
    self.t_n = t_n
    self.d = self.t_n.shape[0]                     
    self.q = self.d - 1
    self.N = self.t_n.shape[1]   
    
    # variational parameters
    self.mean_x = torch.randn(self.q, self.N)
    self.sigma_x = torch.eye(self.q)
    self.mean_mu = torch.randn(self.d, 1)
    self.sigma_mu = torch.eye(self.d)
    self.mean_w = torch.randn(self.d, self.q)
    self.sigma_w = torch.eye(self.q)
    self.a_alpha_tilde = self.a_alpha + self.d / 2
    self.b_alpha_tilde = torch.abs(torch.randn(self.q))  
    self.a_tau_tilde = self.a_tau + self.N * self.d / 2
    self.b_tau_tilde = torch.abs(torch.randn(1))
    
    for i in range(iterations):
      self.__reestimate()

In [0]:
"""
We generate 100 data points in d = 10 dimensions from a Gaussian distribution 
having standard deviations of (5, 4, 3, 2) along four orthogonal directions 
and a standard deviation of 1 in the remaining five directions
"""
m = tdist.multivariate_normal.MultivariateNormal(torch.zeros(10), torch.diag(torch.Tensor([5,4,3,2,1,1,1,1,1,1])))
X = m.sample(sample_shape=torch.Size([100])).t()
"""
Hinton diagram of <W> from variational Bayesian PCA 
"""
test = BayesianPCA()
test.fit(X) 