In [1]:
import tensorflow as tf
import os
os.chdir("/Users/sweaterr/PycharmProjects/TF-recomm")
import dataio
import numpy as np
from collections import deque
from six import next

np.random.seed(13575)

BATCH_SIZE = 1000
USER_NUM = 6040
ITEM_NUM = 3952
DIM = 15
EPOCH_MAX = 100
DEVICE = "/cpu:0"

import time
def get_data():
    df = dataio.read_process("/tmp/movielens/ml-1m/ratings.dat", sep="::")
    rows = len(df)
    df = df.iloc[np.random.permutation(rows)].reset_index(drop=True)
    split_index = int(rows * 0.9)
    df_train = df[0:split_index]
    df_test = df[split_index:].reset_index(drop=True)
    return df_train, df_test

In [2]:
df_train, df_test = get_data()

In [4]:
print(df_train.first)

<bound method DataFrame.first of         user  item  rate          st
0       1893  1692   4.0   974695176
1       5947  2312   4.0   957190990
2        162   365   2.0   977323187
3       5117   456   3.0   962294766
4       2029   315   1.0   974929369
5       2220  1844   5.0   974603135
6       4385  3385   3.0   965172804
7       5779   282   4.0   958156569
8       3617  3740   2.0   966600773
9       3640   584   4.0   966482594
10       515   110   4.0   976205508
11      5138   160   3.0   962060976
12      3032  1290   5.0   970291018
13       521  1199   3.0   976196943
14      3409  1273   5.0   967416389
15      3733  2430   5.0   966194170
16        76  2540   2.0   977813753
17      1299  1320   5.0   974786901
18      5874  3750   4.0   965274403
19      5538   925   5.0   986573601
20      3361   140   4.0   967672860
21      3649   378   4.0   966460630
22      1828  1196   3.0   974696861
23      4276  2659   5.0   983696191
24      5538  2301   4.0  1027814481
25   

In [5]:
train = df_train
test = df_test
samples_per_batch = len(train) // BATCH_SIZE

iter_train = dataio.ShuffleIterator([train["user"],
                                     train["item"],
                                     train["rate"]],
                                    batch_size=BATCH_SIZE)

iter_test = dataio.OneEpochIterator([test["user"],
                                     test["item"],
                                     test["rate"]],
                                    batch_size=-1)

user_batch = tf.placeholder(tf.int32, shape=[None], name="id_user")
item_batch = tf.placeholder(tf.int32, shape=[None], name="id_item")
rate_batch = tf.placeholder(tf.float32, shape=[None])

In [6]:
def inference_svd(user_batch, item_batch, user_num, item_num, dim=5, device="/cpu:0"):
    with tf.device("/cpu:0"):
        bias_global = tf.get_variable("bias_global", shape=[])
        w_bias_user = tf.get_variable("embd_bias_user", shape=[user_num])
        w_bias_item = tf.get_variable("embd_bias_item", shape=[item_num])
        bias_user = tf.nn.embedding_lookup(w_bias_user, user_batch, name="bias_user")
        bias_item = tf.nn.embedding_lookup(w_bias_item, item_batch, name="bias_item")
        w_user = tf.get_variable("embd_user", shape=[user_num, dim],
                                 initializer=tf.truncated_normal_initializer(stddev=0.02))
        w_item = tf.get_variable("embd_item", shape=[item_num, dim],
                                 initializer=tf.truncated_normal_initializer(stddev=0.02))
        embd_user = tf.nn.embedding_lookup(w_user, user_batch, name="embedding_user")
        embd_item = tf.nn.embedding_lookup(w_item, item_batch, name="embedding_item")
    with tf.device(device):
        infer = tf.reduce_sum(tf.mul(embd_user, embd_item), 1)
        infer = tf.add(infer, bias_global)
        infer = tf.add(infer, bias_user)
        infer = tf.add(infer, bias_item, name="svd_inference")
        regularizer = tf.add(tf.nn.l2_loss(embd_user), tf.nn.l2_loss(embd_item), name="svd_regularizer")
    return infer, regularizer


def optimiaztion(infer, regularizer, rate_batch, learning_rate=0.001, reg=0.1, device="/cpu:0"):
    with tf.device(device):
        cost_l2 = tf.nn.l2_loss(tf.sub(infer, rate_batch))
        panelty = tf.constant(reg, dtype=tf.float32, shape=[], name="l2")
        cost = tf.add(cost_l2, tf.mul(regularizer, panelty))
        train_op = tf.train.AdamOptimizer(learning_rate).minimize(cost)
    return cost, train_op

In [7]:
infer, regularizer = inference_svd(user_batch, item_batch, user_num=USER_NUM, item_num=ITEM_NUM, dim=DIM,device=DEVICE)

In [10]:
_, train_op = optimiaztion(infer, regularizer, rate_batch, learning_rate=0.001, reg=0.05, device=DEVICE)

In [14]:
init_op = tf.initialize_all_variables()

In [16]:
def clip(x):
    return np.clip(x, 1.0, 5.0)

with tf.Session() as sess:
    sess.run(init_op)
    print("{} {} {} {}".format("epoch", "train_error", "val_error", "elapsed_time"))
    errors = deque(maxlen=samples_per_batch)
    start = time.time()
    for i in range(EPOCH_MAX * samples_per_batch):
        users, items, rates = next(iter_train)
        _, pred_batch = sess.run([train_op, infer], feed_dict={user_batch: users,
                                                               item_batch: items,
                                                               rate_batch: rates})
        pred_batch = clip(pred_batch)
        errors.append(np.power(pred_batch - rates, 2))
        if i % samples_per_batch == 0:
            train_err = np.sqrt(np.mean(errors))
            test_err2 = np.array([])
            for users, items, rates in iter_test:
                pred_batch = sess.run(infer, feed_dict={user_batch: users,
                                                        item_batch: items})
                pred_batch = clip(pred_batch)
                test_err2 = np.append(test_err2, np.power(pred_batch - rates, 2))
            end = time.time()
            print("{:3d} {:f} {:f} {:f}(s)".format(i // samples_per_batch, train_err, np.sqrt(np.mean(test_err2)),
                                                   end - start))
            start = end

epoch train_error val_error elapsed_time
  0 2.750929 2.808489 0.042135(s)


  1 2.392757 1.593639 2.804247(s)


  2 1.295163 1.120907 3.156988(s)


  3 1.043279 1.003782 2.875520(s)


  4 0.973007 0.966610 2.865396(s)


  5 0.949090 0.951925 2.832754(s)


  6 0.938572 0.944010 2.790420(s)


  7 0.932506 0.939473 2.912340(s)


  8 0.926861 0.935612 2.941052(s)


  9 0.924736 0.933764 2.816258(s)


 10 0.921685 0.930834 2.858970(s)


 11 0.919263 0.927844 2.819722(s)


 12 0.914427 0.924231 2.851686(s)


 13 0.909600 0.920653 2.851846(s)


 14 0.906170 0.915703 3.031541(s)


 15 0.899097 0.911714 3.325228(s)


 16 0.893975 0.907230 3.381106(s)


 17 0.890003 0.903642 3.385030(s)


 18 0.884446 0.900158 2.985412(s)


 19 0.880148 0.897024 3.202554(s)


 20 0.875763 0.894310 3.459687(s)


 21 0.871165 0.890725 3.005499(s)


 22 0.866267 0.887871 2.774388(s)


 23 0.862605 0.884098 2.820320(s)


 24 0.854612 0.881228 2.943208(s)


 25 0.850481 0.877738 3.120777(s)


 26 0.844719 0.874508 2.887946(s)


 27 0.840553 0.871964 3.230954(s)


 28 0.834767 0.868973 3.232065(s)


 29 0.828570 0.865896 3.373367(s)


 30 0.824304 0.863891 2.992586(s)


 31 0.819710 0.861975 2.794684(s)


 32 0.815083 0.859797 3.246930(s)


 33 0.810191 0.858379 3.087175(s)


 34 0.805083 0.856913 2.983240(s)


 35 0.801635 0.855988 2.810281(s)


 36 0.797722 0.854614 3.107162(s)


 37 0.793769 0.853767 3.186988(s)


 38 0.791471 0.852962 3.221489(s)


 39 0.786951 0.852427 3.010188(s)


 40 0.783443 0.851690 3.111211(s)


 41 0.782623 0.851063 3.112257(s)


 42 0.777292 0.850991 2.972740(s)


 43 0.776155 0.850644 3.199054(s)


 44 0.773689 0.850325 2.961411(s)


 45 0.771790 0.850546 2.732660(s)


 46 0.769457 0.850368 2.862627(s)


 47 0.768457 0.850764 2.852869(s)


 48 0.767026 0.850498 2.883381(s)


 49 0.765974 0.850021 2.988509(s)


 50 0.762757 0.849993 3.296206(s)


 51 0.763582 0.850043 3.017561(s)


 52 0.759786 0.850300 3.118379(s)


 53 0.760361 0.850594 2.984945(s)


 54 0.760191 0.850387 3.105119(s)


 55 0.757535 0.850542 2.955032(s)


 56 0.758480 0.849947 3.001484(s)


 57 0.757258 0.850322 2.898693(s)


 58 0.757459 0.850100 2.738280(s)


 59 0.754655 0.850306 2.775728(s)


 60 0.754919 0.850397 2.993973(s)


 61 0.754648 0.850198 3.279388(s)


 62 0.753412 0.850481 3.250277(s)


 63 0.755543 0.850802 2.945834(s)


 64 0.753325 0.850581 2.794458(s)


 65 0.752627 0.850752 2.874487(s)


 66 0.752809 0.850512 2.797553(s)


 67 0.753122 0.850619 2.845801(s)


 68 0.752085 0.850374 2.797071(s)


 69 0.752312 0.850558 2.891723(s)


 70 0.751104 0.850246 2.759296(s)


 71 0.750661 0.850155 2.813679(s)


 72 0.750167 0.850240 2.874317(s)


 73 0.750473 0.850259 2.819639(s)


 74 0.750394 0.850517 2.761451(s)


 75 0.749670 0.850407 2.898341(s)


 76 0.750262 0.850588 2.823478(s)


 77 0.750184 0.850718 2.891350(s)


 78 0.748674 0.850629 2.843940(s)


 79 0.749167 0.850766 2.809509(s)


 80 0.749397 0.850962 2.907408(s)


 81 0.749253 0.850777 2.953213(s)


 82 0.748503 0.851019 2.831474(s)


 83 0.749512 0.850952 3.138992(s)


 84 0.749581 0.851054 3.044079(s)


 85 0.748381 0.851004 2.800919(s)


 86 0.748768 0.850711 2.812873(s)


 87 0.747818 0.850829 2.808074(s)


 88 0.746986 0.850675 2.774953(s)


 89 0.747092 0.851005 2.889211(s)


 90 0.747351 0.850725 2.813415(s)


 91 0.747962 0.850796 2.906871(s)


 92 0.748919 0.850743 2.872600(s)


 93 0.747552 0.850851 2.834741(s)


 94 0.748066 0.850615 2.824153(s)


 95 0.746216 0.850498 2.959233(s)


 96 0.746390 0.850617 3.253878(s)


 97 0.748425 0.850655 3.144680(s)


 98 0.747552 0.850696 3.139214(s)


 99 0.746431 0.850440 2.808393(s)


