In [1]:
import pylab as pb
import numpy as np
import random
from matplotlib import pyplot as plt
from numpy.random import multinomial
from scipy.stats import bernoulli
from math import pi
import itertools as it

In [273]:
np.random.seed(2)

#K = Length of sequence 
K = 10
#N = number of players
N = 1
#p = Probability observation is observed
p = 0.3

group1Probs = [0.2, 0.2, 0.3, 0.1, 0.1, 0.1]
group2Probs = [0.1, 0.1, 0.1, 0.3, 0.2, 0.2]

#Dice through
def roll(groupNum=0):
    diceRoll = 0
    #Diceroll if player is at first group of tables
        
    #Random Distribution (default)
    if (groupNum == 0):
        dist = np.random.randn(6)**2
        dist = dist/dist.sum()
        x = np.random.multinomial(1, dist, size=1)
        diceRoll = np.flatnonzero(x) + 1
        
        
    if (groupNum == 1):
        x = np.random.multinomial(1, group1Probs, size=1)
        diceRoll = np.flatnonzero(x) + 1
    #Diceroll if player is at second group of tables
    elif (groupNum == 2):
        x = np.random.multinomial(1, group2Probs, size=1)
        diceRoll = np.flatnonzero(x) + 1
    
    return diceRoll

class Player(object):
    #Initialize player with empty vector of hiddenthrows and array of
    #all throws (hidden throws == 0)
    def __init__(self, tables):
        self.hiddenThrows = []
        self.Throws = np.zeros(tables)
        self.allThrows = np.zeros(tables)
        self.sum = 0
        #Randomly starts in either group of tables
        self.groupNum = bernoulli.rvs(0.5) + 1
        
        #Hidden throws for group 1 and 2
        self.hidden1 = []
        self.hidden2 = []
        
        #True state sequence
        self.sequence = np.zeros(tables)
        
    #Dice throw. Throw is observed according to probability p
    def Throw(self, throwNum, random=0):
        observed = bernoulli.rvs(p)
        
        #Roll dice according to which group of tables player is at
        if (random):
            throw = roll()
        else:
            throw = roll(self.groupNum)
        self.sequence[throwNum] = self.groupNum
        
        #Add roll to the player sum and to obersevation sequence
        self.sum += throw
        self.allThrows[throwNum] = throw
        
        if (observed):
            self.Throws[throwNum] = throw
        else:
            self.hiddenThrows.append(throw)
            if (self.groupNum == 1):
                self.hidden1.append(throw)
            elif (self.groupNum == 2):
                self.hidden2.append(throw)
         
    #Return the sum of all throws
    def SumThrows(self):
        return self.sum
    
    
players = []
group1 = np.zeros(0)
group2 = np.zeros(0)
hidden1 = np.zeros(0)
hidden2 = np.zeros(0)
#Loop through each player
for i in range(N):
    player = Player(K)
    players.append(player)
    #Loop through each table
    switches = 0
    for j in range(K):
        player.Throw(j)
        #Change the table group. 25% chance of staying in same group, 
        #75% chance of switching group
        switch = bernoulli.rvs(.75)
        if (switch):
            player.groupNum = 3 - player.groupNum
            switches += 1
      
    
    print ("sum", player.sum)
    print ("Observation Sequence", player.Throws)
    print ("All throws", player.allThrows)
    print ("Sequence", player.sequence)
    group1 = np.append(group1, player.allThrows[np.where(player.sequence == 1)])
    group2 = np.append(group2, player.allThrows[np.where(player.sequence == 2)])
    hidden1 = np.append(hidden1, player.hidden1)
    hidden2 = np.append(hidden2, player.hidden2)

noswitch = (N*K - switches)
    
print ("Group1 Throws", group1)
print ("Group2 Trhows", group2)

print(switches)
print(noswitch)   
#print (np.zeros(0))    
    
        
    


sum [42]
Observation Sequence [ 0.  0.  1.  6.  0.  0.  0.  5.  4.  0.]
All throws [ 6.  5.  1.  6.  6.  4.  2.  5.  4.  3.]
Sequence [ 1.  2.  1.  2.  1.  2.  1.  2.  1.  2.]
Group1 Throws [ 6.  1.  6.  2.  4.]
Group2 Trhows [ 5.  6.  4.  5.  3.]
9
1


In [None]:
#Distribution:
#x = np.random.multinomial(1, [1/2.] + [0.]*4 + [1/2.], size=1)
#Group 2: x = np.random.multinomial(1, [1/10.]*5 + [1/2.], size=1)
#K = 100
#Players = 10
#p = 0.5

plt.subplot(1,2,1)
n, bins, patches = plt.hist(group1, bins=np.arange(1,8) - 0.5, width=0.95)
plt.xlim(1-0.5, 6+0.5)
plt.xticks(range(1,7))
plt.xlabel("Outcome")
plt.ylabel("Number of Rolls")
plt.title("Group 1 All Rolls")
plt.tight_layout()

plt.subplot(1,2,2)
n, bins, patches = plt.hist(group2, bins=np.arange(1,8) - 0.5, width=0.95)
plt.xlim(1-0.5, 6+0.5)
plt.xticks(range(1,7))
plt.xlabel("Outcome")
plt.ylabel("Number of Rolls")
plt.title("Group 2 All Rolls")
plt.tight_layout()

print(len(group1))
print(len(group2))

labels = 'Group1','Group2'
sizes = [len(group1), len(group2)]
fig1, ax1 = plt.subplots()
ax1.pie(sizes, labels=labels, autopct='%1.1f%%',
        shadow=True, startangle=90)
ax1.axis('equal')  # Equal aspect ratio ensures that pie is drawn as a circle.
plt.title('Fraction of Rolls in each Group')
plt.show()


labels = 'Switch', 'No Switch'
sizes = [switches, noswitch]
fig1, ax1 = plt.subplots()
ax1.pie(sizes, labels=labels, autopct='%1.1f%%',
        shadow=True, startangle=90)
ax1.axis('equal')  # Equal aspect ratio ensures that pie is drawn as a circle.
plt.title('Percentage of Switches Between Groups')
plt.show()



In [None]:
%matplotlib inline
#Distribution:
#Group 1: x = np.random.multinomial(1, [1/3.] + [0.] + [1/3.] + [0.] + [1/3.] + [0.], size=1)
#Group 2: x = np.random.multinomial(1, [1/6.]*6, size=1)
#K = 1
#Players = 10000
#Probability hidden = 0.5
plt.subplot(2,2,1)
n, bins, patches = plt.hist(group1, bins=np.arange(1,8) - 0.5, width=0.95)
plt.xlim(1-0.5, 6+0.5)
plt.xticks(range(1,7))
plt.xlabel("Outcome")
plt.ylabel("Number of Rolls")
plt.title("Group 1 All Throws")
plt.tight_layout()

plt.subplot(2,2,2)
n, bins, patches = plt.hist(group2, bins=np.arange(1,8) - 0.5, width=0.95)
plt.xlim(1-0.5, 6+0.5)
plt.xticks(range(1,7))
plt.xlabel("Outcome")
plt.ylabel("Number of Rolls")
plt.title("Group 2 All Throws")
plt.tight_layout()

plt.subplot(2,2,3)
n, bins, patches = plt.hist(hidden1, bins=np.arange(1,8) - 0.5, width=0.95)
plt.xlim(1-0.5, 6+0.5)
plt.xticks(range(1,7))
plt.xlabel("Outcome")
plt.ylabel("Number of Rolls")
plt.title("Group 1 Hidden Throws")
plt.tight_layout()

plt.subplot(2,2,4)
n, bins, patches = plt.hist(hidden2, bins=np.arange(1,8) - 0.5, width=0.95)
plt.xlim(1-0.5, 6+0.5)
plt.xticks(range(1,7))
plt.xlabel("Outcome")
plt.ylabel("Number of Rolls")
plt.title("Group 2 Hidden Throws")
plt.tight_layout()

In [None]:

#Distribution:
#2K Distributions
#K = 1000
#Players = 1
#Probability hidden = 0.5

plt.subplot(1,2,1)
n, bins, patches = plt.hist(group1, bins=np.arange(1,8) - 0.5, width=0.95)
plt.xlim(1-0.5, 6+0.5)
plt.xticks(range(1,7))
plt.xlabel("Outcome")
plt.ylabel("Number of Rolls")
plt.title("Group 1")
plt.tight_layout()

plt.subplot(1,2,2)
n, bins, patches = plt.hist(group2, bins=np.arange(1,8) - 0.5, width=0.95)
plt.xlim(1-0.5, 6+0.5)
plt.xticks(range(1,7))
plt.xlabel("Outcome")
plt.ylabel("Number of Rolls")
plt.title("Group 2")
plt.tight_layout()

In [None]:
bernoulli.rvs(p)

# Sum - HMM Smoothing

Here an algorithm to compute $p(X_k^n = s, Z_k = t_k^i|s^n,x^n)$ is presented. We will define an alpha variable for the forward pass as, $\alpha(Z_k) = p(Z_k, s_k^n, x_{1:k}^n)$. We will then perform two alpha passes starting from the final table with its observed partial sum and complete observation sequence. One will be a straightforward pass whereas in the second we will set the observation to $s$ and the hidden state to $t_k^i$ when we get to table k. We will recurse until we reach a base state defined by, $\alpha(Z_1) = \sum_{X_1^n}p(Z_1)p(X_1^n|Z_1)$. 

In [282]:
def alphaVar(table, group, partialSum, obSeq):
    #Define transition probabilities based on current state
    if (group == 1):
        transition = [.25, .75]
        groupProbs = group1Probs
    else:
        transition = [.75, .25]
        groupProbs = group2Probs
        
    #Base Case 
    if (table == 0):
        #If unobserved return the probability of the partial sum at that state
        if (obSeq[table] == 0):
            return 0.5*groupProbs[int(partialSum)-1]
        else:
            #Ensure that observed roll at first state is equal to the partial sum
            if (obSeq[table]==partialSum):
                return 0.5*groupProbs[int(obSeq[table]-1)]
            #If observed roll is not the same as first partial sum, probability is zero
            else:
                return 0
    else:
        #Roll at table is unobserved
        if (obSeq[table] == 0):
            alpha = 0
            #Loop through all possible rolls
            for i in range(6):
                #Calculate previous sum based on current roll and partial sum
                prevSum = partialSum - (i+1)
                #Check to make sure the previous partial sum is possible
                if ((prevSum < (table)) or (prevSum > (6*table))):
                    continue
                #Recurse
                else:
                    #Emission probability for that roll
                    emissionProb = groupProbs[i]
                    alpha += emissionProb*alphaVar(table-1, 1, prevSum, obSeq)*transition[0]
                    alpha += emissionProb*alphaVar(table-1, 2, prevSum, obSeq)*transition[1]
            return alpha
        #Roll at table is observed            
        else: 
            alpha = 0
            #Extract observed roll and calculate previous sum 
            roll = int(obSeq[table])
            prevSum = partialSum - (roll)
            #Make sure previous sum is possible, if not we can return 0 and prune that path
            if ((prevSum < (table)) or (prevSum > (6*table))):
                return alpha
            #Recurse
            emissionProb = groupProbs[roll - 1]
            alpha += emissionProb*alphaVar(table-1, 1, prevSum, obSeq)*transition[0]
            alpha += emissionProb*alphaVar(table-1, 2, prevSum, obSeq)*transition[1]
            return alpha
        
        
def alphaPass(table, group, partialSum, obSeq, obs, state, k):
    if (group == 1):
        transition = [.25, .75]
        groupProbs = group1Probs
    else:
        transition = [.75, .25]
        groupProbs = group2Probs
    
    #Make sure to only consider when Z_k=t_k^i
    if (table == k):
        if ((state == 1 and group == 2) or (state == 2 and group == 1)):
            return 0
    
    if (table == 0): 
        if (obSeq[table] == 0):
            if (table==k and partialSum != obs):
                return 0
            else:
                return 0.5*groupProbs[int(partialSum)-1]
        else:
            if (obSeq[table]==partialSum):
                return 0.5*groupProbs[int(obSeq[table]-1)]
            else:
                return 0
    else:
        if (table == k):
            prevSum = partialSum - obs
            alpha = 0
            if ((prevSum < (table)) or (prevSum > (6*table))):
                return alpha
            if (state == 1):
                emissionProbs = group1Probs[obs-1]
            else:
                emissionProbs = group2Probs[obs-1]
            alpha += emissionProbs*alphaPass(table-1, 1, prevSum, obSeq, obs, state, k)*transition[0]
            alpha += emissionProbs*alphaPass(table-1, 2, prevSum, obSeq, obs, state, k)*transition[1]
            return alpha
        else:
            if (obSeq[table] == 0):
                alpha = 0
                for i in range(6):
                    prevSum = partialSum - (i+1)
                    #Check to make sure the previous partial sum is possible
                    if ((prevSum < table) or (prevSum > (6*table))):
                        continue
                    else:
                        emissionProb = groupProbs[i]
                        alpha += emissionProb*alphaPass(table-1, 1, prevSum, obSeq, obs, state, k)*transition[0]
                        alpha += emissionProb*alphaPass(table-1, 2, prevSum, obSeq, obs, state, k)*transition[1]
                return alpha

            else: 
                alpha = 0
                roll = int(obSeq[table])
                prevSum = partialSum - (roll)
                if ((prevSum < table) or (prevSum > (6*table))):
                    return alpha
                emissionProb = groupProbs[roll - 1]
                alpha += emissionProb*alphaPass(table-1, 1, prevSum, obSeq, obs, state, k)*transition[0]
                alpha += emissionProb*alphaPass(table-1, 2, prevSum, obSeq, obs, state, k)*transition[1]
                return alpha
            
            
            
   
def calcProb(player, state, obs, k):
    if (player.Throws[k] != 0 and player.Throws[k] != obs):
        return 0
    else:
        num = 0
        denom = 0
        #print(alphaVar(len(player.Throws)-1, 1, player.sum, player.Throws))
        #print(alphaVar(len(player.Throws)-1, 2, player.sum, player.Throws))
        denom += alphaVar(len(player.Throws)-1, 1, player.sum, player.Throws)
        denom += alphaVar(len(player.Throws)-1, 2, player.sum, player.Throws)
        #print(alphaPass(len(player.Throws)-1, 1, player.sum, player.Throws, obs, state, k))
        #print(alphaPass(len(player.Throws)-1, 2, player.sum, player.Throws, obs, state, k))
        num += alphaPass(len(player.Throws)-1, 1, player.sum, player.Throws, obs, state, k)
        num += alphaPass(len(player.Throws)-1, 2, player.sum, player.Throws, obs, state, k)
        return num/denom



print(players[0].Throws)
print(players[0].allThrows)
print(players[0].sequence)
print(group1Probs)
print(group2Probs)
print(players[0].sum)
calcProb(players[0], 2, 4, 7)



[ 0.  0.  1.  6.  0.  0.  0.  5.  4.  0.]
[ 6.  5.  1.  6.  6.  4.  2.  5.  4.  3.]
[ 1.  2.  1.  2.  1.  2.  1.  2.  1.  2.]
[0.2, 0.2, 0.3, 0.1, 0.1, 0.1]
[0.1, 0.1, 0.1, 0.3, 0.2, 0.2]
[42]


0

In [49]:
np.array(group1Probs)
np.sum(np.array(group1Probs))
print(range(6))
for i in range(6):
    print(i)

range(0, 6)
0
1
2
3
4
5


In [None]:
obSeq = players[0].Throws
print(obSeq)
print(obSeq[9])
print(players[0].allThrows)
group2Probs[int(obSeq[49]-1)]

In [None]:
import scipy.stats as stats
from scipy.special import gamma

## Generate Data

In [None]:
%matplotlib inline
np.random.seed(7)

#Parameters of normal distribution
u0 = 0
tau = 1

#Generate Data
N = 10
X = np.random.normal(u0, 1/tau, N)



xbar = np.mean(X)
sumSq = np.sum((X-xbar)**2)
sumx = np.sum(X)
sumxsq = np.sum(X**2)



#Initial guesses (small positive values to give broad prior distributions indicating ignorance about the prior distributions) 
a_n = .1
b_n = .1
mu_n = 0.1
lambda_n = .1

#Exact posterior parameter vals
aN = a_n + N/2
bN = b_n + 1/2*sumSq + (lambda_n*N*(xbar - mu_n)**2)/(2*(lambda_n + N))
muN = (lambda_n*mu_n + N*xbar)/(lambda_n + N)
lambdaN = lambda_n + N



muVals = np.linspace(-1,1,100)
lamVals = np.linspace(0.1,2,100)



def plot_pdf(pdf, X, Y, color):
    plt.contour(X, Y, pdf, colors=color)
    plt.xlabel("$\mu$")
    plt.ylabel("$\lambda$")
    
def gaussian_gamma(mu, lam, a, b, u, l):
    pdf_gamma = stats.gamma.pdf(l, a=a, scale=1/b)
    pdf_norm = stats.norm.pdf(u, loc=mu, scale=((l*lam)**(-0.5)))
    return pdf_gamma*pdf_norm

def gauss_gam_pdf(mu, lam, a, b, X, Y):
    return [[gaussian_gamma(mu, lam, a, b, u, l) for u in X]for l in Y]

def gaussian_gammaVI(mu, lam, a, b, u, l):
    pdf_gamma = stats.gamma.pdf(l, a=a, scale=1/b)
    pdf_norm = stats.norm.pdf(u, loc=mu, scale=((lam)**(-0.5)))
    return pdf_gamma*pdf_norm

def gauss_gam_pdfVI(mu, lam, a, b, X, Y):
    return [[gaussian_gammaVI(mu, lam, a, b, u, l) for u in X]for l in Y]

posterior_pdf = gauss_gam_pdf(muN, lambdaN, aN, bN, muVals, lamVals)

lambda_nt = lambda_n
mu_nt = mu_n
a_nt = a_n
b_nt = b_n

it = 0
maxIt = 10
while (it < maxIt):
    oldLq = .5*np.log(1/lambda_nt) + np.log(gamma(a_nt)) - a_nt*np.log(b_nt)
    
    VI_pdf = gauss_gam_pdfVI(mu_nt, lambda_nt, a_nt, b_nt, muVals, lamVals)
    
    mu_nt = (lambda_n*mu_n + N*xbar) / (lambda_n + N)
    a_nt = a_n + (N + 1)/2
    b_nt = b_n + 0.5*((lambda_n + N)*((1/lambda_nt) + mu_nt**2) - 
            2*mu_nt*(lambda_n*mu_n + sumx) + sumxsq + lambda_n*mu_n**2)
    lambda_nt = (lambda_n + N)*(a_nt/b_nt)
    
    newLq = .5*np.log(1/lambda_nt) + np.log(gamma(a_nt)) - a_nt*np.log(b_nt)
    if (abs(newLq - oldLq) < .1):
        plt.subplot(2,2,it)
        plot_pdf(posterior_pdf, muVals, lamVals, "green")
        plot_pdf(VI_pdf, muVals, lamVals, "red")
        plt.tight_layout()
        break
    if (it > 0):
        plt.subplot(2,2,it)
        plot_pdf(posterior_pdf, muVals, lamVals, "green")
        plot_pdf(VI_pdf, muVals, lamVals, "blue")
        plt.tight_layout()

        #plt.show()
    
    it += 1


#plt.scatter(W[0], W[1], s=100, marker="x")
#plt.contourf(w0,w1,mv_norm.pdf(w, mu, S0))


In [None]:
%matplotlib inline
np.random.seed(2)

#Parameters of normal distribution
u0 = 0
tau = 1

#Generate Data
N = 25
X = np.random.normal(u0, 1/tau, N)



xbar = np.mean(X)
sumSq = np.sum((X-xbar)**2)
sumx = np.sum(X)
sumxsq = np.sum(X**2)



#Initial guesses (small positive values to give broad prior distributions indicating ignorance about the prior distributions) 
a_n = 2
b_n = 3
mu_n = 5
lambda_n = 1

#Exact posterior parameter vals
aN = a_n + N/2
bN = b_n + 1/2*sumSq + (lambda_n*N*(xbar - mu_n)**2)/(2*(lambda_n + N))
muN = (lambda_n*mu_n + N*xbar)/(lambda_n + N)
lambdaN = lambda_n + N



muVals = np.linspace(-1,1,100)
lamVals = np.linspace(0.1,2,100)



def plot_pdf(pdf, X, Y, color):
    plt.contour(X, Y, pdf, colors=color)
    plt.xlabel("$\mu$")
    plt.ylabel("$\lambda$")
    
def gaussian_gamma(mu, lam, a, b, u, l):
    pdf_gamma = stats.gamma.pdf(l, a=a, scale=1/b)
    pdf_norm = stats.norm.pdf(u, loc=mu, scale=((l*lam)**(-0.5)))
    return pdf_gamma*pdf_norm

def gauss_gam_pdf(mu, lam, a, b, X, Y):
    return [[gaussian_gamma(mu, lam, a, b, u, l) for u in X]for l in Y]

def gaussian_gammaVI(mu, lam, a, b, u, l):
    pdf_gamma = stats.gamma.pdf(l, a=a, scale=1/b)
    pdf_norm = stats.norm.pdf(u, loc=mu, scale=((lam)**(-0.5)))
    return pdf_gamma*pdf_norm

def gauss_gam_pdfVI(mu, lam, a, b, X, Y):
    return [[gaussian_gammaVI(mu, lam, a, b, u, l) for u in X]for l in Y]

posterior_pdf = gauss_gam_pdf(muN, lambdaN, aN, bN, muVals, lamVals)

lambda_nt = lambda_n
mu_nt = mu_n
a_nt = a_n
b_nt = b_n

it = 0
maxIt = 10 
while (it < maxIt):
    oldLq = .5*np.log(1/lambda_nt) + np.log(gamma(a_nt)) - a_nt*np.log(b_nt)
    
    VI_pdf = gauss_gam_pdfVI(mu_nt, lambda_nt, a_nt, b_nt, muVals, lamVals)
    
    mu_nt = (lambda_n*mu_n + N*xbar) / (lambda_n + N)
    a_nt = a_n + (N + 1)/2
    b_nt = b_n + 0.5*((lambda_n + N)*((1/lambda_nt) + mu_nt**2) - 
            2*mu_nt*(lambda_n*mu_n + sumx) + sumxsq + lambda_n*mu_n**2)
    lambda_nt = (lambda_n + N)*(a_nt/b_nt)
    
    newLq = .5*np.log(1/lambda_nt) + np.log(gamma(a_nt)) - a_nt*np.log(b_nt)
    if (abs(newLq - oldLq) < .1):
        plt.subplot(1,3,it)
        plot_pdf(posterior_pdf, muVals, lamVals, "green")
        plot_pdf(VI_pdf, muVals, lamVals, "red")
        plt.tight_layout()
        break
        
    if (it > 0):
        plt.subplot(1,3,it)
        plot_pdf(posterior_pdf, muVals, lamVals, "green")
        plot_pdf(VI_pdf, muVals, lamVals, "blue")
        plt.tight_layout()

        #plt.show()
    it += 1

In [None]:
%matplotlib inline
np.random.seed(11)

#Parameters of normal distribution
u0 = 0
tau = 1

#Generate Data
N = 100
X = np.random.normal(u0, 1/tau, N)



xbar = np.mean(X)
sumSq = np.sum((X-xbar)**2)
sumx = np.sum(X)
sumxsq = np.sum(X**2)



#Initial guesses (small positive values to give broad prior distributions indicating ignorance about the prior distributions) 
a_n = 2
b_n = 1
mu_n = 5
lambda_n = 1

#Exact posterior parameter vals
aN = a_n + N/2
bN = b_n + 1/2*sumSq + (lambda_n*N*(xbar - mu_n)**2)/(2*(lambda_n + N))
muN = (lambda_n*mu_n + N*xbar)/(lambda_n + N)
lambdaN = lambda_n + N



muVals = np.linspace(-1,1,100)
lamVals = np.linspace(0.1,2,100)



def plot_pdf(pdf, X, Y, color):
    plt.contour(X, Y, pdf, colors=color)
    plt.xlabel("$\mu$")
    plt.ylabel("$\lambda$")
    
def gaussian_gamma(mu, lam, a, b, u, l):
    pdf_gamma = stats.gamma.pdf(l, a=a, scale=1/b)
    pdf_norm = stats.norm.pdf(u, loc=mu, scale=((l*lam)**(-0.5)))
    return pdf_gamma*pdf_norm

def gauss_gam_pdf(mu, lam, a, b, X, Y):
    return [[gaussian_gamma(mu, lam, a, b, u, l) for u in X]for l in Y]

def gaussian_gammaVI(mu, lam, a, b, u, l):
    pdf_gamma = stats.gamma.pdf(l, a=a, scale=1/b)
    pdf_norm = stats.norm.pdf(u, loc=mu, scale=((lam)**(-0.5)))
    return pdf_gamma*pdf_norm

def gauss_gam_pdfVI(mu, lam, a, b, X, Y):
    return [[gaussian_gammaVI(mu, lam, a, b, u, l) for u in X]for l in Y]

posterior_pdf = gauss_gam_pdf(muN, lambdaN, aN, bN, muVals, lamVals)

lambda_nt = lambda_n
mu_nt = mu_n
a_nt = a_n
b_nt = b_n

it = 0
maxIt = 10 
while (it < maxIt):
    oldLq = .5*np.log(1/lambda_nt) + np.log(gamma(a_nt)) - a_nt*np.log(b_nt)
    
    VI_pdf = gauss_gam_pdfVI(mu_nt, lambda_nt, a_nt, b_nt, muVals, lamVals)
    
    mu_nt = (lambda_n*mu_n + N*xbar) / (lambda_n + N)
    a_nt = a_n + (N + 1)/2
    b_nt = b_n + 0.5*((lambda_n + N)*((1/lambda_nt) + mu_nt**2) - 
            2*mu_nt*(lambda_n*mu_n + sumx) + sumxsq + lambda_n*mu_n**2)
    lambda_nt = (lambda_n + N)*(a_nt/b_nt)
    
    newLq = .5*np.log(1/lambda_nt) + np.log(gamma(a_nt)) - a_nt*np.log(b_nt)
    if (abs(newLq - oldLq) < .1):
        plt.subplot(1,3,it)
        plot_pdf(posterior_pdf, muVals, lamVals, "green")
        plot_pdf(VI_pdf, muVals, lamVals, "red")
        plt.tight_layout()
        break
        
    if (it > 0):
        plt.subplot(1,3,it)
        plot_pdf(posterior_pdf, muVals, lamVals, "green")
        plot_pdf(VI_pdf, muVals, lamVals, "blue")
        plt.tight_layout()

        #plt.show()
    it += 1