# Gaussian feedforward
Ro Jefferson<br>
Last updated 2021-05-19 

This notebook grew out of some explorations of criticality in random neural nets, based primarily on the works by [Schoenholz et al.](https://arxiv.org/abs/1611.01232) and [Poole et al.](https://arxiv.org/abs/1606.05340); see also my [blog article](https://rojefferson.blog/2020/06/19/criticality-in-deep-neural-nets/) for a pedagogical treatment of the underlying idea. The titular "Gaussian" refers to the fact that we work with a *random* feedforward neural network here, in which the weights and biases are randomly initialized following some Gaussian distribution(s); in the large-$N$ limit, each layer (as well as the network as a whole) behaves like a Gaussian distribution, which simplifies the analysis considerably. 

This notebook constructs and trains basic feedforward networks of arbitrary depth on the MNIST database, using the built-in `cross_entropy` as the loss function and $tanh$ for the non-linearity. In particular, it is designed to fascillitate comparing a range of different depths for a given set of hyperparameters---especially the variance of the distrubution of weights and biases, which control the phase (ordered vs. chaotic). The data -- accuracies, hooks, model parameters -- are optionally written as HDF5 files to the specified directory. The data are then deleted from the kernel in order to free sufficient memory for the next model. **The user must specify** the `PATH_TO_MNIST` and the `PATH_TO_DATA` below. 

Hooks are computationally intensive and are thus disabled by default. One must pass `hooks=True` when calling `train_models()` to record the layer inputs/outputs, in which case they will be stored *only* for the beginning and end of each run (to minimize computation time while allowing before vs. after analysis). Similarly for the parameters (weights, biases), which we may use in another notebook to compute the KL divergence.

The companion notebok "Gaussian_Feedforward_Analysis.ipynb" is designed to read the aforementioned HDF5 files and perform some analysis, while "RelativeEntropy_Nonsymbolic.ipynb" reads them to compute the KL divergence.

In [1]:
# PyTorch packages:
import torch
import torch.nn as nn                       # neural net package
import torch.nn.functional as F             # useful functions, e.g., convolutions & loss functions
from torch import optim                     # optimizers (torch.optim)
from torch.utils.data import TensorDataset  # for wrapping tensors
from torch.utils.data import DataLoader     # for managing batches

# Numpy, scipy, and plotting:
import numpy as np
from scipy.stats import norm         # Gaussian fitting
import scipy.integrate as integrate  # integration
import matplotlib.pyplot as plt      # plotting
import seaborn as sns; sns.set()     # nicer plotting
import pandas as pd                  # dataframe for use with seaborn

# File i/o:
import pickle  # for unpickling MNIST data
import gzip    # for opening pickled MNIST data file
import h5py    # HDF5

# Miscellaneous:
import math
import random  # random number generators
import re      # regular expressions
import gc      # garbage collection

In [2]:
# Memory tracking (optional/unused):
import os, psutil
process = psutil.Process(os.getpid())
        
# Example usage:
#print('RSS = ', process.memory_info().rss/10**6, 'MB')  # resident set size (RAM)
#print('VMS = ', process.memory_info().vms/10**6, 'MB')  # virtual memory (RAM + swap)

## Import and pre-process MNIST data
Since our focus is on the structure/dynamics of the network rather than state-of-the art optimizations, we'll just use the vanilla MNIST dataset for this notebook. We first unzip and unpickle the dataset, and load it into a training and validation set:

In [3]:
PATH_TO_MNIST = '/full/path/to/local/MNIST/gzip/file/'
FILENAME = 'mnist.pkl.gz'

# open (and automatically close) gzip file in mode for reading binary (`rb`) data:
with gzip.open(PATH_TO_MNIST + FILENAME, 'rb') as file:
    ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(file, encoding="latin-1")

Optinally, in the case of memory limitations or testing, we can opt to work with a small subset of the data:

In [4]:
#truncate = 5000  # max 50,000 training and 10,000 validation images
#x_train, y_train, x_valid, y_valid = x_train[:truncate], y_train[:truncate], x_valid[:truncate], y_valid[:truncate]

Each image consists of $28\times28$ pixels (where each pixel value is a float between 0 and 1), flattened into a row of length 784. Currently however, each image is a numpy array; to use PyTorch, we need to convert this to a `torch.tensor`:

In [5]:
x_train, y_train, x_valid, y_valid = map(torch.from_numpy, (x_train, y_train, x_valid, y_valid))

While we're on the subject of file i/o, let's choose a location to store any data files we create below (n.b., must end with '/')

In [6]:
PATH_TO_DATA = '/full/path/to/desired/write/directory/'

We'll also need to specify whether to create a wide or decimated model (see below):

In [7]:
WIDE = True   # True avoids normalization issues in separate KL divergence computation

## Construct the model(s)
While PyTorch's built-in `Linear` layer seems to exhibit better performance out-of-the-box, it doesn't quite suffice for our purposes, since it uses a uniform distribution for the weight & bias initialization; so instead, we'll define a custom layer in which the parameters are initialized along a Gaussian:

In [8]:
# linear layer z=Wx+b, with W,b drawn from normal distributions:
class GaussianLinear(nn.Module):
    def __init__(self, size_in, size_out, var_W, var_b):
        super().__init__()
        self.size_in, self.size_out = size_in, size_out
        self.var_W, self.var_b = var_W/size_in, var_b  # n.b., must scale var_W by layer width!

        # normally distributed weights with mean=0 and variance=var_W/N:
        norm_vec = torch.normal(mean=torch.zeros(size_out*size_in), std=math.sqrt(self.var_W))
        self.weights = nn.Parameter(norm_vec.view(size_out, size_in))
        
        # normally distributed biases with mean=0 and variance=var_b:
        self.bias = nn.Parameter(torch.normal(mean=torch.zeros(size_out), std=math.sqrt(var_b)))

    def forward(self, x):
        prod = torch.mm(x, self.weights.t())  # Wx
        return torch.add(prod, self.bias)     # Wx+b

We also need functions to compute the gradients and update the parameters -- i.e., to train the model -- subject to our choice of loss function. We'll just use the built-in SGD optimizer, with the built-in cross-entropy as our loss function:

In [9]:
# compute gradients & update parameters for a single batch, given loss function & optimizer:
def loss_batch(model, loss_func, xb, yb, opt=None):
    loss = loss_func(model(xb), yb)  # compute specified loss function for the model

    if opt is not None:
        loss.backward()  # compute gradients
        opt.step()       # update parameters
        opt.zero_grad()  # zero gradients in preparation for next iteration

    # n.b., detaching returns the value of the loss; without, returns entire computation graph!
    return loss.detach().item(), len(xb)

# compute accuracy; predicted digit corresponds to index with maximum value:
def accuracy(preds, yb):
    preds = torch.argmax(preds, 1)       # max argument along axis 1 (n.b., batch size must be > 1, else error)
    return (preds == yb).float().mean()  # for each element: 1 if prediction matches target, 0 otherwise

# train & evaluate the model, given loss function, optimizer, and DataLoaders:
def fit(epochs, model, depth, hooks, file_hook, write_params, file_params, var_w, var_b, loss_func, opt, train_dl, valid_dl, acc_list=-1, loss_list=-1):
    for epoch in range(epochs):
        
        # register hooks only on first and last epoch:
        with torch.no_grad():
            if hooks and (epoch == 0 or epoch == epochs-1):
                inputs, outputs = [], []
                hook_layers(model, inputs, outputs)
            
        model.train()  # ensure training mode (e.g., dropout)
        for xb, yb in train_dl:
            loss_batch(model, loss_func, xb, yb, opt)

        model.eval()   # ensure evaluation mode (e.g., no dropout)
        with torch.no_grad():
            # compute loss:
            losses, nums = zip(*[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl])  # * unzips
            val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)
            
            # compute accuracy:
            accuracies = np.array([accuracy(model(xb), yb) for xb, yb in valid_dl])
            val_acc = np.sum(accuracies)/len(accuracies)
            
        print(epoch, val_loss, val_acc)  # monitor progress
        
        # save progress only if user passed lists (for speed if not):
        if isinstance(loss_list, list):
            loss_list.append(val_loss)
        if isinstance(acc_list, list):
            acc_list.append(val_acc)
        
        with torch.no_grad():
            # optionally write initial & final hooks, parameters:
            if epoch == 0 or epoch == epochs-1:
                if write_params and isinstance(file_params, str):    # check valid filename
                    write_parameters('e{}-'.format(epoch) + file_params, model, depth, var_w, var_b)
            
                if hooks and isinstance(file_hook, str):    # check valid filename
                    write_hooks('e{}-'.format(epoch) + file_hook, depth, inputs, outputs)
                    # clear hooks:
                    inputs, outputs= -1, -1
                    gc.collect()

In order to extract intermediate inputs & activations for later analysis, we'll create a functrion that adds forward hooks to the `nn.Tanh` layers:

In [10]:
# simple class to store layer inputs & outputs:
class Hook():
    def __init__(self, module, input=None, output=None):
        self.hook = module.register_forward_hook(self.hook_fn)
        self.input = input
        #self.output = output

    def hook_fn(self, module, input, output):
        self.input.append(input[0].detach())
        #self.output.append(output.detach())
        
    def close(self):
        self.hook.remove()

# function that recursively registers hooks on Tanh layers:
def hook_layers(net, inputs, outputs):
    for name, layer in net._modules.items():
        # if nn.Sequential, register recursively on constituent modules:
        if isinstance(layer, nn.Sequential):
            hook_layers(layer, inputs, outputs)
        # individual module, register hook only on Tanh:
        elif isinstance(layer, nn.Tanh):
            Hook(layer, inputs, outputs)     

Lastly, it is convenient to use PyTorch's `DataLoader` utility to handle batch management, so we'll define a function to load our training & validation data into that form:

In [11]:
# return DataLoaders for training and validation sets, for batch management:
def get_data(x_train, y_train, x_valid, y_valid, batch_size):
    return (DataLoader(TensorDataset(x_train, y_train), batch_size, shuffle=False),
            DataLoader(TensorDataset(x_valid, y_valid), batch_size*2))

Now we're ready to actually build the model (i.e., the network). To fascillitate playing with different depths, let's create a function that constructs a network of arbitrary depth consisting of `GaussianLinear` layers followed by `Tanh` layers, and which steadily reduces the number of neurons per layer in step sizes of (784-10)/num_layers (n.b., "arbitrary" up to maximum depth of 774, given monotonic reduction constraint). 

In [12]:
# ************************************ used only if `WIDE=False` *************************************
# Construct a Gaussian neural network consisting of GaussianLinear layers followed by Tanh layers.
# The layer widths are steadily reduced from input_dim to output_dim
# in step sizes of (input.dim - output.dim)/n_layers (n.b., implies max depth).
def build_network(num_layers, input_dim, output_dim, var_w, var_b): 
    # determine how much to shrink each layer:
    diff = input_dim - output_dim   
    if num_layers > diff:
        raise Exception('Specified number of layers exceeds maximum value consistent\n'
                        'with monotonic reduction in layer widths. Max allowed depth is {}'.format(diff))
            
    shrink = math.floor(diff/num_layers)  # n.b., rounding up can over-decimate in deep networks!
    
    # compute layer widths:
    widths = []
    for i in range(num_layers):
        widths.append(input_dim - shrink*i)      
    
    # output layer:
    widths.append(output_dim)
    
    # construct and add layers to list (no need to use nn.ModuleList):
    mlist = []
    for i in range(num_layers):
        mlist.append(GaussianLinear(widths[i], widths[i+1], var_w, var_b))
        mlist.append(nn.Tanh())
    
    return nn.Sequential(*mlist)

Alternatively, to test my hypothesis that pathological behaviour in the KL divergence is due to dimensional reduction (i.e., normalization), we can experiment with constant-width networks (at least up until the very end, where we must shrink down to 10): 

In [13]:
# ************************************ used only if `WIDE=True` *************************************
# Construct a Gaussian neural network consisting of GaussianLinear layers followed by Tanh layers.
# Layer widths are kept at 784 until the second-from-last layer, at which point we reduce to
# 400, and then 10 in the output layer.
def build_wide_network(num_layers, input_dim, output_dim, var_w, var_b): 
    # check num_layers > 3:
    if num_layers < 3:
        raise Exception('Too few layers; minimum allowed depth is 3.')
    
    # compute layer widths:
    widths = [input_dim]*(num_layers-1)
    widths.append(400)
    widths.append(10)

    # construct and add layers to list (no need to use nn.ModuleList):
    mlist = []
    for i in range(num_layers):
        mlist.append(GaussianLinear(widths[i], widths[i+1], var_w, var_b))
        mlist.append(nn.Tanh())
    
    return nn.Sequential(*mlist)

Next, let's write a function that encapsulates creating and training a list of models of different depths:

In [14]:
# construct and train models with a range of depths;
# pass -1 (or any non-str) for file names to avoid writing:
def train_models(depth_min, depth_max, depth_step=1, file_acc='accuracies.hdf5',
                 write_params=False, file_params='parameters.hdf5',
                 hooks=False, file_hook='hooks.hdf5', 
                 save_model=False, file_model='model.hdf5'):
    depth = np.arange(depth_min, depth_max, depth_step)
    print('Depth list: ', depth)
            
    # construct new set of models & associated optimizers:
    model = []
    opt = []
    for i,d in enumerate(depth):
        if not WIDE:
            model.append(build_network(d, 784, 10, var_weight, var_bias))
        else:
            model.append(build_wide_network(d, 784, 10, var_weight, var_bias))  # alternative: wide network
        opt.append(optim.SGD(model[i].parameters(), rate, momentum))
                        
    # train models, optionally write data:
    for i in range(len(model)):
        accuracies = []  # store accuracies
        
        print('\nTraining model ', i, ' with depth ', depth[i], '...')

        fit(epochs, model[i], depth[i], hooks, file_hook, write_params, file_params, var_weight, var_bias, loss_func, opt[i], train_dl, valid_dl, accuracies)
        
        if save_model:
            model_name = PATH_TO_DATA + re.sub('\.(.*)$','',file_model) + '-{}.hdf5'.format(depth[i])
            torch.save(model[i].state_dict(), model_name)
        
        # optionally write accuracies in hdf5 format:
        if isinstance(file_acc, str):
            write_accuracies(file_acc, depth[i], accuracies, var_weight, var_bias)

        # optionally write final weights, biases in hdf5 format:
        #if write_params and isinstance(file_params, str):
        #    write_parameters('e{}-'.format(epochs-1) + file_params, model[i], depth[i], var_weight, var_bias)
        
    print('\nTraining complete.\n')

We'll also need functions to write and read the data created by the `train_models` function:

In [15]:
# write file of accuracies:
def write_accuracies(file_name, depth, accuracies, var_weight, var_bias):
    with h5py.File(PATH_TO_DATA + re.sub('\.(.*)$','',file_name) + '-{}.hdf5'.format(depth), 'w') as file:
        file.create_dataset('var_weight', data=var_weight)
        file.create_dataset('var_bias', data=var_bias)
        file.create_dataset('depth', data=depth) 
        file.create_dataset('accuracies', data=accuracies)

        
# read file of accuracies, return dataset as dictionary:
def read_accuracies(file_name):
    with h5py.File(PATH_TO_DATA + file_name, 'r') as file:
        # cast elements as np.array, else returns closed file datasets:
        acc_dict = {key : np.array(file[key]) for key in file.keys()}  
        
    return acc_dict


# write file of inputs/outputs (n.b., create_dataset tries to turn the data into
#   a numpy array, which fails for a list of unevenly-sized tensors; must first
#   pre-process and create one dataset for each layer, combining all batches):
def write_hooks(file_name, depth, inputs, outputs):
    # group data from all batches by layer, by constructing dictionaries of layers;
    # keys = layer number, elements = list of batches of inputs/outputs for that layer:
    layers_in, layers_out = {i : [] for i in range(depth)}, {i : [] for i in range(depth)}
    for key in layers_in.keys():  # same key list for both dicts
        [layers_in[key].append(batch) for batch in inputs[key::depth]]
        #[layers_out[key].append(batch) for batch in outputs[key::depth]]  # optional/unused
       
    # concatenate each list of tensors (dict element) into a single tensor (to enable conversion to numpy array):
    for key in layers_in.keys():
        layers_in[key] = torch.cat(layers_in[key])
        #layers_out[key] = torch.cat(layers_out[key])

    # write each layer as a dataset:
    with h5py.File(PATH_TO_DATA + re.sub('\.(.*)$','',file_name) + '-{}.hdf5'.format(depth), 'w') as file:
        file.create_dataset('var_weight', data=var_weight)
        file.create_dataset('var_bias', data=var_bias)
        file.create_dataset('depth', data=depth)
        for key in layers_in.keys():
            # encode whether input or output in key; elements = all inputs/outputs for that layer:
            file.create_dataset('in-{}'.format(key), data=layers_in[key]) 
            #file.create_dataset('out-{}'.format(key), data=layers_out[key])

            
# read file of inputs/outputs, return dataset as dictionary:
def read_hooks(file_name):    
    with h5py.File(PATH_TO_DATA + file_name, 'r') as file:
        # cast elements as np.array, else returns closed file datasets:
        hook_dict = {key : np.array(file[key]) for key in file.keys()}
    
    return hook_dict


# Write weights and biases for entire network in hdf5 format.
# Note that last three parameters (depth, var_weight, var_bias)
#   are just meta-data, to aid in identifying run upon reading file.
def write_parameters(file_name, model, depth, var_weight, var_bias):
    with h5py.File(PATH_TO_DATA + re.sub('\.(.*)$','',file_name) + '-{}.hdf5'.format(depth), 'w') as file:
        file.create_dataset('var_weight', data=var_weight)
        file.create_dataset('var_bias', data=var_bias)
        file.create_dataset('depth', data=depth) 
        
        for key in model.state_dict():
            # get correct layer index (instead of x2):
            layer_num = int(int(re.findall(r'\d+', key)[0])/2)
                
            # write layer's weights/biases as dictionary entry:
            if key.endswith('weights'):
                file.create_dataset('W{}'.format(layer_num), data=model.state_dict()[key].numpy())     
            elif key.endswith('bias'):
                file.create_dataset('B{}'.format(layer_num), data=model.state_dict()[key].numpy())   

                
# read file of weights, biases; return as dictionary:             
def read_parameters(file_name):
    with h5py.File(PATH_TO_DATA + file_name, 'r') as file:
        # cast elements as np.array, else returns closed file datasets:
        for key in file.keys():
            para_dict = {key : np.array(file[key]) for key in file.keys()}  
        
    return para_dict

# Generate datasets (training/testing)
Now, let's train some models! First, set the hyperparameters and whatnot used by all models we wish to compare:

In [18]:
# set hyperparameters:
rate = 0.005
epochs = 5
momentum = 0.8
batch_size = 64

# load training & validation data into DataLoaders:
train_dl, valid_dl = get_data(x_train, y_train, x_valid, y_valid, batch_size)

# set loss function:
loss_func = F.cross_entropy

As an example, the following cell sequentially trains three models of depth 10, 20, and 30, all with fixed $\sigma_w^2=2.0$ and $\sigma_b^2=0.05$.

**Note on file name conventions**: when writing data, the length of each run is appended to the given filenames, e.g., passing "accuracies-20.hdf5" will result in "accuracies-20-10.hdf5", "accuracies-20-20.hdf5", and "accuracies-20-30.hdf5". The "-20" (as in $\sigma_w^2=2.0$) is my naming convention for keeping runs with different variances straight (though the relevant data to identify them is also written internally). When writing hooks or parameters, then -- for the present example with 5 epochs -- an "e0-" and "e4-" ("e" as in "epoch") will be prepended to the hook/parameter filenames, to distinguish pre- vs. post-training results. 

In [None]:
# variances for Gaussian parameter initialization:
var_weight = 2.0
var_bias = 0.05

train_models(10,31,10, file_acc='accuracies-20.hdf5',
             write_params=True, file_params='parameters-20.hdf5',
             hooks=True, file_hook='hooks-20.hdf5', save_model=True, file_model='models-20.hdf5')

# Grid search
To perform a more systematic search of parameter space as in fig. 6 of the companion paper, the following cell trains a list of networks with depths $L\in\{10,13,16,\ldots,67,70\}$ with $\sigma_w^2\in\{1.00, 1.05, 1.10, \ldots, 2.95,3.00\}$, and fixed $\sigma_b^2=0.05$. Here we only care about the accuracies, so we'll run with `hooks=False` for speed, and not bother writing the parameters or the models themselves either (to save memory). I don't recommend doing this on a standard desktop, unless you are inordinately patient.

The same naming convention as in the previous example is used when writing the accuracies here: at $\sigma_w^2=1.00$, we'll have "acc-100-10.hdf5", "acc-100-13.hdf5", and so on; at $\sigma_w^2=1.05$, we'll have "acc-105-10.hdf5", "acc-105-13.hdf5", and so on; etc.

In [None]:
var_bias = 0.05

# iterate over a range of var_w from 1.0 to 3.0, in steps of 0.05:
for i in range(100,301,5):
    var_weight = i/100
    file_acc = 'acc-{}.hdf5'.format(i)  # base filename
    
    print('Training models with variance {}...\n'.format(var_weight))

    train_models(10,73,3, file_acc=file_acc, write_params=False, file_params='dummy_para_name.hdf5', 
                 hooks=False, file_hook='dummy_hook_name.hdf5', save_model=False, file_model='dummy_model_name.hdf5')