<a href="https://colab.research.google.com/github/prokorpio/everything_190/blob/master/SNIP_reimplementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# SNIP Reimplementation

## 1. Import Libraries


In [2]:
import time

import numpy as np
import tensorflow as tf
print('Tensorflow v', tf.__version__, sep='')
from platform import python_version
print('python v',python_version(), sep='')
import matplotlib.pyplot as plt


from tensorflow.keras.datasets import mnist
from sklearn.model_selection import train_test_split
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.layers import Input, Flatten, Dense, Conv2D
from tensorflow.keras.initializers import VarianceScaling
from tensorflow.keras.models import Model, Sequential, load_model
from tensorflow.keras.optimizers import SGD

import keras.backend as K
import keras
print('Keras v',keras.__version__,sep='')

Tensorflow v1.15.0-rc3
python v3.6.8
Keras v2.2.5


## 2. Setup Dataset

In [8]:
# load mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
img_size = x_train.shape[1:] # shape = [m, h, w]

#from sklearn.model_selection import StratifiedKFold
#folds = 7
#cv = StratifiedKFold(n_splits=folds, random_state=42, shuffle=True)
#x = np.concatenate((x, x_test))
#y = np.concatenate((y, y_test))

# normalize
x_train = x_train.astype('float32')/255
x_test = x_test.astype('float32')/255

# split train:validation as 90%:10%
#x_train, x_valid, y_train, y_valid = train_test_split(x, y, test_size=0.1, shuffle=False)

# convert y labels to one-hot vectors
num_classes = len(np.unique(y_train))
#y = to_categorical(y, num_classes) # do this inside cross-validation loop
y_train = to_categorical(y_train, num_classes)
y_test = to_categorical(y_test, num_classes)

# view sample shape
print('xtrain shape:',x_train.shape)
print('ytrain shape:',y_train.shape)
print('xtest shape:',x_test.shape)
print('ytest shape:',y_test.shape)

xtrain shape: (60000, 28, 28)
ytrain shape: (60000, 10)
xtest shape: (10000, 28, 28)
ytest shape: (10000, 10)


## 3. Create Model

### 3.1 Define Custom Layers 

In [0]:
# source: https://stackoverflow.com/questions/50290769/specify-connections-in-nn-in-keras

class PrunableDense(Dense):

    def __init__(self,units,mask,**kwargs):
          
        self.mask = mask         

        #initalize the original Dense with all the usual arguments   
        super(PrunableDense,self).__init__(units,**kwargs)  


    def call(self, inputs):
        output = K.dot(inputs, self.kernel * self.mask)
        if self.use_bias:
            output = K.bias_add(output, self.bias)
        if self.activation is not None:
            output = self.activation(output)
        return output

### 3.2 Construct Custom Model

In [0]:
def LeNet_300_100(input_shape, num_classes,mask=[1,1,1]):
    """ 
    LeNet 3-Layer FC implementation
    """
    
    vs = VarianceScaling()
    
    X_input = Input(shape = input_shape)
    X = Flatten()(X_input) # 28 * 28 = 784
    X = PrunableDense(300, 
                      mask = mask[0],
                      use_bias = False,
                      kernel_initializer = vs,
                      activation='relu')(X) 
    X = PrunableDense(100,
                      mask = mask[1],
                      use_bias = False,
                      kernel_initializer = vs,
                      activation='relu')(X)
    X = PrunableDense(num_classes, 
                      mask = mask[2],
                      use_bias = False,
                      kernel_initializer = vs,
                      activation='softmax')(X)
    
    return Model(inputs=X_input, outputs=X)


## 4. Train Model

In [10]:
# setup Hyperparams
SGDdizer = SGD(lr=0.1,
               momentum = 0.9,
               decay = 0.0005)
batch_size = 100
epochs = 10
    
# fit model using Strat K-fold cross-validation
scores = []
trial = 0
#for train_index, test_index in cv.split(x,y):
for _ in range(5):
    #x_train, y_train = x[train_index], y[train_index] 
    #x_test, y_test = x[test_index], y[test_index]
    #y_train = to_categorical(y_train, num_classes) 
    #y_test = to_categorical(y_test, num_classes) 
    LeNet = LeNet_300_100(img_size, num_classes)
    LeNet.compile(optimizer=SGDdizer,
              loss='categorical_crossentropy',
              metrics=['accuracy'])
#    start_time = time.time()
    LeNet.fit(x_train, y_train,
              batch_size = batch_size,
              epochs = epochs,
              verbose=0, 
              validation_split=0.1)
#    elapsed_time = time.time() - start_time
#    print("Trained split",split,end='')
#    time.strftime(" elapsed time = %H:%M:%S", time.gmtime(elapsed_time))
    preds = LeNet.evaluate(x_test, y_test, batch_size=batch_size)
    trial_error = 1-preds[1]
    print("Error of trial",trial,"=",trial_error)
    scores.append(trial_error)
    trial += 1
print("Average Error of Original Network =",np.mean(scores))

Error of split 0 = 0.01690000295639038
Error of split 1 = 0.01910001039505005
Error of split 2 = 0.021000027656555176
Error of split 3 = 0.021799981594085693
Error of split 4 = 0.024500012397766113
5-Fold Validation Error of Original Network = 0.020660006999969484


## 5. Results

In [0]:
#@title
def plot_trainings(trainings):
    plt.clf()
    acc = []
    val_acc = []
    loss = []
    val_loss = []

    # concatenate all data points for all trainings
    for train in trainings:
        for point in train.history['acc']:
            acc.append(point)
        for point in train.history['val_acc']:
            val_acc.append(point)
        for point in train.history['loss']:
            loss.append(point)
        for point in train.history['val_loss']:
            val_loss.append(point)
    #epochs = range(1, len(mae) + 1)

    # Plot training history for accuracy
    plt.plot(acc)
    plt.plot(val_acc)
    plt.title('Model Accuracy')
    plt.ylabel('accuracy')
    plt.xlabel('epoch')
    plt.legend(['train', 'test'], loc='upper left')
    plt.grid(True)
    plt.show()

In [0]:
#@title
preds = LeNet.evaluate(x_test, y_test, batch_size=batch_size)
plot_trainings(trainings)
print("\nTest Error: {:.2%}\n".format(1-preds[1])) # should get 1.65ish error

# II. Perform SNIP

In [11]:
# create initial mask matrix (all ones), using LeNet ^^
mask = []
for layer in LeNet.trainable_weights:
    #print(layer.shape)
    mask.append(tf.Variable(np.ones(layer.shape),
                            trainable = False,
                            dtype = 'float32'))
for i, _ in enumerate(mask):
    print('Mask for layer ', i+1,':',sep='')
    print(mask[i])

LeNet_to_Prune = LeNet_300_100(img_size, num_classes,mask=mask)

Mask for layer 1:
<tf.Variable 'Variable:0' shape=(784, 300) dtype=float32_ref>
Mask for layer 2:
<tf.Variable 'Variable_1:0' shape=(300, 100) dtype=float32_ref>
Mask for layer 3:
<tf.Variable 'Variable_2:0' shape=(100, 10) dtype=float32_ref>


In [12]:
# Compute Gradient wrt Mask
labels = y_train[:100] # mini-batch of 100 on MNIST from Experiment Setup Section
loss = K.mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=LeNet_to_Prune.outputs))
grads = K.gradients(loss, mask) # get gradient of loss wrt mask

trainingExample = x_train[:100]
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
evaluated_grads = sess.run(grads,feed_dict={LeNet_to_Prune.input:trainingExample})
sess.close()

print('Grads shape per layer')
for layer in evaluated_grads:
    print(layer.shape)

Grads shape per layer
(784, 300)
(300, 100)
(100, 10)


In [13]:
# convert grads to a dictionary for matrix pruning and reconstruction
# each weight should have an associated key
model_layers = []
for index, layer in enumerate(evaluated_grads):
    print('Shape of layer',index+1, '=', layer.shape)
    #print(type(layer))
    model_layers.append({key:abs(grad) for key, grad in np.ndenumerate(layer)})
                         # key of a weight is its position in tuple form
                         # note, absolute value is already taken here
    print('sample grad value:', model_layers[index][(99,9)]) # sanity check
    print('vectorized? ', np.array(list(model_layers[index].values())).shape)

Shape of layer 1 = (784, 300)
sample grad value: 6.9194857e-07
vectorized?  (235200,)
Shape of layer 2 = (300, 100)
sample grad value: 4.1439953e-06
vectorized?  (30000,)
Shape of layer 3 = (100, 10)
sample grad value: 5.0648437e-06
vectorized?  (1000,)


In [14]:
# I. Normalize and get sensitivity per weight

## 1. get sum of all grads
sum_of_grads = 0
for layer in model_layers: # layer = {(row, col) : grad_of_weight}
    sum_in_layer = sum(layer.values())
    #print(sum_in_layer)
    sum_of_grads += sum_in_layer
print('overall sum =',sum_of_grads)

## 2. normalize each grad
model_layers_normed = []
for index, layer in enumerate(model_layers):
    model_layers_normed.append({key:0 for key in layer.keys()})
    for key in model_layers_normed[index].keys():
        model_layers_normed[index][key] = layer[key]/sum_of_grads

overall sum = 0.7242147353031478


In [27]:
# II. Get position of Top-K weights with highest sensitivity scores
target_sparsity = 0.98 #(m-k)/m, where m is total params and k is non_zero w

## x. Count total parameter values
total_params = 0
for i, layer_weights in enumerate(LeNet_to_Prune.trainable_weights):
    total_params += int(np.prod(layer_weights.shape))
print("# of total params =",total_params)

kappa = int(round(total_params * (1. - target_sparsity)))
print("# of weights to keep =", kappa)

## x. SortDescending
### concatenate the per layer dicts into 1 dict
all_model_grads = {}
for index, layer in enumerate(model_layers_normed):
    all_model_grads.update({(index+1,) + key : grad for key, grad in layer.items()})
                           # new key value = (layer,row,col)
print('num of items =', len(all_model_grads))
print('sample key:', list(all_model_grads.keys())[33] )
print('sample value:', list(all_model_grads.values())[33])  
print('check sum of all normed values = ',sum(all_model_grads.values()))

### sort keys using values
keys = list(all_model_grads.keys())
values = list(all_model_grads.values())
sorted_keys = sorted(all_model_grads.keys(), 
                     key=lambda k: all_model_grads[k],
                     reverse = True)

# get 1st k keys
top_k_keys = sorted_keys[:kappa]

# create mask of zeros
pruning_mask = [] # list of np arrays
for layer in evaluated_grads:
    pruning_mask.append(np.zeros(layer.shape))
    
for i, _ in enumerate(pruning_mask):
    print('Mask for layer ', i+1,':',sep='')
    print(pruning_mask[i].shape)
    
# loop on the keys, and set mask value to 1
for layer,row,col in top_k_keys:
    pruning_mask[layer-1][row,col] = 1

zero_norm = 0 # for sanity check
for layer in pruning_mask:
    zero_norm += np.sum(layer)
print('zero norm =',zero_norm)
print('kappa =', kappa)


# of total params = 266200
# of weights to keep = 5324
num of items = 266200
sample key: (1, 0, 33)
sample value: 0.0
check sum of all normed values =  1.000000000000016
Mask for layer 1:
(784, 300)
Mask for layer 2:
(300, 100)
Mask for layer 3:
(100, 10)
zero norm = 5324.0
kappa = 5324


In [0]:
# more desired to just modify the mask of previous model's layers, 
#    instead of creating new model with new mask.
# if not possible to modify the masks, then create new model and just load
#    weights of the previous model. So the mask is the only difference 
#    (not the initialization etc)

#Pruned_LeNet = LeNet_300_100(img_size, num_classes,mask=pruning_mask)
#Pruned_LeNet.summary()

In [28]:
# setup Hyperparams
SGDdizer = SGD(lr=0.1,
               momentum = 0.9,
               decay = 0.0005)
batch_size = 100
epochs = 10
    
# fit model using Strat K-fold cross-validation
scores = []
trial = 0
#for train_index, test_index in cv.split(x,y):
for _ in range(5):
    #x_train, y_train = x[train_index], y[train_index] 
    #x_test, y_test = x[test_index], y[test_index]
    #y_train = to_categorical(y_train, num_classes) 
    #y_test = to_categorical(y_test, num_classes) 
    Pruned_LeNet = LeNet_300_100(img_size, num_classes, mask=pruning_mask)
    Pruned_LeNet.compile(optimizer=SGDdizer,
              loss='categorical_crossentropy',
              metrics=['accuracy'])
#    start_time = time.time()
    Pruned_LeNet.fit(x_train, y_train,
              batch_size = batch_size,
              epochs = epochs,
              verbose=0, 
              validation_split=0.1)
#    elapsed_time = time.time() - start_time
#    print("Trained split",split,end='')
#    time.strftime(" elapsed time = %H:%M:%S", time.gmtime(elapsed_time))
    preds = Pruned_LeNet.evaluate(x_test, y_test, batch_size=batch_size)
    trial_error = 1-preds[1]
    print("Error of trial",trial,"=",trial_error)
    scores.append(trial_error)
    trial += 1
print("Average Error of", target_sparsity,"Sparse Network =",np.mean(scores))

Error of trial 0 = 0.044499993324279785
Error of trial 1 = 0.06330001354217529
Error of trial 2 = 0.07239997386932373
Error of trial 3 = 0.07660001516342163
Error of trial 4 = 0.08689999580383301
Average Error of 0.98 Sparse Network = 0.06873999834060669


In [0]:
#@title
# setup Hyperparams
SGDdizer = SGD(lr=0.1,
               momentum = 0.9,
               decay = 0.0005)
Pruned_LeNet.compile(optimizer=SGDdizer,
              loss='categorical_crossentropy',
              metrics=['accuracy'])

batch_size = 100
epochs = 10
    
# fit model
trainings = []

start_time = time.time()
trainings.append(Pruned_LeNet.fit(
                  x_train, y_train,
                  batch_size = batch_size,
                  epochs = epochs,
                  validation_data = (x_valid, y_valid)))
elapsed_time = time.time() - start_time

time.strftime("Total training time  = %H:%M:%S", time.gmtime(elapsed_time))

In [0]:
#@title
preds = Pruned_LeNet.evaluate(x_test, y_test, batch_size=batch_size)
plot_trainings(trainings)
print("\nTest Error: {:.2%}\n".format(1-preds[1])) 

# III. Experiment Results (For Pruning LeNet300-100)
> From SNIP Paper
+ Figure 1 in paper shows **start of increase in error @ sparsity = 90.**
+ Table 1 shows **@ sparsity = 95% and 98%,** they got **1.6% and 2.4% errors** respectively. Their **reference error** (error from unpruned/original network) is **1.7%.**
+ Errors from reimplementation were taken from the average of 5 evaluations.
+ Note that from [Yann LeCun's implementation](http://yann.lecun.com/exdb/mnist/) himself, the error was 3.05%.


Sparsity (%) | SNIP Error (%) | Reimplementation Error (%)
--- | --- | ---
orig | 1.7 | 2.07
80 | ~1.73 | 2.52
90 | ~1.78 | 3.26
95 | 1.6 | 4.26
98 | 2.4 | 6.87