In [1]:
THRESHOLD = 0.0001

In [2]:
import sys
import numpy as np
from scipy.stats import multivariate_normal as mn
from scipy.linalg import norm
import pprint
import time
import copy
import random

import matplotlib.pyplot as plt
import seaborn as sea

In [3]:
def kpp_init(docs, mu):
    num_docs = len(docs)
    K = len(mu)
    index = np.random.choice(num_docs)
    mu[0,:] = docs[index,:]
    
    for k in range(1,K):
        min_dist_array= []
        for i in range(num_docs):
            min_dist = norm(docs[i,:] - mu[0,:])
            for j in range(i,k):
                new_dist = norm(docs[i,:] - mu[j,:])
                if new_dist < min_dist:
                    min_dist = new_dist
            min_dist_array.append(min_dist)
        index = np.random.choice(num_docs, p=min_dist_array/np.sum(min_dist_array))
        mu[k,:] = docs[index,:]

def e_step(docs, prior, mu, sigma):
    num_docs = len(docs)
    K = len(mu)
    resp = np.zeros(shape=[num_docs, K]);
    for i in range(num_docs):
        # fill in the raw values for the row corresponding to document i
        for k in range(K):
            prob = mn.pdf(x=docs[i,:],
                          mean=mu[k,:],
                          cov=sigma[k,:,:])
            resp[i,k] = prior[k] * prob
        # normalize each row
        row_sum = np.sum(resp[i,:])
        resp[i,:] /= row_sum
    
    return resp

def m_step(resp, docs):
    num_docs = len(docs)
    K = len(resp[0,:])
    dim = len(docs[0,:])
    
    prior = np.zeros(K)
    mu = np.zeros(shape=[K,dim])
    sigma = np.zeros(shape=[K, dim, dim])
    
    for k in range(K):
        col_sum = np.sum(resp[:,k])
        
        # recalcuate prior
        prior[k] = col_sum / num_docs
        
        # recalculate mu
        for i in range(num_docs):
            mu[k,:] += resp[i,k] * docs[i,:]
        mu[k,:] = mu[k,:] / col_sum
    
    for k in range(K):
        col_sum = np.sum(resp[:,k])
        
        # recalculate sigma
        for i in range(num_docs):
            sigma[k,:,:] += resp[i,k] * np.outer(docs[i,:]-mu[k,:], docs[i,:]-mu[k,:])
        sigma[k,:,:] /= col_sum
        
    return [prior, mu, sigma]
        
def get_ll(resp, docs, prior, mu, sigma):
    num_docs = len(docs)
    K = len(mu)
    
    ll = 0.0
    
    for i in range(num_docs):
        for k in range(K):
            inner = np.log(prior[k]) + mn.logpdf(x=docs[i,:],
                                              mean=mu[k,:],
                                              cov=sigma[k,:,:])
            ll += resp[i,k] * inner
    
    return ll
            

In [4]:
labels = []
num_docs = 0
dim = 0
K=3
docs_dict = []

with open("/Users/waltercai/Documents/cse547/hw2/2DGaussianMixture.csv") as f:
    first_line = True
    for line in f:
        if first_line:
            dim = line.count(",")
            first_line = False
        else:
            line_split = line.split(",")
            labels.append(int(line_split[0]))
            row = {}
            for i in range(dim):
                row[i] = float(line_split[i+1])
            docs_dict.append(row)
    num_docs = len(docs_dict)

docs = np.zeros(shape=[num_docs, dim])
for i in range(num_docs):
    for k in docs_dict[i].keys():
        docs[i,k] = docs_dict[i][k]

prior = np.zeros(K) + 1.0/K
resp = np.zeros(shape=[num_docs, K]);

mu = np.zeros(shape=[K, dim])
sigma = np.zeros(shape=[K, dim, dim])
for k in range(K):
    sigma[k,:,:] = np.identity(dim)


diff = THRESHOLD + 1.0
old_ll = 0.0
kpp_init(docs, mu)

iter_count = 0
lls = []
while diff > THRESHOLD:
# for i in range(500):
    iter_count+=1
    
    resp = e_step(docs=docs, prior=prior, mu=mu, sigma=sigma)
    [prior, mu, sigma] = m_step(resp=resp, docs=docs)
    
    new_ll = get_ll(resp=resp, docs=docs, prior=prior, mu=mu, sigma=sigma)
    lls.append(new_ll)
    diff = np.abs((old_ll - new_ll)/new_ll)
    old_ll = new_ll


In [5]:
print mu
print sigma
guess = {}
for k in range(K):
    guess[k] = []
for i in range(num_docs):
    k = np.argmax(resp[i,:])
    guess[k].append(i)
for k in range(K):
    print guess[k]
print lls

[[ 0.92612491  0.22870629]
 [ 0.75965167  0.77212741]
 [ 0.4215486   0.7414443 ]]
[[[ 0.01774189  0.00567881]
  [ 0.00567881  0.00563769]]

 [[ 0.00130078 -0.00135753]
  [-0.00135753  0.00520517]]

 [[ 0.0036664   0.00025332]
  [ 0.00025332  0.00620674]]]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130]
[131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 17