In [None]:
import numpy as np
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import pickle

IMAGES_PATH = "ml2018spring-hw4-v2/visualization.npy"
PCA_PATH = "model/pca.pickle"
KMEANS_PATH = "model/kmeans.pickle"

def dimension_reduction(X_train):
    # Normalization
    X_train_norm = X_train / 255

    with open(PCA_PATH, 'rb') as file:
        pca = pickle.load(file)
    X_train_pca = pca.transform(X_train_norm)

    return X_train_pca

def dimension_reduction_visualization(X_train, dimension):
    # Normalization
    X_train_norm = X_train / 255

    pca = PCA(n_components=dimension, whiten=True)
    X_train_pca = pca.fit_transform(X_train_norm)

    return X_train_pca

def clustering(imgs_pca):
    with open(KMEANS_PATH, 'rb') as file:
        kmeans = pickle.load(file)
    imgs_label = kmeans.predict(imgs_pca)
    return imgs_label

## Visualization

In [None]:
imgs = np.load(IMAGES_PATH)
imgs_pca = dimension_reduction(imgs)
imgs_pca_vis = dimension_reduction_visualization(imgs_pca, 15)
imgs_embedded = TSNE(n_components=2).fit_transform(imgs_pca_vis)

## Predict label

In [None]:
imgs_label = clustering(imgs_pca)

## Comparison to ground truth

In [None]:
acc = (len(imgs_label[:5000][imgs_label[:5000] == 0]) + len(imgs_label[:5000][imgs_label[:5000] == 0])) / len(imgs_label)
print("Accuracy is {:.5f}".format(acc))

In [None]:
plt.scatter(imgs_embedded[imgs_label == 0][:, 0], imgs_embedded[imgs_label == 0][:, 1], c='b', label="dataset A", s=0.2)
plt.scatter(imgs_embedded[imgs_label == 1][:, 0], imgs_embedded[imgs_label == 1][:, 1], c='r', label="dataset B", s=0.2)
plt.title("Predict label")
plt.legend()
plt.savefig("report/tsne_predict.png")
plt.show()

In [None]:
plt.scatter(imgs_embedded[:5000, 0], imgs_embedded[:5000, 1], c='b', label="dataset A", s=0.2)
plt.scatter(imgs_embedded[5000:, 0], imgs_embedded[5000:, 1], c='r', label="dataset B", s=0.2)
plt.title("Ground truth")
plt.legend()
plt.savefig("report/tsne_truth.png")
plt.show()