In [None]:
%load_ext autoreload
%autoreload 2
from Utils import *
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
import numpy as np
import tensorflow.keras as K
from tensorflow.keras import losses, optimizers, metrics

config=ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction=0.5
config.gpu_options.allow_growth=True
session=InteractiveSession(config=config)
import ResNet
plt.style.use('classic')

In [None]:
batch_size=128
num_classes=100 # For CIFAR 10 use 10
from sklearn.model_selection import train_test_split

x_train, y_train, x_test, y_test, input_shape=load_cifar100(num_classes) # For CIFAR 10 use load_cifar10(num_classes)
x_train,x_val,y_train,y_val=train_test_split(x_train,y_train,test_size=0.1,random_state=12345)
# x_train_12k=x_train.copy()[:12000]
# y_train_12k=y_train.copy()[:12000]
# print(x_train.shape,x_val.shape,y_train.shape,y_val.shape,x_train_12k.shape,y_train_12k.shape)
if num_classes==2:
    num_classes=1
# train_ds=tf.data.Dataset.from_tensor_slices((x_train_12k,y_train_12k)).shuffle(y_train_12k.shape[0]).batch(batch_size)
train_ds=tf.data.Dataset.from_tensor_slices((x_train,y_train)).shuffle(y_train.shape[0]).batch(batch_size)

test_ds=tf.data.Dataset.from_tensor_slices((x_test,y_test)).shuffle(y_test.shape[0]).batch(batch_size)
val_ds=tf.data.Dataset.from_tensor_slices((x_val,y_val)).shuffle(y_val.shape[0]).batch(batch_size)
# print(train_ds.shape)
import random
sample_index=random.sample(range(x_val.shape[0]),2000)
sample_ds=tf.data.Dataset.from_tensor_slices((x_val[sample_index],y_val[sample_index])).shuffle(y_val[sample_index].shape[0]).batch(batch_size)

gc.collect()

In [None]:
from tensorflow.keras.layers import Dense, BatchNormalization, Activation , Flatten
input_size=32
def resnet_layer(inputs,
                 num_filters=16,
                 kernel_size=3,
                 strides=1,
                 activation='relu',
                 batch_normalization=True,
                 conv_first=True):

    x = inputs
    if conv_first:
        x = K.layers.Conv2D(num_filters,
                            kernel_size=kernel_size,
                            strides=strides,
                            padding='same')(x)
        if batch_normalization:
            x = K.layers.BatchNormalization()(x)
        if activation is not None:    
            x = K.layers.Activation(activation)(x)
    else:
        if batch_normalization:
            x = K.layers.BatchNormalization()(x)
        if activation is not None:    
            x = K.layers.Activation(activation)(x)
        x = K.layers.Conv2D(num_filters,
                    kernel_size=kernel_size,
                    strides=strides,
                    padding='same')(x)    
    return x

In [None]:
def resnet(k=16): #16 is the Standard ResNet20
    depth=20 # resnet 20
    num_classes=10
    num_filters = k
    
    num_res_blocks = int((depth - 2) / 6)
    inputs = K.layers.Input(shape=(input_size,input_size, 3))
    x = resnet_layer(inputs=inputs,
                     num_filters = num_filters)
    for stack in range(3):
        for res_block in range(num_res_blocks):
            strides = 1
            if stack > 0 and res_block == 0:  # first layer but not first stack
                strides = 2  # downsample
            y = resnet_layer(inputs=x,
                             num_filters=num_filters,
                             strides=strides)
            y = resnet_layer(inputs=y,
                             num_filters=num_filters,
                             activation=None)
            if stack > 0 and res_block == 0:  # first layer but not first stack
                # linear projection residual shortcut connection to match
                # changed dims
                x = resnet_layer(inputs=x,
                                 num_filters=num_filters,
                                 kernel_size=1,
                                 strides=strides,
                                 activation=None,
                                 batch_normalization=False)
            # x = K.layers.Add()([x,x_inp])
            x = K.layers.Add()([x, y])
            x = K.layers.Activation('relu')(x)
        num_filters *= 2
    # Add classifier on top.
    # v1 does not use BN after last shortcut connection-ReLU
    x = K.layers.AveragePooling2D(pool_size=3)(x)
    y = K.layers.Flatten()(x)
    outputs = K.layers.Dense(num_classes,
                    activation='softmax')(y)

    # Instantiate model.
    model = K.models.Model(inputs=inputs, outputs=outputs)
    return model

In [None]:
net = resnet(k=16)
net.summary()

# k=16 correspond to normal ResNet-20. Should have about 27k params

In [None]:
# from keras.utils.layer_utils import count_params
# Num_count=30
# trainable_cnt=np.zeros(Num_count)
# non_train_cnt=np.zeros(Num_count)
# total_cnt=np.zeros(Num_count)
# for k in range(Num_count):
#     model=resnet(k=k+1)
#     trainable_cnt[k]=count_params(model.trainable_weights)
#     non_train_cnt[k]=count_params(model.non_trainable_weights)
#     total_cnt[k]=trainable_cnt[k]+non_train_cnt[k]



In [None]:

# plt.plot(total_cnt,label='total params',linewidth=1,color='black')
# plt.plot(trainable_cnt,label='trainable',linewidth=1,color='pink')
# plt.plot(non_train_cnt,label='non-trainable',linewidth=1)
# plt.legend()

In [None]:
k=np.array([[1,5,8,10,12,14,16,18,20,25,30,35,40,45,50]])

train_acc=[]
test_acc=[]
for width in k:
    print('-------------k=',width,'---------------')
    pretrain_model=resnet(k=width)
    pretrain_model.build(input_shape)
    history=regular_training(pretrain_model,train_ds,test_ds,lr=0.001,
                             epoch=30,num_classes=num_classes,
                             opt='ADAM',reduceLROnPlateau=True,
                             augmentation=True)
    train_acc.append(history.train_acc[-1])
    test_acc.append(history.test_acc[-1])
    history.train_ds=[]
    history.test_ds=[]
    history.save('result/pretrain_{}.history'.format(width))
np.save('result/pretrain_acc',[train_acc,test_acc])

In [None]:
k=np.array([[1,5,8,10,12,14,16,18,20,25,30,35,40,45,50]])

In [None]:
sparse_train_acc=[]
sparse_test_acc=[]
s=10
for width in k:
    print('-------------k=',width,'---------------,s/k=',float(s)/float(width))
    if width<s:
        continue
    history_training_full=TrainHistory.load('result/pretrain_{}.history'.format(width))
    pruned_model=resnet(k=width)
    pruned_model.build(input_shape)
    pruned_model.set_weights(history_training_full.final_weights)
    masks=compute_mask(pruned_model,float(s)/float(width),global_pruning=True,exclude=[])
    pruned_model.set_weights(history_training_full.init_weights)
    sparse_history=mask_training(pruned_model,masks,train_ds,test_ds,lr=0.001,
                         epoch=100,num_classes=num_classes,opt='ADAM',reduceLROnPlateau=True)
    sparse_train_acc.append(sparse_history.train_acc[-1])
    sparse_test_acc.append(sparse_history.test_acc[-1])
    sparse_history.train_ds=[]
    sparse_history.test_ds=[]
    sparse_history.save('result/sparse_{}_s={}.history'.format(width,s))
np.save('result/sparse_acc',[sparse_train_acc,sparse_test_acc])   
  