In [1]:
import tensorflow as tf
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.models import Model
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical
import matplotlib.pyplot as plt

In [2]:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

In [3]:
x_train = tf.image.resize(x_train, (75, 75))
x_test = tf.image.resize(x_test, (75, 75))

In [4]:
x_train, x_test = x_train / 255.0, x_test / 255.0

In [5]:
y_train, y_test = to_categorical(y_train, 10), to_categorical(y_test, 10)

In [6]:
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(75, 75, 3))

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
[1m94765736/94765736[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m22s[0m 0us/step


In [7]:
for layer in base_model.layers:
    layer.trainable = False

x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dropout(0.4)(x)
output = Dense(10, activation='softmax')(x)

model = Model(inputs=base_model.input, outputs=output)

In [8]:
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])


In [None]:
history = model.fit(x_train, y_train, epochs=5, batch_size=32, validation_data=(x_test, y_test))


Epoch 1/5
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m306s[0m 192ms/step - accuracy: 0.1880 - loss: 2.1992 - val_accuracy: 0.2842 - val_loss: 2.0845
Epoch 2/5
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m392s[0m 251ms/step - accuracy: 0.2433 - loss: 2.0854 - val_accuracy: 0.3022 - val_loss: 2.0255
Epoch 3/5
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m351s[0m 224ms/step - accuracy: 0.2575 - loss: 2.0446 - val_accuracy: 0.3095 - val_loss: 1.9958
Epoch 4/5
[1m 265/1563[0m [32m━━━[0m[37m━━━━━━━━━━━━━━━━━[0m [1m3:35[0m 166ms/step - accuracy: 0.2681 - loss: 2.0374

KeyboardInterrupt: 

: 

In [None]:
#plot
plt.figure(figsize=(12,5))

plt.subplot(1,2,1)
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Val Accuracy')
plt.legend(); plt.title('Accuracy')

plt.subplot(1,2,2)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Val Loss')
plt.legend(); plt.title('Loss')

plt.show()