#### implements `FourierProjector`

In [1]:
import os
import time
from pathlib import Path

import skimage
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.random_projection import GaussianRandomProjection
from tqdm import tqdm

In [2]:
plt.rcParams["figure.figsize"] = (12, 10)

SMALL_SIZE = 8
MEDIUM_SIZE = 10
BIGGER_SIZE = 12

plt.rc("font", size=SMALL_SIZE)          # controls default text sizes
plt.rc("axes", titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc("axes", labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc("xtick", labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc("ytick", labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc("legend", fontsize=SMALL_SIZE)    # legend fontsize
plt.rc("figure", titlesize=BIGGER_SIZE)  # fontsize of the figure title

color_cycle = plt.rcParams["axes.prop_cycle"].by_key()["color"]

sns.set()

In [3]:
RESIZED_IMG_FOLDER = "img64"
RESIZED_DIM = 64

In [4]:
def read_images(path=RESIZED_IMG_FOLDER):
    filename_lst = os.listdir(path)
    width = RESIZED_DIM
    height = width
    n_img = len(filename_lst)
    n_channel = 3

    X = np.empty((n_img, width * height * n_channel), dtype=np.uint8)
    for idx, filename in tqdm(enumerate(filename_lst)):
        input_filepath = Path(RESIZED_IMG_FOLDER) / filename
        im = skimage.io.imread(input_filepath)

        if im.ndim != 3:
            im = np.stack((im,)*3, axis=-1)
            
        X[idx] = im.flatten()
    return X

In [5]:
def calc_error(X, label_lst, k):
    
    # Calculate cluster centers
    cluster_center_lst = np.empty((k, X.shape[1]))
    
    for i in range(k):
        cluster_center_lst[i] = np.mean(X[label_lst == i], axis=0)
    
    error = (np.linalg.norm(X - cluster_center_lst[label_lst], axis=1) ** 2).mean()
    return error

In [6]:
def cluster_with_projection(X, projector, num_clusters=20):
    start = time.perf_counter()

    # reduce dimensions
    X_tr = projector.fit_transform(X)

    # cluster
    kmeans_model = KMeans(n_clusters=num_clusters)
    kmeans_model.fit(X_tr)

    time_elapsed = time.perf_counter() - start
    error = calc_error(X, kmeans_model.labels_, num_clusters)

    return time_elapsed, error

In [7]:
class FourierProjector(TransformerMixin, BaseEstimator):
    def __init__(self, k=20, gamma=1, random_state=None):
        """
        `k` is the dimension of the new space.
        adapted from https://github.com/hichamjanati/srf & modified.
        """
        self.gamma = gamma
        self.k = k
        self.random_state = random_state
        
        self._fitted = False
        self._rng = np.random.default_rng(seed=random_state)
        
    def fit(self, X, y=None):
        _, d = X.shape
        self.w = (np.sqrt(2 * self.gamma)
                      * self._rng.standard_normal(size=(self.k, d)))
        self.u = 2 * np.pi * self._rng.random(self.k)
        self._fitted = True

        return self
    
    def transform(self, X):
        """
        From `N x d` to `N x k`.
        """
        if not self._fitted:
            raise RuntimeError("Need to fit prior to transform.")
        return (np.sqrt(2 / self.k)
                    * np.cos(X.dot(self.w.T) + self.u[None, :]))

In [8]:
# the data
X = read_images()

  from .collection import imread_collection_wrapper
5000it [00:03, 1311.05it/s]


In [9]:
# compare projections for 1 new dimension
NEW_DIM = 100
NUM_CLUSTERS = 20
time_pca, err_pca = cluster_with_projection(X, PCA(NEW_DIM), num_clusters=NUM_CLUSTERS)
time_fourier, err_fourier = cluster_with_projection(X, FourierProjector(NEW_DIM), num_clusters=NUM_CLUSTERS)

In [10]:
time_pca, err_pca

(6.938468199999999, 33542767.186324146)

In [11]:
time_fourier, err_fourier

(2.3659399000000008, 52364502.73190421)

In [12]:
# compare projections for many new dimensions
NEW_DIMS = [100, 200, 400, 800, 1000, 2000, 4000]

In [13]:
def experiment(projector_cls, new_dims, num_clusters=20, **projector_kwargs):
    """
    calls `cluster_with_projection` many (`len(new_dims)`) times.
    """
    elapsed_times = []
    errors = []

    for new_dim in tqdm(new_dims):

        start = time.perf_counter()

        # reduce dimensions
        X_tr = projector_cls(new_dim, **projector_kwargs).fit_transform(X)

        # cluster
        kmeans_model = KMeans(n_clusters=num_clusters)
        kmeans_model.fit(X_tr)

        time_elapsed = time.perf_counter() - start
        error = calc_error(X, kmeans_model.labels_, num_clusters)

        elapsed_times.append(time_elapsed)
        errors.append(error)

    return {"time": elapsed_times, "error": errors}

In [14]:
results = {"pca": experiment(PCA, NEW_DIMS, NUM_CLUSTERS),
           "gauss": experiment(GaussianRandomProjection, NEW_DIMS,
                               NUM_CLUSTERS),
           "fourier": experiment(FourierProjector, NEW_DIMS, NUM_CLUSTERS,
                                 gamma=10)}

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [05:36<00:00, 48.06s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [01:17<00:00, 11.06s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [01:00<00:00,  8.66s/it]


In [15]:
idx = pd.IndexSlice
df = (pd.concat(map(pd.DataFrame, results.values()), keys=results.keys(), axis=1)
         .set_axis(NEW_DIMS, axis=0)
         .rename_axis(index="$k$")
         .round(3))
df.loc[:, idx[:, "error"]] **= 0.5  # RMSE

(df.style
   .highlight_min(subset=df.loc[:, idx[:, "time"]].columns, axis=1, color="lightblue")
   .highlight_min(subset=df.loc[:, idx[:, "error"]].columns, axis=1, color="lightgreen"))

Unnamed: 0_level_0,pca,pca,gauss,gauss,fourier,fourier
Unnamed: 0_level_1,time,error,time,error,time,error
$k$,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2
100,8.477,5794.985089,2.129,5879.228279,2.085,7234.180505
200,11.307,5797.94322,3.113,5825.607414,2.736,7234.441216
400,18.177,5795.177219,3.987,5810.693412,3.933,7234.715607
800,22.639,5795.563812,7.115,5809.120304,6.19,7234.508245
1000,28.458,5796.037574,8.915,5806.642199,7.092,7234.106746
2000,62.221,5793.571216,15.099,5799.142438,12.109,7233.871945
4000,179.174,5796.013221,31.078,5796.657745,20.615,7235.777532


#### the end