In [None]:
import numpy as np
import tensorflow as tf

In [None]:
my_function = lambda inputs: ( # wrapper
  lambda x1, x2, x3: x1 * x2 + x3 # formula
)(*inputs)

In [None]:
def train_gen(formula, starts: list, steps: list, num_values: int, num_samples: int):
  # Generate a range of numbers at regular intervals for training
  x_t = np.array([np.arange(starts[i], starts[i] + num_samples*steps[i], step=steps[i], dtype=np.float32) for i in range(num_values)])
  y_t = formula(x_t)

  return x_t.T, y_t

x_train, y_train = train_gen(
  my_function,
  [0, 5, 1000],
  [1, 1, 1024],
  3,
  1000,
)

print(x_train.shape)
print(y_train.shape)

In [None]:
def test_gen(formula, min: int, max: int, num_values: int, num_samples: int):
  # Generate a series of input numbers for testing
  x_t = np.random.randint(min, max, size=(num_samples, num_values,)).astype(np.float32)
  y_t = formula(x_t.T)

  return x_t, y_t

x_test, y_test = test_gen(
  my_function,
  0,
  10000,
  3,
  200,
)

print(x_test.shape)
print(y_test.shape)

In [None]:
class NAC_Additive(tf.keras.layers.Layer):
  def __init__(self, in_features=2, out_units=1):
    super().__init__()
    self.in_features = in_features
    self.out_units = out_units

  def build(self, input_shape):
    self.W_hat = self.add_weight(name="W_hat",
                                 shape=[self.in_features, self.out_units],
                                 initializer=tf.initializers.random_uniform(minval=-2, maxval=2),
                                 trainable=True)
    self.M_hat = self.add_weight(name="M_hat",
                                 shape=[self.in_features, self.out_units],
                                 initializer=tf.initializers.random_uniform(minval=-2, maxval=2),
                                 trainable=True)

  def call(self, inputs):
    return tf.matmul(inputs, tf.nn.tanh(self.W_hat) * tf.nn.sigmoid(self.M_hat))

In [None]:
class NAC_Multiplicative(tf.keras.layers.Layer):
  def __init__(self, in_features=2, out_units=1, epsilon = 0.000001):
    super().__init__()
    self.in_features = in_features
    self.out_units = out_units
    self.epsilon = epsilon

  def build(self, input_shape):
    self.W_hat = self.add_weight(name="W_hat",
                                 shape=[self.in_features, self.out_units],
                                 initializer=tf.initializers.random_uniform(minval=-2, maxval=2),
                                 trainable=True)
    self.M_hat = self.add_weight(name="M_hat",
                                 shape=[self.in_features, self.out_units],
                                 initializer=tf.initializers.random_uniform(minval=-2, maxval=2),
                                 trainable=True)

  def call(self, inputs):
    return tf.exp(tf.matmul(tf.math.log(tf.abs(inputs) + self.epsilon), tf.nn.tanh(self.W_hat) * tf.nn.sigmoid(self.M_hat)))

In [None]:
class NAC_Gate(tf.keras.layers.Layer):
  def __init__(self, in_features=2, out_units=1):
    super().__init__()
    self.in_features = in_features
    self.out_units = out_units

  def build(self, input_shape):
    self.G = self.add_weight(name="Gate_weights",
                             shape=[self.in_features, self.out_units],
                             initializer=tf.random_normal_initializer(stddev=1.0),
                             trainable=True)

  def call(self, inputs):
    return tf.nn.sigmoid( tf.matmul(inputs, self.G) )

In [None]:
class NALU(tf.keras.layers.Layer):
  def call(self, g, a, m):
    return g * a + (1 - g) * m

In [None]:
loss_fn = tf.keras.losses.MeanSquaredError()

starter_learning_rate = 0.01
end_learning_rate = 0.0001
epochs = 50000
epsilon = 1e-06
decay_steps = 0.9 * epochs
alpha = tf.keras.optimizers.schedules.PolynomialDecay(
    starter_learning_rate,
    decay_steps,
    end_learning_rate,
    power=0.5)

optimizer = tf.keras.optimizers.experimental.RMSprop(
    learning_rate=alpha,
)

In [None]:
earlystopping = tf.keras.callbacks.EarlyStopping(monitor='mse', patience=5000,) # min_delta=epsilon)

checkpoint_filepath = '/tmp/checkpoint'
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='mse',
    save_best_only=True)

In [None]:
inputs = tf.keras.Input(shape=(x_train.shape[1],))
outputs = NALU()(
  NAC_Gate(x_train.shape[1], 1)(inputs),
  NAC_Additive(x_train.shape[1], 1)(inputs),
  NAC_Multiplicative(x_train.shape[1], 1, epsilon=epsilon)(inputs),
)
model = tf.keras.Model(inputs=inputs, outputs=outputs,)

model.compile(optimizer=optimizer,
              loss=loss_fn,
              metrics=[tf.keras.metrics.MeanAbsoluteError(name='mae'), tf.keras.metrics.MeanSquaredError(name='mse'), tf.keras.metrics.MeanAbsolutePercentageError(name='mape'), tf.keras.metrics.MeanSquaredLogarithmicError(name='msle')],)

In [None]:
model.summary()

In [None]:
history = model.fit(x_train,
                    y_train,
                    epochs=epochs,
                    # callbacks=[earlystopping],
                    callbacks=[model_checkpoint_callback],
                    verbose=1,)

In [None]:
# Automated evaluation on main test set
model.evaluate(x_test,  y_test, verbose=2)

7/7 - 0s - loss: 1685804417024.0000 - mae: 871786.0625 - mse: 1685804417024.0000 - mape: 87.3741 - msle: 148.3976 - 246ms/epoch - 35ms/step


[1685804417024.0,
 871786.0625,
 1685804417024.0,
 87.37408447265625,
 148.39755249023438]

In [None]:
# Evaluate the model on a test set and pretty print
def pretty_test(model, x_test, y_test, template, final, epsilon=1e-9):
  assert len(x_test) == len(y_test)
  lines = []
  perc = 0
  total = len(y_test)
  rms = []
  for i in range(total):
    x = x_test[i]
    y = y_test[i]
    pred = model.predict(np.expand_dims(x, axis=0), verbose=0).squeeze().squeeze()
    acc = np.abs(pred/y) if np.abs(pred) < np.abs(y) else np.abs(y/pred)
    perc += acc
    err = np.abs(pred - y)
    rms.append(err)
    lines.append(template.format(*x, y, pred, acc, err))
  acc = perc/total
  rms = np.sqrt(np.mean(np.array(rms)**2))
  lines.append(final.format(acc, rms))
  return lines

In [None]:
# Manual human readable testing
x_test_2, y_test_2 = test_gen(
  my_function,
  50,
  100,
  3,
  10,
)

print("\n".join(pretty_test(
  model,
  x_test_2,
  y_test_2,
  "({} * {}) + {} = {}, prediction: {}, accuracy: {}, error: {}",
  "accuracy: {}, rms: {}",
  epsilon,
)))