In [None]:
from sklearn.datasets import fetch_openml
import numpy as np
import pickle
import os

data_file = 'data/mnist_digits_784.pkl'
# Check if data file exists
if os.path.isfile(data_file):
    # Load data from file
    with open(data_file, 'rb') as f:
        data = pickle.load(f)
else:
    # Fetch data from internet
    data = fetch_openml('mnist_784', version=1, parser='auto')
    # Save data to file
    with open(data_file, 'wb') as f:
        pickle.dump(data, f)
# Print the keys
print(data.keys())
# Print the shape of the data
print(data.data.shape)
print(data.target.shape)

# description of the data
print(data.DESCR)

# Print the unique labels
print(np.unique(data.target))
# Print the first few data points
print(data.data[:5])
# Extract data and labels
X, y = np.array(data.data).T, np.array(data.target).astype('int')
# Split data into training and test sets
# X_train, X_test = data.data[:60000], data.data[60000:]
# y_train, y_test = data.target[:60000], data.target[60000:]

In [None]:
# ---------------------------------------
def montage(A, m, n):
    '''
    Create a montage matrix with mn images
    Inputs:
    A: original pxN image matrix with N images (p pixels), N > mn
    m, n: m rows & n columns, total mn images
    Output:
    M: montage matrix containing mn images
    '''
    
    sz = np.sqrt(A.shape[0]).astype('int') # image size sz x sz
    M = np.zeros((m*sz, n*sz)) # montage image
    
    for i in range(m) :
        for j in range(n) :
            M[i*sz: (i+1)*sz, j*sz:(j+1)*sz] = \
                A[:, i*n+j].reshape(sz, sz)

    return M

In [None]:
def montage_v2(A, m, n):
# by 李晏丞
# 讀取所有圖片並串聯成montage
    sz = np.sqrt(A.shape[0]).astype("int")
    M = np.zeros((sz*m, sz*n))

    for i in range(m*n):
        # 找出新舊矩陣的關係
        row = (i // n) * sz # // 代表整數除法
        col = (i % n) * sz # % 代表取餘數
        M[row:row+sz, col:col+sz] = A[:, i].reshape(sz, sz)
    return M