In [None]:
import os
import numpy as np
from scipy import stats
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
import seaborn as sns
import pandas as pd
from dnnbrain.io.fileio import ActivationFile
from activation_PSI import Dnn_act, sparseness

In [None]:
# specify custom paremeters
root = os.getcwd() # root directory of analysis
net = 'alexnet' # ['alexnet', 'vgg11']
dataset = 'imagenet'  # ['imagenet', 'caltech256', 'caltech143']

In [None]:
# prepare parameters
net_dir = os.path.join(root, net)
caltech256_label = pd.read_csv(os.path.join(root, 'caltech256_label'), sep='\t')
if dataset == 'imagenet':
    stim_per_cat = 50
elif dataset in ['caltech256', 'caltech143']:
    stim_per_cat = 80

In [None]:
# read dnn actiation
if dataset == 'caltech143':
    dnnact_path = os.path.join(
            net_dir, 'dnn_activation', '{0}_{1}.act.h5'.format(net, 'caltech256'))
else:
    dnnact_path = os.path.join(
            net_dir, 'dnn_activation', '{0}_{1}.act.h5'.format(net, dataset))
    
dnnact_alllayer = ActivationFile(dnnact_path).read()
layer_name = list(dnnact_alllayer.keys())

In [None]:
# compute PSI
bins = 20 # bins for activation histogram
sp = []
sp_bincount = []
pdf_bin = []

for layer in layer_name:
    dnnact = Dnn_act(dnnact_alllayer[layer], stim_per_cat=stim_per_cat)
    dnnact_catmean = dnnact.cat_mean_act()[0][:, :, 0]
    
    if dataset == 'caltech143':
        dnnact_catmean = dnnact_catmean[caltech256_label['imagenet1000'] == '0', :]

    dnnact_catmean_z = np.nan_to_num(stats.zscore(dnnact_catmean, 0))

    # PSI
    sparse_p = sparseness(dnnact_catmean_z.T, type='s', norm=True)
    sp_bincount.append(pd.cut(sparse_p, np.linspace(0, 1, bins+1)).value_counts().values
                       /dnnact_catmean.shape[0] * 100)

    sp.append(np.squeeze(sparse_p))    
    print('{0} done'.format(layer))
    
    # pdf
    min_max_scaler = MinMaxScaler(feature_range=(0, 1))
    dnnact_catmean_z_norm = min_max_scaler.fit_transform(dnnact_catmean_z.T)
    
    dist_bin = [np.histogram(dnnact_catmean_z_norm[:,i], bins=np.arange(0,1,0.01),density=True)[0] for i 
                in range(dnnact_catmean_z_norm.shape[-1])]
    pdf_bin.append(np.asarray(dist_bin).mean(0))

sp_bincount = np.asarray(sp_bincount).T
pdf_bin = np.asarray(pdf_bin).T
sp_median = np.array([np.nanmedian(sp[i]) for i in range(len(sp))])
np.save(os.path.join(net_dir, 'PSI_{0}.npy'.format(dataset)),sp)

In [None]:
# plot Fig 1A or SS Fig 2A
fig, axes = plt.subplots(nrows=2, ncols=1, figsize=[4,6])
layer_legend = [layer.split('_')[0] for layer in layer_name]
# Conv
conv_colors = sns.color_palette('Blues', n_colors=len(layer_legend[:-2]))
[axes[0].plot(np.linspace(0, 1, bins+1)[1:], sp_bincount[:,i], c=conv_colors[i]) for i in range(len(layer_legend[:-2]))]
axes[0].set_xlim((0,0.6))
axes[0].legend(layer_legend[:-2])
# FC
fc_colors = sns.color_palette('Oranges', n_colors=2)
[axes[1].plot(np.linspace(0, 1, bins+1)[1:], sp_bincount[:,i+len(layer_legend[:-2])], c=fc_colors[i]) for i in range(2)]
axes[1].set_xlim((0,0.6))
axes[1].legend(layer_legend[-2:])

In [None]:
# plot Fig 1B or SS Fig 2B
fig, axes = plt.subplots(nrows=2, ncols=1, figsize=[4,6])
# Conv
axes[0].plot(layer_legend[:-2], sp_median[:-2], c='tab:blue')
axes[0].set_ylim((0,0.4))
# FC
axes[1].plot(layer_legend[-2:], sp_median[-2:], c='orange')
axes[1].set_ylim((0,0.4))

In [None]:
# stats trend test
sp_alllayer = np.asarray(sp).reshape(-1)
h_index = np.repeat(np.arange(len(sp))+1 , sp[0].shape)
tau = stats.kendalltau(h_index, sp_alllayer)
print(tau)

In [None]:
# plt SS Fig 1
# pdf fit
dist_model = ['norm','weibull']
log_lik = np.zeros((len(layer_name), len(dist_model)))
weib_paras = []
for i in range(len(layer_name)):
    data = pdf_bin[:, i] 
    row = 0
    # norm
    norm_para = stats.norm.fit(data)
    log_lik[i, row] = np.sum(stats.norm.logpdf(data, *norm_para)) 
    row += 1
    # weibull
    weib_para = stats.weibull_min.fit(data)
    log_lik[i, row] = np.sum(stats.weibull_min.logpdf(data, *weib_para)) 
    weib_paras.append(weib_para)
    row += 1
weib_k = np.asarray(weib_paras)[:,0]

# plot
fig, axes = plt.subplots(nrows=4, ncols=1, figsize=[4,8])
# Conv
[axes[0].plot(np.arange(0,1,0.01)[1:], pdf_bin[:,i], c=conv_colors[i]) for i in range(len(layer_legend[:-2]))]
axes[0].set_xlim((0,0.6))
axes[0].legend(layer_legend[:-2])
# FC
[axes[1].plot(np.arange(0,1,0.01)[1:], pdf_bin[:,i+len(layer_legend[:-2])], c=fc_colors[i]) for i in range(2)]
axes[1].set_xlim((0,0.6))
axes[1].legend(layer_legend[-2:])
# comparison of fitting with Norm and Weibull
axes[2].plot(layer_legend, -1* log_lik)
axes[2].legend(dist_model)
# k in Weibull
axes[3].plot(layer_legend, weib_k)