# CIFAR10 Wide ResNet 106 from TensorFlow Hub

- https://tfhub.dev/deepmind/unsupervised-adversarial-training/cifar10/wrn_106/1

In [1]:
import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
import collections

## Functions

In [2]:
Sample = collections.namedtuple('Sample', ['image', 'label'])
def _build_dataset(raw_data, batch_size=32, shuffle=False):
    images, labels = raw_data
    labels = np.squeeze(labels)
    samples = Sample(images.astype(np.float32) / 255., labels.astype(np.int64))
    data = tf.data.Dataset.from_tensor_slices(samples)
    if shuffle:
        data = data.shuffle(1000)
    return data.repeat().batch(batch_size).make_one_shot_iterator().get_next()

In [3]:
def _cifar_meanstd_normalize(image):
    # Channel-wise means and std devs calculated from the CIFAR-10 training set
    cifar_means = [125.3, 123.0, 113.9]
    cifar_devs = [63.0, 62.1, 66.7]
    rescaled_means = [x / 255. for x in cifar_means]
    rescaled_devs = [x / 255. for x in cifar_devs]
    image = (image - rescaled_means) / rescaled_devs
    return image

## Settings

In [4]:
batch_size = 100

## Load dataset

In [5]:
_, data_test = tf.keras.datasets.cifar10.load_data()
data = _build_dataset(data_test, batch_size=batch_size, shuffle=False)

Instructions for updating:
Use `for ... in dataset:` to iterate over a dataset. If using `tf.estimator`, return the `Dataset` object directly from your input function. As a last resort, you can use `tf.compat.v1.data.make_one_shot_iterator(dataset)`.


## Load UAT module

In [6]:
UAT_HUB_URL = ('https://tfhub.dev/deepmind/unsupervised-adversarial-training/cifar10/wrn_106/1')
def make_classifier():
    model = hub.Module(UAT_HUB_URL, trainable=False)
    def classifier(x):
        x = _cifar_meanstd_normalize(x)
        model_input = dict(x=x, decay_rate=0.1, prefix='default')
        return model(model_input)
    return classifier

In [7]:
# Note that a `classifier` is a function mapping [0,1]-scaled image Tensors to a logit Tensor. 
# In particular, it includes *both* the preprocessing function, and the neural network.
classifier = make_classifier()
logits = classifier(data.image)

INFO:tensorflow:Saver not created because there are no variables in the graph to restore


INFO:tensorflow:Saver not created because there are no variables in the graph to restore


## Run inference

In [8]:
iterations = int(len(data_test[0]) / batch_size)
result = np.zeros(len(data_test[0]))

In [9]:
with tf.compat.v1.Session() as sess:
    sess.run([tf.compat.v1.global_variables_initializer(), tf.compat.v1.tables_initializer()])
    for i in range(0, iterations):
        d = sess.run(data)
        out = sess.run(logits, feed_dict={data.image: d.image})
        result[batch_size*i:batch_size*(i+1)] = np.argmax(out, axis=1)

## Show result

In [10]:
y = data_test[1]

In [11]:
num_correct = np.sum(result.flatten() == y.flatten())
print('Accuracy: %f [%%] (%d / %d)' % (num_correct/len(y)*100., num_correct, len(y)))

Accuracy: 86.460000 [%] (8646 / 10000)
