MODEL SENSITIVITY

In [None]:
import numpy as np
from tqdm import tqdm

def filter_spectrum(spec):
    
    width = spec.shape[0]
    h_width = int(spec.shape[0]/2)+1
    
    return spec[0:width,0:h_width].reshape(h_width*width)

def get_index(grad_map):
    epsilon = 1e-8
    width = grad_map.shape[0]
    data = np.mean(grad_map,axis=2)
    f1 = np.fft.fft2( data )
    f1 = np.fft.fftshift( f1 )
    f1 += epsilon
    magnitude_spectrum = np.abs(f1)
    
    magnitude_spectrum = filter_spectrum(magnitude_spectrum)
    
    _idx = np.argsort(magnitude_spectrum)[::-1]    
    _val = np.sort(magnitude_spectrum)[::-1]
    
    return _idx, _val

def obtain_freq_pos(grad_map, line=5):
    grad_map = grad_map.transpose(0,2,3,1)
    width = grad_map.shape[1]
    h_width = int(width/2)+1
    center = int(width/2)
    _power = np.zeros(width*h_width)
    _count = []
    for i in tqdm(range(grad_map.shape[0])):
        normal_index,normal_val = get_index(grad_map[i])
        _power += normal_val
        u1,v1 = normal_index/h_width,normal_index%h_width
        radius = np.sqrt((u1-center)**2+(v1-center)**2)
        _count.append(radius)
    
    _count = np.array(_count)
    _power = _power
    return _count,_power
        
rn18 = np.load('resnet18_cifar10_best.npy')
print(rn18.shape)
freq_cnt, freq_pow = obtain_freq_pos(rn18)

In [None]:
import matplotlib.pyplot as plt

x = [i+1 for i in range(len(freq_pow))]
y = freq_pow
plt.figure(figsize=(6, 4),dpi=300)
plt.plot(x, y, linewidth=2,color='black')
plt.fill_between(x,[0 for i in range(len(freq_pow))],y,color='grey',  alpha=0.3)

plt.xlabel('Rank Index',fontsize=16)
plt.ylabel('Gradient Magnitude',fontsize=16)

plt.title("CIFAR10 ResNet-18", fontsize=15)
plt.xticks(fontsize=13)
plt.yticks(fontsize=13)
plt.ticklabel_format(style='sci', scilimits=(-1,2), axis='y')
plt.show()

In [None]:
from matplotlib.font_manager import FontProperties
import seaborn as sns

plt.rcParams['mathtext.fontset'] = 'stix'
plt.figure(figsize=(6, 4), dpi=300)

camp = plt.get_cmap('Reds')
color = [camp(i) for i in np.linspace(0.8,0.2,6)]

_index = [0,100,200,300,400,-1]
_index_name = ['Top one','20%','40%','60%','80%','Last one']

for i in range(len(_index)):
    sns.kdeplot(freq_cnt[:,_index[i]],shade=False,linewidth=5,color=color[i],label=str(_index_name[i]))
    
plt.vlines([4.98], 0, 0.4, linestyles='dashed', colors='grey',alpha=0.6,linewidth=2)
plt.legend(fontsize=10, title="Rank Index")
plt.ylim(0,0.25)

plt.xlabel('Frequency Radius',fontsize=16)
plt.ylabel('Density',fontsize=16)
plt.xticks(fontsize=13)
plt.yticks(fontsize=13)
plt.ticklabel_format(style='sci', scilimits=(-1,2), axis='y')
plt.title('CIFAR10 ResNet-18',fontsize=15)
plt.show()

LOSS LANDSCAPE

In [None]:
import h5py
import numpy as np
surf_file = 'resnet18_baseline_high12_[-2,2]'
surf_name = 'test_loss'
f = h5py.File(surf_file, 'r')
x = np.array(f['xcoordinates'][:])
y = np.array(f['ycoordinates'][:])
left = 10
right = 90
x = x[left:right]
y = y[left:right]

X, Y = np.meshgrid(x, y)
if surf_name in f.keys():
    Z = np.array(f[surf_name][:])
elif surf_name == 'train_err' or surf_name == 'test_err' :
    Z = 100 - np.array(f[surf_name][:])
Z = Z[left:right,left:right]

In [None]:
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import pyplot as plt 

fig = plt.figure(dpi=300)
CS2 = plt.contourf(X, Y, Z, cmap='RdYlBu', levels=np.arange(2, 7, 0.3))
CS2=plt.contour(X,Y,Z,levels=np.arange(2, 7, 0.3),colors="grey",linewidths=0.5, alpha=0.8)
 plt.clabel(CS2,inline=True,fontsize=9)
