In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
import numpy as np
import matplotlib
import pickle
from binary_mnist_pathnet import *
from sklearn.manifold import TSNE 
from sklearn.decomposition import PCA

sns.set(font_scale=1, context='notebook')
matplotlib.rcParams['figure.figsize'] = 10,10
plt.rcParams['axes.facecolor'] = '#f5f5f5f5'


# Define Visualization Functions


In [None]:
def visualize_gates(gates, images, num_test_digits=10, shuffle=True):
    '''
    Visualize num_test_digits examples of gate activations
    '''
    # Indiviual cell is height=1, width=0.4
    width = 1 * gates.shape[2] # Cell width * num of modules
    height = 0.4 * gates.shape[1] # Cell height * num of layers
    figsize = (width, height) 

    plot_num = num_test_digits * 100 + 11
    for i in range(num_test_digits):
        # To shuffle
        if shuffle:
            k = (i + np.random.randint(low=0, high=gates.shape[0]-1)) % gates.shape[0]
        else:
            k = i

        plt.figure(figsize=(5, 2))
        plt.title('Number is: '+str(labels[k]))
        sns.heatmap(gates[k,:,:])
        plt.figure(figsize=(2, 2))
        plt.imshow(images[k])


def compute_cosine_sim(gates, parameter_dict, num_test_digits, trial=2):
    '''
    Compute cosine similarity
    '''
    gate_vectors = np.reshape(gates, (-1, parameter_dict['L']*parameter_dict['M']))
    #print(gate_vectors.shape)
    gate_vectors = np.transpose(gate_vectors)

    output = np.zeros(num_test_digits)
    for i in range(num_test_digits):
        a = gate_vectors[:, trial]
        b = gate_vectors[:, i]
        an = a/np.linalg.norm(a)
        bn = b/np.linalg.norm(b)
        output[i] = np.dot(an,bn)

    print('output', output)
    elems = np.argsort(output)
    print('Most similar elems to trial', elems[::-1])
    
    
def visualize_model(gates_reshaped, labels, model):
    '''
    Visualize the data X, with label y with the model (t-SNE or PCA)
    '''
    print(gates_reshaped.shape) # sanity check
    independent_labels = list(set(labels))
    print(independent_labels) # sanity check

    x = model.fit_transform(gates_reshaped)
    y = labels

    # Add more colors if doing something other than MNIST
    colors = ['b','g','#00FFFF','c','m','y','k','#00A4BA','r','#F4C2C2']

    fig = plt.figure()
    ax1 = fig.add_subplot(111)

    for i in (independent_labels):
        mask_i = y == i
        x_i = x[mask_i]
        ax1.scatter(x_i[:,0], x_i[:,1], s=10, c=colors[i], label='Label: '+str(i))
    plt.legend(loc='upper left');
    plt.show()


def visualize_comparison(gates_reshaped, labels, model, digits=[7,8]):
    '''
    Same as visualize model but only plots certain digits 
    '''
    mask = labels == digits[0]
    for digit in digits:
        mask_digit = labels == digit
        mask = mask | mask_digit

    gates_reshaped_masked = gates_reshaped[mask, :]
    labels_masked = labels[mask]
    
    x = model.fit_transform(gates_reshaped_masked)
    y = labels_masked

    # Add more colors if doing something other than MNIST
    colors = ['b','g','#00FFFF','c','m','y','k','#00A4BA','r','#F4C2C2']

    fig = plt.figure()
    ax1 = fig.add_subplot(111)

    for i in (digits):
        mask_i = y == i
        x_i = x[mask_i]
        ax1.scatter(x_i[:,0], x_i[:,1], s=10, c=colors[i], label='Label: '+str(i))
    plt.legend(loc='upper left');
    plt.show()


# Define Training, Loading, And Param Setting


In [None]:
# Set params to feed into train
def get_params(M=None, L=None, tensor_size=None, gamma=None, batch_size=None, num_batches=None, 
               learning_rate=None, output_file=None):
    param_dict = {}
    param_dict['M'] = M
    param_dict['L'] = L
    param_dict['tensor_size'] = tensor_size
    param_dict['gamma'] = gamma
    param_dict['batch_size'] = batch_size
    param_dict['num_batches'] = num_batches
    param_dict['learning_rate'] = learning_rate
    param_dict['output_file'] = output_file
    return param_dict

# To call a training function run: 
#train(parameter_dict=param_dict, skip_digits=[8], num_gate_vectors_output=2500)

# To load what you trained:
def load_data(filename):
    with open('output/' + filename + '.pkl', 'rb') as f:
        input = pickle.load(f)

    # Get components from data
    gates = np.array(input[0])
    labels = np.array(input[1])
    images = np.array(input[2])
    parameter_dict = input[3]
    return (gates, labels, images, parameter_dict)




# Encapsulate Train and Vis


In [2]:
def train_and_vis(M=None, L=None, tensor_size=None, gamma=None, batch_size=None, num_batches=None, 
                  learning_rate=None, output_file=None,
                  
                  # Train options
                  skip_digits=[8], num_gate_vectors_output=2500,
                  
                  # Visualize gates opts
                  visualize_gates=True, visualize_features=False, visualize_preds=False,
                  
                  # Compute cosin sim function
                  compute_cosine_sim_gates=False, compute_cosine_sim_featues=False, compute_cosine_sim_preds=False,
                  
                  # Visualize PCA opts
                  visualize_gates_PCA=True, visualize_features_PCA=False, visualize_preds_PCA=False,
                  # Visualize t-SNE opts
                  visualize_gates_TSNE=True, visualize_features_TSNE=False, visualize_preds_TSNE=False,
                  
                  # Visualize comparison opts for PCA and t-SNE
                  model_compare_digits=[1,3,5,7,9],
                  # Visualize PCA comparison opts
                  visualize_gates_PCA_compare=True, visualize_features_PCA_compare=False, visualize_preds_PCA_compare=False,
                  # Visualize t-SNE opts
                  visualize_gates_TSNE_compare=True, visualize_features_TSNE_compare=False, visualize_preds_TSNE_compare=False
                 ):

    # Get the param dict
    param_dict = get_params(M=M, L=L, tensor_size=tensor_size, gamma=gamma, batch_size=batch_size, 
                            num_batches=num_batches, learning_rate=learning_rate, output_file=output_file)

    # Train the network and save data
    train(parameter_dict=param_dict, skip_digits=skip_digits, num_gate_vectors_output=num_gate_vectors_output)


    ########## GATES ##########################################
    # Load the data for gates
    with open('output/' + saved_filename + '.pkl', 'rb') as f:
        input = pickle.load(f)
    # Get components from data
    gates = np.array(input[0])
    labels = np.array(input[1])
    images = np.array(input[2])
    parameter_dict = input[3]

    if visualize_gates:
        visualize_gates(gates, images, num_test_digits=10, shuffle=True)
    if compute_cosine_sim_gates:
        compute_cosine_sim(gates, parameter_dict, num_test_digits=10, trial=2)
    gates_reshaped = np.reshape(gates, [gates.shape[0], -1]) # Preprocess to make gates the correct size
    if visualize_gates_PCA:
        visualize_model(gates_reshaped, labels, PCA())
    if visualize_gates_TSNE:
        visualize_model(gates_reshaped, labels, TSNE())
    if visualize_gates_PCA_compare:
        visualize_comparison(gates_reshaped, labels, PCA(), digits=model_compare_digits)
    if visualize_gates_TSNE_compare:
        visualize_comparison(gates_reshaped, labels, TSNE(), digits=model_compare_digits)


    ########## FEATURES #######################################
    # Load the data for features
    with open('output/' + saved_filename+'__feature_layer_outputs' + '.pkl', 'rb') as f:
        input = pickle.load(f)
    # Get components from data
    features = np.array(input[0])
    labels = np.array(input[1])
    images = np.array(input[2])
    parameter_dict = input[3]

    if visualize_features:
        visualize_features(features, images, num_test_digits=10, shuffle=True)
    if compute_cosine_sim_features:
        compute_cosine_sim(features, parameter_dict, num_test_digits=10, trial=2)
    features_reshaped = np.reshape(features, [features.shape[0], -1]) # Preprocess to make features the correct size
    if visualize_features_PCA:
        visualize_model(features_reshaped, labels, PCA())
    if visualize_features_TSNE:
        visualize_model(features_reshaped, labels, TSNE())
    if visualize_features_PCA_compare:
        visualize_comparison(features_reshaped, labels, PCA(), digits=model_compare_digits)
    if visualize_features_TSNE_compare:
        visualize_comparison(features_reshaped, labels, TSNE(), digits=model_compare_digits)

    ########## PREDS ##########################################
    # Load the data for Preds
    with open('output/' + saved_filename+'__softmax_outputs' + '.pkl', 'rb') as f:
        input = pickle.load(f)
    # Get components from data
    preds = np.array(input[0])
    labels = np.array(input[1])
    images = np.array(input[2])
    parameter_dict = input[3]

    if visualize_preds:
        visualize_preds(preds, images, num_test_digits=10, shuffle=True)
    if compute_cosine_sim_preds:
        compute_cosine_sim(preds, parameter_dict, num_test_digits=10, trial=2)
    preds_reshaped = np.reshape(preds, [preds.shape[0], -1]) # Preprocess to make preds the correct size
    if visualize_preds_PCA:
        visualize_model(preds_reshaped, labels, PCA())
    if visualize_preds_TSNE:
        visualize_model(preds_reshaped, labels, TSNE())
    if visualize_preds_PCA_compare:
        visualize_comparison(preds_reshaped, labels, PCA(), digits=model_compare_digits)
    if visualize_preds_TSNE_compare:
        visualize_comparison(preds_reshaped, labels, TSNE(), digits=model_compare_digits)