In [1]:
%matplotlib inline

import tensorflow as tf
from time import time
import numpy as np

from tensorflow.keras.layers import Dense, Flatten, Conv2D, InputLayer, Layer, MaxPool2D, AveragePooling2D,\
    BatchNormalization, Dropout, ReLU, LeakyReLU, Activation
from tensorflow.keras import Model
import matplotlib.pyplot as plt

from tqdm import tqdm



In [2]:
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
print(x_train.shape, 'train samples')
print(x_test.shape[0], 'test samples')


x_train, x_test = x_train / 255.0, x_test / 255.0

# Add a channels dimension
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]

(60000, 28, 28) train samples
10000 test samples


In [3]:
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(128)

test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(1024)



cross_entropy = tf.keras.losses.SparseCategoricalCrossentropy()


train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')

In [4]:
class MNIST(Model):

    def __init__(self, f_size=7, in_size=16, out_size=16, z_dim=4, z_dim2=32, **kwargs):
        super().__init__(**kwargs)
        
        self.f_size = f_size
        self.in_size = in_size
        self.out_size = out_size
        self.z_dim = z_dim
        
        init = tf.keras.initializers.TruncatedNormal(stddev=0.01)
        
        # self.conv1_weights = self.add_weight('conv1', shape=(f_size, f_size, 1, in_size), 
        #                                      initializer=init)
        self.conv1_biases = self.add_weight('bias1', shape=(out_size), initializer=tf.zeros_initializer())
        
        # hypernetwork
        self.z = self.add_weight('z', shape=(1, z_dim), initializer=init)
        self.w2 = self.add_weight('w2', shape=(z_dim, in_size*z_dim), initializer=init)
        self.b2 = self.add_weight('b2', shape=(in_size*z_dim), initializer=tf.zeros_initializer())
        
        # self.w1 = self.add_weight('w1', shape=(z_dim, out_size*f_size*f_size), initializer=init, trainable=False)
        
        self.h2z = self.add_weight('h2z', shape=(1, z_dim2), initializer=init)
        self.h2w = self.add_weight('h2w', shape=(z_dim2, z_dim*out_size*f_size*f_size), initializer=init, trainable=False)
        
        self.b1 = self.add_weight('b1', shape=(out_size*f_size*f_size), initializer=tf.zeros_initializer(), trainable=False)

        self.zf = self.add_weight('zf', shape=(1, z_dim), initializer=init) #
        
        
        self.conv2_biases = self.add_weight('bias2', shape=(out_size), initializer=tf.zeros_initializer())
        
        
        self.dense = self.add_weight('d', shape=(784, 10), initializer=tf.keras.initializers.Orthogonal(), trainable=False)
        self.dense_biases = self.add_weight('db', shape=(10), initializer=tf.zeros_initializer(), trainable=False)
        

    def call(self, x, **kwargs):
        
        h_in = tf.matmul(self.z, self.w2) + self.b2
        # print(h_in.shape) # (1, 64)

        h_in = tf.reshape(h_in, (self.in_size, self.z_dim))
        # print(h_in.shape) # (16, 4)

        w1 = tf.keras.backend.dot(self.h2z, self.h2w)
        w1 = tf.reshape(w1, (self.z_dim, self.out_size*self.f_size*self.f_size))

        h_final = tf.matmul(h_in, w1) + self.b1
        # print(h_final.shape) # (16, 784)

        kernel2 = tf.reshape(h_final, (self.out_size, self.in_size, self.f_size, self.f_size))
        # print(kernel2.shape) # (16, 16, 7, 7)

        conv2_weights = tf.transpose(kernel2)
        # print(conv2_weights.shape) # (7, 7, 16, 16)

        # first conv layer weights
        hf_final = tf.matmul(self.zf, w1) + self.b1
        conv1_weights = tf.transpose(tf.reshape(hf_final, (self.in_size, 1, self.f_size, self.f_size)))


        x = tf.nn.conv2d(x, conv1_weights, strides=[1,1,1,1], padding='SAME')
        x = tf.nn.relu(x+self.conv1_biases)
        x = tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID')
        
        x = tf.nn.conv2d(x, conv2_weights, strides=[1, 1, 1, 1], padding='SAME')
        x = tf.nn.relu(x+self.conv2_biases)
        x = tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID')
        
        x = tf.reshape(x, (-1, np.prod(x.shape[1:])))
        x = tf.keras.backend.dot(x, self.dense) + self.dense_biases
        
        x = tf.nn.softmax(x)
        
        return x


In [5]:
class FCMNIST(Model):

    def __init__(self, hidden_size=256, z_dim=4, **kwargs):
        super().__init__(**kwargs)
        self.hidden_size = hidden_size
        init = tf.keras.initializers.TruncatedNormal(stddev=0.01)

        self.z_dim = z_dim
        z_dim2 = 4
        self.z_dim2 = z_dim2
        
        self.z = self.add_weight('z', shape=(1, z_dim), initializer=init)
        self.h_w2 = self.add_weight('hw2', shape=(z_dim, 56*z_dim), initializer=init)
        self.h_b2 = self.add_weight('hb2', shape=(56*z_dim), initializer=tf.zeros_initializer())
        
        
        self.z2 = self.add_weight('z2', shape=(1, z_dim2), initializer=init)

        self.h2_w2 = self.add_weight('hw2', shape=(z_dim2, 14*z_dim2), initializer=init)
        self.h2_b2 = self.add_weight('hb2', shape=(14*z_dim2), initializer=tf.zeros_initializer())
      
        self.h2_w1 = self.add_weight('hw1', shape=(z_dim2, hidden_size*4), initializer=init, trainable=True)
        
        # self.h_w1 = self.add_weight('hw1', shape=(z_dim, 14*hidden_size), initializer=init, trainable=True)
        # self.h_b1 = self.add_weight('hb1', shape=(14*hidden_size), initializer=tf.zeros_initializer(), trainable=True)

        
        self.w2 = self.add_weight('hw2', shape=(self.hidden_size, 10), initializer=tf.keras.initializers.GlorotUniform())

        
    def call(self, x, **kwargs):
        
        h_in = tf.matmul(self.z, self.h_w2) + self.h_b2
        h_in = tf.reshape(h_in, (56, self.z_dim))

        h2_in = tf.matmul(self.z2, self.h2_w2) + self.h2_b2
        h2_in = tf.reshape(h2_in, (14, self.z_dim2))
        
        h_w1 = tf.matmul(h2_in, self.h2_w1) 
        h_w1 = tf.reshape(h_w1, (self.z_dim, self.hidden_size*14))
        
        h_final = tf.matmul(h_in, h_w1)
        
        w1 = tf.reshape(h_final, (784, self.hidden_size))

        x = tf.reshape(x, (-1, 784))

        x = tf.keras.backend.dot(x, w1)
        
        x = tf.nn.relu(x)
        x = tf.keras.backend.dot(x, self.w2)
        x = tf.nn.softmax(x)
        
        # print(h_final.shape) # (16, 784)
        
        return x


In [6]:
class FCMNIST_NORM(Model):

    def __init__(self, hidden_size=20, **kwargs):
        super().__init__(**kwargs)
        self.hidden_size = hidden_size
        
        self.w1 = self.add_weight('hw1', shape=(784, hidden_size), initializer=tf.keras.initializers.GlorotUniform(), trainable=True)
        self.w2 = self.add_weight('hw2', shape=(hidden_size, 10), initializer=tf.keras.initializers.GlorotUniform(), trainable=True)

        
    def call(self, x, **kwargs):
        
        x = tf.reshape(x, (-1, 784))

        x = tf.keras.backend.dot(x, self.w1)
        
        x = tf.nn.relu(x)
        x = tf.keras.backend.dot(x, self.w2)
        x = tf.nn.softmax(x)
                
        return x

In [7]:
# @tf.function
def train_step(images, labels, optimizer, trainable_variables):
    with tf.GradientTape() as tape:
        predictions = model(images, training=True)

        loss = cross_entropy(labels, predictions)


    gradients = tape.gradient(loss, trainable_variables)
    optimizer.apply_gradients(zip(gradients, trainable_variables))

    train_loss(loss)
    train_accuracy(labels, predictions)



# @tf.function
def test_step(images, labels):
    predictions = model(images, training=False)
    t_loss = cross_entropy(labels, predictions)

    test_loss(t_loss)
    test_accuracy(labels, predictions)

In [8]:
optimizer = tf.keras.optimizers.Adam(3e-4)

EPOCHS = 1000

In [9]:
model = FCMNIST()
# model = FCMNIST_NORM()

In [10]:
for epoch in range(EPOCHS):

    st = time()

    for i, (images, labels) in enumerate(tqdm(train_ds)):
        # print('im,', images[0].shape)
        train_step(images, labels, optimizer, model.trainable_variables)


    for test_images, test_labels in test_ds:
        test_step(test_images, test_labels)


    template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
    print(template.format(epoch+1,
            round(float(train_loss.result()), 4),
            round(float(train_accuracy.result()*100), 3),
            round(float(test_loss.result()), 4),
            round(float(test_accuracy.result()*100), 2)))

    # Reset the metrics for the next epoch
    train_loss.reset_states()
    train_accuracy.reset_states()
    test_loss.reset_states()
    test_accuracy.reset_states()

    print('T:', time()-st)

  0%|          | 0/469 [00:00<?, ?it/s]



To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.



100%|██████████| 469/469 [00:04<00:00, 99.90it/s] 
  0%|          | 1/469 [00:00<01:11,  6.58it/s]

Epoch 1, Loss: 1.6018, Accuracy: 45.117, Test Loss: 1.1474, Test Accuracy: 58.3
T: 4.822418689727783


100%|██████████| 469/469 [00:04<00:00, 106.01it/s]
  0%|          | 1/469 [00:00<01:12,  6.43it/s]

Epoch 2, Loss: 0.9754, Accuracy: 66.075, Test Loss: 0.8102, Test Accuracy: 72.53
T: 4.550597667694092


100%|██████████| 469/469 [00:04<00:00, 100.75it/s]
  0%|          | 1/469 [00:00<01:20,  5.78it/s]

Epoch 3, Loss: 0.7639, Accuracy: 74.607, Test Loss: 0.6938, Test Accuracy: 76.98
T: 4.80423641204834


100%|██████████| 469/469 [00:04<00:00, 106.19it/s]
  0%|          | 1/469 [00:00<01:16,  6.15it/s]

Epoch 4, Loss: 0.6658, Accuracy: 78.363, Test Loss: 0.5907, Test Accuracy: 81.05
T: 4.547563076019287


100%|██████████| 469/469 [00:04<00:00, 101.55it/s]
  0%|          | 1/469 [00:00<01:11,  6.58it/s]

Epoch 5, Loss: 0.5716, Accuracy: 81.702, Test Loss: 0.5101, Test Accuracy: 83.67
T: 4.797450065612793


  9%|▉         | 44/469 [00:00<00:05, 76.58it/s]


KeyboardInterrupt: 