In [45]:
import repliclust
import numpy as np

import repliclust.overlap._gradients as _gradients

# define new version of _optimize_centers that stores the loss after each epoch
def _optimize_centers(self, centers, cov_inv, 
                      max_epoch=1000, learning_rate=0.1, tol=1e-10,
                      verbose=False, quiet=False):
    """
    See description in original (overridden) function.
    """
    if not quiet: print("\n[=== optimizing cluster overlaps ===]\n")
    pad_epoch = int(np.maximum(2, np.floor(1+np.log10(max_epoch))))
    
    self.loss_log = []

    epoch_count = 0
    keep_optimizing = (epoch_count < max_epoch)
    while keep_optimizing:
        epoch_order = repliclust.config._rng.permutation(centers.shape[0])
        for i in epoch_order:
            _gradients.update_centers(
                i, centers, cov_inv, 
                learning_rate=learning_rate, 
                overlap_bounds=self.overlap_bounds
                )
        epoch_count += 1
        keep_optimizing = (epoch_count < max_epoch)

        loss = _gradients.total_loss(centers, cov_inv,
                                     self.overlap_bounds)
        self.loss_log.append((epoch_count,loss))
        
        if verbose:
            self._print_optimization_progress(
                    epoch_count, max_epoch, pad_epoch, loss)
        if np.allclose(loss, 0, atol=1e-14, rtol=1e-14):
            if not verbose and not quiet:
                print(" "*17 + "...")
            if not quiet:
                self._print_optimization_result(epoch_count,
                                                pad_epoch)
            return centers

    return centers

repliclust.overlap.centers.ConstrainedOverlapCenters._optimize_centers = _optimize_centers


[=== optimizing cluster overlaps ===]

got that loss!!!
got that loss!!!
got that loss!!!
got that loss!!!
got that loss!!!
got that loss!!!
got that loss!!!
got that loss!!!
got that loss!!!
got that loss!!!
got that loss!!!
got that loss!!!
got that loss!!!
got that loss!!!
got that loss!!!
got that loss!!!
got that loss!!!
got that loss!!!
got that loss!!!
got that loss!!!
got that loss!!!
got that loss!!!
got that loss!!!
got that loss!!!
got that loss!!!
got that loss!!!
got that loss!!!
got that loss!!!
got that loss!!!
got that loss!!!
got that loss!!!
got that loss!!!
got that loss!!!
got that loss!!!
got that loss!!!
                 ...




In [None]:
import matplotlib.pyplot as plt
from statistics import median

epoch_max = 500
n_std = 10

fig, ax = plt.subplots(figsize=(10,5),dpi=300, ncols=2)
ax[0].set_ylabel('Loss')
ax[0].set_title('Loss vs Epoch (Varying Number of Clusters)')
ax[1].set_title('Loss vs Epoch (Varying Dimensionality)')

linestyles = ['solid','dotted','dashed']
markers = ['o', '^', 's']

p_vals = [10,50,100]
k_vals = [2,100,200]

for subplot_idx in [0,1]:
    
    loss_firstmom = np.zeros(shape=(epoch_max,))
    
    ax[subplot_idx].set_xlabel('Epoch')
    ax[subplot_idx].set_yscale('log')
    ax[subplot_idx].set_xlim([1,50])
    ax[subplot_idx].set_ylim([1e-10,10])
    
    vals = k_vals if subplot_idx==0 else p_vals
    for j, val in enumerate(vals):
        
        k = val if subplot_idx==0 else median(k_vals)
        p = val if subplot_idx==1 else median(p_vals)
        
        for i in range(n_std):
            print(k)
            print(p)
            archie = repliclust.Archetype(n_clusters=k, dim=int(p))
            dug = repliclust.DataGenerator(archie)
            X, y, _ = dug.synthesize(quiet=True)

            _, loss = list(zip(*archie.center_sampler.loss_log))
            loss = np.array(loss, dtype='float')
            loss_padded = np.concatenate([loss, np.zeros(shape=(epoch_max-len(loss),))])
            loss_firstmom += loss_padded/n_std

        loss_mean = loss_firstmom

        ax[subplot_idx].plot(np.arange(1,epoch_max+1), loss_mean, zorder=1, color='black', linestyle=linestyles[j])
        #ax.scatter(np.arange(1,epoch_max+1), loss_mean, color='black', facecolor='white', marker=markers[j], zorder=2)
        
plt.savefig('convergence.png',bbox_inches='tight')

2
50
2
50
2
50
2
50
2
50
2
50
2
50
2
50
2
50
2
50
100
50
100
50
100
50
100
50
100
50
100
50
100
50
100
50
100
50
100
50
200
50
200
50
200
50
200
50
200
50
200
50
200
50
200
50
