In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
import numpy as np
import matplotlib
import pickle
from binary_mnist_pathnet import *
from sklearn.manifold import TSNE 
from sklearn.decomposition import PCA

sns.set(font_scale=1, context='notebook')
matplotlib.rcParams['figure.figsize'] = 10,10
plt.rcParams['axes.facecolor'] = '#f5f5f5f5'


# Define Visualization Functions


In [None]:
def visualize_gates(gates, images, num_test_digits=10, shuffle=True):
    '''
    Visualize num_test_digits examples of gate activations
    '''
    # Indiviual cell is height=1, width=0.4
    width = 1 * gates.shape[2] # Cell width * num of modules
    height = 0.4 * gates.shape[1] # Cell height * num of layers
    figsize = (width, height) 

    plot_num = num_test_digits * 100 + 11
    for i in range(num_test_digits):
        # To shuffle
        if shuffle:
            k = (i + np.random.randint(low=0, high=gates.shape[0]-1)) % gates.shape[0]
        else:
            k = i

        plt.figure(figsize=(5, 2))
        plt.title('Number is: '+str(labels[k]))
        sns.heatmap(gates[k,:,:])
        plt.figure(figsize=(2, 2))
        plt.imshow(images[k])


def compute_cosine_sim(gates, parameter_dict, num_test_digits, trial=2):
    '''
    Compute cosine similarity
    '''
    gate_vectors = np.reshape(gates, (-1, parameter_dict['L']*parameter_dict['M']))
    #print(gate_vectors.shape)
    gate_vectors = np.transpose(gate_vectors)

    output = np.zeros(num_test_digits)
    for i in range(num_test_digits):
        a = gate_vectors[:, trial]
        b = gate_vectors[:, i]
        an = a/np.linalg.norm(a)
        bn = b/np.linalg.norm(b)
        output[i] = np.dot(an,bn)

    print('output', output)
    elems = np.argsort(output)
    print('Most similar elems to trial', elems[::-1])
    
    
def visualize_model(gates_reshaped, labels, model):
    '''
    Visualize the data X, with label y with the model (t-SNE or PCA)
    '''
    print(gates_reshaped.shape) # sanity check
    independent_labels = list(set(labels))
    print(independent_labels) # sanity check

    x = model.fit_transform(gates_reshaped)
    y = labels

    # Add more colors if doing something other than MNIST
    colors = ['b','g','#00FFFF','c','m','y','k','#0048BA','r','#F4C2C2']

    fig = plt.figure()
    ax1 = fig.add_subplot(111)

    for i in (independent_labels):
        mask_i = y == i
        x_i = x[mask_i]
        ax1.scatter(x_i[:,0], x_i[:,1], s=10, c=colors[i], label='Label: '+str(i))
    plt.legend(loc='upper left');
    plt.show()


def visualize_comparison(gates_reshaped, labels, model, a=7, b=8):
    y = labels
    # Comparing a and b
    maska = y == a
    maskb = y == b
    reshapeda = gates_reshaped[maska]
    num_a = len(reshapeda) # number of 'a's (used below for scatterplot)
    reshapedb = gates_reshaped[maskb]
    rab = np.concatenate((reshapeda, reshapedb), axis=0)

    '''
    print(rab.shape)
    print(reshapeda.shape)
    print(reshapedb.shape)
    print(num_a)
    '''

    xab = model.fit_transform(rab)

    fig = plt.figure()
    ax1 = fig.add_subplot(111)

    ax1.scatter(xab[:num_a,0], xab[:num_a,1], s=10, c='#0048BA', label='Label: '+str(a))
    ax1.scatter(xab[num_a:,0], xab[num_a:,1], s=10, c='r', label='Label: '+str(b))
    plt.legend(loc='upper left');
    plt.show()
    



# Define Training, Loading, And Param Setting


In [None]:
# Set params to feed into train
def get_params(M=None, L=None, tensor_size=None, gamma=None, batch_size=None, num_batches=None, 
               learning_rate=None, output_file=None):
    param_dict = {}
    param_dict['M'] = M
    param_dict['L'] = L
    param_dict['tensor_size'] = tensor_size
    param_dict['gamma'] = gamma
    param_dict['batch_size'] = batch_size
    param_dict['num_batches'] = num_batches
    param_dict['learning_rate'] = learning_rate
    param_dict['output_file'] = output_file
    return param_dict

# To call a training function run: 
#train(parameter_dict=param_dict, skip_digit=8, num_gate_vectors_output=2500)

# To load what you trained:
def load_data(filename):
    with open('output/' + filename + '.pkl', 'rb') as f:
        input = pickle.load(f)

    # Get components from data
    gates = np.array(input[0])
    labels = np.array(input[1])
    images = np.array(input[2])
    parameter_dict = input[3]
    return (gates, labels, images, parameter_dict)




# Encapsulate Train and Vis


In [None]:
train_and_vis(M=None, L=None, tensor_size=None, gamma=None, batch_size=None, num_batches=None, 
              learning_rate=None, output_file=None, 
              skip_digit=8, num_gate_vectors_output=2500,
              visualize_gates=True, visualize_gates_num=10, visualize_gates_shuffle=True,
              compute_cosine_sim=False, compute_cosine_sim_num=10, compute_cosine_sim_trial=2,
              visualize_model_TSNE=False,
              visualize_model_PCA=True,
              visualize_comparison_TSNE=True, visualize_comparison_TSNE_a=7, visualize_comparison_TSNE_b=8,
              visualize_comparison_PCA=True, visualize_comparison_PCA_a=7, visualize_comparison_PCA_b=8):
    
    # Get Param dict
    param_dict = get_params(M=M, L=L, tensor_size=tensor_size, gamma=gamma, batch_size=batch_size, 
                            num_batches=num_batches, learning_rate=learning_rate, output_file=output_file)
    
    # Train and save
    train(parameter_dict=param_dict, skip_digit=skip_digit, num_gate_vectors_output=num_gate_vectors_output)
    
    # Load data
    gates, labels, images, parameter_dict = load_data(filename)
    
    ## VISUALIZATIONS
    
    # Preprocess to make gates the correct size
    gates_reshaped = np.reshape(gates, [gates.shape[0], -1])
    
    # Vis Gates
    if visualize_gates:
        visualize_gates(gates, images, num_test_digits=visualize_gates_num, shuffle=visualize_gates_shuffle)
    
    # Compute cosine sim
    if compute_cosine_sim:
        compute_cosine_sim(gates, parameter_dict, compute_cosine_sim_num, trial=compute_cosine_sim_trial)

    # Visualize PCA
    if visualize_model_PCA:
        visualize_model(gates_reshaped, labels, PCA())

    # Visualize t-SNE
    if visualize_model_TSNE:
        visualize_model(gates_reshaped, labels, TSNE())

    # Visualize PCA class comparison
    if visualize_comparison_TSNE:
        visualize_comparison(gates_reshaped, labels, TSNE(),
                             a=visualize_comparison_TSNE_a, b=visualize_comparison_TSNE_b)

    # Visualize t-SNE class comparison
    if visualize_comparison_PCA:
        visualize_comparison(gates_reshaped, labels, PCA(),
                             a=visualize_comparison_PCA_a, b=visualize_comparison_PCA_b)