In [1]:
import testing.equivariant as equivariant
eps = 1e-4




In [2]:
# Tests if all linear equivariant transformations are correct
# equivariant.test_all(eps)

In [3]:
# Tests the permutation equivariance/invariance of the layers
from layers import LinEq2v2, Msg, LinEq2v0
from keras.models import Sequential
from keras.layers import Dense, Flatten
import numpy as np
import tensorflow as tf


A = tf.zeros(shape=(35, 15))
B = tf.zeros(shape=(35, 15))
print(A.shape, B.shape)
C = tf.concat([A, B], axis=-1)
print(C.shape)
num_tests = 100


(35, 15) (35, 15)
(35, 30)


In [4]:
# Check perm invariance

model_perminv = Sequential(
    layers=
    [
        Msg(30, activation='sigmoid'),
        LinEq2v0(20, activation='sigmoid'),
        Dense(5)
    ]
)

N = np.random.randint(10, 20) # Number of particles

momentum = np.random.normal(
    size=(N, 3)
)

input_tensor = np.einsum("pi, qj->pq", momentum, momentum)
output = model_perminv(input_tensor.reshape((1, N, N, 1)))

for n in range(num_tests):
    perm = np.random.permutation(N)
    model_input = np.einsum("pi, qj->pq", momentum[perm], momentum[perm])
    model_input = np.reshape(model_input, (1, N, N, 1))
    model_output = model_perminv(model_input)
    diff = np.max(np.abs(output - model_output))

    assert(diff < eps)

print("All tests passed")




All tests passed


In [5]:
# Check perm equivariance (Know from test above that test is only invariant iff all layers are equivariant/invariant)

model_permeq = Sequential(
    layers=
    [
        Msg(30, activation='leaky_relu'),
        LinEq2v2(40, activation='leaky_relu'),

        LinEq2v0(20, activation='leaky_relu'),
        Dense(5, activation='relu')
    ]
)

N = np.random.randint(10, 20) # Number of particles

momentum = np.random.normal(
    size=(N, 3)
)

input_tensor = np.einsum("pi, qj->pq", momentum, momentum)
output = model_permeq(input_tensor.reshape((1, N, N, 1)))

for n in range(num_tests):
    perm = np.random.permutation(N)
    model_input = np.einsum("pi, qj->pq", momentum[perm], momentum[perm])
    model_input = np.reshape(model_input, (1, N, N, 1))

    model_output = model_permeq(model_input)
    diff = np.max(np.abs(output - model_output))
    # print(model_output, model_input[0][0][0][0])

    assert(diff < eps)

print("All tests passed")


All tests passed
