This NB is visualizing the aligned and averaged Inputs/NAPs  
To this end we perform:
- visualization of GradNAvAI (gradient-weighted normalized averaging of aligned inputs), which is basically a GradNAP o the input
- clustering of GradNAPs in each layer for phonemes and graphemes respectively
- ERP curves of filters of highest activation (GradNAPs as line plots)
- featurevis(AM) for the filters of the ERP approach (maybe sets of neurons?)

In [1]:
import sys
import numpy as np
import pickle
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics.pairwise import pairwise_distances
from scipy.cluster import hierarchy
from sklearn.metrics import silhouette_score

import time
import os
import copy
import string
char_list = list(" '" + string.ascii_lowercase + '12 ')

def absmax(nd_array):
    a = np.max(np.abs(nd_array))
    return((-a,a))

def elapsed_time(t_start,unit):
    t_end = time.time()
    d = t_end - t_start
    if(unit=='min'):
        d /= 60
    elif(unit=='h'):
        d /= 3600
    print('%.2f '%d, unit + ' elapsed',sep='')

In [2]:
data_dir = "/data/asr_introspection/"

with open(data_dir + "vocabularies.pkl", "rb") as input_file:
    graphemes_list, phonemes_list, _ = pickle.load(input_file)
joint_list = graphemes_list + phonemes_list

In [3]:
# compute averages from one or multiple files
with open(data_dir + 'summation_abs_actxgrad_all_noblank.pkl', 'rb') as input_file:
    summation, n_samples = pickle.load(input_file)

# normalizer (average acts/grads over all characters (only chars!))
# should be similar to the blank label results <- check that
normalizer = copy.deepcopy(summation)
for layer_id in range(12):
    for value_set in ['acts','grads']:
        normalizer[layer_id][value_set] = np.sum(summation[layer_id][value_set][:27,:,:],0)/np.sum(n_samples[:27])

# average for each character/phoneme and both for acts and grads
averages = copy.deepcopy(summation)
normalized_averages = copy.deepcopy(summation)
for layer_id in range(12):
    for char_id in range(67):
        if(n_samples[char_id]>0):
            for value_set in ['acts','grads']:
                averages[layer_id][value_set][char_id,:,:] = summation[layer_id][value_set][char_id,:,:]/n_samples[char_id]
                normalized_averages[layer_id][value_set][char_id,:,:] = averages[layer_id][value_set][char_id,:,:] - normalizer[layer_id][value_set]

In [104]:
def draw_heatmap_on_axis(plot_mat, axes, row_id, col_id, cmap_type, aspect):
    if cmap_type=='div':
        cmap = plt.cm.bwr
    else:
        cmap = mpl.colors.LinearSegmentedColormap.from_list("whitered", [(1,1,1),(1,0,0)] , N=100)
    heatmap = axes[row_id,col_id].pcolor(plot_mat, cmap=cmap)
    axes[row_id,col_id].set_aspect(aspect)
    #fig.colorbar(heatmap,ax=axes[row_id,col_id],orientation=cbar_orientation)
    clim = absmax(plot_mat)
    if cmap_type=='seq': clim=(0,clim[1])
    heatmap.set_clim(clim)

def get_responsive_units(act_mat, n_top, absolute=False):
    if(absolute):
        abs_act_mat = np.abs(act_mat)
        AUC_estimate = np.sum(abs_act_mat,1)
        responsive_units = np.argsort(AUC_estimate)[::-1][:n_top]
        responsive_units = (np.sign(np.sum(act_mat,1))[responsive_units] * responsive_units).astype('int32')
    else:
        neg_act_mat = copy.deepcopy(act_mat)
        neg_act_mat[neg_act_mat>0] = 0
        neg_AUC_estimate = -1*np.sum(neg_act_mat,1)
        pos_act_mat = copy.deepcopy(act_mat)
        pos_act_mat[pos_act_mat<0] = 0
        pos_AUC_estimate = np.sum(pos_act_mat,1)
        responsive_units = np.concatenate((np.argsort(pos_AUC_estimate)[::-1][:n_top],
                                           np.argsort(neg_AUC_estimate)[::-1][:n_top]))
    return(responsive_units)

def steps_to_seconds(s):
    return((s*206)/2.075)

layers = [1,2]#np.arange(12)
layer_ranges = [206,80,74,68,62,56,50,44,38,32,1,1,1]
n_rows = 2
character_id_list = [1]#np.arange(len(joint_list))
n_responsive_units = 5
responsive_units_array = np.zeros((len(joint_list),len(layers),n_responsive_units),dtype='int32')


figure_dir = "/project/asr_introspection/figures/"
fig_format = "svg"

do_values_plots = True
if(do_values_plots):
    for char_id in character_id_list:
        print(joint_list[char_id] + " - " + str(n_samples[char_id]) + " samples")
        fig, axes = plt.subplots(nrows=n_rows, ncols=len(layers), figsize=(len(layers)*4,n_rows*5.5))
        for col_id, layer_id in enumerate(layers):
            cbar_orientation = 'vertical'
            if(layer_id==0):
                cbar_orientation = 'horizontal'

            acts = normalized_averages[layer_id]['grads'][char_id,:,:]
            norm_acts = normalized_averages[layer_id]['acts'][char_id,:,:]
            absgrads = averages[layer_id]['grads'][char_id,:,:]
            gradNAP = np.zeros_like(norm_acts)
            if(n_samples[char_id]>0):
                gradient_mask = np.abs(absgrads) / np.max(np.abs(absgrads))
                gradNAP = norm_acts * gradient_mask


            aspect = (1 if layer_id<10 else 1/50)
            #draw_heatmap_on_axis(acts,axes,0,col_id,'div',aspect)
            draw_heatmap_on_axis(norm_acts,axes,0,col_id,'div',aspect)
#             draw_heatmap_on_axis(absgrads,axes,2,col_id,'seq',aspect)
            draw_heatmap_on_axis(gradNAP,axes,1,col_id,'div',aspect)

        plt.tight_layout()
#         plt.show()
        plt.savefig(figure_dir + "gradNAP_plots_char_"+joint_list[char_id]+"."+fig_format,
                    dpi=300,
                    format=fig_format,
                    bbox_inches='tight')
        plt.close()

do_lineplots = False
if(do_lineplots):
    for char_id in character_id_list:
        fig = plt.figure(figsize=(80,8))
        grid = plt.GridSpec(1, len(layers)-3, hspace=0.2, wspace=0.2, width_ratios=np.array(layer_ranges[1:10])/32)
        print(joint_list[char_id] + " - " + str(n_samples[char_id]) + " samples")
        for col_id, layer_id in enumerate(layers):
            norm_acts = normalized_averages[layer_id]['acts'][char_id,:,:]
            absgrads = averages[layer_id]['grads'][char_id,:,:]
            gradNAP = np.zeros_like(norm_acts)
            if(n_samples[char_id]>0):
                gradient_mask = np.abs(absgrads) / np.max(np.abs(absgrads))
                gradNAP = norm_acts * gradient_mask
            gradNAP = gradNAP*1000 #np.max(np.abs(gradNAP))
            
            responsive_units = get_responsive_units(gradNAP,n_responsive_units,absolute=True)
            responsive_units_array[char_id,col_id,:] = responsive_units
            responsive_gradNAP = copy.deepcopy(gradNAP)
            responsive_gradNAP[np.abs(responsive_units),:] *= 100

            if(col_id>0 and col_id<10):
    #             plt.subplot(1,len(layers)-3,col_id)
                ax = fig.add_subplot(grid[col_id-1])
                center = np.int(layer_ranges[col_id]/2)
                l = mpl.lines.Line2D([center,center],[-10,10],color="grey",linestyle="dashed")
                ax.add_line(l)
                plt.plot(np.transpose(gradNAP),'-',color='black',alpha=0.05)
                plt.plot(np.transpose(gradNAP[np.abs(responsive_units),:]),'-')
                plt.xlabel('time (s)')
                plt.ylabel('gradient-adjusted activation ($10^{-3}$)')
#                 plt.yticks(np.arange(-1,1.01,0.25))
                plt.ylim(np.array(absmax(gradNAP))*1.1)
#                 plt.ylim(np.array([np.min(gradNAP),np.max(gradNAP)])*1.1)
                step_width = steps_to_seconds(0.1)
                xtick_pos = np.arange(center-(3*step_width),center+1+(3*step_width),step_width)
                xtick_labels = ["-0.6","-0.4","-0.2","0","+0.2","+0.4","+0.6"]
                plt.xticks(xtick_pos,xtick_labels)
                plt.xlim((0,layer_ranges[col_id]-1))
                plt.grid(linestyle='dotted')

#         plt.show()
        plt.savefig(figure_dir + "activation_line_plot_char_"+joint_list[char_id]+"."+fig_format,
                    dpi=300,
                    format=fig_format,
                    bbox_inches='tight')
        plt.close()
                

a - 32624 samples


In [73]:
((400+(205*160))/16000)

2.075

In [None]:
2.075  0.25
  206

In [74]:
(0.25*206)/2.075

24.819277108433734

In [31]:
np.transpose(gradNAP).shape

(1, 2048)

In [11]:
np.save('/data/asr_introspection/responsive_units_array_signed.npy',responsive_units_array)

In [53]:
# clustering based on gradNAPs

def plot_cluster_heatmap_on_gridspec(profile_values, labels, cluster_metric, linkage_method, score_percentile, gs, gs_from, clim, rotate_x, plot_colorbar):
    distance_mat = pairwise_distances(profile_values,metric=cluster_metric)
    linkage = hierarchy.linkage(profile_values, metric=cluster_metric, method=linkage_method, optimal_ordering=False)
    dendrogram_order = np.asarray(hierarchy.dendrogram(linkage,no_plot=True)['leaves'],dtype='int32')[::-1]
    
    plt.subplot(gs[gs_from])
    hierarchy.dendrogram(linkage,orientation='left',
                         no_labels=True,
                         color_threshold=np.percentile(linkage[:,2],score_percentile))
    plt.axis('off')

    ax = plt.subplot(gs[gs_from+1])
    plt.imshow(distance_mat[dendrogram_order,:][:,dendrogram_order[::-1]],cmap='hot')
    plt.clim(clim)
    sorted_chars = []
    sorted_chars.extend(labels)
    sorted_chars = np.asarray(sorted_chars)[dendrogram_order]
    ax.xaxis.set_ticks(range(len(sorted_chars)))
    ax.xaxis.set_ticklabels(sorted_chars[::-1])
    plt.xticks(fontsize=8)
    if(rotate_x):
        plt.xticks(rotation=90)
    ax.yaxis.set_ticks(range(len(sorted_chars)))
    ax.yaxis.set_ticklabels(sorted_chars)
    plt.yticks(fontsize=8)
    ax.yaxis.tick_left()
    
    if(plot_colorbar):
        plt.subplot(gs[gs_from+2])
        mpb = plt.imshow([[clim[1]+1]],cmap='hot')
        plt.axis('off')
        plt.clim(clim)
        plt.colorbar()

def plot_cluster_heatmaps(profile_values_list, labels_list, cluster_metric='euclidean', linkage_method='complete', score_percentile=80, figure_dir=None, layer_id=None, fig_format='png'):
    if(len(profile_values_list)!=len(labels_list)):
        raise ValueError("list of profile values (" + str(len(profile_values_list)) + ") must be of same length as list of label sets (" + str(len(labels_list)) + ")")
    n_profiles = len(profile_values_list)
    
    clim=[0,0]
    for profile in profile_values_list:
        distance_mat = pairwise_distances(profile,metric=cluster_metric)
        if(clim[1] < np.max(distance_mat)):
            clim[1] = np.max(distance_mat)
    
    width_ratios = [2.5]
    for i in range(n_profiles):
        width_ratios = [1,10] + width_ratios
    gs = gridspec.GridSpec(1,2*n_profiles+1, width_ratios=width_ratios)
    
    figsize = (np.sum(width_ratios)/1.55,6)
    fig = plt.figure(figsize=figsize)
    
    plot_colorbar = False
    rotate_x = False
    for profile_id in range(n_profiles):
        if(profile_id == n_profiles-1):
            plot_colorbar = True
            rotate_x = True
            
        plot_cluster_heatmap_on_gridspec(profile_values = profile_values_list[profile_id],
                                         labels = labels_list[profile_id],
                                         cluster_metric = cluster_metric,
                                         linkage_method = linkage_method,
                                         score_percentile = score_percentile,
                                         gs=gs,
                                         gs_from = profile_id*2,
                                         clim = clim,
                                         rotate_x = rotate_x,
                                         plot_colorbar = plot_colorbar)
    
    if(figure_dir is None):
        plt.tight_layout()
        plt.show()
    else:
        plt.savefig(figure_dir + "clustering_heatmap_threshold"+str(score_percentile)+"_layer" + str(layer_id).zfill(2) + "."+fig_format,
                    dpi=300,
                    format=fig_format,
                    bbox_inches='tight')
        plt.close() 

def score(dists, Z,score_percentile):
#     threshold = 0.7*max(Z[:,2]) # default from dendrogram
    threshold = np.percentile(Z[:,2],score_percentile)
    clusters = hierarchy.fcluster(Z, threshold,'distance')
    if(len(np.unique(clusters))==1):
        return(0)
    else:
        return silhouette_score(dists, clusters, metric='precomputed')  
    
def silhouette_scores(profile_values, cluster_metric='euclidean', linkage_method='complete', score_percentiles=[80]):
    distance_mat = pairwise_distances(profile_values,metric=cluster_metric)
    linkage = hierarchy.linkage(profile_values, metric=cluster_metric, method=linkage_method, optimal_ordering=False)
    silhouettes = np.zeros(len(score_percentiles))
    for i, threshold in enumerate(score_percentiles):
        silhouettes[i] = score(distance_mat, linkage, threshold)
    return(silhouettes)
    

layers = np.arange(12)
score_percentiles = np.arange(75,96,5)

silhouettes = np.zeros((2,len(layers),len(score_percentiles)))

figure_dir = "/project/asr_introspection/figures/"
fig_format = "png"

for l_id, layer_id in enumerate(layers):
    norm_acts = normalized_averages[layer_id]['acts']
    absgrads = averages[layer_id]['grads']
    gradNAPs_flat = np.zeros((len(joint_list),norm_acts.shape[1]*norm_acts.shape[2]))
    for char_id in range(len(joint_list)):
        if(n_samples[char_id]>0):
            gradient_mask = np.abs(absgrads[char_id,:,:]) / np.max(np.abs(absgrads[char_id,:,:]))
            gradNAPs_flat[char_id] = np.reshape(norm_acts[char_id,:,:] * gradient_mask,-1)
    
    gradNAPs_flat_graphemes = gradNAPs_flat[:27,:]
    gradNAPs_flat_graphemes = gradNAPs_flat_graphemes/np.max(np.abs(gradNAPs_flat_graphemes))
    gradNAPs_flat_phonemes = gradNAPs_flat[28:,:]
    gradNAPs_flat_phonemes = gradNAPs_flat_phonemes/np.max(np.abs(gradNAPs_flat_phonemes))
    
    grapheme_list = joint_list[:27]
    for g_id in range(len(grapheme_list)):
        grapheme_list[g_id] = grapheme_list[g_id].upper()
    phoneme_list = joint_list[28:]
    
#     for p in score_percentiles:
#         plot_cluster_heatmaps(profile_values_list=[gradNAPs_flat_graphemes,gradNAPs_flat_phonemes],
#                               labels_list=[grapheme_list,phoneme_list],
#                               cluster_metric='euclidean',
#                               score_percentile = p,
#                               figure_dir = figure_dir,
#                               layer_id = layer_id,
#                               fig_format = fig_format)

    silhouettes[0,l_id,:] = silhouette_scores(gradNAPs_flat_graphemes,score_percentiles=score_percentiles)
    silhouettes[1,l_id,:] = silhouette_scores(gradNAPs_flat_phonemes,score_percentiles=score_percentiles)


font = {'family' : 'sans',
        'weight' : 'regular',
        'size'   : 22}
mpl.rc('font', **font)

titles = ['graphemes','phonemes','averaged']
fig = plt.figure(figsize=(16,5))
for i in range(3):
    plt.subplot(1,3,i+1)
    if i<2:
        plt.plot(silhouettes[i,:,:],'o-')
        plt.legend(score_percentiles,ncol=2,columnspacing=0.5,labelspacing=0,loc='upper left',fontsize=20)
    else:
        plt.plot(np.transpose(np.mean(silhouettes,2)),'o-')
        plt.legend(['graphemes','phonemes'],ncol=1,columnspacing=0.5,labelspacing=0,loc='upper left',fontsize=20)
    plt.xlabel('layer')
    if(i<1):
        plt.ylabel('silhouette score')
        plt.yticks(np.arange(0,np.max(silhouettes)+0.05,0.1))
    else:
        plt.ylabel('')
        plt.yticks(np.arange(0,np.max(silhouettes)+0.05,0.1),'')
    plt.ylim([0,np.max(silhouettes)+0.05])
    xticklabs = ['in']+list(np.array(np.arange(11)+1,'str'))
    for j in np.arange(2,11,2):
        xticklabs[j] = ''
    plt.xticks(range(12),xticklabs)
    plt.grid(linestyle='dotted')
    ttl = plt.title(titles[i])
    ttl.set_position([.5, 1.01])

plt.tight_layout()
plt.savefig(figure_dir + "score_curves." +fig_format,
        dpi=300,
        format=fig_format,
        bbox_inches='tight')

plt.close()

# plt.show()

    
    

In [45]:
xticklabs

['', '1', '', '3', '', '5', '', '7', '', '9', '', '11']

In [46]:
np.arange(0,11,2)

array([ 0,  2,  4,  6,  8, 10])

In [49]:
?plt.title