In [None]:
import itertools as it
import matplotlib.pyplot as plt
import numpy as np

In [None]:
MODELS='vit_tiny_patch16_224 deit_tiny_patch16_224 deit_tiny_distilled_patch16_224 vit_small_patch16_224 deit_small_patch16_224 deit_small_distilled_patch16_224'.split()

In [None]:
MODEL='vit_tiny_patch16_224'

In [None]:
npz_path = f'{MODEL}_layernorm.npz'
data = np.load(npz_path)

In [None]:
metric = 'var'

def box_plot(data, edge_color, fill_color):
    bp = ax.boxplot(data, patch_artist=True)
    
    for element in ['boxes', 'whiskers', 'fliers', 'means', 'medians', 'caps']:
        plt.setp(bp[element], color=edge_color)
        if element == 'fliers':
            plt.setp(bp[element], markeredgecolor=edge_color)

    for patch in bp['boxes']:
        patch.set(facecolor=fill_color)       
        
    return bp

data=np.load('deit_tiny_distilled_patch16_224_layernorm.npz')
data2=np.load('deit_tiny_patch16_224_layernorm.npz')
fig, ax = plt.subplots(figsize=(12,8))
keys = [k for k in data.keys() if k.endswith(f'/{metric}')]
box_plot([data[k] for k in keys], 'red', 'tan')
box_plot([data2[k] for k in keys],'blue', 'cyan')
#ax.set_xticklabels([k.split('/')[0] for k in keys], rotation=90)
ax.set_title(f'{metric} [{model}]')
plt.show()

In [None]:
def plot_metric(metric, model, data):
    fig, ax = plt.subplots(figsize=(12,8))
    keys = [k for k in data.keys() if k.endswith(f'/{metric}')]
    ax.boxplot([data[k] for k in keys])
    ax.set_xticklabels([k.split('/')[0] for k in keys], rotation=90)
    ax.set_title(f'{metric} [{model}]')
    plt.show()

In [None]:
for metric, model in it.product(['var', 'mean'], MODELS):
    npz_path = f'{model}_layernorm.npz'
    data = np.load(npz_path)
    plot_metric(metric, model, data)

In [None]:
fig, ax = plt.subplots(figsize=(12,8))
ax.boxplot(data['blocks.8.norm1/mean_token'][0:50].T)

In [None]:
fig, ax = plt.subplots(figsize=(12,8))
ax.boxplot(data['blocks.8.norm1/var_token'][0:50].T)

In [None]:
keys

In [None]:
data['blocks.8.norm1/var'].max(), data['blocks.8.norm1/var'].min(), data['blocks.8.norm1/var'].mean()

In [None]:
keys=['blocks.0.norm1',
 'blocks.0.norm2',
 'blocks.1.norm1',
 'blocks.1.norm2',
 'blocks.10.norm1',
 'blocks.10.norm2',
 'blocks.11.norm1',
 'blocks.11.norm2',
 'blocks.2.norm1',
 'blocks.2.norm2',
 'blocks.3.norm1',
 'blocks.3.norm2',
 'blocks.4.norm1',
 'blocks.4.norm2',
 'blocks.5.norm1',
 'blocks.5.norm2',
 'blocks.6.norm1',
 'blocks.6.norm2',
 'blocks.7.norm1',
 'blocks.7.norm2',
 'blocks.8.norm1',
 'blocks.8.norm2',
 'blocks.9.norm1',
 'blocks.9.norm2',
 'norm']

In [None]:
def plot_profile(metric, models):
    fig, ax = plt.subplots(figsize=(12,8))
    N = len(keys)
    for model in models:
        npz_path = f'{model}_layernorm.npz'
        data = np.load(npz_path)
        d = [data[f'{k}/{metric}'] for k in keys]
        ax.scatter(range(N), [x.mean() for x in d], label=model)
    ax.set_xticks(range(N))
    ax.set_xticklabels(keys, rotation=90)
    ax.set_title(f'{metric} [{models}]')
    plt.legend()
    plt.show()

In [None]:
models = ['vit_tiny_patch16_224', 'deit_tiny_patch16_224']
plot_profile('var', models)

In [None]:
models = ['deit_tiny_patch16_224', 'deit_tiny_distilled_patch16_224']
plot_profile('var', models)

In [None]:
models = ['vit_small_patch16_224', 'deit_small_patch16_224']
plot_profile('var', models)

In [None]:
models = ['deit_small_patch16_224', 'deit_small_distilled_patch16_224']
plot_profile('var', models)

In [None]:
models = ['deit_tiny_patch16_224', 'deit_small_patch16_224']
plot_profile('var', models)

### Breakdown by Layer and by Sequence Position

In [None]:
data.files

In [None]:
data['blocks.8.norm1/var_token'].shape

In [None]:
def plot_token_metric(metric, model, data):
    
    layers = [k for k in data.keys() if k.endswith(f'/{metric}_token')]
    for k in layers:
        fig, ax = plt.subplots(figsize=(12,8))
        layer_name = k.split('/')[0]
        print('working on layer {}'.format(layer_name))
        vals = data[k][:128]
        print('data shape is {}'.format(vals.shape))
        ax.boxplot(vals[:][:10])
        ax.set_title(f'{metric} [{model}-{layer_name}]')
        plt.show()

In [None]:
layers = [k for k in data.keys() if k.endswith(f'/{metric}_token')]
data[layers[0]].shape

In [None]:
plot_token_metric('var', 'deit_tiny_distilled_patch16_224', np.load('deit_tiny_distilled_patch16_224_layernorm.npz'))

In [None]:
plot_token_metric('var', 'deit_tiny_patch16_224', np.load('deit_tiny_distilled_patch16_224_layernorm.npz'))

In [None]:
def plot_var_mean(model, data):
    fig, ax = plt.subplots(figsize=(12,8))
    layers = [k.split('/')[0] for k in data.keys() if k.endswith(f'/var_token')]
    for layer in layers:
        mean_data = data[f'{layer}/mean_token']
        var_data = data[f'{layer}/var_token']
        N = mean_data.shape[-1]
        colors = plt.cm.get_cmap('hsv', N)
        for i in range(N):
            ax.scatter(mean_data[:, i], var_data[:, i], color=colors(i))
        ax.set_title(f'{metric} [{model}-{layer}]')
        ax.set_xlabel('mean')
        ax.set_ylabel('var')
        plt.show()
        break

In [None]:
plot_var_mean('deit_tiny_distilled_patch16_224', np.load('deit_tiny_distilled_patch16_224_layernorm.npz'))