In [1]:
import pathlib
import tensorflow as tf

train_dir = pathlib.Path('DRAM_train/')
test_dir = pathlib.Path('DRAM_test/')

m = 224
batch = 16

In [2]:
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    train_dir,
    seed=123,
    image_size=(m, m),
    batch_size=batch
)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    test_dir,
    seed=123,
    image_size=(m, m),
    batch_size=batch
)
num_classes = 4

Found 5677 files belonging to 4 classes.
Found 583 files belonging to 4 classes.


In [3]:
AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

In [4]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Rescaling(1./255),
  tf.keras.applications.ResNet50(include_top=False, pooling='avg', weights=None, input_shape=(m, m, 3)),
  tf.keras.layers.Dense(num_classes, activation='softmax')
])

model.layers[1].trainable = False

model.compile(
  optimizer=tf.optimizers.Adam(),
  loss=tf.keras.losses.SparseCategoricalCrossentropy(),
  metrics=['accuracy']
)

model.fit(
  train_ds,
  validation_data=val_ds,
  epochs=11
)

Epoch 1/11
Epoch 2/11
Epoch 3/11
Epoch 4/11
Epoch 5/11
Epoch 6/11
Epoch 7/11
Epoch 8/11
Epoch 9/11
Epoch 10/11
Epoch 11/11


<keras.callbacks.History at 0x1b3e4c151e0>

In [5]:
model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 rescaling (Rescaling)       (None, 224, 224, 3)       0         
                                                                 
 resnet50 (Functional)       (None, 2048)              23587712  
                                                                 
 dense (Dense)               (None, 4)                 8196      
                                                                 
Total params: 23,595,908
Trainable params: 8,196
Non-trainable params: 23,587,712
_________________________________________________________________
