1+ """
2+ A CNN model for sentence classification
3+ source: 'https://github.com/dennybritz/cnn-text-classification-tf/blob/master/text_cnn.py'
4+ 2016/12/21
5+ """
6+ import numpy as np
7+ import tensorflow as tf
8+
9+ class TextCNN (object ):
10+ """
11+ A CNN class for sentence classification
12+ The model includes an embedding layer, a convolutional layer, a max-pooling layer and
13+ a softmax layer as the output.
14+ """
15+ def __init__ (self , seq_len , vocab_size , embedding_size , filter_sizes , num_filters ,
16+ num_classes = 2 , l2_reg_lambda = 0.0 ):
17+ """
18+ :param seq_len: int, the sequence length (i.e. the length of the sentences,
19+ keep all length same by zero-padding)
20+ :param vocab_size: int, the size of vocabulary to define the embedding layer
21+ :param embedding_size: int, the dimensionality of the embeddings (word vector).
22+ :param filter_sizes: list or tuple, The number of words we want our convolutional filters to cover.
23+ For example, [3, 4, 5] means that we will have filters that slide over 3, 4
24+ and 5 words respectively
25+ :param num_filters: int, the number of each filter with different filter_size, hence, we have a total of
26+ len(filter_sizes) * num_filters filters
27+ :param num_classes: the number of classes we want to predict in the output layer, default 2
28+ :param l2_reg_lambda: float, the ratio of L2 loss
29+ """
30+ # keep track of all parameters
31+ self .seq_len = seq_len
32+ self .vocab_size = vocab_size
33+ self .embedding_szie = embedding_size
34+ self .filter_sizes = filter_sizes
35+ self .num_filters = num_filters
36+ self .num_classes = num_classes
37+ self .l2_reg_lambda = l2_reg_lambda
38+ # Define the input and output
39+ self .x = tf .placeholder (tf .int32 , shape = [None , seq_len ], name = "x" )
40+ self .y = tf .placeholder (tf .float32 , shape = [None , num_classes ], name = "y" )
41+ # The dropout probability
42+ self .dropout_keep_prob = tf .placeholder (tf .float32 , name = "dropout_keep_prob" )
43+ # Compute the L2 regularization loss
44+ L2_loss = tf .constant (0.0 ) # initial value 0.0
45+
46+ # The Embedding layer
47+ with tf .device ("/cpu:0" ): # embedding implementation not support GPU
48+ with tf .name_scope ("embedding" ):
49+ # The embedding matrix
50+ self .W_embedding = tf .Variable (tf .random_uniform ([vocab_size , embedding_size ], - 1.0 , 1.0 ),
51+ dtype = tf .float32 , name = "W_embedding" )
52+ # The embedding results
53+ self .embedded_chars = tf .nn .embedding_lookup (self .W_embedding , self .x ) #[None, seq_len, embedding_size]
54+ # Expand it to use conv2D operation
55+ self .embedded_chars_expanded = tf .expand_dims (self .embedded_chars , axis = - 1 ) # [None, seq_len, embedding_size, 1]
56+
57+ # The convolution and maxpool layer
58+ pooled_outputs = []
59+ self .Ws_conv = []
60+ self .bs_conv = []
61+ # For each filter
62+ for i , filter_size in enumerate (filter_sizes ):
63+ with tf .name_scope ("conv_maxpool_{0}" .format (filter_size )):
64+ # Convolution layer
65+ filter_shape = [filter_size , embedding_size , 1 , num_filters ]
66+ # Conv params
67+ W_conv = tf .Variable (tf .truncated_normal (filter_shape , stddev = 0.1 ),
68+ dtype = tf .float32 , name = "W_conv" )
69+ self .Ws_conv .append (W_conv )
70+ b_conv = tf .Variable (tf .constant (0.1 , shape = [num_filters ,]), dtype = tf .float32 ,
71+ name = "b_conv" )
72+ self .bs_conv .append (b_conv )
73+ # conv result
74+ conv_output = tf .nn .conv2d (self .embedded_chars_expanded , W_conv , strides = [1 , 1 , 1 , 1 ],
75+ padding = "VALID" , name = "conv" ) # [None, seq_len-filter_size+1, 1, num_filters]
76+ # use relu as activation
77+ conv_h = tf .nn .relu (tf .nn .bias_add (conv_output , b_conv ), name = "relu" )
78+ # Use max-pooling
79+ pool_output = tf .nn .max_pool (conv_h , ksize = [1 , seq_len - filter_size + 1 , 1 , 1 ],
80+ strides = [1 , 1 , 1 , 1 ], padding = "VALID" , name = "max_pooling" )
81+ pooled_outputs .append (pool_output ) # [None, 1, 1, num_filters]
82+ # Combine all pooled features
83+ num_filters_total = num_filters * len (filter_sizes )
84+ self .h_pool = tf .concat (3 , pooled_outputs ) # [None, 1, 1, num_filters_total]
85+ self .h_pool_flat = tf .reshape (self .h_pool , shape = [- 1 , num_filters_total ]) # [None, num_filters_total]
86+
87+ # The dropout layer
88+ with tf .name_scope ("dropout" ):
89+ self .h_dropout = tf .nn .dropout (self .h_pool_flat , keep_prob = self .dropout_keep_prob , name = "dropout" )
90+
91+ # The output layer (softmax)
92+ with tf .name_scope ("output" ):
93+ self .W_fullyconn = tf .get_variable ("W_fullyconn" , shape = [num_filters_total , num_classes ],
94+ initializer = tf .contrib .layers .xavier_initializer ())
95+ self .b_fullyconn = tf .Variable (tf .constant (0.1 , shape = [num_classes ,]), dtype = tf .float32 , name = "b_fullyconn" )
96+ # L2_loss
97+ L2_loss += tf .nn .l2_loss (self .W_fullyconn )
98+ self .scores = tf .nn .xw_plus_b (self .h_dropout , self .W_fullyconn , self .b_fullyconn , name = "scores" )
99+ self .preds = tf .argmax (self .scores , axis = 1 , name = "preds" )
100+
101+ # The loss
102+ with tf .name_scope ("loss" ):
103+ losses = tf .nn .softmax_cross_entropy_with_logits (self .scores , self .y )
104+ self .loss = tf .reduce_mean (losses ) + L2_loss * l2_reg_lambda
105+
106+ # Accuracy
107+ with tf .name_scope ("accuracy" ):
108+ correct_preds = tf .equal (self .preds , tf .argmax (self .y , axis = 1 ))
109+ self .accuracy = tf .reduce_mean (tf .cast (correct_preds , tf .float32 ))
110+
111+ def save_weights (self , sess , filename , name = "TextRNN" ):
112+ """"""
113+ save_dicts = {name + "_W_embedding" : self .W_embedding }
114+ for i in range (len (self .Ws_conv )):
115+ save_dicts .update ({name + "_W_conv_" + str (i ): self .Ws_conv [i ],
116+ name + "_b_conv_" + str (i ): self .bs_conv [i ]})
117+ save_dicts .update ({name + "_W_fullyconn" : self .W_fullyconn ,
118+ name + "_b_fullyconn" : self .b_fullyconn })
119+ saver = tf .train .Saver (save_dicts )
120+ return saver .save (sess , filename )
121+
122+ def load_weights (self , sess , filename , name = "TextRNN" ):
123+ """"""
124+ save_dicts = {name + "_W_embedding" : self .W_embedding }
125+ for i in range (len (self .Ws_conv )):
126+ save_dicts .update ({name + "_W_conv_" + str (i ): self .Ws_conv [i ],
127+ name + "_b_conv_" + str (i ): self .bs_conv [i ]})
128+ save_dicts .update ({name + "_W_fullyconn" : self .W_fullyconn ,
129+ name + "_b_fullyconn" : self .b_fullyconn })
130+ saver = tf .train .Saver (save_dicts )
131+ saver .restore (sess )
0 commit comments