In [10]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

In [11]:
# Loading the fashion mnist dataset
fashion_data = tf.keras.datasets.cifar10.load_data()

In [12]:
# Preprocess step because the imagenet training images used for the Xception was of size 224x224
def preprocess(image, label):
    resized_image = tf.image.resize(image, [224, 224])
    final_image = tf.keras.applications.xception.preprocess_input(resized_image)
    return final_image, label

In [13]:
# Normalizing the data
fashion_scaled_train = fashion_data[0][0] / 255
fashion_scaled_test = fashion_data[1][0] / 255

# Adding the channel layer for this to work with a cnn
#fashion_scaled_train = fashion_scaled_train[..., np.newaxis]
#fashion_scaled_test = fashion_scaled_test[..., np.newaxis]

# Splitting up the dat into train, valid, test
fashion_train = tf.data.Dataset.from_tensor_slices((fashion_scaled_train[:10000], fashion_data[0][1][:10000]))
fashion_test = tf.data.Dataset.from_tensor_slices((fashion_scaled_test, fashion_data[1][1]))
fashion_valid = tf.data.Dataset.from_tensor_slices((fashion_scaled_train[40000:], fashion_data[0][1][40000:]))

# Shuffling the dataset
fashion_train = fashion_train.map(preprocess).shuffle(1000).batch(32)
fashion_test = fashion_test.map(preprocess).shuffle(1000).batch(32)
fashion_valid = fashion_valid.map(preprocess).shuffle(1000).batch(32)

In [6]:
# Sanity Check because aurelien is a dolt
for something in fashion_test:
    print(something[0].shape)
    break

(32, 224, 224, 3)


In [25]:
# Building the top of the model
n_classes = 10
base_model = tf.keras.applications.resnet50.ResNet50(weights="imagenet",
include_top=False)
avg = tf.keras.layers.GlobalAveragePooling2D()(base_model.output)
output = tf.keras.layers.Dense(n_classes, activation="softmax")(avg)
model = tf.keras.Model(inputs=base_model.input, outputs=output)

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5


In [26]:
# Freezing the weights so I don't ruin the good lower layers
for layer in base_model.layers:
    layer.trainable = False

In [28]:
model.compile(loss="sparse_categorical_crossentropy",optimizer="nadam", metrics=["accuracy"])
model.fit(fashion_train, epochs=5, validation_data=fashion_valid)

Epoch 1/5

KeyboardInterrupt: 

In [46]:
batch_output = model(next(fashion_train.take(1).as_numpy_iterator())[0])

In [54]:
np.argmax(batch_output, axis=1)

array([4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
       4, 4, 4, 4, 4, 4, 4, 4, 4, 4])

In [22]:
len(next(fashion_train.take(1).as_numpy_iterator())[0])

28