In [1]:
import tensorflow as tf
import numpy as np

In [2]:
dtype = 'float32'
tf.keras.backend.set_floatx(dtype)

In [3]:
# fashion_mnist = tf.keras.datasets.fashion_mnist
# (X_train, y_train), (X_test, y_test) = fashion_mnist.load_data()

# X_train = X_train.astype(dtype) / 255.0
# y_train = y_train.astype(dtype)
# X_test = X_test.astype(dtype)  / 255.0
# y_test = y_test.astype(dtype)

# X_train = np.reshape(X_train, (-1, 784))
# X_test = np.reshape(X_test, (-1, 784))

In [4]:
cifar10 = tf.keras.datasets.cifar10
(X_train, y_train), (X_test, y_test) = cifar10.load_data()

X_train = X_train.astype(dtype) / 255.0
y_train = y_train.astype(dtype)
X_test = X_test.astype(dtype)  / 255.0
y_test = y_test.astype(dtype)

X_train = np.reshape(X_train, (-1, 3072))
X_test = np.reshape(X_test, (-1, 3072))

In [98]:
class SSRegularizer(tf.keras.regularizers.Regularizer):
    def __init__(self, l1):
        self.l1 = l1

    def __call__(self, x):
        scaling_vector = tf.cumsum(tf.constant(self.l1, shape=(x.shape[-1],), dtype=dtype), axis=0) - self.l1
        return tf.reduce_sum(scaling_vector * tf.abs(x))

    def get_config(self):
        return {'l1': float(self.l1)}


class SSLayer(tf.keras.Model):
    def __init__(self, input_units, units, activation, l1, kernel_initializer, bias_initializer):
        super().__init__()
        
        self.activation = tf.keras.activations.get(activation)
        self.l1 = l1
        self.kernel_initializer = tf.keras.initializers.get(kernel_initializer)
        self.bias_initializer = tf.keras.initializers.get(bias_initializer)
        self.regularizer = SSRegularizer(self.l1)
        
        self.W = tf.Variable(
            name='W',
            initial_value=self.kernel_initializer(shape=(input_units, units), dtype=dtype),
            trainable=True)
        
        self.b = tf.Variable(
            name='b',
            initial_value=self.bias_initializer(shape=(units,), dtype=dtype),
            trainable=True)
        
        self.add_loss(lambda: self.regularizer(self.W))
        self.add_loss(lambda: self.regularizer(self.b))
    
    def call(self, inputs):
        return self.activation(tf.matmul(inputs, self.W) + self.b)


class SSModel(tf.keras.Model):
    def __init__(self, layer_sizes, activation=None, l1=0.01, kernel_initializer='glorot_uniform', bias_initializer='zeros'):
        super().__init__()
        
        self.sslayers = list()
        for l in range(len(layer_sizes) - 1):
            input_units = layer_sizes[l]
            units = layer_sizes[l + 1]
            if l < len(layer_sizes) - 2:
                layer = SSLayer(input_units, units, activation, l1, kernel_initializer, bias_initializer)
            else:  # Last layer
                layer = SSLayer(input_units, units, 'softmax', 0., kernel_initializer, bias_initializer)
            self.sslayers.append(layer)

    def call(self, inputs):
        x = inputs
        for layer in self.sslayers:
            x = layer(x)
        return x
    
    def get_layer_sizes(self):
        layer_sizes = list()
        for l in range(len(self.sslayers)):
            layer = self.sslayers[l]
            layer_sizes.append(layer.W.shape[0])
            if l == len(self.sslayers) - 1:  # Last layer
                layer_sizes.append(layer.W.shape[1])
        return layer_sizes
    
    def print_neurons(self):
        for layer in self.sslayers:
            print(get_param_string(layer.W, layer.b))
    
    def prune(self, threshold=0.001):
        for l in range(len(self.sslayers) - 1):
            layer1 = self.sslayers[l]
            layer2 = self.sslayers[l + 1]
            
            W1 = layer1.W.value()
            b1 = layer1.b.value()
            W2 = layer2.W.value()

            weights_with_biases = tf.concat([W1, tf.reshape(b1, (1, -1))], axis=0)
            neurons_are_active = tf.math.reduce_max(tf.abs(weights_with_biases), axis=0) >= threshold
            active_neurons_indices = tf.reshape(tf.where(neurons_are_active), (-1,))
            
            new_W1 = tf.gather(W1, active_neurons_indices, axis=1)
            new_b1 = tf.gather(b1, active_neurons_indices, axis=0)
            new_W2 = tf.gather(W2, active_neurons_indices, axis=0)

            layer1.W = tf.Variable(name='W', initial_value=new_W1, trainable=True)
            layer1.b = tf.Variable(name='b', initial_value=new_b1, trainable=True)
            layer2.W = tf.Variable(name='W', initial_value=new_W2, trainable=True)
    
    def grow(self, min_new_neurons=5, scaling_factor=0.001):   
        for l in range(len(self.sslayers) - 1):
            layer1 = self.sslayers[l]
            layer2 = self.sslayers[l + 1]
       
            W1 = layer1.W.value()
            b1 = layer1.b.value()
            W2 = layer2.W.value()

            n_new_neurons = max(min_new_neurons, int(W1.shape[1] * 0.2))

            W1_growth = layer1.kernel_initializer(shape=(W1.shape[0], W1.shape[1] + n_new_neurons), dtype=dtype)[:, -n_new_neurons:] * scaling_factor
            b1_growth = layer1.bias_initializer(shape=(n_new_neurons,), dtype=dtype)
            W2_growth = layer2.kernel_initializer(shape=(W2.shape[0] + n_new_neurons, W2.shape[1]), dtype=dtype)[-n_new_neurons:, :]

            new_W1 = tf.concat([W1, W1_growth], axis=1)
            new_b1 = tf.concat([b1, b1_growth], axis=0)
            new_W2 = tf.concat([W2, W2_growth], axis=0)

            layer1.W = tf.Variable(name='W1', initial_value=new_W1, trainable=True)
            layer1.b = tf.Variable(name='b1', initial_value=new_b1, trainable=True)
            layer2.W = tf.Variable(name='W2', initial_value=new_W2, trainable=True)

In [99]:
def get_param_string(weights, bias):
    param_string = ""
    weights_with_bias = tf.concat([weights, tf.reshape(bias, (1, -1))], axis=0)
    max_parameters = tf.math.reduce_max(tf.abs(weights_with_bias), axis=0).numpy()
    magnitudes = np.floor(np.log10(max_parameters))
    for m in magnitudes:
        if m > 0:
            m = 0
        param_string += str(int(-m))
    return param_string


def print_epoch_statistics(model):
    y_pred = model(X_train)
    loss = tf.reduce_mean(tf.keras.losses.sparse_categorical_crossentropy(y_train, y_pred))
    accuracy = tf.reduce_mean(tf.keras.metrics.sparse_categorical_accuracy(y_train, y_pred))
    
    y_pred_val = model(X_test)
    val_loss = tf.reduce_mean(tf.keras.losses.sparse_categorical_crossentropy(y_test, y_pred_val))
    val_accuracy = tf.reduce_mean(tf.keras.metrics.sparse_categorical_accuracy(y_test, y_pred_val))
    print(f"loss: {loss} - accuracy: {accuracy} - val_loss: {val_loss} - val_accuracy: {val_accuracy}")
    print(f"layer sizes: {model.get_layer_sizes()}")
    model.print_neurons()
    
#     print(f"units: {model.W1.shape[1]} - {get_param_string(model.W1)}")
    

def train_model(model, optimizer, epochs, batch_size, train_dataset):
    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}/{epochs}")
        
        print("Before growing:")
        print_epoch_statistics(model)
        model.grow(min_new_neurons=10, scaling_factor=0.001)
        print("After growing:")
        print_epoch_statistics(model)

        for step, (x_batch, y_batch) in enumerate(train_dataset):
            with tf.GradientTape() as tape:
                y_pred = model(x_batch, training=True)
                loss_value = tf.reduce_mean(tf.keras.losses.sparse_categorical_crossentropy(y_batch, y_pred))
                loss_value += sum(model.losses)

            grads = tape.gradient(loss_value, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))
        
        print("Before pruning:")
        print_epoch_statistics(model)
        model.prune(threshold=0.001)
        print("After pruning:")
        print_epoch_statistics(model)

In [101]:
epochs = 20
batch_size = 32

train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

model = SSModel(layer_sizes=[3072, 100, 100, 10], activation='relu', l1=0.0001, kernel_initializer='he_normal')
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

train_model(model, optimizer, epochs, batch_size, train_dataset)

Epoch 1/20
Before growing:
loss: 2.5018150806427 - accuracy: 0.10633999854326248 - val_loss: 2.501587390899658 - val_accuracy: 0.10719999670982361
layer sizes: [3072, 100, 100, 10]
2222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222
1111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111
1111111111
After growing:
loss: 2.501856565475464 - accuracy: 0.10639999806880951 - val_loss: 2.501626968383789 - val_accuracy: 0.10750000178813934
layer sizes: [3072, 120, 120, 10]
222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222255555555555555555555
111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111144444444444444444444
1111111111
Before pruning:
loss: 1.9681110382080078 - accuracy: 0.25519999861717224 - val_loss: 1.9679940938949585 - val_accuracy: 0.2513999938964844
layer sizes: [3072, 120, 120, 10]
233134331

loss: 1.8649706840515137 - accuracy: 0.3234800100326538 - val_loss: 1.861268401145935 - val_accuracy: 0.32120001316070557
layer sizes: [3072, 31, 116, 10]
2331113333333331333335555555555
11111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111114444444444444444444
1111111111
Before pruning:
loss: 1.8825409412384033 - accuracy: 0.30928000807762146 - val_loss: 1.8808683156967163 - val_accuracy: 0.3037000000476837
layer sizes: [3072, 31, 116, 10]
2331113433434413343343334333333
11111111111112111211121333111212112133132331312333332332334434331334333334344434333333342432333344323433333433444333
1111111111
After pruning:
loss: 1.8827027082443237 - accuracy: 0.3091199994087219 - val_loss: 1.881025791168213 - val_accuracy: 0.30410000681877136
layer sizes: [3072, 24, 98, 10]
233111333313333333333333
11111111111112111211121333111212112133132331312333332332333331333333333333333323233333233333333333
1111111111
Epoch 8/20
Before growing:
loss: 1.88270270824432

loss: 1.8060723543167114 - accuracy: 0.3462800085544586 - val_loss: 1.8064124584197998 - val_accuracy: 0.33959999680519104
layer sizes: [3072, 29, 96, 10]
20113133333333333333333333333
011111111111221223111323332122132113333333333233333333333333333323333233333333333333333333233333
1111111111
Epoch 14/20
Before growing:
loss: 1.8060723543167114 - accuracy: 0.3462800085544586 - val_loss: 1.8064124584197998 - val_accuracy: 0.33959999680519104
layer sizes: [3072, 29, 96, 10]
20113133333333333333333333333
011111111111221223111323332122132113333333333233333333333333333323333233333333333333333333233333
1111111111
After growing:
loss: 1.806084394454956 - accuracy: 0.34615999460220337 - val_loss: 1.8064199686050415 - val_accuracy: 0.33959999680519104
layer sizes: [3072, 39, 115, 10]
201131333333333333333333333335555555555
0111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111114444444444444444444
1111111111
Before pruning:
loss: 1.7975218296051025 - accur

Before pruning:
loss: 1.8113493919372559 - accuracy: 0.34536001086235046 - val_loss: 1.8152164220809937 - val_accuracy: 0.3402000069618225
layer sizes: [3072, 36, 109, 10]
200113333333433334333333443343443343
0101111011113124231233341133133331333333333333333334333333333333433333333333333333333443333343433333334334334
0111111111
After pruning:
loss: 1.8110742568969727 - accuracy: 0.34544000029563904 - val_loss: 1.8149316310882568 - val_accuracy: 0.3400999903678894
layer sizes: [3072, 28, 98, 10]
2001133333333333333333333333
01011110111131223123331133133331333333333333333333333333333333333333333333333333333333333333333333
0111111111
