In [2]:
import tensorflow as tf
import time
from sklearn.metrics import roc_auc_score
import pandas as pd
from get_data import get_data

In [6]:
x = tf.placeholder(tf.float32,shape=[None,108])
y = tf.placeholder(tf.float32,shape=[None])

**MLR算法**

$$p(y=1|x)=g(\sum_{j=1}^m \sigma(u_j^T x) \mu(w_j^T x))$$
常用softmax函数作为分片函数：
$$p(y=1|x)=\sum_{i=1}^m \frac{exp(u_i^Tx)}{\sum_{j=1}^m exp(u_j^Tx)} \cdot \frac{1}{1 + exp(-w_i^Tx)}$$

**FtrlOptimizer (Follow The Regularized Leader)**
FtrlOptimizer更新方法主要用于广告点击预测,广告点击预测通常千万级别的维度,因此有巨量的稀疏权重.其主要特点是将接近0 的权重直接置0,这样计算时可以直接跳过,从而简化计算.这个方法已经验证过在股票数据上较有效。

In [4]:
m = 2
learning_rate = 0.3
u = tf.Variable(tf.random_normal([108,m],0.0,0.5),name='u')
w = tf.Variable(tf.random_normal([108,m],0.0,0.5),name='w')

In [7]:
U = tf.matmul(x,u)
p1 = tf.nn.softmax(U)

In [8]:
W = tf.matmul(x,w)
p2 = tf.nn.sigmoid(W)

In [9]:
pred = tf.reduce_sum(tf.multiply(p1,p2),1)

In [10]:
cost1=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=pred, labels=y))
cost=tf.add_n([cost1])
train_op = tf.train.FtrlOptimizer(learning_rate).minimize(cost)

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


In [11]:
train_x,train_y,test_x,test_y = get_data()

In [None]:
time_s=time.time()
result = []
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for epoch in range(0, 10000):
        f_dict = {x: train_x, y: train_y}

        _, cost_, predict_ = sess.run([train_op, cost, pred], feed_dict=f_dict)

        auc = roc_auc_score(train_y, predict_)
        time_t = time.time()
        if epoch % 100 == 0:
            f_dict = {x: test_x, y: test_y}
            _, cost_, predict_test = sess.run([train_op, cost, pred], feed_dict=f_dict)
            test_auc = roc_auc_score(test_y, predict_test)
            print("%d %ld cost:%f,train_auc:%f,test_auc:%f" % (epoch, (time_t - time_s), cost_, auc, test_auc))
            result.append([epoch,(time_t - time_s),auc,test_auc])

pd.DataFrame(result,columns=['epoch','time','train_auc','test_auc']).to_csv("mlr_"+str(m)+'.csv')

0 1 cost:0.844245,train_auc:0.546987,test_auc:0.758541
100 22 cost:0.694244,train_auc:0.826318,test_auc:0.824048
200 44 cost:0.688313,train_auc:0.847975,test_auc:0.845930
300 67 cost:0.684391,train_auc:0.861430,test_auc:0.859873
400 88 cost:0.681675,train_auc:0.869072,test_auc:0.868000
500 112 cost:0.679850,train_auc:0.873585,test_auc:0.872885
600 144 cost:0.678534,train_auc:0.876737,test_auc:0.876306
700 174 cost:0.677526,train_auc:0.879168,test_auc:0.878936
800 207 cost:0.676715,train_auc:0.881167,test_auc:0.881084
900 233 cost:0.676040,train_auc:0.882862,test_auc:0.882867
1000 254 cost:0.675464,train_auc:0.884308,test_auc:0.884401
1100 274 cost:0.674962,train_auc:0.885587,test_auc:0.885733
1200 295 cost:0.674517,train_auc:0.886717,test_auc:0.886907
1300 315 cost:0.674117,train_auc:0.887724,test_auc:0.887965
1400 335 cost:0.673756,train_auc:0.888624,test_auc:0.888899
1500 355 cost:0.673427,train_auc:0.889443,test_auc:0.889743
1600 375 cost:0.673126,train_auc:0.890180,test_auc:0.89048