In [None]:
import os
import numpy as np
import scipy as sp

# plotting lib(s) and updated default plot settings
import matplotlib.pyplot as plt
plt.style.use('default')
plt.rcParams['figure.dpi'] = 100
plt.rcParams['savefig.dpi'] = 200

from mpl_toolkits.mplot3d import Axes3D
%matplotlib widget

In [None]:
EXPERIMENT = 'MNIST_FF'

### Get Level-set Traversal Points

In [None]:
points = np.load('{}_all_traversal_points.npy'.format(EXPERIMENT))

In [None]:
points_mean, points_std = np.mean(points, axis=0), np.std(points, axis=0)

In [None]:
points_normalized = (points - points_mean)/points_std

### Principal Component Analysis

In [None]:
from sklearn import decomposition
NUM_PC = 6
pca = decomposition.PCA(n_components=NUM_PC, random_state=0)

In [None]:
pca.fit(points_normalized)

In [None]:
fig = plt.figure()
plt.clf()
plt.barh(range(1,NUM_PC+2), pca.explained_variance_ratio_.tolist()+[pca.explained_variance_ratio_.sum()], color = ['b']*NUM_PC+['r'])

plt.xlabel('Explained Variance Ratio')
plt.ylabel('Principal Components')
plt.yticks(range(1, NUM_PC+2), list(range(1, NUM_PC+1))+['Cumulative'])
plt.xticks(np.arange(0, 1.1, 0.1).astype('float16'), np.arange(0, 1.1, 0.1).astype('float16'))
plt.gca().invert_yaxis()
plt.show()


### Visualize Level-set Traversal 

In [None]:
avg_endpoint = np.load('{}_avg_endpoint_traversal.npy'.format(EXPERIMENT))

In [None]:
fig = plt.figure(figsize=(12, 6))

ax = fig.add_subplot(1, 2, 1, projection='3d')

for starting_index in range(0,len(points), len(points)//5):
    points_reduced_dim = pca.transform((points[starting_index:starting_index+len(points)//5] - points_mean)/points_std)

    #standardize(_arr, _arr.mean(), _arr.std())
    #*standardize(points_reduced_dim.mean(axis=0), points_reduced_dim.mean(axis=0)
    sctr = ax.scatter(points_reduced_dim[:, 0], points_reduced_dim[:, 1], points_reduced_dim[:, 2], c = range(points_reduced_dim.shape[0]))
    ax.scatter(*points_reduced_dim.mean(axis=0)[:3], color='r', marker='*', s = 100)
    avg = pca.transform([(avg_endpoint - points_mean)/points_std])
    ax.scatter(*avg.ravel()[:3], color='b', marker='P', s = 200, alpha=0.1)



ax.set_xlabel('PC1')
ax.set_ylabel('PC2')
ax.set_zlabel('PC3')
ax.view_init(elev=30, azim=45)

ax = fig.add_subplot(1, 2, 2, projection='3d')
for starting_index in range(0,len(points), len(points)//5):
    points_reduced_dim = pca.transform((points[starting_index:starting_index+len(points)//5] - points_mean)/points_std)

    #standardize(_arr, _arr.mean(), _arr.std())
    #*standardize(points_reduced_dim.mean(axis=0), points_reduced_dim.mean(axis=0)
    sctr = ax.scatter(points_reduced_dim[:, 0+3], points_reduced_dim[:, 1+3], points_reduced_dim[:, 2+3], c = range(points_reduced_dim.shape[0]))
    ax.scatter(*points_reduced_dim.mean(axis=0)[3:], color='r', marker='*', s = 100)
    avg = pca.transform([(avg_endpoint - points_mean)/points_std])
    ax.scatter(*avg.ravel()[3:], color='b', marker='P', s = 200, alpha=0.1)

ax.set_xlabel('PC4')
ax.set_ylabel('PC5')
ax.set_zlabel('PC6')
ax.view_init(elev=30, azim=45)





fig.colorbar(sctr, ax=fig.get_axes(), shrink=0.6, pad = 0.15, ticks=np.arange(0, points_reduced_dim.shape[0]+1, (points_reduced_dim.shape[0] - points_reduced_dim.shape[0]%10)/10))

### Visualizing Level-set Travsersal and Average Weight Decay-Endpoint

In [None]:
avg_endpoint_l2 = np.load('L2_{}_avg_endpoint_l2.npy'.format(EXPERIMENT))

In [None]:
fig = plt.figure(figsize=(12, 6))

ax = fig.add_subplot(1, 2, 1, projection='3d')

for starting_index in range(0,len(points), len(points)//5):
    points_reduced_dim = pca.transform((points[starting_index:starting_index+len(points)//5] - points_mean)/points_std)

    #standardize(_arr, _arr.mean(), _arr.std())
    #*standardize(points_reduced_dim.mean(axis=0), points_reduced_dim.mean(axis=0)
    sctr = ax.scatter(points_reduced_dim[:, 0], points_reduced_dim[:, 1], points_reduced_dim[:, 2], c = range(points_reduced_dim.shape[0]))
    ax.scatter(*points_reduced_dim.mean(axis=0)[:3], color='r', marker='*', s = 100)
    avg = pca.transform([(avg_endpoint - points_mean)/points_std])
    ax.scatter(*avg.ravel()[:3], color='b', marker='P', s = 200, alpha=0.1)
    avgl2 = pca.transform([(avg_endpoint_l2 - points_mean)/points_std])
    ax.scatter(*avgl2.ravel()[:3], color='k', marker='p', s = 200, alpha=0.1)



ax.set_xlabel('PC1')
ax.set_ylabel('PC2')
ax.set_zlabel('PC3')
ax.view_init(elev=30, azim=45)

ax = fig.add_subplot(1, 2, 2, projection='3d')
for starting_index in range(0,len(points), len(points)//5):
    points_reduced_dim = pca.transform((points[starting_index:starting_index+len(points)//5] - points_mean)/points_std)

    #standardize(_arr, _arr.mean(), _arr.std())
    #*standardize(points_reduced_dim.mean(axis=0), points_reduced_dim.mean(axis=0)
    sctr = ax.scatter(points_reduced_dim[:, 0+3], points_reduced_dim[:, 1+3], points_reduced_dim[:, 2+3], c = range(points_reduced_dim.shape[0]))
    ax.scatter(*points_reduced_dim.mean(axis=0)[3:], color='r', marker='*', s = 100)
    avg = pca.transform([(avg_endpoint - points_mean)/points_std])
    ax.scatter(*avg.ravel()[3:], color='b', marker='P', s = 200, alpha=0.1)
    avgl2 = pca.transform([(avg_endpoint_l2 - points_mean)/points_std])
    ax.scatter(*avgl2.ravel()[3:], color='k', marker='p', s = 200, alpha=0.1)

ax.set_xlabel('PC4')
ax.set_ylabel('PC5')
ax.set_zlabel('PC6')
ax.view_init(elev=30, azim=45)





fig.colorbar(sctr, ax=fig.get_axes(), shrink=0.6, pad = 0.15, ticks=np.arange(0, points_reduced_dim.shape[0]+1, (points_reduced_dim.shape[0] - points_reduced_dim.shape[0]%10)/10))