In [2]:
import numpy as np
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import (
    Add,
    BatchNormalization,
    Conv2D,
    Dense,
    GlobalAveragePooling2D,
    ReLU,
)

In [14]:
def resnet_block(
    inputs: tf.Tensor,
    filters: int,
    strides: int = 1,
) -> tf.Tensor:
    """Residual block for ResNet."""
    x = Conv2D(
        filters,
        kernel_size=3,
        strides=strides,
        padding="same",
        kernel_initializer="he_normal",
    )(inputs)

    x = BatchNormalization()(x)
    x = ReLU()(x)

    x = Conv2D(
        filters,
        kernel_size=3,
        strides=1,
        padding="same",
        kernel_initializer="he_normal",
    )(x)

    if strides != 1 or inputs.shape[-1] != filters:
        shortcut = Conv2D(
            filters,
            kernel_size=1,
            strides=strides,
            padding="same",
            kernel_initializer="he_normal",
        )(inputs)
        shortcut = BatchNormalization()(shortcut)
    else:
        shortcut = inputs

    x = Add()([x, shortcut])
    x = ReLU()(x)
    return x


def create_resnet_logits(
    input_shape: tuple[int, int, int],
    num_classes: int,
    filters: int = 16,
) -> Model:
    """Builds a small ResNet-like model returning logits."""
    inputs = Input(shape=input_shape)

    x = Conv2D(
        filters=filters,
        kernel_size=3,
        strides=1,
        padding="same",
        kernel_initializer="he_normal",
    )(inputs)

    x = BatchNormalization()(x)
    x = ReLU()(x)

    for num_filters, strides in zip([16, 32, 64], [1, 2, 2], strict=False):
        for i in range(3):
            s = strides if i == 0 else 1
            x = resnet_block(x, filters=num_filters, strides=s)

    x = GlobalAveragePooling2D()(x)
    outputs = Dense(num_classes)(x)

    model = tf.keras.models.Model(inputs, outputs)
    return model

In [3]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
y_train = y_train.reshape(-1)
x_train_new, x_val, y_train_new, y_val = train_test_split(
    x_train,
    y_train,
    test_size=0.2,
    random_state=42,
)
num_classes = len(np.unique(y_train))

In [None]:
# Initialize the ResNet-based classification model
model_logits = create_resnet_logits((32, 32, 3), num_classes)

# Compile the model
model_logits.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer="adam",
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

# Create a callback to save the model checkpoint with the highest validation accuracy.
# The model will be stored as 'best_model_logits.keras'
checkpoint = tf.keras.callbacks.ModelCheckpoint(
    "best_model_logits.keras",
    monitor="val_accuracy",
    save_best_only=True,
    mode="max",
)

# Train the model on the preprocessed dataset for 50 epochs.
# The checkpoint callback ensures that the best-performing model is saved.
history = model_logits.fit(x_train_new, y_train_new, epochs=50, callbacks=[checkpoint])