# Three Ways to Prune a Neural Network

In [None]:
### imports

import os
import copy

import autograd
from autograd import numpy as np
from autograd import elementwise_grad as egrad

import sklearn
import sklearn.datasets

fn_type = type(lambda x: x)
nparray = type(np.random.rand(1,2))


## Defining the forward pass and sgd updates

In [None]:
relu = lambda x: x * (x > 0.0)

def forward(x: nparray, layers: list) -> nparray:

    for layer in layers[:-1]:
        x = relu(x @ layer) 

    x = x @ layers[-1]

    return x

relu = lambda x: x * (x > 0.0)

def forward(x: nparray, layers: list) -> nparray:

    for layer in layers[:-1]:
        x = relu(x @ layer) 

    x = x @ layers[-1]

    return x

def initialize_layer_weights(in_dim: int, out_dim: int) -> nparray:

    sigma = np.sqrt(2 / (in_dim + out_dim))

    return sigma * np.random.randn(in_dim, out_dim) 

def initialize_weights(dimensions: list) -> list:

    layers = []
    for dims in dimensions:

        layers.append(initialize_layer_weights(dims[0], dims[1]))

    return layers

def initialize_model(my_seed: int, in_dim: int,\
        h_dim: int, out_dim: int, number_hidden: int) -> list:

    dims = [[in_dim, h_dim]]
    for layer_index in range(number_hidden):
        dims.append([h_dim, h_dim])

    dims.append([h_dim, out_dim])

    np.random.seed(my_seed)

    return initialize_weights(dims)

nll_loss_fn = lambda y, pred: -np.mean(y*np.log(sm(pred)) + (1-y)*np.log(1. - sm(pred)))


def compute_loss(x: nparray,\
        y: nparray,\
        loss_function: fn_type,\
        layers: list) -> np.float64:


    predicted = forward(x, layers)

    loss = loss_function(y, predicted)

    return loss

grad_loss = autograd.grad(compute_loss, argnum=3)  

def sgd_update(layers: list, grad_layers: list, lr: float) -> list:

    new_layers = []
    for index, grad_layer in enumerate(grad_layers):

        # multiplying by abs value of layer 
        # freezes weights with value zero 
        update = - lr * grad_layer * (np.abs(layers[index]) > 0.0)
        new_layers.append(layers[index] + update)

    return new_layers



# Sanity Check: Fitting random vectors

In [None]:
x = np.random.rand(128,64)
targets = np.random.rand(128,10)
mse_loss_fn = lambda y, p: np.mean((y-p)**2)
lr = 1e-3

dims = [[64,32], [32,32], [32,10]]

layers = initialize_weights(dims)

for step in range(1000):

    grad_layers = grad_loss(x, targets, mse_loss_fn, layers)

    if step % 100 == 0:
        loss = compute_loss(x, targets, mse_loss_fn, layers) 
        print(f"loss at step {step} = {loss:.3f}")

    layers = sgd_update(layers, grad_layers, lr=lr)

## Pruning nodes by gradient magnitude

In [None]:
def forward_skeletonize(x: nparray, nodes: list, layers: list):
    # special forward pass with nodes for dE/da

    for layer, alpha in zip(layers[:-1], nodes):
        
        x = relu((x * alpha) @ layer)

    x = (x * nodes[-1]) @ layers[-1]

    return x
   
def compute_skeleton_loss(x: nparray,\
        y: nparray,\
        loss_function: fn_type,\
        nodes: list,\
        layers: list) -> np.float64:

    predicted = forward_skeletonize(x, nodes, layers)

    loss = loss_function(y, predicted)

    return loss

grad_skeleton_loss = autograd.grad(compute_skeleton_loss, argnum=3) 

def prune_node(layers: list,\
        grad_nodes: list) -> list:
    # prunes only the least important input node 

    new_layers = layers

    prune_indices = []
    lowest_grad = float("Inf") 
    for index, grad_layer in enumerate(grad_nodes[1:]):

        prune_indices.append(np.argmin(grad_layer))

        if grad_layer[prune_indices[-1]] < lowest_grad:
            lowest_grad = grad_layer[prune_indices[-1]]
            prune_index = index

    p = prune_indices[prune_index]
    new_layer = layers[prune_index+1][0:p, :]
    new_layer = np.append(new_layer,\
            layers[prune_index+1][p+1:, :], axis=0) 
    new_layers[prune_index+1] = new_layer

    new_layer_1 = layers[prune_index][:, 0:p]
    new_layer_1 = np.append(new_layer_1,\
            layers[prune_index][:, p+1:], axis=1)
    new_layers[prune_index] = new_layer_1

    return new_layers



In [None]:
## Pruning weights by their second derivative of loss

In [None]:
from autograd import elementwise_grad as egrad

grad2_loss = egrad(egrad(compute_loss, argnum=3), argnum=3)

def prune_weights_by_grad2(layers: list,\
        grad2_layers: list,\
        prune_per_layer: int=10,\
        initial_threshold: float=1e-5) -> list:

    new_layers = []

    for layer, grad2_layer in zip(layers, grad2_layers):
        threshold = 1.0 * initial_threshold
        done = False

        while not done:
        
            prunable_weights = np.sum(\
                    np.abs(grad2_layer) < threshold)

            if prunable_weights >= prune_per_layer:
                done = True
            else:
                threshold *= 2.

        new_layer = layer * (np.abs(grad2_layer) > threshold)
        new_layers.append(new_layer)

    return new_layers



## Weight magnitude pruning

In [None]:
def prune_weights_by_magnitude(layers: list,\
        prune_per_layer: int=10,\
        initial_threshold: float=1e-3) -> list:

    return prune_weights_by_grad2(layers,\
            layers, prune_per_layer,\
            initial_threshold)



## The training loop

In [None]:
sm = lambda x: np.exp(x)/(np.sum(np.exp(x), \
        axis=-1, keepdims=True))
nll_loss_fn = lambda y, pred: -np.mean(\
        y*np.log(sm(pred)) + (1-y)*np.log(1. - sm(pred)))

def indices_to_one_hot(y: nparray, max_index=None) -> nparray: 

    if max_index is None:
        max_index = np.max(y)

    y_target = np.zeros((y.shape[0], max_index+1))

    for ii in range(y.shape[0]):
        y_target[ii, y[ii]] = 1.0

    return y_target

def compute_accuracy(y: nparray, predicted: nparray) -> float:

    return np.mean(y.argmax(-1) == predicted.argmax(-1))

def prune_mode_0(layers: list,\
        train_x: nparray=None, train_y: nparray=None) -> list:
    return layers

def prune_mode_1(layers: list,  train_x: nparray, train_y: nparray) -> list:

    nodes = [np.ones(elem.shape[0]) for elem in layers]
    grad_nodes = grad_skeleton_loss(\
            train_x, train_y, nll_loss_fn, nodes, layers)

    layers = prune_node(layers, grad_nodes)

    return layers

def prune_mode_2(layers: list, train_x: nparray, train_y: nparray) -> list: 

    grad2_layers = grad2_loss(\
            train_x, train_y, nll_loss_fn, layers)
    layers = prune_weights_by_grad2(layers, grad2_layers)

    return layers 

def prune_mode_3(layers: list,\
        train_x: nparray=None, train_y: nparray=None) -> list:

    return prune_weights_by_magnitude(layers)

def retrieve_prune_fn(mode: int=0):

    if mode == 0:
        return prune_mode_0
    elif mode ==1:
        return prune_mode_1
    elif mode ==2:
        return prune_mode_2
    elif mode ==3:
        return prune_mode_3
def split_digits(my_seed: int=13) -> tuple:

    x, y_indices = sklearn.datasets.load_digits(return_X_y = True)
    x = x / np.max(x)
    y = indices_to_one_hot(y_indices)

    np.random.seed(my_seed)
    np.random.shuffle(x)

    np.random.seed(my_seed)
    np.random.shuffle(y)

    split_validation = int(0.1*x.shape[0])

    val_x = x[:split_validation,:]
    val_y = y[:split_validation,:]
    train_x = x[split_validation:,:]
    train_y = y[split_validation:,:]

    return train_x, train_y, val_x, val_y

def print_progress(layers: list=None,\
        batch_x: nparray=None,\
        batch_y: nparray=None,\
        tag: str="train",\
        step: int=0,\
        verbose: bool=True):

    if layers is None:
        # print column labels
        msg = f"split, step, loss, accuracy"
    else:
        loss = compute_loss(batch_x, batch_y, nll_loss_fn, layers) 

        msg = f"{tag}, {step:05}, {loss:.4f}, "

        predicted = forward(batch_x, layers)
        accuracy = compute_accuracy(batch_y, predicted)

        msg += f"{accuracy:.4f}\n"

    if verbose:
        print(msg)

    return msg

def train(my_seed: int=13,\
        number_epochs: int=100,\
        mode: int=0,\
        lr: float=1e-3,\
        verbose: bool=True):
    """
    mode 0 - no pruning
    mode 1 - pruning nodes (Mozer and Smolensky 1989)
    mode 2 - pruning w by 2nd derivative (LeCun et al. 1990)
    mode 3 - by magnitude (e.g. Han et al. 2015 and others)
    """

    my_prune_fn = retrieve_prune_fn(mode)
    batch_size = 1024
    number_prunes = 28
    h_dim = 16
    number_hidden = 2
    display_every = number_epochs // 10

    train_x, train_y, val_x, val_y = split_digits(my_seed)

    in_dim = train_x.shape[-1]
    out_dim = train_y.shape[-1]

    layers = initialize_model(my_seed, in_dim, h_dim, out_dim, number_hidden)

    ticket_layers = copy.deepcopy(layers)

    progress = print_progress()
    for step in range(number_epochs):

        if step % display_every == 0:
            progress += print_progress(layers,\
                    train_x, train_y, tag="train", verbose=verbose, step=step)
            progress += print_progress(layers,\
                    val_x, val_y, tag="valid", verbose=verbose, step=step)

        batch_indices = np.random.randint(train_x.shape[0],\
                size=(batch_size,))
        batch_x, batch_y = train_x[batch_indices], train_y[batch_indices]
        grad_layers = grad_loss(batch_x, batch_y, nll_loss_fn, layers)

        layers = sgd_update(layers, grad_layers, lr=lr)

    for pruning_step in range(number_prunes):
        layers = my_prune_fn(layers,  train_x, train_y)
        ticket_layers = my_prune_fn(ticket_layers, train_x, train_y)

    progress += print_progress(layers,\
            train_x, train_y, tag="train_post_prune", verbose=verbose, step=step)
    progress += print_progress(layers,\
            val_x, val_y, tag="valid_post_prune", verbose=verbose, step=step)

    for steps in range(display_every):
        
        grad_layers = grad_loss(train_x, train_y, nll_loss_fn, layers)
        layers = sgd_update(layers, grad_layers, lr=lr)

    progress += print_progress(layers,\
            train_x, train_y, tag="train_retrained", verbose=verbose, step=step)
    progress += print_progress(layers,\
            val_x, val_y, tag="valid_retrained", verbose=verbose, step=step)

    save_dir = os.path.join("parameters", f"mode_{mode}")

    if os.path.exists(save_dir):
        pass
    else:
        os.system(f"mkdir -p {save_dir}")

    print(f"model shape with mode {mode} pruning")

    for ii, layer in enumerate(layers):
        print(layer.shape, np.sum(np.abs(layer) > 0))
        save_filepath = os.path.join(save_dir,f"layer{ii}.npy")
        np.save(save_filepath, layer)

    # examine lottery ticket hypothesis
    pruned_dims = [layer.shape for layer in ticket_layers]
    noticket_layers = initialize_weights(pruned_dims)

    for step in range(number_epochs):
        
        batch_indices = np.random.randint(train_x.shape[0],\
                size=(batch_size,))
        batch_x, batch_y = train_x[batch_indices], train_y[batch_indices]
        ticket_grad_layers = grad_loss(batch_x, batch_y, nll_loss_fn, ticket_layers)
        noticket_grad_layers = grad_loss(batch_x, batch_y, nll_loss_fn, noticket_layers)

        ticket_layers = sgd_update(ticket_layers, ticket_grad_layers, lr=lr)
        noticket_layers = sgd_update(noticket_layers, noticket_grad_layers, lr=lr)

        if step % display_every == 0 or step == (number_epochs-1):
            progress += print_progress(ticket_layers,\
                    train_x, train_y, tag="train_ticket", verbose=verbose, step=step)
            progress +=  print_progress(ticket_layers,\
                    val_x, val_y, tag="valid_ticket", verbose=verbose, step=step)
            progress +=  print_progress(noticket_layers,\
                    train_x, train_y, tag="train_noticket", verbose=verbose, step=step)
            progress +=  print_progress(noticket_layers,\
                    val_x, val_y, tag="valid_noticket", verbose=verbose, step=step)

    progress_filepath = f"mode_{mode}_log.csv"
    with open(progress_filepath, "w") as f:
        f.write(progress)


In [None]:
number_of_epochs = 10000
lr = 1e-1

for mode in range(4):
    train(number_epochs=number_of_epochs,mode=mode, lr=lr) 