In [1]:
from pathlib import Path
import sys

root_dir = Path().resolve().parent.parent.parent.as_posix()
if root_dir not in sys.path:
    sys.path.append(root_dir)
del root_dir

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.applications import resnet_v2
import tensorflow as tf

from utils import config, datasets, models

In [2]:
MODEL_NAME = 'resnet-101-v2'
EPOCHS = 10

In [3]:
train_dataset, validation_dataset, test_dataset = datasets.load_dataset()

Found 6400 files belonging to 4 classes.


In [4]:
data_augmentation = keras.Sequential([
    layers.Resizing(224, 224),
    layers.Rescaling(1. / 255),
])

In [5]:
top = keras.Sequential([
    layers.Dense(256, activation=keras.layers.LeakyReLU(alpha=0.05),
                 kernel_initializer='he_normal'),
    layers.Dense(4, kernel_initializer='he_normal',
                 activation=keras.activations.softmax)
])

In [6]:
model = models.create_model(
    model_name=MODEL_NAME,
    preprocessing_layers=data_augmentation,
    base_model=resnet_v2.ResNet101V2,
    top_layers=top,
)

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet101v2_weights_tf_dim_ordering_tf_kernels_notop.h5


In [7]:
history = models.fit_model(
    model,
    train_data=train_dataset,
    validation_data=validation_dataset,
    epochs=EPOCHS,
)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [8]:
model.summary()

Model: "resnet-101-v2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 sequential (Sequential)     (None, 224, 224, 3)       0         
                                                                 
 resnet101v2 (Functional)    (None, 2048)              42626560  
                                                                 
 sequential_1 (Sequential)   (None, 4)                 525572    
                                                                 
Total params: 43,152,132
Trainable params: 525,572
Non-trainable params: 42,626,560
_________________________________________________________________
