In [1]:
import random
import os
import numpy as np
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
from scipy.stats import norm

In [2]:
# Limit GPU usage

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)

In [3]:
# Arguments
layer_width = 512
model_seed = 52233264

In [4]:
# Combine test and train images together into one dataset

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_images = train_images.astype(np.float32) / 255.0
test_images = test_images.astype(np.float32) / 255.0  

all_images =np.concatenate([train_images, test_images], axis=0)
all_labels =np.concatenate([train_labels, test_labels], axis=0)

In [5]:
# Load original model

cur_folder = os.getcwd()
model_folder = os.path.join(cur_folder,"models")
model_name = "mnist_dense" + '-w' + str(layer_width) + 'x' + str(layer_width) + '-' + str(model_seed) +".h5"
model_file = os.path.join(model_folder, model_name)
original_model = tf.keras.models.load_model(model_file)
original_model.summary()    

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten (Flatten)            (None, 784)               0         
_________________________________________________________________
dense (Dense)                (None, 512)               401920    
_________________________________________________________________
dense_1 (Dense)              (None, 512)               262656    
_________________________________________________________________
dense_2 (Dense)              (None, 10)                5130      
Total params: 669,706
Trainable params: 669,706
Non-trainable params: 0
_________________________________________________________________


In [6]:
# Evaluate original model
test_loss, test_acc = original_model.evaluate(all_images, all_labels, verbose=0)
print("Original Model Name: ", model_name )
print('Original Model Accuracy:', test_acc)

Original Model Name:  mnist_dense-w512x512-52233264.h5
Original Model Accuracy: 0.9974428415298462
