In [None]:
import os

import clean
import matplotlib.pyplot as plt
import numpy as np

os.environ["KERAS_BACKEND"] = "jax"
import keras
from keras import layers

In [None]:
splits = clean.get_base_splits()
i_train, i_tune, i_test, t_train, t_tune, t_test = (
    splits[part]
    for part in ["i_train", "i_tune", "i_test", "t_train", "t_tune", "t_test"]
)

In [None]:
keras.utils.set_random_seed(42)

input_norm = layers.Normalization(
    axis=-1, mean=i_train.mean(axis=0), variance=i_train.var(axis=0)
)
output_denorm = layers.Normalization(
    axis=-1,
    invert=True,
    mean=t_train.mean(axis=0),
    variance=t_train.var(axis=0),
)

model = keras.Sequential(
    [
        layers.Input((4,), batch_size=50),
        input_norm,
        layers.Dense(4, activation="relu", name="layer1"),
        layers.Dense(1, name="layer2"),
        output_denorm,
    ]
)

model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=0.001),
    loss=keras.losses.MeanSquaredError(),
    metrics=[keras.losses.MeanAbsolutePercentageError()],
)
model.summary()

In [None]:
keras.utils.set_random_seed(42)

history = model.fit(
    i_train,
    t_train,
    batch_size=500,
    epochs=30,
    validation_data=(i_tune, t_tune),
    # callbacks=[keras.callbacks.EarlyStopping(restore_best_weights=True)],
)

In [None]:
model.save("../models/poc.keras")

In [None]:
model = keras.models.load_model("../models/poc.keras")

In [None]:
x = range(1, 31)

plt.plot(x, history.history["loss"], color="red", label="Training MSE")

plt.plot(x, history.history["val_loss"], color="blue", label="Validation MSE")
plt.legend()
plt.title("Losses during training")
plt.xlabel("Epochs trained")
plt.ylabel("Error")
plt.ylim((0, 1000))
plt.show()

In [None]:
# keras.utils.plot_model(model, show_shapes=True, show_layer_activations=True)

In [None]:
results = model.evaluate(i_test, t_test, batch_size=50)

print("RMSE:", np.sqrt(results[0]))