# Training a NN to Improve Raw ORB/RANSAC Outputs

In [26]:
# Read in data from nn_X.csv and nn_Y.csv. Train a neural network on the data.

import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

# Read in data. Ignore the chain_id column and the i column.
X = pd.read_csv('nn_X.csv').drop(columns=['chain_id', 'i']).to_numpy()
Y = pd.read_csv('nn_Y.csv').drop(columns=['chain_id', 'i']).to_numpy()


# DATA NORMALISATION
# Normalise the range column from [20, 600] to [0, 1].
X[:, 2] = (X[:, 2] - 20) / (600 - 20)

# Normalise the x, y, z columns from [-1000, 1000] to [-1, 1].
X[:, 3:6] = (X[:, 3:6] - (-1000)) / (1000 - (-1000))
Y[:, 2:5] = (Y[:, 2:5] - (-1000)) / (1000 - (-1000))

# FEATURE ENGINEERING
# For each of the columns x, y, z, qw, qx, qy, qz, add a column x^2, y^2, z^2, etc.
# This is to help it learn formulae that depend on the square of these values.
X = np.hstack([X, X[:, 3:]**2])

# Split data into training and testing sets
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2)

# Convert data to TensorFlow tensors
X_train = tf.convert_to_tensor(X_train, dtype=tf.float32)
X_test = tf.convert_to_tensor(X_test, dtype=tf.float32)
Y_train = tf.convert_to_tensor(Y_train, dtype=tf.float32)
Y_test = tf.convert_to_tensor(Y_test, dtype=tf.float32)

# Create neural network
model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(64, activation='relu', input_shape=(X.shape[1],)),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(32, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(16, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(Y.shape[1])
])

model.compile(optimizer='adam', loss='mean_squared_error')

# Train neural network
model.fit(X_train, Y_train, epochs=200, batch_size=20, validation_data=(X_test, Y_test), verbose=2)

# Evaluate neural network
Y_pred = model.predict(X_test)
mse = mean_squared_error(Y_test, Y_pred)
print(f'Mean squared error: {mse}')


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/200
40/40 - 3s - 72ms/step - loss: 1313.9128 - val_loss: 1233.8230
Epoch 2/200
40/40 - 0s - 5ms/step - loss: 1296.4750 - val_loss: 1233.4274
Epoch 3/200
40/40 - 0s - 4ms/step - loss: 1295.6123 - val_loss: 1233.5216
Epoch 4/200
40/40 - 0s - 4ms/step - loss: 1296.1058 - val_loss: 1233.3290
Epoch 5/200
40/40 - 0s - 4ms/step - loss: 1300.4453 - val_loss: 1233.3389
Epoch 6/200
40/40 - 0s - 3ms/step - loss: 1292.6761 - val_loss: 1232.8871
Epoch 7/200
40/40 - 0s - 4ms/step - loss: 1297.1234 - val_loss: 1232.7111
Epoch 8/200
40/40 - 0s - 4ms/step - loss: 1302.4094 - val_loss: 1232.9468
Epoch 9/200
40/40 - 0s - 6ms/step - loss: 1282.5576 - val_loss: 1232.4993
Epoch 10/200
40/40 - 0s - 4ms/step - loss: 1294.3613 - val_loss: 1233.0886
Epoch 11/200
40/40 - 0s - 3ms/step - loss: 1291.2968 - val_loss: 1232.7426
Epoch 12/200
40/40 - 0s - 4ms/step - loss: 1289.4254 - val_loss: 1232.2721
Epoch 13/200
40/40 - 0s - 3ms/step - loss: 1298.9469 - val_loss: 1232.6129
Epoch 14/200
40/40 - 0s - 4ms/ste

KeyboardInterrupt: 