In [1]:
import tensorflow as tf

In [2]:
BATCH_SIZE = 32
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

In [3]:
(train_img, y_train), (val_img, y_val) = tf.keras.datasets.cifar10.load_data()

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


In [4]:
def preprocess_images(input_images):
  input_images = input_images.astype("float32")
  output_images = tf.keras.applications.resnet50.preprocess_input(input_images)
  return output_images

In [6]:
X_train = preprocess_images(train_img)
X_val = preprocess_images(val_img)
train = tf.data.Dataset.from_tensor_slices((X_train, y_train)).shuffle(buffer_size=1024).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
val = tf.data.Dataset.from_tensor_slices((X_val, y_val)).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)


In [18]:
feature_extractor_layer = tf.keras.applications.resnet.ResNet50(input_shape = (224, 224, 3),
                                                                include_top = False,
                                                                weights="imagenet")
inputs = tf.keras.layers.Input(shape=(32, 32, 3))
resize = tf.keras.layers.UpSampling2D(size=(7, 7))(inputs)
x = feature_extractor_layer(resize)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
# x = tf.keras.layers.Dense(1024, activation='relu')(x)
# x = tf.keras.layers.Dense(512, activation='relu')(x)
outputs = tf.keras.layers.Dense(10, activation='softmax')(x)

model = tf.keras.Model(inputs, outputs)


In [19]:

for layer in feature_extractor_layer.layers:
  layer.trainable = False
model.summary()

Model: "model_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_6 (InputLayer)        [(None, 32, 32, 3)]       0         
                                                                 
 up_sampling2d_2 (UpSamplin  (None, 224, 224, 3)       0         
 g2D)                                                            
                                                                 
 resnet50 (Functional)       (None, 7, 7, 2048)        23587712  
                                                                 
 global_average_pooling2d_2  (None, 2048)              0         
  (GlobalAveragePooling2D)                                       
                                                                 
 dense_4 (Dense)             (None, 10)                20490     
                                                                 
Total params: 23608202 (90.06 MB)
Trainable params: 20490 (

In [20]:
model.compile(optimizer='adam',
              loss="sparse_categorical_crossentropy",
              metrics=['accuracy'])


In [21]:
model.fit(train,
          epochs=4,
          validation_data=val)

Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


<keras.src.callbacks.History at 0x7fbc94a591b0>