In [2]:
import tensorflow as tf
import numpy as np

In [3]:
dtype = 'float32'
tf.keras.backend.set_floatx(dtype)

In [27]:
fashion_mnist = tf.keras.datasets.fashion_mnist
(X_train, y_train), (X_test, y_test) = fashion_mnist.load_data()

X_train = X_train.astype(dtype) / 255.0
y_train = y_train.astype(dtype)
X_test = X_test.astype(dtype)  / 255.0
y_test = y_test.astype(dtype)

X_train = np.reshape(X_train, (-1, 784))
X_test = np.reshape(X_test, (-1, 784))

In [28]:
class SSRegularizer(tf.keras.regularizers.Regularizer):
    def __init__(self, l1):
        self.l1 = l1

    def __call__(self, x):
        scaling_vector = tf.cumsum(tf.constant(self.l1, shape=(x.shape[-1],), dtype=dtype), axis=0) - self.l1
        return tf.reduce_sum(scaling_vector * tf.abs(x))

    def get_config(self):
        return {'l1': float(self.l1)}


class SSModel(tf.keras.Model):
    def __init__(self, units, activation=None, l1=0.01, kernel_initializer='glorot_uniform', bias_initializer='zeros'):
        super().__init__()
        self.activation1 = tf.keras.activations.get(activation)
        self.activation2 = tf.keras.activations.get('softmax')
        self.l1 = l1
        self.kernel_initializer = tf.keras.initializers.get(kernel_initializer)
        self.bias_initializer = tf.keras.initializers.get(bias_initializer)
        self.regularizer = SSRegularizer(self.l1)
        
        self.W1 = tf.Variable(
            name='W1',
            initial_value=self.kernel_initializer(shape=(784, units), dtype=dtype),
            trainable=True)
        
        self.b1 = tf.Variable(
            name='b1',
            initial_value=self.bias_initializer(shape=(units,), dtype=dtype),
            trainable=True)
        
        self.W2 = tf.Variable(
            name='W2',
            initial_value=self.kernel_initializer(shape=(units, 10), dtype=dtype),
            trainable=True)
        
        self.b2 = tf.Variable(
            name='b2',
            initial_value=self.bias_initializer(shape=(10,), dtype=dtype),
            trainable=True)
        
        self.add_loss(lambda: self.regularizer(self.W1))
        self.add_loss(lambda: self.regularizer(self.b1))

    def call(self, inputs):
        S1 = tf.matmul(inputs, self.W1)
        A1 = self.activation1(S1 + self.b1)
        A2 = self.activation2(tf.matmul(A1, self.W2) + self.b2)

        return A2
    
    def prune(self, threshold=0.001):
        W1 = self.W1.value()
        b1 = self.b1.value()
        W2 = self.W2.value()
        
        weights_with_biases = tf.concat([W1, tf.reshape(b1, (1, -1))], axis=0)
        neurons_are_active = tf.math.reduce_max(weights_with_biases, axis=0) >= threshold
        active_neurons_indices = tf.reshape(tf.where(neurons_are_active), (-1,))
        
        new_W1 = tf.gather(W1, active_neurons_indices, axis=1)
        new_b1 = tf.gather(b1, active_neurons_indices, axis=0)
        new_W2 = tf.gather(W2, active_neurons_indices, axis=0)
        
        self.W1 = tf.Variable(name='W1', initial_value=new_W1, trainable=True)
        self.b1 = tf.Variable(name='b1', initial_value=new_b1, trainable=True)
        self.W2 = tf.Variable(name='W2', initial_value=new_W2, trainable=True)
    
    def grow(self, min_new_neurons=5, scaling_factor=0.001):      
        W1 = self.W1.value()
        b1 = self.b1.value()
        W2 = self.W2.value()
        
        n_new_neurons = max(min_new_neurons, int(W1.shape[1] * 0.2))
        
        W1_growth = self.kernel_initializer(shape=(W1.shape[0], W1.shape[1] + n_new_neurons), dtype=dtype)[:, -n_new_neurons:] * scaling_factor
        b1_growth = self.kernel_initializer(shape=(n_new_neurons,), dtype=dtype)
        W2_growth = self.kernel_initializer(shape=(W2.shape[0] + n_new_neurons, W2.shape[1]), dtype=dtype)[-n_new_neurons:, :]
        
        new_W1 = tf.concat([W1, W1_growth], axis=1)
        new_b1 = tf.concat([b1, b1_growth], axis=0)
        new_W2 = tf.concat([W2, W2_growth], axis=0)
        
        self.W1 = tf.Variable(name='W1', initial_value=new_W1, trainable=True)
        self.b1 = tf.Variable(name='b1', initial_value=new_b1, trainable=True)
        self.W2 = tf.Variable(name='W2', initial_value=new_W2, trainable=True)

In [53]:
def get_param_string(weights):
    param_string = ""
    max_parameters = tf.math.reduce_max(weights, axis=0).numpy()
    magnitudes = np.floor(np.log10(max_parameters))
    for m in magnitudes:
        if m > 0:
            m = 0
        param_string += str(int(-m))
    return param_string


def print_epoch_statistics(model):
    y_pred = model(X_train)
    loss = tf.reduce_mean(tf.keras.losses.sparse_categorical_crossentropy(y_train, y_pred))
    accuracy = tf.reduce_mean(tf.keras.metrics.sparse_categorical_accuracy(y_train, y_pred))
    
    y_pred_val = model(X_test)
    val_loss = tf.reduce_mean(tf.keras.losses.sparse_categorical_crossentropy(y_test, y_pred_val))
    val_accuracy = tf.reduce_mean(tf.keras.metrics.sparse_categorical_accuracy(y_test, y_pred_val))
    print(f"loss: {loss} - accuracy: {accuracy} - val_loss: {val_loss} - val_accuracy: {val_accuracy}")
    print(f"units: {model.W1.shape[1]} - {get_param_string(model.W1)}")
    

def train_model(model, optimizer, epochs, batch_size, train_dataset):
    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}/{epochs}")
        
        print("Before growing:")
        print_epoch_statistics(model)
        model.grow(min_new_neurons=5, scaling_factor=0.001)
        print("After growing:")
        print_epoch_statistics(model)

        for step, (x_batch, y_batch) in enumerate(train_dataset):
            with tf.GradientTape() as tape:
                y_pred = model(x_batch, training=True)
                loss_value = tf.reduce_mean(tf.keras.losses.sparse_categorical_crossentropy(y_batch, y_pred))
                loss_value += sum(model.losses)

            grads = tape.gradient(loss_value, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))
        
        print("Before pruning:")
        print_epoch_statistics(model)
        model.prune(threshold=0.001)
        print("After pruning:")
        print_epoch_statistics(model)

In [57]:
epochs = 20
batch_size = 32

train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

model = SSModel(units=200, activation='relu', l1=0.000001, kernel_initializer='he_normal')
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

train_model(model, optimizer, epochs, batch_size, train_dataset)

Epoch 1/20
Before growing:
loss: 2.3557889461517334 - accuracy: 0.16066665947437286 - val_loss: 2.3568131923675537 - val_accuracy: 0.16110000014305115
units: 200 - 11111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111
After growing:
loss: 2.366809368133545 - accuracy: 0.1540333330631256 - val_loss: 2.3678812980651855 - val_accuracy: 0.15639999508857727
units: 240 - 111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111114444444444444444444444444444444444444444
Before pruning:
loss: 0.46726271510124207 - accuracy: 0.8273666501045227 - val_loss: 0.5071299076080322 - val_accuracy: 0.8126999735832214
units: 240 - 11111111111111111111121111111121111111111111111111111131113112141311112141411111311113111111

loss: 0.29783010482788086 - accuracy: 0.8907833099365234 - val_loss: 0.3593212366104126 - val_accuracy: 0.8690999746322632
units: 142 - 1011111111111111111111111111111111212121112111122232133231333333233333333333322333333333222333333233333223323332333323333333333333333333333333
Epoch 7/20
Before growing:
loss: 0.29783010482788086 - accuracy: 0.8907833099365234 - val_loss: 0.3593212366104126 - val_accuracy: 0.8690999746322632
units: 142 - 1011111111111111111111111111111111212121112111122232133231333333233333333333322333333333222333333233333223323332333323333333333333333333333333
After growing:
loss: 0.29767563939094543 - accuracy: 0.8903833627700806 - val_loss: 0.3596833348274231 - val_accuracy: 0.8672000169754028
units: 170 - 10111111111111111111111111111111112121211121111222321332313333332333333333333223333333332223333332333332233233323333233333333333333333333333334444444444444444444444444444
Before pruning:
loss: 0.2979072332382202 - accuracy: 0.8901833295822144 - val_loss: 0.3673929

Before pruning:
loss: 0.27381640672683716 - accuracy: 0.8989666700363159 - val_loss: 0.3640250563621521 - val_accuracy: 0.8736000061035156
units: 80 - 00000111101101111111112111411111111111324423342422443244233224242442323334232423
After pruning:
loss: 0.27381640672683716 - accuracy: 0.8989666700363159 - val_loss: 0.3640250563621521 - val_accuracy: 0.8736000061035156
units: 67 - 0000011110110111111111211111111111111322332223223322224423233323223
Epoch 15/20
Before growing:
loss: 0.27381640672683716 - accuracy: 0.8989666700363159 - val_loss: 0.3640250563621521 - val_accuracy: 0.8736000061035156
units: 67 - 0000011110110111111111211111111111111322332223223322224423233323223
After growing:
loss: 0.27821439504623413 - accuracy: 0.8970666527748108 - val_loss: 0.36789897084236145 - val_accuracy: 0.870199978351593
units: 80 - 00000111101101111111112111111111111113223322232233222244232333232234444444444444
Before pruning:
loss: 0.2658196985721588 - accuracy: 0.9023000001907349 - val_loss: 0.35

In [50]:
model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(50, activation='relu', kernel_initializer='he_normal'),
    tf.keras.layers.Dense(10, activation='softmax', kernel_initializer='he_normal')
])

In [51]:
model.compile(loss='sparse_categorical_crossentropy', optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), metrics=['accuracy'])

In [52]:
model.fit(X_train, y_train, epochs=20, validation_data=(X_test, y_test))

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<tensorflow.python.keras.callbacks.History at 0x7f5317706a60>