In [1]:
from Classifications import Classifications
import tensorflow as tf
from tensorflow.keras.layers import InputLayer, Conv1D, MaxPooling1D, GlobalAveragePooling1D, Dense
from tensorflow.keras.models import Sequential

In [2]:
classifications = Classifications()
model_h_path = "../ModelData.h"

In [3]:
slice_indexes = tf.cumsum(tf.concat([[0], [len(v) for v in classifications.classifications.values()]], axis=0))

def custom_softmax(x):
    return tf.concat([
        tf.nn.softmax(x[:, slice_indexes[i]:slice_indexes[i+1]]) 
        for i in range(len(slice_indexes) - 1)
    ], axis=1)

model = Sequential([
    InputLayer((classifications.num_shot_steps, classifications.num_features)),
    Conv1D(16, 3, activation="relu"),
    MaxPooling1D(2),
    Conv1D(32, 3, activation="relu"),
    GlobalAveragePooling1D(),
    Dense(classifications.num_classes, activation=custom_softmax),
])

model.summary()

In [4]:
sample_input = tf.random.uniform(shape=(1, classifications.num_shot_steps, classifications.num_features))
prediction = model.predict(sample_input)

start_idx = end_idx = 0
predictions = []
for classes in classifications.classifications.values():
    end_idx += len(classes)
    pred_idx = tf.argmax(prediction[:, start_idx:end_idx], axis=1)[0]
    start_idx = end_idx
    predictions.append(classes[pred_idx])
print(f"Prediction: {' '.join(reversed(predictions))}")

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 45ms/step
Prediction: topspin forehand groundstroke


In [5]:
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

INFO:tensorflow:Assets written to: /var/folders/1y/nzsqhm41529c176c8k4bw7s80000gn/T/tmp311ddzgz/assets


INFO:tensorflow:Assets written to: /var/folders/1y/nzsqhm41529c176c8k4bw7s80000gn/T/tmp311ddzgz/assets


Saved artifact at '/var/folders/1y/nzsqhm41529c176c8k4bw7s80000gn/T/tmp311ddzgz'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 64, 6), dtype=tf.float32, name='keras_tensor')
Output Type:
  TensorSpec(shape=(None, 9), dtype=tf.float32, name=None)
Captures:
  5969853904: TensorSpec(shape=(), dtype=tf.resource, name=None)
  5972954960: TensorSpec(shape=(), dtype=tf.resource, name=None)
  5972955344: TensorSpec(shape=(), dtype=tf.resource, name=None)
  5972953232: TensorSpec(shape=(), dtype=tf.resource, name=None)
  5972953808: TensorSpec(shape=(), dtype=tf.resource, name=None)
  5972955152: TensorSpec(shape=(), dtype=tf.resource, name=None)
  5972953616: TensorSpec(shape=(4,), dtype=tf.int32, name=None)


W0000 00:00:1744403125.235595 9904850 tf_tfl_flatbuffer_helpers.cc:365] Ignored output_format.
W0000 00:00:1744403125.235610 9904850 tf_tfl_flatbuffer_helpers.cc:368] Ignored drop_control_dependency.
2025-04-11 16:25:25.235874: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: /var/folders/1y/nzsqhm41529c176c8k4bw7s80000gn/T/tmp311ddzgz
2025-04-11 16:25:25.236111: I tensorflow/cc/saved_model/reader.cc:52] Reading meta graph with tags { serve }
2025-04-11 16:25:25.236115: I tensorflow/cc/saved_model/reader.cc:147] Reading SavedModel debug info (if present) from: /var/folders/1y/nzsqhm41529c176c8k4bw7s80000gn/T/tmp311ddzgz
I0000 00:00:1744403125.238392 9904850 mlir_graph_optimization_pass.cc:425] MLIR V1 optimization pass is not enabled
2025-04-11 16:25:25.238719: I tensorflow/cc/saved_model/loader.cc:236] Restoring SavedModel bundle.
2025-04-11 16:25:25.250872: I tensorflow/cc/saved_model/loader.cc:220] Running initialization op on SavedModel bundle at path: /var/folder

In [6]:
print(f"Model size: {len(tflite_model)} bytes")

with open(model_h_path, "w") as f:
    f.write("#ifndef _MODELDATA_H_\n#define _MODELDATA_H_\n")
    f.write("const unsigned char model[] = {")
    f.write(",".join(f"0x{b:02x}" for b in tflite_model))
    f.write("};\n#endif\n")

Model size: 14240 bytes
