In [None]:
from Classifications import Classifications
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
from scipy.signal import find_peaks
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.utils import compute_class_weight
import glob
import numpy as np
import pandas as pd
import tensorflow as tf

In [None]:
classifications = Classifications()
model_h_path = "../ModelData.h"
scaler_h_path = "../Scaler.h"
best_model_path = "best_model.keras"
data_csvs = glob.glob("csvs/*.csv")
batch_size = 128
num_epochs = 512
patience = 256
seed = 42
rng = np.random.default_rng(seed=seed)
test_split = 0.2
val_split = 0.2

In [None]:
model = Sequential([
    layers.InputLayer((classifications.num_steps, classifications.num_features), dtype=tf.float32),

    layers.Conv1D(32, 3, padding="same"),
    layers.BatchNormalization(),
    layers.Activation("relu"),
    layers.Dropout(0.1),

    layers.Conv1D(64, 3, padding="same"),
    layers.BatchNormalization(),
    layers.Activation("relu"),
    layers.MaxPooling1D(pool_size=2),
    layers.Dropout(0.1),

    layers.Flatten(),
    layers.Dense(128, activation="relu"),
    layers.Dropout(0.4),
    layers.Dense(64, activation="relu"),
    layers.Dropout(0.4),
    layers.Dense(classifications.num_classes, activation="softmax"),
])

model.summary()

In [None]:
sample_input = tf.random.uniform(shape=(1, classifications.num_steps, classifications.num_features), dtype=tf.float32, seed=seed)
logits = model.predict(sample_input)

In [None]:
prediction = tf.argmax(logits, axis=1)[0]
print(prediction)

In [None]:
prediction_string = ' '.join(reversed(classifications.classes[prediction]))
print(f"Prediction: {prediction_string}")

In [None]:
df = pd.concat([pd.read_csv(data_csv) for data_csv in data_csvs], ignore_index=True)

# Peak detection
peaks, _ = find_peaks(
    np.sum(df[["ax", "ay", "az"]].values ** 2, axis=1),
    height=classifications.squared_acceleration_threshold,
    distance=classifications.num_steps,
)

sensor_columns = ["ax", "ay", "az", "gx", "gy", "gz"]

X = []
y = []

for peak in peaks:
    start_idx = peak - classifications.steps_before_peak
    end_idx = peak + classifications.steps_after_peak
    if start_idx < 0 or end_idx >= len(df):
        continue

    shot_df = df.loc[start_idx:end_idx]
    assert len(shot_df) == classifications.num_steps

    shot_data = shot_df[sensor_columns].values
    stroke = df.loc[peak, "stroke"].lower()
    side = df.loc[peak, "side"].lower()
    spin = df.loc[peak, "spin"].lower()
    label_key = (stroke, side, spin)

    if label_key in classifications.classes:
        label = classifications.class_to_idx[label_key]
        X.append(shot_data)
        y.append(label)

X = np.array(X).astype(np.float32)
y = np.array(y)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_split, random_state=seed, stratify=y)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=val_split, random_state=seed, stratify=y_train)

scaler = StandardScaler()
scaler.fit(X_train.reshape(-1, X_train.shape[-1]))

with open(scaler_h_path, "w") as f:
    f.write("#ifndef _SCALER_H_\n#define _SCALER_H_\n\n#include \"SensorData.h\"\n\n")
    f.write("constexpr float mean_[NUM_FEATURES] = {")
    f.write(",".join(f"{x}" for x in scaler.mean_))
    f.write("};\n")
    f.write("constexpr float scale_[NUM_FEATURES] = {")
    f.write(",".join(f"{x}" for x in scaler.scale_))
    f.write("};\n")
    f.write("\n#endif\n")

def transform_data(X, scaler):
    num_samples, num_steps, num_features = X.shape
    X_flat = X.reshape(-1, num_features)
    X_scaled_flat = scaler.transform(X_flat)
    return X_scaled_flat.reshape(num_samples, num_steps, num_features)

X_train = transform_data(X_train, scaler)
X_val = transform_data(X_val, scaler)
X_test = transform_data(X_test, scaler)

def random_rotate_sample_tf(sample, label):
    def random_rotation_matrix():
        u1 = tf.random.uniform([], 0, 1)
        u2 = tf.random.uniform([], 0, 1)
        u3 = tf.random.uniform([], 0, 1)

        q1 = tf.sqrt(1 - u1) * tf.sin(2 * np.pi * u2)
        q2 = tf.sqrt(1 - u1) * tf.cos(2 * np.pi * u2)
        q3 = tf.sqrt(u1) * tf.sin(2 * np.pi * u3)
        q4 = tf.sqrt(u1) * tf.cos(2 * np.pi * u3)

        x, y, z, w = q1, q2, q3, q4
        rot = tf.stack([
            [1 - 2*y*y - 2*z*z,     2*x*y - 2*z*w,     2*x*z + 2*y*w],
            [2*x*y + 2*z*w,     1 - 2*x*x - 2*z*z,     2*y*z - 2*x*w],
            [2*x*z - 2*y*w,         2*y*z + 2*x*w, 1 - 2*x*x - 2*y*y]
        ])
        return rot

    R = random_rotation_matrix()
    R = tf.cast(R, sample.dtype)

    accel = sample[:, :3]
    gyro = sample[:, 3:]

    rotated_accel = tf.linalg.matmul(accel, R, transpose_b=True)
    rotated_gyro = tf.linalg.matmul(gyro, R, transpose_b=True)

    rotated_sample = tf.concat([rotated_accel, rotated_gyro], axis=1)
    return rotated_sample, label

train_ds = (
    tf.data.Dataset.from_tensor_slices((X_train, y_train))
    .shuffle(len(X_train), seed=seed)
    .map(random_rotate_sample_tf, num_parallel_calls=tf.data.AUTOTUNE)
    .batch(batch_size)
    .prefetch(tf.data.AUTOTUNE)
)
val_ds = (
    tf.data.Dataset.from_tensor_slices((X_val, y_val))
    .batch(batch_size)
    .prefetch(tf.data.AUTOTUNE)
)
test_ds = (
    tf.data.Dataset.from_tensor_slices((X_test, y_test))
    .batch(batch_size)
    .prefetch(tf.data.AUTOTUNE)
)

In [None]:
model.compile(
    optimizer="adam",
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"],
)

checkpoint_callback = ModelCheckpoint(
    filepath=best_model_path,
    monitor="val_accuracy",
    save_best_only=True,
    verbose=1,
)

early_stopping = EarlyStopping(
    monitor="val_accuracy",
    patience=patience,
    restore_best_weights=True,
    verbose=1,
)

class_weights = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
class_weights = {i: class_weights[np.where(np.unique(y_train) == i)[0][0]] if i in np.unique(y_train) else 0 for i in range(classifications.num_classes)}

history = model.fit(
    train_ds,
    epochs=num_epochs,
    callbacks=[checkpoint_callback, early_stopping],
    validation_data=val_ds,
    class_weight=class_weights,
    verbose=1,
)

In [None]:
test_loss, test_accuracy = model.evaluate(test_ds)
print(f"Non-Quantized Test Loss: {test_loss}")
print(f"Non-Quantized Test Accuracy: {test_accuracy}")

In [None]:
model = tf.keras.models.load_model(f"best_model.keras")
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
def representative_dataset():
    for input_value in tf.data.Dataset.from_tensor_slices(X_train.astype(np.float32)).batch(1).take(100):
        yield [input_value]
converter.representative_dataset = representative_dataset

converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type  = tf.int8
converter.inference_output_type = tf.int8
tflite_model = converter.convert()

In [None]:
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()[0]
output_details = interpreter.get_output_details()[0]

input_scale, input_zero_point = input_details['quantization']
output_scale, output_zero_point = output_details['quantization']

# Evaluate accuracy
correct = 0
total = 0

for x, y_true in test_ds:
    for i in range(x.shape[0]):
        x_input = x[i].numpy()
        x_input = np.round(x_input / input_scale + input_zero_point).astype(input_details["dtype"])
        x_input = np.expand_dims(x_input, axis=0)

        interpreter.set_tensor(input_details['index'], x_input)
        interpreter.invoke()

        output = interpreter.get_tensor(output_details['index'])[0]
        y_pred = np.argmax(output)

        if y_pred == y_true[i].numpy():
            correct += 1
        total += 1

test_accuracy = correct / total
print(f"Quantized Test Accuracy: {test_accuracy}")

In [None]:
print(f"Model size: {len(tflite_model)} bytes")

with open(model_h_path, "w") as f:
    f.write("#ifndef _MODELDATA_H_\n#define _MODELDATA_H_\n")
    f.write("const unsigned char model[] = {")
    f.write(",".join(f"0x{b:02x}" for b in tflite_model))
    f.write("};\n#endif\n")