In [1]:
import keras
import numpy as np
from layers import Msg, LinEq2v2, LinEq2v0
from keras.layers import Dense, Flatten, Dropout
from keras.models import Model





In [2]:
class PELICAN(Model):
    def __init__(self, depth=1, dropout=0.0, activation=None, msg_outputs=10, agg_outputs=10, dense_output=10, scal_outputs=2):
        super().__init__()

        self.msg_layers = [Msg(outputs=msg_outputs, activation=activation) for _ in range(depth)]
        self.dropout    = Dropout(rate=dropout)
        self.agg_layers = [LinEq2v2(outputs=agg_outputs, activation=activation) for _ in range(depth)]
        self.scal_layer = LinEq2v0(outputs=scal_outputs, activation=activation)

        self.dense = Dense(units=dense_output)

    def call(self, inputs):
        # Note: Assumes input shape to be
        # Batch x N x N x L
        # Where N is dimension of 2d input tensors. Every L signals
        # different 2d tensor

        x = inputs
        for msg, agg in zip(self.msg_layers, self.agg_layers):
            x = msg(x)
            x = self.dropout(x)
            x = agg(x)
        x = self.scal_layer(x)
        x = Flatten()(x)
        x = self.dense(x)

        return x

In [3]:
model = PELICAN(
    depth=4,
    activation='sigmoid',
    msg_outputs=37,
    agg_outputs=59,
    scal_outputs=21,
    dense_output=10
)

model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)






In [4]:
a = np.random.normal(size=(20, 3))


In [5]:
A = np.array(
[
        np.einsum('ij, kj->ik', elem, elem) for elem in [
        np.random.permutation(a) for _ in range(100)
    ]
])

A = np.expand_dims(A, -1)


In [6]:
y = model(A)
yavg = np.average(y, axis=0)





In [7]:
print(np.std(y-yavg))
print(np.mean(y-yavg))

6.113182e-07
2.3841858e-08


In [8]:
print(y.shape)


(100, 10)


In [9]:
model.summary()

Model: "pelican"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 msg (Msg)                   multiple                  185       
                                                                 
 msg_1 (Msg)                 multiple                  2331      
                                                                 
 msg_2 (Msg)                 multiple                  2331      
                                                                 
 msg_3 (Msg)                 multiple                  2331      
                                                                 
 dropout (Dropout)           multiple                  0         
                                                                 
 lin_eq2v2 (LinEq2v2)        multiple                  32745     
                                                                 
 lin_eq2v2_1 (LinEq2v2)      multiple                  3274