In [None]:
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from sklearn import model_selection
from profit.util import quasirand

## Defining function for testing

We define a simple function that depends on $x_1, x_2, x_3$ via a nested linear dependency on two hidden parameters $t_1 = x_1 + x_2$ and $t_2 = x_1 - x_3$.

In [None]:
def f(x):
  return np.cos(x[0] + x[1])*np.sin(x[0] - x[2])

# Plot f, see https://plotly.com/python/3d-volume-plots/
X, Y, Z = np.mgrid[0:1:20j, 0:1:20j, 0:1:20j]
x = np.vstack((X.flatten(), Y.flatten(), Z.flatten()))

values = f(x)

fig = go.Figure(data=go.Volume(
    x=X.flatten(),
    y=Y.flatten(),
    z=Z.flatten(),
    value=values.flatten(),
    isomin=-1.0,
    isomax=1.0,
    opacity=0.3, # needs to be small to see through all surfaces
    surface_count=17, # needs to be a large number for good volume rendering
    ))
fig.add_scatter3d(x=[0, 1], y=[0, 1], z=[1, 1])
fig.add_scatter3d(x=[0, 1], y=[0, 0], z=[1, 0])

fig.show()

## Training and test data

Here we use scikit-learn to split training and test data. The positions in parameter space are chosen as a Halton sequence. TensorFlow is used with a unified dataset of input and output.

In [None]:
Ndim = 3
indata = quasirand(500, Ndim)
outdata = f(indata.T).reshape(-1,1)

train, test = model_selection.train_test_split(np.hstack((indata, outdata)))
dataset = tf.data.Dataset.from_tensor_slices((train[:,:Ndim], train[:,Ndim]))

fig = px.scatter_3d(x=indata[:,0], y=indata[:,1], z=indata[:,2])
fig.update_traces(marker={'size': 2})

## Building and training the network

Now we initialize a network, where the *first* layer after the input has its activation function set to `None`, meaning a linear transformation. The weights of this transformation will directly yield the linear projection down to the lower dimensional subspace of hidden parameters.

In [None]:
model = keras.Sequential()
model.add(keras.layers.Input(shape=[Ndim,]))
model.add(keras.layers.Dense(2, None))
model.add(keras.layers.Dense(32, 'tanh'))
model.add(keras.layers.Dense(32, 'tanh'))
model.add(keras.layers.Dense(1, None))
model.compile(optimizer=tf.optimizers.Adam(), loss='mse')

In [None]:
history = model.fit(dataset.batch(1), epochs = 128)

## Predictions

Now we are going to compare the network output with the original output. Ideally, if network output and original output are identical, we should see a diagonal line.

In [None]:
output_eval_train = model.predict(train[:,:Ndim])
output_eval_test = model.predict(test[:,:Ndim])

fig = px.scatter(x=train[:,Ndim], y=output_eval_train[:,0], labels={'x': 'training output', 'y': 'network output'})
fig.show()
px.scatter(x=test[:,Ndim], y=output_eval_test[:,0], labels={'x': 'test output', 'y': 'network output'})


## Identification of linear embedding

The linear embedding is identified from the weights of the first (linear) layer. This output is ambiguous and will yield arbitrary two vectors spanning the relevant subspace. To check whether these vectors lie in the correct plane, we compute the determinant of the matrix formed by the two reference vectors spanning the plane, and each of the identified vectors.

In [None]:
model.layers[0].weights

In [None]:
v1 = model.layers[0].weights[0].numpy()[:,0]
v2 = model.layers[0].weights[0].numpy()[:,1]
vref1 = np.array([1.0, 1.0, 0.0])
vref2 = np.array([1.0, 0.0, -1.0])

M1 = np.vstack((v1, vref1, vref2))
M2 = np.vstack((v2, vref1, vref2))

print(np.linalg.det(M1))
print(np.linalg.det(M2))

In [None]:
fig = go.Figure()
fig.add_scatter3d(x=[0, vref1[0]], y=[0, vref1[1]], z=[0, vref1[2]], name='vref1')
fig.add_scatter3d(x=[0, vref2[0]], y=[0, vref2[1]], z=[0, vref2[2]], name='vref2')
fig.add_scatter3d(x=[0, v1[0]], y=[0, v1[1]], z=[0, v1[2]], name='v1')
fig.add_scatter3d(x=[0, v2[0]], y=[0, v2[1]], z=[0, v2[2]], name='v2')

In [None]:
extractor = keras.Model(inputs=model.inputs,
                        outputs=[model.layers[0].output, model.layers[-1].output])
features = extractor(indata)

In [None]:
fig = go.Figure()
fig.add_scatter3d(x=features[0][:,0], y=features[0][:,1], z=features[1][:,0], mode='markers', marker={'size':2})
fig.update_layout(scene = {
    'xaxis': {'title': 't1'},
    'yaxis': {'title': 't2'},
    'zaxis': {'title': 'g(t1, t2)'},
})