In [None]:
import tensorflow as tf
import os
from tensorflow.keras.backend import clear_session
import numpy as np
import pandas as pd

In [None]:
# Постоянные для использования в модели:
SEED = 1337
IMAGE_SIZE = (32, 32)
BATCH_SIZE = 32

In [None]:
# Генерируем обучающую и валидационную выборки:
train_images = tf.keras.preprocessing.image_dataset_from_directory(
    directory = "../input/mds-misis-dl-cifar-10-classificationn/cifar_10/cifar_10/train",
    label_mode = "categorical",
    image_size = IMAGE_SIZE,
    batch_size = BATCH_SIZE,
)

validation_images = tf.keras.preprocessing.image_dataset_from_directory(
    directory = "../input/mds-misis-dl-cifar-10-classificationn/cifar_10/cifar_10/validation",
    label_mode = "categorical",
    image_size = IMAGE_SIZE,
    batch_size = BATCH_SIZE,
)

In [None]:
# Дополнение к сети для аугментации обучающей выборки:
augmentation_network = tf.keras.Sequential(
    [
        tf.keras.layers.experimental.preprocessing.RandomFlip("horizontal"),
        #tf.keras.layers.experimental.preprocessing.RandomFlip("vertical"),
        tf.keras.layers.experimental.preprocessing.RandomRotation(0.1),
        # Случайные трансляции изображения:
        # tf.keras.layers.RandomTranslation(0.4, 0.1),
        # Меняем контрастность изображений:
        #tf.keras.layers.experimental.preprocessing.RandomContrast(0.1),
        # Изменяем размер
        #tf.keras.layers.experimental.preprocessing.RandomZoom(0.1)
    ]
)

In [None]:
# Производим аугментацию обучающей выборки:
augmented_training_images = train_images.map(
  lambda image, label: (augmentation_network(image, training = True), label)
)

In [None]:
# Буферизируем выборки:
augmented_training_images = augmented_training_images.prefetch(buffer_size=BATCH_SIZE)
validation_images = validation_images.prefetch(buffer_size=BATCH_SIZE)

In [None]:
# Удаляем информацию о моделях, обученных в предыдущих запусках:
!rm -rf ./xception_network_cifar_10_checkpoints/
!rm -rf ./xception_network_cifar_10/

In [None]:
# Очищаем сессию:
clear_session()

# Функция для создания модели Xception network.
def build_xception_network(input_shape, num_classes):
    inputs = tf.keras.Input(shape = input_shape)

    x = tf.keras.layers.experimental.preprocessing.Rescaling(1.0 / 255)(inputs)
    x = tf.keras.layers.Conv2D(32, 3, strides=2, padding="same")(x)
    x = tf.keras.layers.BatchNormalization()(x)
    #x = tf.keras.layers.Conv2D(32, 3, strides=2, padding="same")(x)
    x = tf.keras.layers.Activation("relu")(x)
    #x = tf.keras.layers.MaxPooling2D(3, strides=2, padding="same")(x)
    #x = tf.keras.layers.BatchNormalization()(x)
    #x = tf.keras.layers.Dropout(0.25)(x)
    x = tf.keras.layers.Conv2D(64, 3, padding="same")(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation("relu")(x)

    previous_block = x

    for size in [128, 256, 512, 728]:
        #x = tf.keras.layers.Conv2D(64, 3, padding="same")(x)
        x = tf.keras.layers.Activation("relu")(x)
        #x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.SeparableConv2D(size, 3, padding="same")(x)
        x = tf.keras.layers.BatchNormalization()(x)
        #x = tf.keras.layers.Dropout(0.5)(x)
        #x = tf.keras.layers.MaxPooling2D(3, strides=2, padding="same")(x)
        x = tf.keras.layers.Activation("relu")(x)
        x = tf.keras.layers.SeparableConv2D(size, 3, padding="same")(x)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.MaxPooling2D(3, strides=2, padding="same")(x)
        #x = tf.keras.layers.Activation("relu")(x)

        residual = tf.keras.layers.Conv2D(size, 1, strides=2, padding="same")(previous_block)
        x = tf.keras.layers.add([x, residual])
        previous_block = x

    x = tf.keras.layers.SeparableConv2D(1024, 3, padding="same")(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation("relu")(x)

    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    x = tf.keras.layers.Dropout(0.5)(x)
    activation_function = "softmax"
    units = num_classes
    outputs = tf.keras.layers.Dense(units, activation = activation_function)(x)
    return tf.keras.Model(inputs, outputs)

In [None]:
# Создаём модель:
xception_network_cifar_10 = build_xception_network(input_shape = IMAGE_SIZE + (3,), num_classes = 10)

In [None]:
# Обучение нейронной сети, созданной по модели выше:
EPOCHS_NUMBER = 50

# Сохранение коэфф-тов модели после каждой эпохи обучения:
callbacks = [
    tf.keras.callbacks.ModelCheckpoint("xception_network_cifar_10_checkpoints/epoch_{epoch}.h5")
]
xception_network_cifar_10.compile(
    optimizer = tf.keras.optimizers.Adam(1e-3),
    loss = "categorical_crossentropy",
    metrics = ["accuracy"],
)
xception_network_cifar_10.fit(
    augmented_training_images, epochs=EPOCHS_NUMBER, callbacks=callbacks, validation_data=validation_images
)
# Сохранение полностью обученной модели:
xception_network_cifar_10.save("xception_network_cifar_10");

In [None]:
# Берем весовые коэфф-ты после наиболее удачной эпохи:
xception_network_cifar_10.load_weights('./xception_network_cifar_10_checkpoints/epoch_48.h5')

In [None]:
# Загрузка полностью обученной модели:
# xception_network_cifar_10 = tf.keras.models.load_model("xception_network_cifar_10")

In [None]:
# Запись файла с ответами:

class_index_to_label = [
    "AIRPLANE",
    "AUTOMOBILE",
    "BIRD",
    "CAT",
    "DEER",
    "DOG",
    "FROG",
    "HORSE",
    "SHIP",
    "TRUCK"
]

answer_key = pd.read_csv('../input/mds-misis-dl-cifar-10-classificationn/sample_submission.csv')
answers = pd.DataFrame()
for id in answer_key['Id']:
    img = tf.keras.preprocessing.image.load_img(
        f"../input/mds-misis-dl-cifar-10-classificationn/cifar_10/cifar_10/test/{id}.jpg", target_size=IMAGE_SIZE
    )
    img_array = tf.keras.preprocessing.image.img_to_array(img)
    img_array = tf.expand_dims(img_array, 0)

    predictions = xception_network_cifar_10.predict(img_array)
    predicted_class_index = np.argmax(predictions, axis=1)[0]
    answers = answers.append({'Id': id, 'Category': class_index_to_label[predicted_class_index]}, ignore_index=True)

answers = answers.set_index(keys = 'Id')
answers.index = answers.index.astype(int)
answers.to_csv('sample_submission.csv')