In [1]:
import numpy as np
import pickle
from datetime import datetime

In [2]:
class Bmm():
    def __init__(self, k, dim, max_iter, threshold=1e-3):
        self.k = k
        self.d = dim
        self.max_iter = max_iter
        self.threshold = threshold
        self.means = np.random.rand(self.k, self.d) * 0.5 + 0.25
        self.pis = np.array([1 / self.k for _ in range(self.k)])
        self.EPS = np.finfo(float).eps

    def fit(self, data):
        iter_count = 0
        log_likelihood = -np.inf
        while iter_count < self.max_iter:
            t0 = datetime.now()
            new_log_likelihood, gamma = self._e_step(data)
            #print('e step : {}s'.format(datetime.now()-t0))
            new_log_likelihood = new_log_likelihood.mean()
            self._m_step(data, gamma)
            print(iter_count+1, new_log_likelihood)
            change = abs(new_log_likelihood - log_likelihood)
            if change < self.threshold:
                print('breaking training loop')
                break
            log_likelihood = new_log_likelihood
            iter_count += 1

    def predict(self, data):
        _, gamma = self._e_step(data)
        return np.argmax(gamma, axis=1)

    def _e_step(self, X):
        # cal. laten code gamma(w/z) in log
        X_c = 1-X
        means_c = 1-self.means
        
        log_support = np.ndarray(shape=(X.shape[0], self.k)) 
        t0 = datetime.now()
        
        for i in range(self.k):
            #log_support[:,i] = (np.sum(X*np.log(self.means[i,:].clip(min=1e-50)),1)+np.sum(X_c*np.log(means_c[i,:].clip(min=1e-50)),1))
            log_support[:,i] = (np.sum(X*np.log(self.means[i,:].clip(min=self.EPS*10)),1)+np.sum(X_c*np.log(means_c[i,:].clip(min=self.EPS*10)),1))
        

        #print('log support : {}s'.format(datetime.now()-t0))
        lpr = log_support + np.log(self.pis)
        log_likelihood = np.logaddexp.reduce(lpr, axis=1)
        print(log_likelihood)
        gamma = np.exp(lpr - log_likelihood[:, np.newaxis])
        #print(gamma.shape)
        #print(gamma)
        return log_likelihood, gamma

    def _m_step(self, X, gamma):
        # cal. pi(lambda) and mean(p)
        weights = gamma.sum(axis=0)
        weighted_x_sum = np.dot(gamma.T, X)
        inverse_weights = 1.0 / (weights[:, np.newaxis] + 10 * self.EPS)

        self.pis = (weights / (weights.sum() + 10 * self.EPS) + self.EPS)
        self.means = weighted_x_sum * inverse_weights

    def save(self, path=None):
        path = path or "./model.pkl"
        with open(path, 'wb') as file:
            pickle.dump((self.k, self.d, self.max_iter, self.threshold, self.pis, self.means), file)

    @staticmethod
    def load(path=None):
        path = path or "./model.pkl"
        with open(path, 'rb') as file:
            k, d, max_iter, threshold, pis, means = pickle.load(file)
        model = Bmm(k, d, max_iter, threshold)
        model.pis = pis
        model.means = means
        return model

def get_data(path = None, train='True'):
    imgs_path = './data/train-images.idx3-ubyte'
    labels_path = './data/t10k-labels.idx1-ubyte'
    if not train: 
       c
    if path:
        imgs_path, labels_path = path
        
    with open(imgs_path,'rb') as imgs_file, open(labels_path, 'rb') as labels_file:
        _ = int.from_bytes(imgs_file.read(4), byteorder='big', signed=False)
        no_img = int.from_bytes(imgs_file.read(4), byteorder='big', signed=False)
        no_rows = int.from_bytes(imgs_file.read(4), byteorder='big', signed=False)
        no_cols = int.from_bytes(imgs_file.read(4), byteorder='big', signed=False)
        labels_file.read(8)
        X = list()
        Y = list()
        for _ in range(no_img):
            X.append(list(imgs_file.read(no_rows*no_cols)))
            Y.append(int.from_bytes(labels_file.read(1), byteorder='big', signed=False))
        return np.asarray(X)  , np.asarray(Y)
            

In [3]:


def main():
    train=True
    x, y = get_data(train=True)
    print('start')
    x = np.where(x > 127, 1, 0)
    
    #tx, ty = get_data(train=False)
    if train :
        model = Bmm(10,784, max_iter=100)
        model.fit(x)
        model.save('./model.pkl')
    else:
        model = Bmm.load('./model.pkl')
    result = list()
    all_digit = []
    for i in range(10):
        result.append(model.predict(x[y==i]))
        #print('pred',result)
        unique, counts = np.unique(result[-1], return_counts=True)
        #print('pair',unique,counts)
        sorted_by_value = sorted(dict(zip(unique, counts)).items(), key=lambda kv: kv[1], reverse=True)
        #print('sorted',sorted_by_value)
        #print('digit {} {}'.format(i, sorted_by_value))
        all_digit.append(sorted_by_value)
    
    temp = {}
    for i, digit in enumerate(all_digit):
        check_cluster(temp, i, all_digit)
    digit_to_cluster = {}
    for cluster, (_, digit) in temp.items():
        
        digit_to_cluster[digit] = cluster
    #print(temp)

    matrixs = confusion_matrix(y, digit_to_cluster, all_digit)
    for i, cm in enumerate(matrixs):
        print('digit : {}'.format(i))
        sensitivity = cm[0][0] / (cm[0][0] + cm[1][0])
        specificity = cm[1][1] / (cm[0][1] + cm[1][1])
        print(cm)
        print('sensitivity : {:0.3f}'.format(sensitivity))
        print('specificity : {:0.3f}'.format(specificity))
        print()
        
def confusion_matrix(label, digit_to_cluster, all_digit):
    matrixs = []
    for digit in range(10):
        cluster = digit_to_cluster[digit]
        TP = all_digit[digit][cluster]
        TP = sum([ size for c, size in all_digit[digit] if c == cluster])
        FN = sum([size for _, size in all_digit[digit]]) - TP
        FP = sum([sum([ size for c, size in d if c == cluster]) for d in all_digit ]) - TP
        TN = len(label) - TP - FN - FP
        matrixs.append(np.array([[TP,FP],[FN,TN]]))
    return matrixs
    

def check_cluster(result, n, all_digit):
    #cluster no : (size, n)
    for cluster, size in all_digit[n]:
        if cluster in result and result[cluster][0] < size:
            old_digit = result[cluster][1]
            result[cluster] = (size, n)
            check_cluster(result, old_digit, all_digit)
            break
        elif cluster not in result:
            result[cluster] = (size, n)
            break
            

if __name__ == '__main__':
    
    main()


start
[-572.39858397 -570.9025014  -579.27443297 ... -574.75663086 -569.89334267
 -573.48997844]
1 -574.3659140100036
[-212.1790413  -216.00560462 -236.39535322 ... -197.10926693 -198.32738448
 -170.948256  ]
2 -185.90745617925262
[-199.92836458 -200.10972893 -232.9133649  ... -197.21319545 -194.28797415
 -169.73647839]
3 -174.4931038184912
[-194.56727792 -193.28596234 -242.2530546  ... -199.94579458 -184.78793212
 -172.24203226]
4 -170.25929687075043
[-192.60664789 -188.47145229 -240.0039786  ... -199.4103937  -177.45563545
 -173.30335379]
5 -168.53368155176366
[-191.73474593 -182.78736717 -237.77387788 ... -197.18136251 -175.18112419
 -173.98977084]
6 -167.69747591100676
[-191.11634318 -176.40649118 -235.73730222 ... -194.34485753 -174.0938631
 -173.73682019]
7 -167.16551079410863
[-190.838067   -169.67567122 -234.06307708 ... -191.85611121 -173.48625304
 -172.60500845]
8 -166.71149599075883
[-190.17152617 -163.63990185 -232.9960751  ... -188.93667814 -173.12897177
 -170.96523795]
9 

[ -82.94558886 -122.09582023 -114.39801773 -129.18656033 -179.60573486
 -154.72216733 -107.96844404  -77.32858618 -269.09484765 -170.19916804
 -145.55590024 -131.4196894   -64.70417785 -119.52361242 -127.68494087
 -108.79117803 -118.47726157 -130.73570739 -148.53202415 -218.54208498
 -214.40098208  -67.3034041  -156.60336739 -177.27484109 -135.91711758
 -234.64195299 -130.38767418 -173.59383158 -158.04141051 -217.19117958
 -111.93050448 -181.88704388 -150.00269868 -214.29085236 -115.23120019
 -181.4357078  -157.23415056 -139.16793843 -145.09471911 -209.44130881
 -128.14508103 -110.5549557  -133.3496669  -133.13078488 -202.22155543
 -143.28967065 -194.11912877 -164.84538398 -280.38152037 -157.14632836
 -154.97249381 -195.73903726 -163.98965554 -181.56823155 -239.96594976
 -118.76465885  -71.37069451  -85.53427013  -81.64885631 -188.87160574
 -139.30576611 -130.47356321 -162.5173443  -172.71847561  -80.80889434
 -141.9728848  -127.12303191 -205.95993361 -196.93911566 -152.86438356
 -148.

[-188.76615742 -153.00384343 -115.34054793 ... -131.68468112 -185.92374929
 -159.27259266]
[-145.67059955 -153.21326846 -159.41117141  -75.62773544 -215.55786499
 -191.9879445   -98.4221327  -211.00165902 -222.32147976  -83.68243991
 -116.48658078 -155.75031841 -176.00312022 -214.737762   -160.11219206
 -223.22408615 -187.91718501 -185.60552412 -207.42060174  -64.42052087
 -174.22327811 -140.17000397 -120.14799783 -112.54885641 -149.836315
 -132.79515956 -235.56529155 -118.75034582 -157.41581417 -161.42157162
 -120.05315961 -148.54302061 -120.44695274 -175.33578357 -155.22562455
  -73.5526228  -140.83655768 -171.2038846  -177.66044566 -222.361002
 -122.15148751 -155.85202559  -71.65667802 -196.76614867 -140.16224613
 -105.42803981 -230.90140865  -77.27483314 -115.03572945 -208.51614387
 -272.26849662 -233.5768822  -148.05252567 -202.28704826 -243.78742637
 -121.57714279 -146.40504176 -148.68282993 -176.9029919  -199.54295714
 -157.08373694 -120.80579179 -197.98244595 -166.21721969 -167

[-160.46968535 -154.12086968 -176.21273684 ... -151.44352487 -110.13371549
 -182.95933978]
digit : 0
[[ 6552  1078]
 [44428  7942]]
sensitivity : 0.129
specificity : 0.880

digit : 1
[[  147  6935]
 [  988 51930]]
sensitivity : 0.130
specificity : 0.882

digit : 2
[[  128  6978]
 [  904 51990]]
sensitivity : 0.124
specificity : 0.882

digit : 3
[[  108  5513]
 [  902 53477]]
sensitivity : 0.107
specificity : 0.907

digit : 4
[[   87  4851]
 [  895 54167]]
sensitivity : 0.089
specificity : 0.918

digit : 5
[[   79  4562]
 [  813 54546]]
sensitivity : 0.089
specificity : 0.923

digit : 6
[[   85  4578]
 [  873 54464]]
sensitivity : 0.089
specificity : 0.922

digit : 7
[[  137  6904]
 [  891 52068]]
sensitivity : 0.133
specificity : 0.883

digit : 8
[[   76  4569]
 [  898 54457]]
sensitivity : 0.078
specificity : 0.923

digit : 9
[[  114  6519]
 [  895 52472]]
sensitivity : 0.113
specificity : 0.889

