In [1]:
import numpy as np
import tensorflow as tf

  from ._conv import register_converters as _register_converters


In [2]:
# Example dummy data from Rendle 2010 
# http://www.csie.ntu.edu.tw/~b97053/paper/Rendle2010FM.pdf
# Stolen from https://github.com/coreylynch/pyFM
# Categorical variables (Users, Movies, Last Rated) have been one-hot-encoded 

x_data = np.matrix([
#    Users  |     Movies     |    Movie Ratings   | Time | Last Movies Rated
#   A  B  C | TI  NH  SW  ST | TI   NH   SW   ST  |      | TI  NH  SW  ST
    [1, 0, 0,  1,  0,  0,  0,   0.3, 0.3, 0.3, 0,     13,   0,  0,  0,  0 ],
    [1, 0, 0,  0,  1,  0,  0,   0.3, 0.3, 0.3, 0,     14,   1,  0,  0,  0 ],
    [1, 0, 0,  0,  0,  1,  0,   0.3, 0.3, 0.3, 0,     16,   0,  1,  0,  0 ],
    [0, 1, 0,  0,  0,  1,  0,   0,   0,   0.5, 0.5,   5,    0,  0,  0,  0 ],
    [0, 1, 0,  0,  0,  0,  1,   0,   0,   0.5, 0.5,   8,    0,  0,  1,  0 ],
    [0, 0, 1,  1,  0,  0,  0,   0.5, 0,   0.5, 0,     9,    0,  0,  0,  0 ],
    [0, 0, 1,  0,  0,  1,  0,   0.5, 0,   0.5, 0,     12,   1,  0,  0,  0 ]
])
# ratings
y_data = np.array([5, 3, 1, 4, 5, 1, 5])

# Let's add an axis to make tensoflow happy.
y_data.shape += (1, )

In [3]:
print(x_data.shape)
print(y_data.shape)

(7, 16)
(7, 1)


In [4]:
n, p = x_data.shape

# number of latent factors
k = 5

X = tf.placeholder('float', shape=[n, p])  # design matrix
y = tf.placeholder('float', shape=[n, 1])  # target vector

# bias and weights
w0 = tf.Variable(tf.zeros([1]))
W = tf.Variable(tf.zeros([p]))

# interaction factors, randomly initialized 
V = tf.Variable(tf.random_normal([k, p], stddev=0.01))

# estimate of y, initialized to 0.
y_pred = tf.Variable(tf.zeros([n, 1]))

print('X_shape: {}'.format(X.shape))
print('V_shape: {}'.format(V.shape))

X_shape: (7, 16)
V_shape: (5, 16)


FM算法中预测 $\hat{y}$ 的公式:

$\hat{y} = w_0 + \sum_{i}{w_i x_i} + \sum_{i}{\sum_{i<j}{w_{ij} x_i x_j}}$

In [5]:
linear_terms = tf.add(w0, tf.reduce_sum(tf.multiply(W, X), 1, keepdims=True))

In [6]:
# 变量间交互项
interactions = (tf.multiply(0.5,
                    tf.reduce_sum(tf.subtract(tf.pow(tf.matmul(X, tf.transpose(V)), 2),
                        tf.matmul(tf.pow(X, 2), tf.transpose(tf.pow(V, 2)))), 1, keepdims=True)))

In [7]:
y_pred = tf.add(linear_terms, interactions)

In [8]:
# L2 regularized sum of squares loss function over W and V
lambda_w = tf.constant(0.001, name='lambda_w')
lambda_v = tf.constant(0.001, name='lambda_v')

l2_norm = (tf.reduce_sum(
            tf.add(
                tf.multiply(lambda_w, tf.pow(W, 2)),
                tf.multiply(lambda_v, tf.pow(V, 2)))))

mse = tf.reduce_mean(tf.square(tf.subtract(y, y_pred)))
loss = tf.add(mse, l2_norm)

In [9]:
eta = tf.constant(0.1)
optimizer = tf.train.AdagradOptimizer(eta).minimize(loss)

In [10]:
N_EPOCHS = 1000
init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)

    for epoch in range(N_EPOCHS):
        indices = np.arange(n)
        np.random.shuffle(indices)
        x_data, y_data = x_data[indices], y_data[indices]
        sess.run(optimizer, feed_dict={X: x_data, y: y_data})

    print('MSE: ', sess.run(mse, feed_dict={X: x_data, y: y_data}))
    print('Loss (regularized error): ', sess.run(loss, feed_dict={X: x_data, y: y_data}))
    print('Predictions:\n', sess.run(y_pred, feed_dict={X: x_data, y: y_data}))
    print('Learnt weights:\n', sess.run(W, feed_dict={X: x_data, y: y_data}))
    print('Learnt factors:\n', sess.run(V, feed_dict={X: x_data, y: y_data}))

MSE:  3.4389262e-07
Loss (regularized error):  0.0034227965
Predictions:
 [[3.9990382]
 [4.999921 ]
 [3.00038  ]
 [4.9995203]
 [1.0006304]
 [4.9999886]
 [1.000839 ]]
Learnt weights:
 [ 0.13565649  0.19769716 -0.04497153  0.01938483 -0.03698042  0.2189549
  0.10642584  0.03061904  0.07590548  0.11230227  0.15114953  0.13685484
  0.19490509 -0.17412204  0.10642584  0.        ]
Learnt factors:
 [[ 8.66873190e-02  2.45530605e-01 -2.12830439e-01  1.56533360e-01
  -3.02045465e-01  2.24947050e-01  6.79889023e-02 -4.69868518e-02
   4.50542383e-02  4.89947274e-02  2.16567487e-01  3.04345638e-01
   2.56425649e-01 -4.21109259e-01  8.10839683e-02 -1.56625436e-04]
 [-1.13541625e-01 -8.32601264e-03  1.60201281e-01 -1.19042188e-01
   1.94987461e-01 -1.17759794e-01  4.23520207e-02  5.08476701e-03
  -9.11881998e-02  7.27587321e-04 -1.50686437e-02 -7.85334632e-02
  -8.81031901e-02  2.79943675e-01  3.44421566e-02  3.19958222e-03]
 [-6.06829077e-02 -1.03675030e-01  1.07081302e-01 -2.55021770e-02
   2.4960