In [14]:
import matplotlib.pyplot as plt
import numpy as np
from filter import filter_encode, filter_decode
from helper import runlength_encode
import keras
from collections import defaultdict


from sklearn.cluster import KMeans, MiniBatchKMeans

import multiprocessing as multi
from multiprocessing import Pool

import pickle

In [2]:
fashion_mnist = keras.datasets.fashion_mnist
(x_train, y_train), (test_images, test_labels) = fashion_mnist.load_data()

In [3]:
def freq_vector(arr):
    d = {i: 1 for i in range(256)}
    for i in arr:
        d[i] += 1
    return d

def to_freq(img):
    fil = filter_encode(img, fid=1)
    rle = runlength_encode(fil.ravel())
    return list(freq_vector(rle).values())

In [4]:
X = [x_train[i, :, :] for i in range(60000)]

n_cores = multi.cpu_count()
with Pool(n_cores) as p:
    freq = p.map(to_freq, X)

print(np.array(freq))

new_X = np.array(freq)

[[ 68  23  30 ...  16  19  13]
 [102  37  40 ...  23  35  31]
 [105  22  26 ...  10  18  21]
 ...
 [ 86  20  14 ...  13  10  18]
 [ 98  37  43 ...  20  19  16]
 [ 44  17   6 ...   8   3  15]]


In [5]:
new_X.shape

(60000, 256)

In [6]:
km = MiniBatchKMeans(n_clusters=256, max_iter=1000)

km.fit(new_X)
km.cluster_centers_

array([[127.        ,  52.        ,  56.        , ...,  23.        ,
         29.        ,  53.        ],
       [ 93.        ,  32.        ,  42.        , ...,  21.        ,
         20.        ,  31.        ],
       [ 47.23243243,   1.17297297,  11.41081081, ...,  16.09189189,
          2.05405405,   1.51891892],
       ...,
       [103.82330827,  18.0075188 ,  18.44360902, ...,   8.39849624,
          8.72556391,  16.7443609 ],
       [ 90.97530864,  43.11111111,  57.47839506, ...,  18.25925926,
         25.57098765,  35.97839506],
       [ 50.59447005,   2.55299539,  33.21198157, ...,   2.64976959,
         20.0921659 ,   2.57142857]])

In [8]:
type(km.labels_)

numpy.ndarray

In [16]:
from huffman import HuffmanCoding

dicts = []
trees = []

for i in range(256):
    huff = HuffmanCoding()
    label = new_X[km.labels_ == i]

    enc = huff.encode(label.ravel())

    dicts.append(huff.code)
    trees.append(huff.tree)



In [18]:
with open("dicts.pkl", "wb") as f:
    pickle.dump(dicts, f)

with open("trees.pkl", "wb") as f:
    pickle.dump(trees, f)

In [19]:
with open("trees.pkl", "rb") as f:
    trees_loaded = pickle.load(f)

print(trees_loaded)

[<huffman.Node object at 0x14bcce790>, <huffman.Node object at 0x14f0342d0>, <huffman.Node object at 0x14fcdf7d0>, <huffman.Node object at 0x14fd8cad0>, <huffman.Node object at 0x14bcca750>, <huffman.Node object at 0x14ee15910>, <huffman.Node object at 0x1504fad90>, <huffman.Node object at 0x1504bd8d0>, <huffman.Node object at 0x14fcafdd0>, <huffman.Node object at 0x15044d590>, <huffman.Node object at 0x14bcef850>, <huffman.Node object at 0x14f7fa610>, <huffman.Node object at 0x14f548890>, <huffman.Node object at 0x14f5dd710>, <huffman.Node object at 0x14fbd90d0>, <huffman.Node object at 0x14f612710>, <huffman.Node object at 0x14f80ce90>, <huffman.Node object at 0x14f047650>, <huffman.Node object at 0x14bcd6690>, <huffman.Node object at 0x14fab2710>, <huffman.Node object at 0x14f5e3590>, <huffman.Node object at 0x14f938ed0>, <huffman.Node object at 0x14bcea990>, <huffman.Node object at 0x14bce45d0>, <huffman.Node object at 0x14f104ed0>, <huffman.Node object at 0x14f790dd0>, <huffman.No

In [21]:
with open("cluster_centers.pkl", "wb") as f:
    pickle.dump(km.cluster_centers_, f)

km.cluster_centers_

array([[127.        ,  52.        ,  56.        , ...,  23.        ,
         29.        ,  53.        ],
       [ 93.        ,  32.        ,  42.        , ...,  21.        ,
         20.        ,  31.        ],
       [ 47.23243243,   1.17297297,  11.41081081, ...,  16.09189189,
          2.05405405,   1.51891892],
       ...,
       [103.82330827,  18.0075188 ,  18.44360902, ...,   8.39849624,
          8.72556391,  16.7443609 ],
       [ 90.97530864,  43.11111111,  57.47839506, ...,  18.25925926,
         25.57098765,  35.97839506],
       [ 50.59447005,   2.55299539,  33.21198157, ...,   2.64976959,
         20.0921659 ,   2.57142857]])