In [2]:
import numpy as np
import pickle
from datetime import datetime

In [29]:
class Bmm():
    def __init__(self, k, dim, max_iter, threshold=1e-3):
        self.k = k
        self.d = dim
        self.max_iter = max_iter
        self.threshold = threshold
        self.means = np.random.rand(self.k, self.d) * 0.5 + 0.25
        self.pis = np.array([1 / self.k for _ in range(self.k)])
        self.EPS = np.finfo(float).eps

    def fit(self, data):
        iter_count = 0
        log_likelihood = -np.inf
        while iter_count < self.max_iter:
            t0 = datetime.now()
            new_log_likelihood, gamma = self._e_step(data)
            #print('e step : {}s'.format(datetime.now()-t0))
            new_log_likelihood = new_log_likelihood.mean()
            self._m_step(data, gamma)
            print(iter_count+1, new_log_likelihood)
            change = abs(new_log_likelihood - log_likelihood)
            if change < self.threshold:
                print('breaking training loop')
                break
            log_likelihood = new_log_likelihood
            iter_count += 1

    def predict(self, data):
        _, gamma = self._e_step(data)
        return np.argmax(gamma, axis=1)

    def _e_step(self, X):
        # cal. laten code gamma(w/z) in log
        X_c = 1-X
        means_c = 1-self.means
        
        log_support = np.ndarray(shape=(X.shape[0], self.k)) 
        t0 = datetime.now()
        
        for i in range(self.k):
            #log_support[:,i] = (np.sum(X*np.log(self.means[i,:].clip(min=1e-50)),1)+np.sum(X_c*np.log(means_c[i,:].clip(min=1e-50)),1))
            log_support[:,i] = (np.sum(X*np.log(self.means[i,:].clip(min=self.EPS*10)),1)+np.sum(X_c*np.log(means_c[i,:].clip(min=self.EPS*10)),1))
        

        #print('log support : {}s'.format(datetime.now()-t0))
        lpr = log_support + np.log(self.pis)
        log_likelihood = np.logaddexp.reduce(lpr, axis=1)
        print(log_likelihood)
        gamma = np.exp(lpr - log_likelihood[:, np.newaxis])
        #print(gamma.shape)
        #print(gamma)
        return log_likelihood, gamma

    def _m_step(self, X, gamma):
        # cal. pi(lambda) and mean(p)
        weights = gamma.sum(axis=0)
        weighted_x_sum = np.dot(gamma.T, X)
        inverse_weights = 1.0 / (weights[:, np.newaxis] + 10 * self.EPS)

        self.pis = (weights / (weights.sum() + 10 * self.EPS) + self.EPS)
        self.means = weighted_x_sum * inverse_weights

    def save(self, path=None):
        path = path or "./model.pkl"
        with open(path, 'wb') as file:
            pickle.dump((self.k, self.d, self.max_iter, self.threshold, self.pis, self.means), file)

    @staticmethod
    def load(path=None):
        path = path or "./model.pkl"
        with open(path, 'rb') as file:
            k, d, max_iter, threshold, pis, means = pickle.load(file)
        model = Bmm(k, d, max_iter, threshold)
        model.pis = pis
        model.means = means
        return model

def get_data(path = None, train='True'):
    imgs_path = './data/train-images-idx3-ubyte'
    labels_path = './data/train-labels-idx1-ubyte'
    if not train: 
        imgs_path = './data/t10k-images.idx3-ubyte'
        labels_path = './data/t10k-labels.idx1-ubyte'
    if path:
        imgs_path, labels_path = path
        
    with open(imgs_path,'rb') as imgs_file, open(labels_path, 'rb') as labels_file:
        _ = int.from_bytes(imgs_file.read(4), byteorder='big', signed=False)
        no_img = int.from_bytes(imgs_file.read(4), byteorder='big', signed=False)
        no_rows = int.from_bytes(imgs_file.read(4), byteorder='big', signed=False)
        no_cols = int.from_bytes(imgs_file.read(4), byteorder='big', signed=False)
        labels_file.read(8)
        X = list()
        Y = list()
        for _ in range(no_img):
            X.append(list(imgs_file.read(no_rows*no_cols)))
            Y.append(int.from_bytes(labels_file.read(1), byteorder='big', signed=False))
        return np.asarray(X)  , np.asarray(Y)
            

In [32]:


def main():
    train=True
    x, y = get_data(train=True)
    print('start')
    x = np.where(x > 127, 1, 0)
    
    #tx, ty = get_data(train=False)
    if train :
        model = Bmm(10,784, max_iter=100)
        model.fit(x)
        model.save('./model.pkl')
    else:
        model = Bmm.load('./model.pkl')
    result = list()
    all_digit = []
    for i in range(10):
        result.append(model.predict(x[y==i]))
        #print('pred',result)
        unique, counts = np.unique(result[-1], return_counts=True)
        #print('pair',unique,counts)
        sorted_by_value = sorted(dict(zip(unique, counts)).items(), key=lambda kv: kv[1], reverse=True)
        #print('sorted',sorted_by_value)
        #print('digit {} {}'.format(i, sorted_by_value))
        all_digit.append(sorted_by_value)
    print('all digit',all_digit)
    temp = {}
    for i, digit in enumerate(all_digit):
        print('all_digit',i,digit)
        check_cluster(temp, i, all_digit)
        print('temp',temp)
    digit_to_cluster = {}
    for cluster, (_, digit) in temp.items():
        
        digit_to_cluster[digit] = cluster
    print('digit to clus',digit_to_cluster)

    matrixs = confusion_matrix(y, digit_to_cluster, all_digit)
    for i, cm in enumerate(matrixs):
        print('digit : {}'.format(i))
        sensitivity = cm[0][0] / (cm[0][0] + cm[1][0])
        specificity = cm[1][1] / (cm[0][1] + cm[1][1])
        print(cm)
        print('sensitivity : {:0.3f}'.format(sensitivity))
        print('specificity : {:0.3f}'.format(specificity))
        print()
        
def confusion_matrix(label, digit_to_cluster, all_digit):
    matrixs = []
    for digit in range(10):
        cluster = digit_to_cluster[digit]
        TP = all_digit[digit][cluster]
        TP = sum([ size for c, size in all_digit[digit] if c == cluster])
        FN = sum([size for _, size in all_digit[digit]]) - TP
        FP = sum([sum([ size for c, size in d if c == cluster]) for d in all_digit ]) - TP
        TN = len(label) - TP - FN - FP
        matrixs.append(np.array([[TP,FP],[FN,TN]]))
    return matrixs
    

def check_cluster(result, n, all_digit):
    #cluster no : (size, n)
    for cluster, size in all_digit[n]:
        if cluster in result and result[cluster][0] < size:
            old_digit = result[cluster][1]
            result[cluster] = (size, n)
            check_cluster(result, old_digit, all_digit)
            break
        elif cluster not in result:
            result[cluster] = (size, n)
            break
            

if __name__ == '__main__':
    
    main()


start
[-565.81781969 -570.43373094 -575.39228518 ... -568.76348574 -575.60989986
 -573.71027987]
1 -574.032949089063
[-203.51062642 -159.93432265 -244.86422349 ... -180.33195641 -175.77819337
 -186.05481971]
2 -188.18989644339558
[-202.1693992  -143.02775468 -229.90178375 ... -169.58892119 -167.70323509
 -173.11899682]
3 -177.11403492055481
[-206.15310715 -138.21976099 -216.46106815 ... -171.52443664 -166.12538144
 -165.59007922]
4 -171.96522819788728
[-203.00462957 -139.17300239 -208.35183614 ... -174.68127289 -165.13037229
 -167.40594524]
5 -168.35031548624806
[-201.45969525 -140.01655059 -204.29066571 ... -180.32060985 -166.06064902
 -169.50889733]
6 -166.4764743781604
[-201.26613874 -140.86686425 -202.54960251 ... -183.39656966 -167.73117536
 -170.65510388]
7 -165.7861056475601
[-201.23026333 -141.70715607 -201.81434485 ... -183.6155875  -169.30415782
 -171.25030832]
8 -165.47974729915762
[-201.2211267  -142.30728001 -201.69100831 ... -183.90693017 -170.31372868
 -171.64242537]
9 -

IndexError: list index out of range