In [1]:
import tensorflow as tf
import numpy as np

In [112]:
mnist = tf.keras.datasets.mnist
(X_train, y_train), (X_test, y_test) = mnist.load_data()

X_train = X_train.astype('float32') / 255.0
y_train = y_train.astype('float32')
X_test = X_test.astype('float32')  / 255.0
y_test = y_test.astype('float32')

X_train = np.reshape(X_train, (-1, 784))
X_test = np.reshape(X_test, (-1, 784))

In [169]:
class SSRegularizer(tf.keras.regularizers.Regularizer):
    def __init__(self, l1):
        self.l1 = l1

    def __call__(self, x):
        scaling_matrix = tf.cumsum(tf.constant(self.l1, shape=x.shape), axis=1) - self.l1
        return self.l1 * tf.reduce_sum(tf.multiply(scaling_matrix, tf.abs(x)))

    def get_config(self):
        return {'l1': float(self.l1)}


class SSModel(tf.keras.Model):
    def __init__(self, units, activation=None, kernel_initializer='glorot_uniform', bias_initializer='zeros'):
        super().__init__()
        self.units = units
        self.activation1 = tf.keras.activations.get(activation)
        self.activation2 = tf.keras.activations.get('softmax')
        self.kernel_initializer = tf.keras.initializers.get(kernel_initializer)
        self.bias_initializer = tf.keras.initializers.get(bias_initializer)
        self.kernel_regularizer = SSRegularizer(0.01)
        
        self.W1 = tf.Variable(
            name='W1',
            initial_value=self.kernel_initializer(shape=(784, self.units), dtype='float32'),
            trainable=True)
        
        self.b1 = tf.Variable(
            name='b1',
            initial_value=self.bias_initializer(shape=(self.units,), dtype='float32'),
            trainable=True)
        
        self.W2 = tf.Variable(
            name='W2',
            initial_value=self.kernel_initializer(shape=(self.units, 10), dtype='float32'),
            trainable=True)
        
        self.b2 = tf.Variable(
            name='b2',
            initial_value=self.bias_initializer(shape=(10,), dtype='float32'),
            trainable=True)
        
        self.add_loss(lambda: self.kernel_regularizer(self.W1))
        self.add_loss(lambda: self.kernel_regularizer(self.W2))

    def call(self, inputs):
        A1 = self.activation1(tf.matmul(inputs, self.W1) + self.b1)
        A2 = self.activation2(tf.matmul(A1, self.W2) + self.b2)

        return A2

In [172]:
epochs = 5
batch_size = 32

train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

model = SSModel(units=100, activation='relu')
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)

for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs}")

    for step, (x_batch, y_batch) in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            y_pred = model(x_batch, training=True)
            loss_value = tf.reduce_mean(tf.keras.losses.sparse_categorical_crossentropy(y_batch, y_pred))
            loss_value += sum(model.losses)
            
        grads = tape.gradient(loss_value, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))

    y_pred = model(X_train)
    loss_value = tf.reduce_mean(tf.keras.losses.sparse_categorical_crossentropy(y_train, y_pred))
    accuracy = tf.reduce_mean(tf.keras.metrics.sparse_categorical_accuracy(y_train, y_pred))
    print(f"loss: {loss_value} - accuracy: {accuracy}")

Epoch 1/5
loss: 0.5553401112556458 - accuracy: 0.8698999881744385
Epoch 2/5
loss: 0.4396760165691376 - accuracy: 0.887583315372467
Epoch 3/5
loss: 0.3905361592769623 - accuracy: 0.8967666625976562
Epoch 4/5
loss: 0.36550816893577576 - accuracy: 0.9010166525840759
Epoch 5/5
loss: 0.3454333245754242 - accuracy: 0.9063000082969666


In [173]:
model.W1

<tf.Variable 'W1:0' shape=(784, 100) dtype=float32, numpy=
array([[-5.4783776e-02, -6.3836873e-02,  1.0264431e-02, ...,
        -4.0428204e-05, -7.2806259e-05,  8.1714134e-05],
       [-6.9456547e-04,  6.7252368e-02, -1.1140723e-06, ...,
         3.2782016e-05,  9.2355651e-05,  7.1384347e-05],
       [-6.0269188e-02,  4.8960105e-02,  3.8840002e-03, ...,
        -9.4548159e-05,  2.2228342e-05, -3.6981524e-05],
       ...,
       [ 4.2940959e-02, -5.4830378e-03,  9.2584165e-03, ...,
        -5.9602287e-05,  4.2661559e-05,  7.4062365e-05],
       [ 2.9998869e-02,  4.8098937e-02,  4.8972670e-02, ...,
         8.4394094e-05, -1.0140822e-05, -4.1611347e-06],
       [-6.9319934e-02,  3.8508490e-02, -6.0509924e-02, ...,
         8.4811531e-05, -1.3004523e-05, -2.8084971e-05]], dtype=float32)>

In [174]:
np.amax(model.W1, 0)

array([3.56070369e-01, 2.10372388e-01, 1.42277285e-01, 2.13853776e-01,
       1.56535193e-01, 3.27035159e-01, 1.62230641e-01, 1.62828580e-01,
       2.18302608e-01, 2.21107543e-01, 1.79352477e-01, 7.54457340e-02,
       2.88171005e-02, 1.92572489e-01, 1.70845777e-01, 7.25833774e-02,
       1.62880704e-01, 7.63008967e-02, 2.06182316e-01, 2.30065927e-01,
       1.35879621e-01, 1.77269727e-01, 2.19682970e-05, 9.94047672e-02,
       1.09607264e-01, 2.18581870e-01, 1.30011663e-01, 1.01828218e-01,
       1.75623310e-04, 2.88958490e-05, 3.18506323e-02, 1.03400074e-01,
       7.54134890e-05, 4.14117537e-02, 1.42391858e-04, 3.49181719e-05,
       2.54248567e-02, 3.69264380e-05, 1.17008691e-04, 3.89895322e-05,
       3.99304699e-05, 1.80518359e-01, 6.69397786e-02, 3.06429894e-04,
       2.59649532e-04, 1.67136692e-04, 9.36674696e-05, 8.62652669e-05,
       4.79294613e-05, 4.88564219e-05, 3.22741078e-04, 5.08092198e-05,
       1.59270858e-04, 2.32335587e-04, 1.78564136e-04, 5.44189425e-05,
      

In [119]:
model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(100, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

In [120]:
model.compile(loss='sparse_categorical_crossentropy', optimizer=tf.keras.optimizers.SGD(learning_rate=0.01), metrics=['accuracy'])

In [121]:
model.fit(X_train, y_train, epochs=10, validation_data=(X_test, y_test))

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x7fa3c2899880>