# The Heavy-Tail Phenomenon in SGD

This notebook contains the code for computing the tail indices for the neural networks. 


In [None]:
import os
import numpy as np
import torch
import itertools
import matplotlib.pyplot as plt
%matplotlib inline

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import sys
sys.path.append('./')
from models import alexnet, fc

import math

from tqdm import tqdm


In [None]:
# Corollary 2.4 in Mohammadi 2014 - for 1d
def alpha_estimator_one(m, X):
    N = len(X)
    n = int(N/m) # must be an integer
    
    X = X[0:n*m]
    
    Y = np.sum(X.reshape(n, m),1)
    eps = np.spacing(1)

    Y_log_norm =  np.log(np.abs(Y) + eps).mean()
    X_log_norm =  np.log(np.abs(X) + eps).mean()
    diff = (Y_log_norm - X_log_norm) / math.log(m)
    return 1 / diff

In [None]:
path = "PATH TO THE RESULTS FOLDER"

In [None]:
# Load the trained neural networks

import glob

dataset = 'mnist'
# dataset = 'cifar10'

loss = 'NLL'
arch = 'fc'

num_nets = 1000

search_str = path + '*' + dataset + '_' + loss + '_' + arch + '*'

results = {}
hist_tr = {}
hist_te = {}
nets    = {}

for folder_tmp in sorted(glob.glob(search_str)):
    drive, path_and_file = os.path.splitdrive(folder_tmp)
    dummy, folder = os.path.split(path_and_file)

    base = path + folder
    f = path + folder + '/training_history.hist'
    if os.path.isfile(f):
        print(folder)
        results[folder] = torch.load(f)
        hist_tr[folder] = torch.load(base + '/evaluation_history_TRAIN.hist',map_location='cpu')
        hist_te[folder] = torch.load(base + '/evaluation_history_TEST.hist',map_location='cpu')
        nets[folder]    = [] 
        for i in range(num_nets):
            tmp_f = path + folder + '/net_' + str(i) + '.pyT'
            nets[folder].append(torch.load(tmp_f, map_location='cpu'))
        

In [None]:
# Compute the tail index for each step-size/batch-size pair

depth = 3
width = 128

etas         = [0.0001, 0.001, 0.01, 0.015, 0.02, 0.025, 0.03, 0.04, 0.045, 0.05, 0.06, 0.070, 0.075, 0.08, 0.09, 0.1]
batch_sizes  = [1, 5, 10]

alphas_mc    = np.zeros((len(etas), len(batch_sizes), depth))-1

for ei, eta in tqdm(enumerate(etas)):
    for bi, bs in tqdm(enumerate(batch_sizes)):

        print(ei,bi)
        exp_name = '{}_{:04d}_{}_{}_{}_{:E}_{:04d}'.format(depth, width, dataset, loss, arch, eta, bs)
        
        if exp_name not in nets:
            print(exp_name + " does not exist")
            continue
            
        net = nets[exp_name]

        weights = []
        for i in range(depth):
            weights.append([])

        # record the layers in different arrays
        for i in range(num_nets):
            tmp_net = net[i]
            for ix, p in enumerate(tmp_net.parameters()):
                layer = p.detach().numpy()
                layer = layer.reshape(-1,1)
                weights[ix].append(layer)

        for i in range(depth):
            weights[i] = np.concatenate(weights[i], axis = 1)  

        for i in range(depth):
            tmp_weights = np.mean(weights[i], axis=1)
            tmp_weights = tmp_weights.reshape(-1,1)
            tmp_weights = tmp_weights - np.mean(tmp_weights)
            tmp_alphas = [alpha_estimator_one(mm, tmp_weights) for mm in (2, 5, 10, 20, 50, 100, 500, 1000)]
            alphas_mc[ei,bi,i] = np.median(tmp_alphas)



In [None]:
# Visualization

for i in range(depth):
    plt.figure()
    for ei, eta in enumerate(etas):
        for bi, bs in enumerate(batch_sizes):
            if(alphas_mc[ei,bi,i] > 0):
                plt.plot(eta/bs, alphas_mc[ei,bi,i],'.')
                plt.xlabel("eta/b")
                plt.ylabel("alpha")
    plt.title('Layer '+str(i+1))

