import scipy as sp
import scipy.stats
class GMM(object):
    def __init__(self, X, k=2):
        # dimension
        X = np.asarray(X)
        self.m = len(X)
        self.n = 1
        self.data = X.copy()
        # number of mixtures
        self.k = k
        
    def _init(self):
        # init mixture means/sigmas
        self.mean_arr = np.asmatrix(np.random.random((self.k, self.n)))
        self.sigma_arr = np.array([np.asmatrix(np.identity(self.n)) for i in range(self.k)])
        self.phi = np.ones(self.k)/self.k
        self.w = np.asmatrix(np.empty((self.m, self.k), dtype=float))
        #print(self.mean_arr)
        #print(self.sigma_arr)
    
    def fit(self, tol=1e-4):
        self._init()
        num_iters = 0
        ll = 1
        previous_ll = 0
        while(ll-previous_ll > tol):
            previous_ll = self.loglikelihood()
            self._fit()
            num_iters += 1
            ll = self.loglikelihood()
            print('Iteration %d: log-likelihood is %.6f'%(num_iters, ll))
        print('Terminate at %d-th iteration:log-likelihood is %.6f'%(num_iters, ll))
    
    def loglikelihood(self):
        ll = 0
        for i in range(self.m):
            tmp = 0
            for j in range(self.k):
                #print(self.sigma_arr[j])
                tmp += sp.stats.multivariate_normal.pdf(self.data[i, :], 
                                                        #self.mean_arr[j, :].A1, 
                                                        self.sigma_arr[j, :]) *\
                       self.phi[j]
            ll += np.log(tmp) 
        return ll
    
    def _fit(self):
        self.e_step()
        self.m_step()
        
    def e_step(self):
        # calculate w_j^{(i)}
        for i in range(self.m):
            den = 0
            for j in range(self.k):
                num = sp.stats.multivariate_normal.pdf(self.data[i, :], 
                                                       self.mean_arr[j].A1, 
                                                       self.sigma_arr[j]) *\
                      self.phi[j]
                den += num
                self.w[i, j] = num
            self.w[i, :] /= den
            assert self.w[i, :].sum() - 1 < 1e-4
            
    def m_step(self):
        for j in range(self.k):
            const = self.w[:, j].sum()
            self.phi[j] = 1/self.m * const
            _mu_j = np.zeros(self.n)
            _sigma_j = np.zeros((self.n, self.n))
            for i in range(self.m):
                _mu_j += (self.data[i, :] * self.w[i, j])
                _sigma_j += self.w[i, j] * ((self.data[i, :] - self.mean_arr[j, :]).T * (self.data[i, :] - self.mean_arr[j, :]))
                #print((self.data[i, :] - self.mean_arr[j, :]).T * (self.data[i, :] - self.mean_arr[j, :]))
            self.mean_arr[j] = _mu_j / const
            self.sigma_arr[j] = _sigma_j / const
        #print(self.sigma_arr)
        
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

bias = 0
std = 1

m,n = np.array([1,2])
print(m)
print(n)


x = np.array([8,8])
H = np.array([[0,0],[10,0],[0,10],[10,10]])
d = x-H
d = np.multiply(d,d)
d = np.sqrt(d.sum(axis=1))

d_col_1 = d[1]-d[0]
d_col_2 = d[2]-d[0] 
d_col_3 = d[3]-d[0]
real_arr = np.array([d_col_1,d_col_2,d_col_3])
#print(real_arr)

SAMPLE = 10
zd_col_tmp_1 = d[1]-d[0] + np.random.normal(bias,std,1)
zd_col_tmp_2 = d[2]-d[0] + np.random.normal(bias,std,1)
zd_col_tmp_3 = d[3]-d[0] + np.random.normal(bias,std,1)
X = np.array([zd_col_tmp_1,zd_col_tmp_2,zd_col_tmp_3])
X = np.transpose(X)

for i in range(1,SAMPLE + 1):
    zd_col_1 = d[1]-d[0] + np.random.normal(bias,std,1)
    zd_col_2 = d[2]-d[0] + np.random.normal(bias,std,1)
    zd_col_3 = d[3]-d[0] + np.random.normal(bias,std,1)
    mu_arr = np.array([zd_col_1,zd_col_2,zd_col_3])
    #mu_arr = np.transpose(mu_arr)
    X = np.append(X,mu_arr)
    print(X)
print(X.shape)
#X = np.random.normal(loc=mu_arr[0], scale=1, size=SAMPLE)

#for i, mu in enumerate(mu_arr[1:]):
X = np.append(X, np.random.normal(loc=mu, scale=1, size=SAMPLE))

fig, ax = plt.subplots(figsize=(15, 4))
sns.distplot(X[:SAMPLE], ax=ax, rug=True)
sns.distplot(X[SAMPLE:SAMPLE*2], ax=ax, rug=True)
sns.distplot(X[SAMPLE*2:], ax=ax, rug=True)

gmm = GMM(X, 3)
gmm.fit()
