Skip to content

Commit

Permalink
Add compile method to TFBaseModel
Browse files Browse the repository at this point in the history
  • Loading branch information
shenweichen committed Oct 9, 2017
1 parent e704c5a commit 977ec7c
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 76 deletions.
11 changes: 8 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
- sklearn
## 设计说明
`base`基类仿照`keras`模型实现以下公有方法,包括
- compile 编译模型
- save_model 保存模型
- load_mdel 加载模型
- train_on_batch 小批量训练
Expand All @@ -14,19 +15,23 @@
- evaluate 模型评估
- predict_on_batch 小批量预测
- predict 全量预测
私有方法包括
- _create_optimizer
- _create_metrics

同时设计了若干抽象方法
- _get_data_loss
- _get_input_data
- _get_input_target
- _get_output_target
- _get_data_loss
- _get_optimizer
- _get_optimizer_loss
- _build_graph

要求子类在`__init__`方法的最后调用`self._build_graph()`构建计算图。

## 计划
- 添加`tf.summary.FileWriter`
- 添加自定义度量函数
- 添加带权损失函数

## DeepFM
>DeepFM: A Factorization-Machine based Neural Network for CTR Prediction [arxiv](https://arxiv.org/abs/1703.04247)
Expand Down
98 changes: 82 additions & 16 deletions base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@
from sklearn.utils import shuffle as sklearn_shuffle
from sklearn.model_selection import train_test_split

from .utils import sigmoid_cross_entropy_with_probs


class TFBaseModel(metaclass=ABCMeta):
def __init__(self, seed=1024, checkpoint_path=None):
def __init__(self, seed=1024, checkpoint_path=None,):
self.seed = seed

if checkpoint_path and checkpoint_path.count('/') < 2:
raise ValueError('checkpoint_path must be dir/model_name format')
self.checkpoint_path = checkpoint_path
Expand All @@ -21,39 +24,103 @@ def __init__(self, seed=1024, checkpoint_path=None):
self.saver = tf.train.Saver()
self.sess = tf.Session(graph=self.graph)


@abstractmethod
def _get_input_data(self, ):
raise NotImplementedError

@abstractmethod
def _get_input_target(self, ):
raise NotImplementedError

@abstractmethod
def _get_output_target(self, ):
raise NotImplementedError
@abstractmethod
def _get_data_loss(self, ):
raise NotImplementedError
@abstractmethod
def _get_optimizer(self):
raise NotImplementedError

@ abstractmethod
def _get_optimizer_loss(self,):
"""
return the loss tensor that the optimizer wants to minimize
:return:
"""
@abstractmethod
def _build_graph(self):
"""
该方法必须在子类的初始化方法末尾被调用
子类的方法在默认图中构建计算图
with self.graph.as_default(): # , tf.device('/cpu:0'):
tf.set_random_seed(self.seed)
#构建计算图
#...
#...
#最后初始化
init = tf.global_variables_initializer()# init
self.sess.run(init)# sess defined in scope
"""
raise NotImplementedError

def compile(self, optimizer='sgd', loss='logloss', metrics=None, loss_weights=None, sample_weight_mode=None):
"""
compile the model with optimizer and loss function
:param optimizer:str or predefined optimizer in tensorflow
['sgd','adam','adagrad','rmsprop','moment','ftrl']
:param loss: str
:param metrics: str ['logloss','mse','mean_squared_error','logloss_with_logits']
:param loss_weights:
:param sample_weight_mode:
:return:
"""
with self.graph.as_default():# , tf.device('/cpu:0'):
#根据指定的优化器和损失函数初始化
self.metric_list = self._create_metrics(metrics)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)#for the use of BN
with tf.control_dependencies(update_ops):#for the use of BN
self.optimizer = self._create_optimizer(optimizer).minimize(self.loss, global_step=self.global_step)
#执行初始化操作
init = tf.global_variables_initializer() # init
self.sess.run(init) # sess defined in scope

def _create_metrics(self,metric):
if metric is None:#若不指定,则以训练时的损失函数作为度量
return [self._get_optimizer_loss()]

if metric not in ['logloss','mse','mean_squared_error','logloss_with_logits']:
raise ValueError('invalid param metrics')
# TODO:添加更多度量函数和函数作为参数
metrics_list = []

if metric == 'logloss':
metrics_list.append(tf.reduce_sum(sigmoid_cross_entropy_with_probs(
labels=self._get_input_target(), probs=self._get_output_target())))
elif metric=='mse' or metric == 'mean_squared_error':
metrics_list.append(tf.reduce_sum(tf.squared_difference(self._get_input_target(),self._get_output_target())))
elif metric=='logloss_with_logits':
metrics_list.append(tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=self._get_input_target(), logits=self.logit)))
return metrics_list

def _create_optimizer(self,optimizer='sgd'):
"""
:param optimizer: str of optimizer or predefined optimizer in tensorflow
:return: optimizer object
"""

optimizer_dict = {'sgd':tf.train.GradientDescentOptimizer(0.01),
'adam':tf.train.AdamOptimizer(0.001),
'adagrad':tf.train.AdagradOptimizer(0.01),
#'adagradda':tf.train.AdagradDAOptimizer(),
'rmsprop':tf.train.RMSPropOptimizer(0.001),
'moment':tf.train.MomentumOptimizer(0.01,0.9),
'ftrl':tf.train.FtrlOptimizer(0.01)
#tf.train.ProximalAdagradOptimizer#padagrad
#tf.train.ProximalGradientDescentOptimizer#pgd
}
if isinstance(optimizer,str):
if optimizer in optimizer_dict.keys():
return optimizer_dict[optimizer]
else:
raise ValueError('invalid optimizer name')
elif isinstance(optimizer,tf.train.Optimizer):
return optimizer
else:
raise ValueError('invalid parm for optimizer')



def save_model(self, save_path):
self.saver.save(self.sess, save_path + '.ckpt', self.global_step)
Expand All @@ -73,8 +140,7 @@ def load_model(self, meta_graph_path, ckpt_dir=None, ckpt_path=None):
def train_on_batch(self, x, y ): # fit a batch
feed_dict_ = {self._get_input_data(): x,
self._get_input_target(): y, self.train_flag: True}
loss, _ = self.sess.run((self._get_data_loss(), self._get_optimizer()), feed_dict=feed_dict_)
return loss
self.sess.run((self._get_optimizer_loss(), self.optimizer), feed_dict=feed_dict_)

def fit(self,x, y, batch_size=1024, epochs=50, validation_split = 0.0, validation_data=None,
val_size=2 ** 18, shuffle=True,initial_epoch=0,min_display=50,max_iter=-1):
Expand Down Expand Up @@ -106,7 +172,7 @@ def fit(self,x, y, batch_size=1024, epochs=50, validation_split = 0.0, validatio
batch_x = x[j * batch_size:(j + 1) * batch_size]
batch_y = y[j * batch_size:(j + 1) * batch_size]

l = self.train_on_batch(batch_x, batch_y )
self.train_on_batch(batch_x, batch_y )
if j % min_display == 0:
tr_loss = self.evaluate(x, y, val_size)
self.tr_loss_list.append(tr_loss)
Expand Down Expand Up @@ -140,7 +206,7 @@ def test_on_batch(self,x,y, ):
"""
feed_dict_ = {self._get_input_data(): x,
self._get_input_target(): y, self.train_flag: False}
loss = self.sess.run([self._get_data_loss()], feed_dict=feed_dict_)
loss = self.sess.run(self.metric_list, feed_dict=feed_dict_)
return loss[0]

def evaluate(self, x,y, val_size=2 ** 18):
Expand Down
41 changes: 11 additions & 30 deletions deep_cross_network.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import tensorflow as tf
from .base import TFBaseModel

class DeepCrossNetwork(TFBaseModel):
def __init__(self, field_dim, feature_dim,embedding_size=4,
lr=0.1, cross_layer_num=1, hidden_size=[], use_batchnorm=True,deep_l2_reg=0.0,
cross_layer_num=1, hidden_size=[], use_batchnorm=True,deep_l2_reg=0.0,
init_std=0.01, seed=1024, keep_prob=0.5,
checkpoint_path=None, opt="adam", ):
checkpoint_path=None, ):
super(DeepCrossNetwork, self).__init__(
seed=seed, checkpoint_path=checkpoint_path)

self.field_dim = field_dim
self.feature_dim = feature_dim
self.embedding_size = embedding_size
self.lr = lr

self.deep_l2_reg = deep_l2_reg
self.init_std = init_std
Expand All @@ -21,24 +23,18 @@ def __init__(self, field_dim, feature_dim,embedding_size=4,

#self.feature_list = feature_list
#self.feature_count = feature_count

self.opt = opt
self.use_batchnorm = use_batchnorm

self._build_graph()

def _get_data_loss(self):
return self.log_loss
def _get_optimizer_loss(self):
return self.loss

def _get_input_data(self, ):
return self.X

def _get_input_target(self, ):
return self.Y

def _get_optimizer(self):
return self.optimizer

def _get_output_target(self, ):
return tf.sigmoid(self.logit)

Expand All @@ -51,11 +47,7 @@ def _build_graph(self,):
self._create_variable()
self._forward_pass()
self._create_loss()
self._create_optimizer()

# init
init = tf.global_variables_initializer()
self.sess.run(init)

def _create_placeholders(self, ):

Expand Down Expand Up @@ -164,18 +156,7 @@ def _create_loss(self, ):
#test_writer = tf.summary.FileWriter('../check/DCN/test')
#https://www.tensorflow.org/get_started/summaries_and_tensorboard
self.loss = self.log_loss # + l2_reg_w_loss

def _create_optimizer(self):
if self.opt == "adam":
opt = tf.train.AdamOptimizer(self.lr)
elif self.opt == "ftrl":
opt = tf.train.FtrlOptimizer(
self.lr, l2_regularization_strength=0.5, l1_regularization_strength=0.5)
elif self.opt == "momentum":
opt = tf.train.MomentumOptimizer(self.lr, 0.9)
else:
opt = tf.train.GradientDescentOptimizer(self.lr)

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
self.optimizer = opt.minimize(self.loss)
if __name__ == '__main__':
model = DeepCrossNetwork(2, 3)
model.compile('adam',)
print('DeepCrossNetwork test pass')
33 changes: 6 additions & 27 deletions deepfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,36 +5,28 @@

class DeepFM(TFBaseModel):
def __init__(self, field_dim, feature_dim, embedding_size=4,
lr=0.1, opt='ftrl', use_cross=True, hidden_size=[], l2_reg_w=1.0, l2_reg_V=1.0,
use_cross=True, hidden_size=[], l2_reg_w=1.0, l2_reg_V=1.0,
init_std=0.01, seed=1024, hidden_unit=100, keep_prob=0.5,
checkpoint_path=None, ):

super(DeepFM, self).__init__(
seed=seed, checkpoint_path=checkpoint_path)
seed=seed, checkpoint_path=checkpoint_path,opt=opt)
self.params = locals()
self._build_graph()

def _get_output_target(self, ):
return tf.sigmoid(self.logit)

def _get_data_loss(self, ):
return self.log_loss

def _get_optimizer(self):
return self.optimizer
def _get_optimizer_loss(self,):
return self.loss

def _build_graph(self, ):
with self.graph.as_default(): # , tf.device('/cpu:0'):
tf.set_random_seed(self.seed)
self._create_placeholders()

self._create_variable()
self._forward_pass()
self._create_loss()
self.optimizer = self._create_optimizer()
init = tf.global_variables_initializer() # init

self.sess.run(init) # sess defined in scope

def _get_input_data(self, ):
return self.placeholders['X']
Expand Down Expand Up @@ -127,20 +119,7 @@ def _create_loss(self, ):
# self.loss += l2_reg_V_loss
pass

def _create_optimizer(self):
if self.params['opt'] == "adam":
opt = tf.train.AdamOptimizer(self.params['lr'])
elif self.params['opt'] == "ftrl":
opt = tf.train.FtrlOptimizer(self.params['lr'], l2_regularization_strength=0.5,
l1_regularization_strength=0.5)
elif self.params['opt'] == "momentum":
opt = tf.train.MomentumOptimizer(self.params['lr'], 0.9)
else:
opt = tf.train.GradientDescentOptimizer(self.params['lr'])

return opt.minimize(self.loss, global_step=self.global_step)


if __name__ == '__main__':
model = DeepFM(2, 3)
print('DeepFM test pass')
model.compile('ftrl',)
print('DeepFM test pass')
8 changes: 8 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import tensorflow as tf
def sigmoid_cross_entropy_with_probs(labels=None,probs=None,name=None):
try:
labels.get_shape().merge_with(probs.get_shape())
except ValueError:
raise ValueError("logits and labels must have the same shape (%s vs %s)" %
(logits.get_shape(), labels.get_shape()))
return -tf.reduce_sum(labels * tf.log(probs,)+(1-labels)*tf.log(1-probs), name=name)

0 comments on commit 977ec7c

Please sign in to comment.