### 1. Зависимости

In [86]:
# !pip install tensorflow tf2onnx onnx onnxruntime numpy pillow sklearn

In [87]:
import os, gzip, numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Model
import tf2onnx, onnxruntime as ort
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from PIL import Image
import random

### 2. Загрузчик данных

In [88]:
class MNISTDataLoader:
    def __init__(self, data_dir, class_num=20):
        self.data_dir  = data_dir
        self.class_num = class_num
        self.train_X, self.train_y = self._load_split('train')
        self.test_X,  self.test_y  = self._load_split('t10k')
        self.train_y = tf.keras.utils.to_categorical(self.train_y, class_num)
        self.test_y  = tf.keras.utils.to_categorical(self.test_y,  class_num)
        self.num_train = len(self.train_X)
        print(f"Loaded: train {self.train_X.shape}, test {self.test_X.shape}")

    def _load_split(self, prefix):
        imf = os.path.join(self.data_dir, f'{prefix}-images-idx3-ubyte.gz')
        lf  = os.path.join(self.data_dir, f'{prefix}-labels-idx1-ubyte.gz')
        with gzip.open(imf,'rb') as f:
            buf = f.read()
        X = np.frombuffer(buf, np.uint8, offset=16).astype(np.float32)/255.0
        X = X.reshape(-1,28,28,1)
        with gzip.open(lf,'rb') as f:
            buf = f.read()
        y = np.frombuffer(buf, np.uint8, offset=8)
        return X, y

Параметры запуска

In [89]:
DATA_DIR   = '../data/dataset/20class/FlowAllLayers'
CLASS_NUM  = 20
BATCH_SIZE = 500
EPOCHS     = 1
ONNX_PATH  = '../model/model.onnx'

In [90]:
dl = MNISTDataLoader(DATA_DIR, class_num=CLASS_NUM)

Loaded: train (245437, 28, 28, 1), test (27271, 28, 28, 1)


### 2. Определение модели:

In [92]:
def build_simple_cnn(class_num):
    inp = tf.keras.Input(shape=(28,28,1), name='input')
    x = layers.Conv2D(32,3,padding='same',activation='relu')(inp)
    x = layers.MaxPool2D(2)(x)
    x = layers.Conv2D(64,3,padding='same',activation='relu')(x)
    x = layers.MaxPool2D(2)(x)
    x = layers.Conv2D(128, (7,7), activation='relu', padding='valid')(x)   # 1×1 conv по пространству 7×7
    x = layers.Conv2D(20,  (1,1), activation=None,  padding='valid')(x)      # 1×1 conv для классов
    logits = layers.Reshape((20,))(x) 
    return Model(inp, logits, name='SimpleCNN')

model = build_simple_cnn(CLASS_NUM)
model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-3),
    loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)
model.summary()

Обучение

In [93]:
history = model.fit(
    dl.train_X, dl.train_y,
    validation_data=(dl.test_X, dl.test_y),
    batch_size=BATCH_SIZE,
    epochs=EPOCHS
)

[1m491/491[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m331s[0m 669ms/step - accuracy: 0.6906 - loss: 0.9902 - val_accuracy: 0.9467 - val_loss: 0.1481


### 4. Проверк точности на тестовой выборке

In [94]:
preds = model.predict(dl.test_X)
y_true = np.argmax(dl.test_y, axis=1)
y_pred = np.argmax(preds, axis=1)
acc = accuracy_score(y_true, y_pred)
prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred,
                                                average='macro',
                                                zero_division=0)
print(f"TF Accuracy = {acc:.4f}, Precision={prec:.4f}, Recall={rec:.4f}, F1={f1:.4f}")


[1m853/853[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 11ms/step
TF Accuracy = 0.9467, Precision=0.9529, Recall=0.9483, F1=0.9497


In [95]:
print(type(model))          # должно быть <class 'keras.engine.functional.Functional'> или <class 'keras.engine.sequential.Sequential'>
print(model.input.name)    # например ['input']
print(model.output_names)   # например ['output']

<class 'keras.src.models.functional.Functional'>
input
ListWrapper(['reshape'])


### 5. Экспорт в ONNX

In [96]:
import tf2onnx

# Ваше определение spec остаётся прежним:
spec = (tf.TensorSpec((1, 28, 28, 1), tf.float32, name="input"),)

# Конвертация с транспозицией входа в NCHW
model_proto, external_tensor_storage = tf2onnx.convert.from_keras(
    model,
    input_signature=spec,
    opset=13,
    inputs_as_nchw=[model.input.name],
    output_path=ONNX_PATH
)
print("ONNX saved to", ONNX_PATH)


I0000 00:00:1746887065.332224  233320 devices.cc:67] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0
I0000 00:00:1746887065.332435  233320 single_machine.cc:374] Starting new session
I0000 00:00:1746887065.538571  233320 devices.cc:67] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0
I0000 00:00:1746887065.538705  233320 single_machine.cc:374] Starting new session
rewriter <function rewrite_constant_fold at 0x72a4137a0a40>: exception `np.cast` was removed in the NumPy 2.0 release. Use `np.asarray(arr, dtype=dtype)` instead.


ONNX saved to ../model/model.onnx


In [97]:
import onnx
m = onnx.load(ONNX_PATH)
for inp in m.graph.input:
    shape = [d.dim_value for d in inp.type.tensor_type.shape.dim]
    print(inp.name, shape)

input [1, 1, 28, 28]


In [98]:
import onnx
m = onnx.load("../model/model.onnx")
ops = {n.op_type for n in m.graph.node}
print("Ops in graph:", ops)


Ops in graph: {'Reshape', 'Conv', 'MaxPool', 'Relu'}


In [None]:
# import onnx
# from onnx.tools import update_model_dims
# from onnx import shape_inference, checker, helper, TensorProto

# # 1) Загрузить NHWC-модель
# model = onnx.load(ONNX_PATH)

# # 2) Зафиксировать вход [1,28,28,1]
# fixed = update_model_dims.update_inputs_outputs_dims(
#     model,
#     {"input": [1, 28, 28, 1]},  # вход NHWC
#     { "output": [1, CLASS_NUM]}                         # выходы пусть shape_inference подтянет сами
# )

# # 3) Переименовать старый вход и вставить Transpose
# graph = fixed.graph
# # старый input → input_nhwc
# graph.input[0].name = "input_nhwc"
# # новый NCHW-вход
# new_in = helper.make_tensor_value_info(
#     "input", TensorProto.FLOAT, [1, 1, 28, 28]
# )
# graph.input.insert(0, new_in)
# # вставляем Transpose: input → input_nhwc
# transpose_node = helper.make_node(
#     "Transpose", ["input"], ["input_nhwc"], perm=[0,3,1,2]
# )
# graph.node.insert(0, transpose_node)

# # # 4) Shape inference и проверка
# inferred = shape_inference.infer_shapes(fixed)
# # checker.check_model(inferred)

# # 5) Сохранить окончательный NCHW-ONNX
# onnx.save(inferred, ONNX_PATH)
# print("Saved static_nchw_model")


Saved static_nchw_model


# Локальная проверка ONNX


In [100]:
import onnxruntime as rt, numpy as np
from PIL import Image

sess = rt.InferenceSession(ONNX_PATH)
inp_name  = sess.get_inputs()[0].name
out_name  = sess.get_outputs()[0].name

print("Input shape:", sess.get_inputs()[0].shape)   # должен быть [1,1,28,28]
print("Output shape:", sess.get_outputs()[0].shape)

# Подготовка одного образца
img = Image.open("../data/sample.bmp").convert("L")
arr = np.array(img, dtype=np.float32) / 255.0
arr = arr[np.newaxis, np.newaxis, :, :]  # [1,1,28,28]
out = sess.run([out_name], {inp_name: arr})[0]
print("ONNX logits:", out, "predicted class:", np.argmax(out))

Input shape: [1, 1, 28, 28]
Output shape: [1, 20]
ONNX logits: [[  3.681105     5.278968     3.8235843    1.3284911   -4.9846506
    3.013633    -6.269668    10.897063    13.320679    -5.1513815
  -10.515752   -15.496873    -7.702152   -14.793939     0.29847774
    2.4666004  -14.060369    -8.991285    -2.0131204   -0.31099385]] predicted class: 8
