In [None]:
from tensorflow.keras.applications import VGG16
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Flatten, Dropout, GlobalAveragePooling2D
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical
import tensorflow as tf

In [None]:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

In [None]:
x_train = tf.image.resize(x_train, (100, 100))
x_test = tf.image.resize(x_test, (100, 100))

In [None]:
x_train, x_test = x_train / 255.0, x_test / 255.0

In [None]:
y_train, y_test = to_categorical(y_train, 10), to_categorical(y_test, 10)

In [None]:
base_model = VGG16(weights='imagenet', include_top=False, input_shape=(100, 100, 3))

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg16/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5
[1m58889256/58889256[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step


In [None]:
for layer in base_model.layers:
    layer.trainable = False

x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(256, activation='relu')(x)
x = Dropout(0.5)(x)
output = Dense(10, activation='softmax')(x)

model = Model(inputs=base_model.input, outputs=output)

In [None]:
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

model.summary()

In [None]:
history = model.fit(
    x_train, y_train,
    epochs=5,
    batch_size=64,
    validation_data=(x_test, y_test)
)

Epoch 1/5
[1m358/782[0m [32m━━━━━━━━━[0m[37m━━━━━━━━━━━[0m [1m16:21[0m 2s/step - accuracy: 0.3550 - loss: 1.8410

KeyboardInterrupt: 

In [None]:
# for fine tunning
# for layer in base_model.layers[-4:]:  # Unfreeze last 4 conv layers
#     layer.trainable = True

# model.compile(optimizer=tf.keras.optimizers.Adam(1e-5),
#               loss='categorical_crossentropy',
#               metrics=['accuracy'])

# model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test))

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 5))

# Accuracy plot
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
# Loss plot
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')


plt.show()
