In [2]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

In [106]:
def get_dense_model(dim=32):
    inputs = keras.Input(shape=(784,))
    x = layers.Dense(dim, activation="relu")(inputs)
    x = layers.Dense(dim, activation="relu")(x)
    outputs = layers.Dense(10)(x)
    model = keras.Model(inputs=inputs, outputs=outputs, name="bilinear_mnist_model")
    return model
dense_model = get_dense_model()
dense_model.summary()

def get_linear_model():
    inputs = keras.Input(shape=(784,))
    outputs = layers.Dense(10)(inputs)
    model = keras.Model(inputs=inputs, outputs=outputs, name="bilinear_mnist_model")
    return model
linear = get_linear_model()
linear.summary()

Model: "bilinear_mnist_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_87 (InputLayer)        [(None, 784)]             0         
_________________________________________________________________
dense_155 (Dense)            (None, 32)                25120     
_________________________________________________________________
dense_156 (Dense)            (None, 32)                1056      
_________________________________________________________________
dense_157 (Dense)            (None, 10)                330       
Total params: 26,506
Trainable params: 26,506
Non-trainable params: 0
_________________________________________________________________
Model: "bilinear_mnist_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_88 (InputLayer)        [(None, 784)]             0         
____________

In [28]:
def bilinear(input_tensor):
    dim = input_tensor.get_shape()[-1]
    x = layers.Dense(dim, activation=None, use_bias=False)(input_tensor)
    x = tf.math.multiply(input_tensor, x)
    return x
    

In [127]:
def get_bilinear_model(dim=32):
    inputs = keras.Input(shape=(784,))
    x = layers.Dense(dim, activation="relu")(inputs)
    x = bilinear(x)
    outputs = layers.Dense(10)(x)
    model = keras.Model(inputs=inputs, outputs=outputs, name="bilinear_mnist_model")
    return model
bilinear_model = get_bilinear_model()
bilinear_model.summary()

Model: "bilinear_mnist_model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_108 (InputLayer)          [(None, 784)]        0                                            
__________________________________________________________________________________________________
dense_186 (Dense)               (None, 32)           25120       input_108[0][0]                  
__________________________________________________________________________________________________
dense_187 (Dense)               (None, 32)           1024        dense_186[0][0]                  
__________________________________________________________________________________________________
tf.math.multiply_37 (TFOpLambda (None, 32)           0           dense_186[0][0]                  
                                                                 dense_187[0][0

In [71]:
class FactorizationMachine(keras.layers.Layer):
    def __init__(self, units=32,
                 embed_dim=5,
                 kernel_initializer='glorot_uniform',
                 embedding_initializer = 'uniform',
                 bias_initializer='zeros',
                 kernel_regularizer=None,
                 embedding_regularizer=None,
                 bias_regularizer=None):
        super(FactorizationMachine, self).__init__()
        self.units = units
        self.embed_dim = embed_dim
        self.kernel_initializer = kernel_initializer
        self.embedding_initializer = embedding_initializer
        self.bias_initializer = bias_initializer
        self.kernel_regularizer = kernel_regularizer
        self.bias_regularizer = bias_regularizer
        self.embedding_regularizer = embedding_regularizer

    def build(self, input_shape):
        self.v = self.add_weight(
            shape=(input_shape[-1], self.embed_dim, self.units),
            initializer=self.embedding_initializer,
            regularizer=self.embedding_regularizer,
            trainable=True,
        )

    def call(self, inputs):
        einsum_equation = 'bei,ieo->beo'

        broadcast_shape = [self.embed_dim, tf.shape(inputs)[0], tf.shape(inputs)[1]]
        x = tf.broadcast_to(inputs, broadcast_shape)
        # x.shape should be [embed_dim, batch_num, input_dim]

        x = tf.transpose(x, perm=[1, 0, 2])
        # x.shape should be [batch_num, embed_dim, input_dim]

        first_term = tf.math.square(tf.einsum(einsum_equation, x, self.v))
        # the shape of the first term should be [batch_num, embed_dim, output units]

        second_term = tf.einsum(einsum_equation, tf.math.square(x), tf.math.square(self.v))

        output = tf.reduce_sum(tf.math.subtract(first_term, second_term), 1)
        return output


In [131]:
def get_fm_model(embed_dim=2):
    inputs = keras.Input(shape=(784,))
    fm = FactorizationMachine(10, embed_dim=embed_dim, embedding_regularizer=keras.regularizers.L2(l2=0.0001))(inputs)
    linear = layers.Dense(10, activation=None, use_bias=True)(inputs)
    x = layers.Add()([fm, linear])
    model = keras.Model(inputs=inputs, outputs=x, name="fm_mnist_model")
    return model
fm_model = get_fm_model(2)
fm_model.summary()

Model: "fm_mnist_model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_112 (InputLayer)          [(None, 784)]        0                                            
__________________________________________________________________________________________________
factorization_machine_57 (Facto (None, 10)           15680       input_112[0][0]                  
__________________________________________________________________________________________________
dense_192 (Dense)               (None, 10)           7850        input_112[0][0]                  
__________________________________________________________________________________________________
add_51 (Add)                    (None, 10)           0           factorization_machine_57[0][0]   
                                                                 dense_192[0][0]     

In [136]:
model = get_fm_model(4)
# model = get_bilinear_model(128)
# model = get_dense_model()
# model = get_linear_model()

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

x_train = x_train.reshape(60000, 784).astype("float32") / 255
x_test = x_test.reshape(10000, 784).astype("float32") / 255

model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=keras.optimizers.RMSprop(),
    metrics=["accuracy"],
)

history = model.fit(x_train, y_train, batch_size=64, epochs=1, validation_split=0.2)

test_scores = model.evaluate(x_test, y_test, verbose=2)
print("Test loss:", test_scores[0])
print("Test accuracy:", test_scores[1])


313/313 - 0s - loss: 0.2192 - accuracy: 0.9403
Test loss: 0.2191684991121292
Test accuracy: 0.9402999877929688
