In [5]:
from contextlib import closing
import io
import numpy
from copy import deepcopy

from matplotlib import pyplot
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.models import load_model
from tensorflow.keras.layers import (
    Add,
    Activation,
    BatchNormalization,
    Conv2D,
    Dense,
    GlobalAveragePooling2D,
    Dropout,
    Flatten,
    concatenate,
    Input,
    InputLayer,
    MaxPooling2D,
    ReLU,
)
from tensorflow.keras.utils import to_categorical, plot_model
from tensorflow.keras.datasets import mnist

In [6]:
# ---------- Prepare data section ---------- #
# load mnist dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# declare num of training samples
num_labels = len(numpy.unique(y_train))

# image dimensions (assumed square)
image_size = x_train.shape[1]
input_size = image_size * image_size

# convert a single list to categorical list of lists
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

In [7]:


# Определяем Residual Block
def residual_block(x, filters, stride=1):
    """
    Реализует остаточный блок с двумя сверточными слоями.
    :param x: Входной тензор.
    :param filters: Число фильтров.
    :param stride: Шаг свёртки (для изменения размеров).
    :return: Выходной тензор.
    """
    shortcut = x
    x = Conv2D(filters, kernel_size=3, strides=stride, padding='same', use_bias=False)(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = Conv2D(filters, kernel_size=3, strides=1, padding='same', use_bias=False)(x)
    x = BatchNormalization()(x)

    # Пропускное соединение
    if stride != 1 or shortcut.shape[-1] != filters:
        shortcut = Conv2D(filters, kernel_size=1, strides=stride, padding='same', use_bias=False)(shortcut)
        shortcut = BatchNormalization()(shortcut)

    # Суммирование и активация
    x = Add()([x, shortcut])
    x = ReLU()(x)
    return x

In [None]:

# network parameters
BATCH_SIZE = 64
HIDDEN_UNITS = 128
DROPOUT = 0.2
KERNEL_SIZE = 3 
POOL_SIZE = 2
FILTERS = 64
EPOCHS = 20
input_shape = (image_size, image_size, 1)

def build_model(units=HIDDEN_UNITS, dropout=DROPOUT, input_size=input_size,
                num_labels=num_labels, kernel_size=KERNEL_SIZE, pool_size=POOL_SIZE,
                filters=FILTERS):


    inputs = Input(shape=input_shape)

    # Начальный сверточный слой
    x = Conv2D(16, kernel_size=3, strides=1, padding='same', use_bias=False)(inputs)
    x = BatchNormalization()(x)
    x = ReLU()(x)

    # Первый уровень (stride=1)
    x = residual_block(x, filters=16, stride=1)
    x = residual_block(x, filters=16, stride=1)

    # Второй уровень (stride=2)
    x = residual_block(x, filters=32, stride=2)
    x = residual_block(x, filters=32, stride=1)

    # Третий уровень (stride=2)
    x = residual_block(x, filters=64, stride=2)
    x = residual_block(x, filters=64, stride=1)

    # Среднее глобальное объединение и финальный полносвязный слой
    x = GlobalAveragePooling2D()(x)
    outputs = Dense(10, activation='softmax')(x)

    model = Model(inputs, outputs)
    return model

model = build_model()
#plot_model(model, to_file='architecture.png', show_shapes=True, show_layer_names=True)

# close file handler at the end of context execution
with closing(io.StringIO()) as fh:
    model.summary(print_fn=lambda x: fh.write(x + "\n"))
    summary_str = fh.getvalue()

# set up image representation
pyplot.figure(figsize=(12, 6))
pyplot.text(0, 1, summary_str, fontsize=12, family='monospace', va='top')
pyplot.axis('off')
# save as PNG
pyplot.savefig('model-summary.png', bbox_inches='tight')
pyplot.close()

You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) for plot_model to work.


In [None]:
# Компиляция модели
model.compile(optimizer='adam',
            loss='categorical_crossentropy',
            metrics=['accuracy'])

# Обучение модели
model.fit(x_train, y_train, batch_size=64, epochs=5, validation_split=0.1)

# Оценка на тестовых данных
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
print(f"Test accuracy: {test_acc * 100:.2f}%")