Skip to content

Commit 48368b2

Browse files
authored
text_cnn
1 parent 9fec142 commit 48368b2

File tree

1 file changed

+131
-0
lines changed

1 file changed

+131
-0
lines changed
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
"""
2+
A CNN model for sentence classification
3+
source: 'https://github.com/dennybritz/cnn-text-classification-tf/blob/master/text_cnn.py'
4+
2016/12/21
5+
"""
6+
import numpy as np
7+
import tensorflow as tf
8+
9+
class TextCNN(object):
10+
"""
11+
A CNN class for sentence classification
12+
The model includes an embedding layer, a convolutional layer, a max-pooling layer and
13+
a softmax layer as the output.
14+
"""
15+
def __init__(self, seq_len, vocab_size, embedding_size, filter_sizes, num_filters,
16+
num_classes=2, l2_reg_lambda=0.0):
17+
"""
18+
:param seq_len: int, the sequence length (i.e. the length of the sentences,
19+
keep all length same by zero-padding)
20+
:param vocab_size: int, the size of vocabulary to define the embedding layer
21+
:param embedding_size: int, the dimensionality of the embeddings (word vector).
22+
:param filter_sizes: list or tuple, The number of words we want our convolutional filters to cover.
23+
For example, [3, 4, 5] means that we will have filters that slide over 3, 4
24+
and 5 words respectively
25+
:param num_filters: int, the number of each filter with different filter_size, hence, we have a total of
26+
len(filter_sizes) * num_filters filters
27+
:param num_classes: the number of classes we want to predict in the output layer, default 2
28+
:param l2_reg_lambda: float, the ratio of L2 loss
29+
"""
30+
# keep track of all parameters
31+
self.seq_len = seq_len
32+
self.vocab_size = vocab_size
33+
self.embedding_szie = embedding_size
34+
self.filter_sizes = filter_sizes
35+
self.num_filters = num_filters
36+
self.num_classes = num_classes
37+
self.l2_reg_lambda = l2_reg_lambda
38+
# Define the input and output
39+
self.x = tf.placeholder(tf.int32, shape=[None, seq_len], name="x")
40+
self.y = tf.placeholder(tf.float32, shape=[None, num_classes], name="y")
41+
# The dropout probability
42+
self.dropout_keep_prob = tf.placeholder(tf.float32, name="dropout_keep_prob")
43+
# Compute the L2 regularization loss
44+
L2_loss = tf.constant(0.0) # initial value 0.0
45+
46+
# The Embedding layer
47+
with tf.device("/cpu:0"): # embedding implementation not support GPU
48+
with tf.name_scope("embedding"):
49+
# The embedding matrix
50+
self.W_embedding = tf.Variable(tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0),
51+
dtype=tf.float32, name="W_embedding")
52+
# The embedding results
53+
self.embedded_chars = tf.nn.embedding_lookup(self.W_embedding, self.x) #[None, seq_len, embedding_size]
54+
# Expand it to use conv2D operation
55+
self.embedded_chars_expanded = tf.expand_dims(self.embedded_chars, axis=-1) # [None, seq_len, embedding_size, 1]
56+
57+
# The convolution and maxpool layer
58+
pooled_outputs = []
59+
self.Ws_conv = []
60+
self.bs_conv = []
61+
# For each filter
62+
for i, filter_size in enumerate(filter_sizes):
63+
with tf.name_scope("conv_maxpool_{0}".format(filter_size)):
64+
# Convolution layer
65+
filter_shape = [filter_size, embedding_size, 1, num_filters]
66+
# Conv params
67+
W_conv = tf.Variable(tf.truncated_normal(filter_shape, stddev=0.1),
68+
dtype=tf.float32, name="W_conv")
69+
self.Ws_conv.append(W_conv)
70+
b_conv = tf.Variable(tf.constant(0.1, shape=[num_filters,]), dtype=tf.float32,
71+
name="b_conv")
72+
self.bs_conv.append(b_conv)
73+
# conv result
74+
conv_output = tf.nn.conv2d(self.embedded_chars_expanded, W_conv, strides=[1, 1, 1, 1],
75+
padding="VALID", name="conv") # [None, seq_len-filter_size+1, 1, num_filters]
76+
# use relu as activation
77+
conv_h = tf.nn.relu(tf.nn.bias_add(conv_output, b_conv), name="relu")
78+
# Use max-pooling
79+
pool_output = tf.nn.max_pool(conv_h, ksize=[1, seq_len-filter_size+1, 1, 1],
80+
strides=[1, 1, 1, 1], padding="VALID", name="max_pooling")
81+
pooled_outputs.append(pool_output) # [None, 1, 1, num_filters]
82+
# Combine all pooled features
83+
num_filters_total = num_filters * len(filter_sizes)
84+
self.h_pool = tf.concat(3, pooled_outputs) # [None, 1, 1, num_filters_total]
85+
self.h_pool_flat = tf.reshape(self.h_pool, shape=[-1, num_filters_total]) # [None, num_filters_total]
86+
87+
# The dropout layer
88+
with tf.name_scope("dropout"):
89+
self.h_dropout = tf.nn.dropout(self.h_pool_flat, keep_prob=self.dropout_keep_prob, name="dropout")
90+
91+
# The output layer (softmax)
92+
with tf.name_scope("output"):
93+
self.W_fullyconn = tf.get_variable("W_fullyconn", shape=[num_filters_total, num_classes],
94+
initializer=tf.contrib.layers.xavier_initializer())
95+
self.b_fullyconn = tf.Variable(tf.constant(0.1, shape=[num_classes,]), dtype=tf.float32, name="b_fullyconn")
96+
# L2_loss
97+
L2_loss += tf.nn.l2_loss(self.W_fullyconn)
98+
self.scores = tf.nn.xw_plus_b(self.h_dropout, self.W_fullyconn, self.b_fullyconn, name="scores")
99+
self.preds = tf.argmax(self.scores, axis=1, name="preds")
100+
101+
# The loss
102+
with tf.name_scope("loss"):
103+
losses = tf.nn.softmax_cross_entropy_with_logits(self.scores, self.y)
104+
self.loss = tf.reduce_mean(losses) + L2_loss * l2_reg_lambda
105+
106+
# Accuracy
107+
with tf.name_scope("accuracy"):
108+
correct_preds = tf.equal(self.preds, tf.argmax(self.y, axis=1))
109+
self.accuracy = tf.reduce_mean(tf.cast(correct_preds, tf.float32))
110+
111+
def save_weights(self, sess, filename, name="TextRNN"):
112+
""""""
113+
save_dicts = {name+"_W_embedding": self.W_embedding}
114+
for i in range(len(self.Ws_conv)):
115+
save_dicts.update({name+"_W_conv_"+str(i): self.Ws_conv[i],
116+
name+"_b_conv_"+str(i): self.bs_conv[i]})
117+
save_dicts.update({name+"_W_fullyconn": self.W_fullyconn,
118+
name+"_b_fullyconn": self.b_fullyconn})
119+
saver = tf.train.Saver(save_dicts)
120+
return saver.save(sess, filename)
121+
122+
def load_weights(self, sess, filename, name="TextRNN"):
123+
""""""
124+
save_dicts = {name+"_W_embedding": self.W_embedding}
125+
for i in range(len(self.Ws_conv)):
126+
save_dicts.update({name+"_W_conv_"+str(i): self.Ws_conv[i],
127+
name+"_b_conv_"+str(i): self.bs_conv[i]})
128+
save_dicts.update({name+"_W_fullyconn": self.W_fullyconn,
129+
name+"_b_fullyconn": self.b_fullyconn})
130+
saver = tf.train.Saver(save_dicts)
131+
saver.restore(sess)

0 commit comments

Comments
 (0)