-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain_cv.py
103 lines (88 loc) · 3.87 KB
/
train_cv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# -*- coding: utf-8 -*-
# @Time : 2018/7/26 23:08
# @Author : Xiaoyu Liu
# @Email : liuxiaoyu16@fudan.edu.
from __future__ import unicode_literals
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
import utils
import models
import random
tf.flags.DEFINE_integer('batch_size',128,'batch_size')
tf.flags.DEFINE_integer("max_epoch",100,'max_epoch')
tf.flags.DEFINE_string("model_name","HybridCNN","model_name")
tf.flags.DEFINE_string("model_file","HybridCNN","model_file")
tf.flags.DEFINE_integer('restore',0,'restore')
FLAGS=tf.flags.FLAGS
def train():
source_data,target_data,test_data,word2id=utils.load_data()
embeddings=utils.load_embeddings(word2id)
random.seed(1)
random.shuffle(target_data)
cv_losses=[]
for k in range(1,11):
train_data, dev_data = utils.train_dev_split(target_data, k)
model_file=FLAGS.model_file+str(k)
print model_file
print "训练集1数据大小:%d" % len(source_data)
print "训练集2数据大小:%d" % len(train_data)
print "验证集数据大小:%d" % len(dev_data)
print "embedding大小:(%d,%d)"%(embeddings.shape[0],embeddings.shape[1])
model_dir='../model'
graph=tf.Graph()
sess=tf.Session(graph=graph)
with graph.as_default():
model=getattr(models,FLAGS.model_name)(embeddings)
saver = tf.train.Saver(tf.global_variables())
if FLAGS.restore==1:
saver.restore(sess, os.path.join(model_dir, FLAGS.model_file))
print "Restore from pre-trained model"
else:
sess.run(tf.global_variables_initializer())
print "Train start!"
best_loss=1e6
best_epoch=0
not_improved=0
for epoch in range(FLAGS.max_epoch):
print epoch,"================================================"
train_loss=[]
ground_trues=[]
predicts=[]
for batch_data in utils.minibatches(train_data,FLAGS.batch_size,mode='train'):
loss,predict=model.train_step(sess,batch_data)
train_loss.extend(loss)
predicts.extend(predict)
ground_trues.extend(batch_data[2])
train_loss=utils.loss(ground_trues,train_loss)
p,r,f1=utils.score(ground_trues,predicts)
print "%d-fold Train epoch %d finished. loss:%.4f p:%.4f r:%.4f f1:%.4f" % (k,epoch,train_loss,p,r,f1)
valid_loss=[]
ground_trues=[]
predicts=[]
for batch_data in utils.minibatches(dev_data,FLAGS.batch_size,mode='dev'):
loss,predict= model.valid_step(sess, batch_data)
valid_loss.extend(loss)
predicts.extend(predict)
ground_trues.extend(batch_data[2])
valid_loss=utils.loss(ground_trues,valid_loss)
p, r, f1=utils.score(ground_trues, predicts)
print "%d-fold,Valid epoch %d finished. loss:%.4f p:%.4f r:%.4f f1:%.4f" % (k,epoch,valid_loss,p,r,f1)
if valid_loss<best_loss:
best_loss=valid_loss
best_epoch=epoch
not_improved=0
print "save model!"
saver.save(sess, os.path.join(model_dir, model_file))
else:
not_improved+=1
if not_improved>4:
print "停止训练!"
break
print
print "Best epoch %d best loss %.4f" % (best_epoch,best_loss)
print "#########################################################"
cv_losses.append(best_loss)
print "final cv loss: %.4f" % (sum(cv_losses) / len(cv_losses))
if __name__ == '__main__':
train()