In [1]:
%matplotlib inline

import tensorflow as tf
from time import time
import numpy as np

from tensorflow.keras.layers import Dense, Flatten, Conv2D, InputLayer, Layer, MaxPool2D, AveragePooling2D,\
    BatchNormalization, Dropout, ReLU, LeakyReLU, Activation
from tensorflow.keras import Model
import matplotlib.pyplot as plt

from tqdm import tqdm



In [2]:
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
print(x_train.shape, 'train samples')
print(x_test.shape[0], 'test samples')


x_train, x_test = x_train / 255.0, x_test / 255.0

# Add a channels dimension
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]

(60000, 28, 28) train samples
10000 test samples


In [3]:
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(128)

test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(1024)



cross_entropy = tf.keras.losses.SparseCategoricalCrossentropy()


train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')

In [4]:
class MNIST(Model):

    def __init__(self, f_size=7, in_size=16, out_size=16, z_dim=4, **kwargs):
        super().__init__(**kwargs)
        
        self.f_size = f_size
        self.in_size = in_size
        self.out_size = out_size
        self.z_dim = z_dim
        
        init = tf.keras.initializers.TruncatedNormal(stddev=0.01)
        
        self.conv1_weights = self.add_weight('conv1', shape=(f_size, f_size, 1, out_size), 
                                             initializer=init)
        self.conv1_biases = self.add_weight('bias1', shape=(out_size), initializer=tf.zeros_initializer())
        
        # hypernetwork
        self.z = self.add_weight('z', shape=(1, z_dim), initializer=init)
        self.w1 = self.add_weight('w1', shape=(z_dim, out_size*f_size*f_size), initializer=init)
        self.b1 = self.add_weight('b1', shape=(out_size*f_size*f_size), initializer=tf.zeros_initializer())
        self.w2 = self.add_weight('w2', shape=(z_dim, in_size*z_dim), initializer=init)
        self.b2 = self.add_weight('b2', shape=(in_size*z_dim), initializer=tf.zeros_initializer())
        
        self.conv2_biases = self.add_weight('bias2', shape=(out_size), initializer=tf.zeros_initializer())
        
        
        self.dense = self.add_weight('d', shape=(784, 10), initializer=tf.keras.initializers.Orthogonal())
        self.dense_biases = self.add_weight('db', shape=(10), initializer=tf.zeros_initializer())
        
        

    def call(self, x, **kwargs):
        
        h_in = tf.matmul(self.z, self.w2) + self.b2
        h_in = tf.reshape(h_in, (self.in_size, self.z_dim))
        h_final = tf.matmul(h_in, self.w1) + self.b1
        kernel2 = tf.reshape(h_final, (self.out_size, self.in_size, self.f_size, self.f_size))
        conv2_weights = tf.transpose(kernel2)

        x = tf.nn.conv2d(x, self.conv1_weights, strides=[1,1,1,1], padding='SAME')
        x = tf.nn.relu(x+self.conv1_biases)
        x = tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID')
        
        x = tf.nn.conv2d(x, conv2_weights, strides=[1, 1, 1, 1], padding='SAME')
        x = tf.nn.relu(x+self.conv2_biases)
        x = tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID')
        
        x = tf.reshape(x, (-1, np.prod(x.shape[1:])))
        x = tf.keras.backend.dot(x, self.dense) + self.dense_biases
        
        x = tf.nn.softmax(x)
        
        return x


In [11]:
# @tf.function
def train_step(images, labels, optimizer, trainable_variables):
    with tf.GradientTape() as tape:
        predictions = model(images, training=True)

        loss = cross_entropy(labels, predictions)


    gradients = tape.gradient(loss, trainable_variables)
    optimizer.apply_gradients(zip(gradients, trainable_variables))

    train_loss(loss)
    train_accuracy(labels, predictions)



# @tf.function
def test_step(images, labels):
    predictions = model(images, training=False)
    t_loss = cross_entropy(labels, predictions)

    test_loss(t_loss)
    test_accuracy(labels, predictions)

In [6]:
optimizer = tf.keras.optimizers.Adam(5e-3)

EPOCHS = 1000

In [7]:
model = MNIST()

In [None]:
for epoch in range(EPOCHS):

    st = time()

    for i, (images, labels) in enumerate(tqdm(train_ds)):
        # print('im,', images[0].shape)
        train_step(images, labels, optimizer, model.trainable_variables)


    for test_images, test_labels in test_ds:
        test_step(test_images, test_labels)


    template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
    print(template.format(epoch+1,
            round(float(train_loss.result()), 4),
            round(float(train_accuracy.result()*100), 3),
            round(float(test_loss.result()), 4),
            round(float(test_accuracy.result()*100), 2)))

    # Reset the metrics for the next epoch
    train_loss.reset_states()
    train_accuracy.reset_states()
    test_loss.reset_states()
    test_accuracy.reset_states()

    print('T:', time()-st)

 93%|█████████▎| 435/469 [00:39<00:03, 10.87it/s]