In [None]:
from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture
import pickle
import os
import numpy as np

In [None]:
from data_helpers.wine_quality_data_helper import load_wine_quality_data
from data_helpers.mnist_data_helper import load_mnist_data

In [None]:
class GaussianMixtureWrapper:
    NAME = 'GMM'
    def __init__(self, n_components):
        self.n_components = n_components
        self.model = GaussianMixture(n_components=n_components)

    def fit(self, X):
        self.model.fit(X)

    def predict(self, X):
        return self.model.predict(X)

class KMeansWrapper:
    NAME = 'KMeans'
    def __init__(self, n_clusters):
        self.n_clusters = n_clusters
        self.model = KMeans(n_clusters=n_clusters)

    def fit(self, X):
        self.model.fit(X)

    def predict(self, X):
        return self.model.predict(X)

In [None]:
# DATASET_NAME = 'wine_quality'
# DATASET_STR = 'Wine Quality'

DATASET_NAME = 'mnist'
DATASET_STR = 'MNIST'

METHOD = GaussianMixtureWrapper
# METHOD = KMeansWrapper

In [None]:
if DATASET_NAME == 'wine_quality':
    X_train, y_train, X_test, y_test = load_wine_quality_data()
elif DATASET_NAME == 'mnist':
    X_train, y_train, X_test, y_test = load_mnist_data()
else:
    raise ValueError(f'Unknown dataset: {DATASET_NAME}')

In [None]:
components = 20

In [None]:
model = METHOD(components)
model.fit(X_train)

In [None]:
if METHOD == GaussianMixtureWrapper:
    X_train_star = model.model.predict_proba(X_train)
    X_test_star = model.model.predict_proba(X_test)
elif METHOD == KMeansWrapper:
    def get_distances(X, model):
        distances = []
        for x in X:
            distance_vector = []
            for center in model.model.cluster_centers_:
                distance_vector.append(np.linalg.norm(x - center))
            distances.append(distance_vector)
        return np.array(distances)
    X_train_star = get_distances(X_train, model)
    X_test_star = get_distances(X_test, model)

In [None]:
X_train_star = np.hstack((X_train, X_train_star))
X_test_star = np.hstack((X_test, X_test_star))

In [None]:
dir = f"transformed_data/step_5/{DATASET_NAME}/{METHOD.NAME}"
os.makedirs(dir, exist_ok=True)
pickle.dump(X_train_star, open(f"{dir}/X_train.pkl", "wb"))
pickle.dump(X_test_star, open(f"{dir}/X_test.pkl", "wb"))
pickle.dump(y_train, open(f"{dir}/y_train.pkl", "wb"))
pickle.dump(y_test, open(f"{dir}/y_test.pkl", "wb"))

In [None]:
X_test.shape, X_test_star.shape