In [None]:
import tensorflow as tf
import pathlib
import numpy as np
import matplotlib.pyplot as plt

In [None]:
data_dir = pathlib.Path('../data/posters/')

BATCH_SIZE = 32
IMG_HEIGHT = 224
IMG_WIDTH = 224
CLASS_NAMES = np.array(['Action', 'Animation', 'Comédie', 'Comédie-dramatique', 'Documentaire', 'Drame', 'Thriller-Policier'])

In [None]:
image_generator ={
    'train': tf.keras.preprocessing.image.ImageDataGenerator(
                                                    rescale=1./255,
#                                                     rotation_range=30,
                                                    width_shift_range=.15,
                                                    height_shift_range=.15,
                                                    horizontal_flip=True,
#                                                     brightness_range=(0, 1)
#                                                     zoom_range=0.5
                                                    ),
    'val': tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255),
    'test': tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
}

dataset = {x: image_generator[x].flow_from_directory(directory=str(data_dir/x),
                                                     batch_size=BATCH_SIZE,
                                                     shuffle=True,
                                                     target_size=(IMG_HEIGHT, IMG_WIDTH),
                                                     interpolation='bilinear')
           for x in ['train', 'val', 'test']
}

In [None]:
def plotImages(images_arr):
    fig, axes = plt.subplots(1, 5, figsize=(20,20))
    axes = axes.flatten()
    for img, ax in zip( images_arr, axes):
        ax.imshow(img)
        ax.axis('off')
    plt.tight_layout()
    plt.show()

def show_batch(image_batch, label_batch):
    plt.figure(figsize=(10,10))
    for n in range(25):
        ax = plt.subplot(5,5,n+1)
        plt.imshow(image_batch[n])
        plt.title(CLASS_NAMES[label_batch[n]==1][0].title())
        plt.axis('off')

In [None]:
augmented_images = [dataset['train'][0][0][0] for i in range(5)]
plotImages(augmented_images)

In [None]:
image_batch, label_batch = next(dataset['train'])
show_batch(image_batch, label_batch)

In [None]:
# Create the base model from the pre-trained model MobileNet V2
base_model = tf.keras.applications.MobileNetV2(input_shape=(IMG_WIDTH, IMG_HEIGHT, 3),
                                               include_top=False,
                                               weights='imagenet')
base_model.trainable = False
global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
prediction_layer = tf.keras.layers.Dense(7)

model = tf.keras.Sequential([
  base_model,
  global_average_layer,
  prediction_layer
])

In [None]:
base_learning_rate = 0.001
model.compile(optimizer=tf.keras.optimizers.Adam(lr=base_learning_rate),
              loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

In [None]:
model.summary()

In [None]:
initial_epochs = 10
validation_steps=20

history = model.fit(dataset['train'],
                    epochs=initial_epochs,
                    validation_data=dataset['val'])