In [None]:
from typing import Sequence, Literal, Optional
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import pandas as pd
import torch
from torch import Tensor

from typing import Literal, Sequence

import torch
from torch import nn, Tensor
import torch.nn.functional as F

In [None]:
plt.rcParams.update({
    'figure.titlesize': 12,
    'axes.titlesize':   10,
    'axes.labelsize':   10,
    'font.size':        8,
    'xtick.labelsize':  8,
    'ytick.labelsize':  8,
    'legend.fontsize':  8,
    'lines.linewidth':  1,
})

COLORS = ['red', 'blue', 'green', 'orange', 'purple',
          'brown', 'pink', 'gray', 'olive', 'cyan',
          'tab:red', 'tab:blue', 'tab:green', 'tab:orange', 'tab:purple',
          'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan']

1. Prototype Methods
    + K-means Clustering
    + Learning Vector Quantization
    + Gaussian Mixtures
2. k-Nearest-Neighbors

In [None]:
class Normalizer(nn.Module):
    def __init__(self, X_train:Tensor):
        super(Normalizer, self).__init__()
        
        self.mean = X_train.mean(dim=0)
        self.std = X_train.std(dim=0)

    def forward(self, input:Tensor) -> Tensor:
        return self.normalize(input=input)

    def normalize(self, input:Tensor) -> Tensor:
        return (input - self.mean)/self.std

    def unnormalize(self, input:Tensor) -> Tensor:
        return input*self.std + self.mean

    def extra_repr(self) -> str:
        params = {
            'mean':self.mean,
            'std': self.std,
        }
        return ', '.join([f'{k}={v}' for k, v in params.items() if v is not None])

In [None]:
class KMeansClustering(nn.Module):
    def __init__(self, num_centroids:int):
        super(KMeansClustering, self).__init__()
        self.num_centroids = num_centroids
        
    def forward(self, input:Tensor, with_distance:bool=False):
        distance = torch.zeros([input.size()[0], self.num_centroids])
        for k in torch.arange(self.num_centroids):
            distance[:, k] = torch.sqrt(torch.sum((self.centroids[k, :] - input)**2, dim=1))
        yhat = distance.argmin(dim=1, keepdim=True)
        
        if with_distance == False:
            return yhat
        else:
            return yhat, distance

    def backward(self, input:Tensor, pred:Tensor):
        for k in torch.arange(self.num_centroids):
            self.centroids[k, :] = torch.mean(input[pred.squeeze()==k, :], dim = 0)
            
    def fit(self, X_train:Tensor, centroids_init:Optional[Tensor]=None, with_logs:bool=False) -> Tensor:
        if centroids_init is None:
            rand_idx = torch.randperm(n=X_train.shape[0])[0:self.num_centroids]
            self.centroids = X_train[rand_idx]
        else:
            self.centroids = centroids_init
        
        centroids_log = torch.unsqueeze(self.centroids.clone(), dim=0)
        converge = False
        while not converge:
            # Fix centroids, update labels
            yhat = self.forward(input=X_train)
            # Fix labels, update centroids
            self.backward(input=X_train, pred=yhat)
            # Log centroids position
            centroids_log = torch.cat([centroids_log, self.centroids.unsqueeze(dim=0)], dim=0)
            # Check for stopping condition: when centroids stop moving
            converge = (centroids_log[-1] == centroids_log[-2]).all()

        if with_logs == False:
            return self.centroids.clone()
        else:
            return self.centroids.clone(), centroids_log.clone()
    
    def extra_repr(self) -> str:
        params = {
            'num_centroids':self.num_centroids,
        }
        return ', '.join([f'{k}={v}' for k, v in params.items() if v is not None])

In [None]:
if __name__ == '__main__':
    import matplotlib.pyplot as plt
    from matplotlib.colors import ListedColormap
    from utils_data import get_clusters_2D

    NUM_CLUSTERS = 4
    NUM_CENTROIDS = 4
    PLOT_STEP = 0.01

    X_train = get_clusters_2D(num_clusters=NUM_CLUSTERS, radius=1.5, sigma_diag=0.1)[0]
    normalizer = Normalizer(X_train=X_train)
    
    # Train
    h = KMeansClustering(num_centroids=NUM_CENTROIDS)
    centroids, centroids_log = h.fit(X_train=normalizer(X_train), with_logs=True)
    centroids = normalizer.unnormalize(centroids)
    centroids_log = normalizer.unnormalize(centroids_log)
    yhat = h(X_train)

    # Plot all centroids and examples 
    fig, ax = plt.subplots()
    ax.set_title(h)
    for k in torch.arange(h.num_centroids):
        # Centroids' path
        ax.plot(centroids_log[:, k, 0], centroids_log[:, k, 1],
                color=COLORS[k], linestyle='dashed')
        ax.scatter(centroids_log[0:-1, k, 0], centroids_log[0:-1, k, 1],
                color=COLORS[k], marker='o', s=15)
        ax.scatter(centroids_log[-1, k, 0], centroids_log[-1, k, 1],
                color=COLORS[k], marker='o', s=70)

        # Training data
        ax.scatter(X_train[yhat.squeeze()==k, 0], X_train[yhat.squeeze()==k, 1],
                color=COLORS[k], alpha=0.7, s=2, zorder=100)

    # Decision boundary
    ptp_X = X_train.max(dim = 0)[0] - X_train.min(dim = 0)[0]
    plot_x1 = torch.arange(X_train[:, 0].min() - 0.2*ptp_X[0], X_train[:, 0].max() + 0.2*ptp_X[1], PLOT_STEP)
    plot_x2 = torch.arange(X_train[:, 1].min() - 0.2*ptp_X[0], X_train[:, 1].max() + 0.2*ptp_X[1], PLOT_STEP)
    x1, x2 = torch.meshgrid([plot_x1 - PLOT_STEP/2, plot_x2 - PLOT_STEP/2])
    x = torch.cat([x1.flatten().unsqueeze(dim=1), x2.flatten().unsqueeze(dim=1)], dim=1)
    plot_yhat = h.forward(x).reshape([plot_x1.size()[0], plot_x2.size()[0]])

    ax.pcolormesh(x1, x2, plot_yhat,
                  cmap=ListedColormap(COLORS[0:h.num_centroids]), alpha=0.2, shading='auto')
    
    plt.show()

In [None]:
class KMeansClassifier(nn.Module):
    def __init__(self, num_centroids_per_class:int):
        super(KMeansClassifier, self).__init__()
        self.num_classes = -1
        self.num_centroids_per_class = num_centroids_per_class
    
    def forward(self, input:Tensor, with_distance:bool=False):
        distance = torch.zeros([input.shape[0], self.num_classes, self.num_centroids_per_class])
        for k in range(self.num_classes):
            distance[:, k, :] = self.clusterers[k](input=input, with_distance=True)[1]
        
        yhat = distance.min(dim=2)[0].argmin(dim=1)
        
        if with_distance == False:
            return yhat
        else:
            return yhat, distance
        
    def fit(self,
        X_train:Tensor,
        y_train:Tensor,
        with_logs:bool=False
    ) -> Tensor:
        self.num_classes = y_train.unique().shape[0]
        self.clusterers = [KMeansClustering(num_centroids=self.num_centroids_per_class) for k in range(self.num_classes)]
        self.idx_k = [(y_train.squeeze(dim=1) == k).nonzero().squeeze(dim=1) for k in range(self.num_classes)]
        
        centroids = [None]*self.num_classes
        centroids_log = [None]*self.num_classes
        for k in range(self.num_classes):
            fit_outputs = self.clusterers[k].fit(
                X_train=X_train[self.idx_k[k]],
                with_logs=with_logs,
            )
            
            if with_logs == False:
                centroids[k] = fit_outputs
            else:
                centroids[k] = fit_outputs[0]
                centroids_log[k] = fit_outputs[1].clone()
        
        self.centroids = torch.stack(tensors=centroids, dim=0)
        
        if with_logs == False:
            return self.centroids.clone()
        else:
            return self.centroids.clone(), centroids_log
    
    def extra_repr(self) -> str:
        params = {
            'num_classes':             self.num_classes,
            'num_centroids_per_class': self.num_centroids_per_class,
        }
        return ', '.join([f'{k}={v}' for k, v in params.items() if v is not None])

In [None]:
if __name__ == '__main__':
    import matplotlib.pyplot as plt
    from matplotlib.colors import ListedColormap
    from utils_data import get_clusters_2D

    NUM_CLUSTERS = 4
    NUM_CLASSES = 4
    NUM_CENTROIDS_PER_CLASS = 3
    PLOT_STEP = 0.01

    X_train, y_train = get_clusters_2D(num_clusters=NUM_CLUSTERS, radius=1, sigma_diag=0.2)
    normalizer = Normalizer(X_train=X_train)
    
    # Train
    h = KMeansClassifier(num_centroids_per_class=NUM_CENTROIDS_PER_CLASS)
    centroids, centroids_log = h.fit(X_train=normalizer(X_train), y_train=y_train, with_logs=True)
    centroids = normalizer.unnormalize(centroids)
    centroids_log = [normalizer.unnormalize(cl) for cl in centroids_log]
    yhat = h(X_train)

    # Plot all centroids and examples 
    fig, ax = plt.subplots()
    ax.set_title(h)
    
    for k in torch.arange(h.num_classes):
        for j in torch.arange(h.num_centroids_per_class):
            # Centroids' path
            ax.plot(centroids_log[k][:, j, 0], centroids_log[k][:, j, 1],
                    color=COLORS[k], linestyle='dashed')
            ax.scatter(centroids_log[k][0:-1, j, 0], centroids_log[k][0:-1, j, 1],
                    color=COLORS[k], marker='o', s=15)
            ax.scatter(centroids_log[k][-1, j, 0], centroids_log[k][-1, j, 1],
                    color=COLORS[k], marker='o', s=70)

        # Training data
        ax.scatter(X_train[y_train.squeeze()==k, 0], X_train[y_train.squeeze()==k, 1],
                color=COLORS[k], alpha=0.7, s=2, zorder=100)

    # Decision boundary
    ptp_X = X_train.max(dim = 0)[0] - X_train.min(dim = 0)[0]
    plot_x1 = torch.arange(X_train[:, 0].min() - 0.2*ptp_X[0], X_train[:, 0].max() + 0.2*ptp_X[1], PLOT_STEP)
    plot_x2 = torch.arange(X_train[:, 1].min() - 0.2*ptp_X[0], X_train[:, 1].max() + 0.2*ptp_X[1], PLOT_STEP)
    x1, x2 = torch.meshgrid([plot_x1 - PLOT_STEP/2, plot_x2 - PLOT_STEP/2])
    x = torch.cat([x1.flatten().unsqueeze(dim = 1), x2.flatten().unsqueeze(dim = 1)], dim = 1)
    plot_yhat = h.forward(x).reshape([plot_x1.size()[0], plot_x2.size()[0]])

    ax.pcolormesh(x1, x2, plot_yhat,
                  cmap=ListedColormap(COLORS[0:h.num_classes]), alpha=0.2, shading='auto')
    
    plt.show()

In [None]:
class LearningVectorQuantization(nn.Module):
    def __init__(self, num_centroids_per_class:int):
        super(LearningVectorQuantization, self).__init__()
        self.num_centroids_per_class = num_centroids_per_class
        self.num_classes = -1
    
    def forward(self, input:Tensor):
        input = input.view(input.shape[0], 1, 1, *input.shape[1:])
        distance = (input - self.centroids).norm(p=2, dim=3)
        yhat = distance.min(dim=2)[0].argmin(dim=1)
        return yhat
    
    def fit(self, X_train:Tensor, y_train:Tensor, centroids_init:Optional[Tensor]=None, num_iters:Optional[int]=None, lr_init:float=0.1) -> Tensor:
        self.num_classes = y_train.unique().shape[0]
        self.idx_k = [(y_train.squeeze(dim=1) == k).nonzero().squeeze(dim=1) for k in range(self.num_classes)]
        if num_iters is None:
            num_iters = X_train.shape[0]
        
        # 1. Choose R initial prototypes for each class: m1(k), m2(k), . . . , mR(k),
        # k = 1, 2, . . . , K, for example, by sampling R training points at random from each class.
        if centroids_init is None:
            centroids = [None]*self.num_classes
            for k in range(self.num_classes):
                rand_idx_k = torch.randperm(n=self.idx_k[k].shape[0])[0:self.num_centroids_per_class]
                centroids[k] = X_train[self.idx_k[k][rand_idx_k]]
            self.centroids = torch.stack(tensors=centroids, dim=0)
        else:
            self.centroids = centroids_init

        # 2. Sample a training point xi randomly (with replacement), and let (j, k)
        # index the closest prototype mj(k) to xi.
        # 3. Repeat step 2, decreasing the learning rate ǫ with each iteration to wards zero.
        for i in range(num_iters):
            lr = lr_init*(num_iters-i)/num_iters
            
            rand_idx = torch.randint(low=0, high=X_train.shape[0], size=[1]).item()
            x, y = X_train[rand_idx], y_train[rand_idx]
            x = x.view(1, 1, *x.shape)
            
            distance = (x - self.centroids).norm(p=2, dim=2)
            closest_centroid = (distance == distance.min()).nonzero().squeeze(dim=0)
            
            direction = (x - self.centroids[*closest_centroid] > 0).to(dtype=torch.int)
            self.centroids[*closest_centroid] = self.centroids[*closest_centroid] + direction*lr*(x-self.centroids[*closest_centroid])
        return self.centroids.clone()
            
    def extra_repr(self) -> str:
        params = {
            'num_classes':             self.num_classes,
            'num_centroids_per_class': self.num_centroids_per_class,
        }
        return ', '.join([f'{k}={v}' for k, v in params.items() if v is not None])

In [None]:
if __name__ == '__main__':
    import matplotlib.pyplot as plt
    from matplotlib.colors import ListedColormap
    from utils_data import get_clusters_2D

    NUM_CLUSTERS = 4
    NUM_CLASSES = 4
    NUM_CENTROIDS_PER_CLASS = 3
    LR_INIT = 0.01
    NUM_ITERS = 6000
    PLOT_STEP = 0.01

    X_train, y_train = get_clusters_2D(num_clusters=NUM_CLUSTERS, radius=1, sigma_diag=0.4)
    normalizer = Normalizer(X_train=X_train)
    
    # Train
    h = LearningVectorQuantization(num_centroids_per_class=NUM_CENTROIDS_PER_CLASS)
    centroids = h.fit(X_train=normalizer(X_train), y_train=y_train, lr_init=LR_INIT, num_iters=NUM_ITERS)
    centroids = normalizer.unnormalize(centroids)
    yhat = h(X_train)

    # Plot all centroids and examples 
    fig, ax = plt.subplots()
    ax.set_title(h)
    
    for k in torch.arange(h.num_classes):
        for j in torch.arange(h.num_centroids_per_class):
            ax.scatter(centroids[k, j, 0], centroids[k, j, 1],
                    color=COLORS[k], marker='o', s=70)

        # Training data
        ax.scatter(X_train[y_train.squeeze()==k, 0], X_train[y_train.squeeze()==k, 1],
                color=COLORS[k], alpha=0.7, s=2, zorder=100)

    # Decision boundary
    ptp_X = X_train.max(dim = 0)[0] - X_train.min(dim = 0)[0]
    plot_x1 = torch.arange(X_train[:, 0].min() - 0.2*ptp_X[0], X_train[:, 0].max() + 0.2*ptp_X[1], PLOT_STEP)
    plot_x2 = torch.arange(X_train[:, 1].min() - 0.2*ptp_X[0], X_train[:, 1].max() + 0.2*ptp_X[1], PLOT_STEP)
    x1, x2 = torch.meshgrid([plot_x1 - PLOT_STEP/2, plot_x2 - PLOT_STEP/2])
    x = torch.cat([x1.flatten().unsqueeze(dim = 1), x2.flatten().unsqueeze(dim = 1)], dim = 1)
    plot_yhat = h(x).reshape([plot_x1.size()[0], plot_x2.size()[0]])

    ax.pcolormesh(x1, x2, plot_yhat,
                  cmap=ListedColormap(COLORS[0:h.num_classes]), alpha=0.2, shading='auto')
    
    plt.show()

In [None]:
if __name__ == '__main__':
    import matplotlib.pyplot as plt
    from matplotlib.colors import ListedColormap
    from utils_data import get_clusters_2D

    NUM_CLUSTERS = 4
    NUM_CLASSES = 4
    NUM_CENTROIDS_PER_CLASS = 3
    LR_INIT = 0.01
    NUM_ITERS = 6000
    PLOT_STEP = 0.01

    X_train, y_train = get_clusters_2D(num_clusters=NUM_CLUSTERS, radius=1, sigma_diag=0.4)
    normalizer = Normalizer(X_train=X_train)
    
    # Train
    kmeans = KMeansClassifier(num_centroids_per_class=NUM_CENTROIDS_PER_CLASS)
    kmeans_centroids, centroids_log = kmeans.fit(X_train=normalizer(X_train), y_train=y_train, with_logs=True)
    kmeans_centroids = normalizer.unnormalize(kmeans_centroids)
    
    lvq = LearningVectorQuantization(num_centroids_per_class=NUM_CENTROIDS_PER_CLASS)
    lvq_centroids = lvq.fit(X_train=normalizer(X_train), y_train=y_train, lr_init=LR_INIT, num_iters=NUM_ITERS, centroids_init=normalizer(kmeans_centroids))
    lvq_centroids = normalizer.unnormalize(lvq_centroids)
    yhat = lvq(X_train)


    # Plot all centroids and examples 
    fig, ax = plt.subplots(nrows=1, ncols=2, figsize=[12, 5], squeeze=False)
    ax[0, 0].set_title(kmeans)
    
    # K-means
    for k in torch.arange(kmeans.num_classes):
        for j in torch.arange(kmeans.num_centroids_per_class):
            # Centroids' path
            ax[0, 0].plot(centroids_log[k][:, j, 0], centroids_log[k][:, j, 1],
                    color=COLORS[k], linestyle='dashed')
            ax[0, 0].scatter(centroids_log[k][0:-1, j, 0], centroids_log[k][0:-1, j, 1],
                    color=COLORS[k], marker='o', s=15)
            ax[0, 0].scatter(centroids_log[k][-1, j, 0], centroids_log[k][-1, j, 1],
                    color=COLORS[k], marker='o', s=70)
        # Training data
        ax[0, 0].scatter(X_train[y_train.squeeze()==k, 0], X_train[y_train.squeeze()==k, 1],
                color=COLORS[k], alpha=0.7, s=2, zorder=100)
    # Decision boundary
    ptp_X = X_train.max(dim = 0)[0] - X_train.min(dim = 0)[0]
    plot_x1 = torch.arange(X_train[:, 0].min() - 0.2*ptp_X[0], X_train[:, 0].max() + 0.2*ptp_X[1], PLOT_STEP)
    plot_x2 = torch.arange(X_train[:, 1].min() - 0.2*ptp_X[0], X_train[:, 1].max() + 0.2*ptp_X[1], PLOT_STEP)
    x1, x2 = torch.meshgrid([plot_x1 - PLOT_STEP/2, plot_x2 - PLOT_STEP/2])
    x = torch.cat([x1.flatten().unsqueeze(dim = 1), x2.flatten().unsqueeze(dim = 1)], dim = 1)
    plot_yhat = kmeans.forward(x).reshape([plot_x1.size()[0], plot_x2.size()[0]])

    ax[0, 0].pcolormesh(x1, x2, plot_yhat,
                  cmap=ListedColormap(COLORS[0:kmeans.num_classes]), alpha=0.2, shading='auto')
    
    # LVQ
    ax[0, 1].set_title(lvq)
    for k in torch.arange(lvq.num_classes):
        for j in torch.arange(lvq.num_centroids_per_class):
            ax[0, 1].scatter(lvq_centroids[k, j, 0], lvq_centroids[k, j, 1],
                    color=COLORS[k], marker='o', s=70)
        # Training data
        ax[0, 1].scatter(X_train[y_train.squeeze()==k, 0], X_train[y_train.squeeze()==k, 1],
                color=COLORS[k], alpha=0.7, s=2, zorder=100)

    # Decision boundary
    ptp_X = X_train.max(dim = 0)[0] - X_train.min(dim = 0)[0]
    plot_x1 = torch.arange(X_train[:, 0].min() - 0.2*ptp_X[0], X_train[:, 0].max() + 0.2*ptp_X[1], PLOT_STEP)
    plot_x2 = torch.arange(X_train[:, 1].min() - 0.2*ptp_X[0], X_train[:, 1].max() + 0.2*ptp_X[1], PLOT_STEP)
    x1, x2 = torch.meshgrid([plot_x1 - PLOT_STEP/2, plot_x2 - PLOT_STEP/2])
    x = torch.cat([x1.flatten().unsqueeze(dim = 1), x2.flatten().unsqueeze(dim = 1)], dim = 1)
    plot_yhat = lvq(x).reshape([plot_x1.size()[0], plot_x2.size()[0]])

    ax[0, 1].pcolormesh(x1, x2, plot_yhat,
                  cmap=ListedColormap(COLORS[0:lvq.num_classes]), alpha=0.2, shading='auto')
    plt.show()