In [1]:
import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np

NUM_CLASSES = 11  # 0~9 + blank
MAX_COLS = 7


def build_cnn_backbone():
    inputs = layers.Input(shape=(None, None, 1))  # 가변 크기

    x = layers.Conv2D(64, 3, padding='same', activation='relu')(inputs)
    x = layers.MaxPooling2D((2, 2))(x)

    x = layers.Conv2D(128, 3, padding='same', activation='relu')(x)
    x = layers.MaxPooling2D((2, 2))(x)

    x = layers.Conv2D(256, 3, padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)

    return models.Model(inputs, x)


def split_columns(x):
    # x: (B, H, W, C)
    w = tf.shape(x)[2]
    col_width = w // MAX_COLS

    cols = []
    for i in range(MAX_COLS):
        col = x[:, :, i * col_width:(i + 1) * col_width, :]
        col = tf.reduce_mean(col, axis=2)  # width 평균
        cols.append(col)

    return tf.stack(cols, axis=2)  # (B, H, 7, C)


def build_model():
    image = layers.Input(shape=(None, None, 1), name="image")

    cnn = build_cnn_backbone()
    features = cnn(image)

    col_features = layers.Lambda(split_columns)(features)
    # (B, H, 7, C)

    x = layers.Permute((2, 1, 3))(col_features)
    # (B, 7, H, C)

    x = layers.Reshape((-1, x.shape[-1]))(x)
    # (B, 7*H, C)

    x = layers.TimeDistributed(
        layers.Bidirectional(layers.LSTM(256, return_sequences=True))
    )(col_features)

    outputs = layers.Dense(NUM_CLASSES, activation='softmax')(x)

    return models.Model(image, outputs)


def ctc_loss(y_true, y_pred):
    batch_size = tf.shape(y_pred)[0]
    input_length = tf.fill([batch_size, 1], tf.shape(y_pred)[1])
    label_length = tf.math.count_nonzero(y_true, axis=1, keepdims=True)

    return tf.keras.backend.ctc_batch_cost(
        y_true, y_pred, input_length, label_length
    )


model = build_model()
model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-4),
    loss=ctc_loss
)

model.summary()

2025-12-21 07:41:18.941866: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1766302878.957079     455 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1766302878.961386     455 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1766302878.974688     455 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1766302878.974705     455 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1766302878.974707     455 computation_placer.cc:177] computation placer alr