# 变长输入的textCNN实现

In [52]:
import sys
import os
import re
import math

import pandas as pd
import numpy as np

from nltk.corpus import stopwords
from nltk.util import ngrams

from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences

import torch as t
from torch import nn
from torch.utils import data
from torch.utils.data import DataLoader
import torch.nn.functional as F

stop_words = set(stopwords.words('english'))

In [53]:
def clean_text(s):
    s = re.sub(r'<[^>]+>',' ',s)
    s = re.sub(r'[^a-zA-Z\']',' ',s)
    s = s.lower()
    s = s.split(" ")
    s = [w for w in s if not w in stop_words]
    return " ".join(s)

In [54]:
class Movie(data.Dataset):
    def __init__(self, x, y=None, train=True):
        self.x = x
        self.y = y
        self.train = train
    def __getitem__(self, index):
        if self.train:
            return self.x[index],self.y[index]
        return self.x[index]
    
    def __len__(self):
        return len(self.x)

In [55]:
max_features = 6000  # 词汇表的大小
embedding_size = 300
import gc
class textCNN(nn.Module):
    def __init__(self,kernel_size_list,conv_out_channel):
        super(textCNN, self).__init__()
        self.num_conv = len(kernel_size_list)
        self.conv_out_channel = conv_out_channel
        self.kernel_size_list = kernel_size_list
        
        self.embd = nn.Embedding(max_features,embedding_size)
        self.conv = []
        for i in range(self.num_conv):
            self.conv.append(nn.Conv2d(1, self.conv_out_channel, (self.kernel_size_list[i],embedding_size),bias=True))
            n = self.conv[i].kernel_size[0] * self.conv[i].kernel_size[1] * self.conv[i].out_channels
            self.conv[i].weight.data.normal_(0, math.sqrt(2. / n))
        self.fc = nn.Sequential(
            nn.Linear(self.conv_out_channel*self.num_conv,1),
            nn.Sigmoid()
        )
    def forward(self, x):
        res = t.Tensor()
        for xi in x:
            sent_len = xi.size()[0]
            mx = max(self.kernel_size_list)
            if sent_len < mx:
                xi = t.cat((xi,t.ones(mx-sent_len).long()))
                sent_len = mx
            xi = self.embd(xi)   # sent_len, em_size
            pooled = []
            xi = t.unsqueeze(xi, 0)
            xi = t.unsqueeze(xi, 0)  # 1,1,sent_len,em_size
            for i in range(self.num_conv):
                tmp = self.conv[i](xi)  # 1,2,sent_len-self.kernel_size[i]+1,1
                m = nn.ReLU()
                tmp = m(tmp)
                m = nn.MaxPool2d((sent_len-self.kernel_size_list[i]+1,1))
                tmp = m(tmp)   # 1,2,1,1
                tmp = tmp.squeeze()
                pooled.append(tmp)
#                 del tmp
            xxi = t.cat(pooled)
            xxi = self.fc(xxi)
            res = t.cat((res,xxi))
#             del xxi,xi,pooled
#             gc.collect()
        return res


In [56]:
# net = textCNN([2,3,4],2)
# inputs = t.rand(50,2).long()
# # for i in inputs:
# #     print(i.size())
# outputs = net(inputs)

In [57]:
def train(model, x, y, criterion, optimizer, save_model_path=r'model.pkl', batch_size=50,epoch=2, display_iter=10, num_threads=8):
#     t.set_num_threads(num_threads)
    
    for epoch_iter in range(epoch):
        running_loss = 0.0
        for i in range(500):

            # 输入数据
            inputs = x[i*batch_size:i*batch_size+batch_size]
            labels = y[i*batch_size:i*batch_size+batch_size]
            inputs = [t.Tensor(i).long() for i in inputs]
            
            labels = t.Tensor(labels).float()

            # 梯度清零
            optimizer.zero_grad()
            outputs = model(inputs)
            outputs = outputs.squeeze()
            loss = criterion(outputs, labels)
            loss.backward()

            # 更新参数
            optimizer.step()

            # 打印log
            running_loss += loss.item()
            if i % display_iter == display_iter-1:
                print('[%d,%5d] loss %.3f' % (epoch_iter+1, i+1, running_loss / 10))
                running_loss = 0.0
    t.save(model.state_dict(), save_model_path) # 保存的是以字典 key - value pair 形式的数据，每一个参数对应着一个值 state_dict 状态字典 
    print('training finished!!!!!')

In [58]:
from tqdm import tqdm
import gc
def predict(model, x,batch_size=100):
    res = t.Tensor()
    with t.no_grad():
        for i in tqdm(range(250)):
            inputs = x[i*batch_size:i*batch_size+batch_size]
            inputs = [t.Tensor(i).long() for i in inputs]
            outputs = model(inputs)
            outputs = outputs.squeeze()
            res = t.cat((res,outputs))
#             print(res.size())
#             del inputs,outputs
#             gc.collect()
    return res

In [59]:
data_path = r'E:\kaggle\movies'
train_data_path = os.path.join(data_path,'labeledTrainData.tsv')
test_data_path = os.path.join(data_path,'testData.tsv')
train_df = pd.read_csv(train_data_path,header=0,sep='\t')
test_df = pd.read_csv(test_data_path,header=0,sep='\t')
test_df['text'] = test_df.review.apply(clean_text)
train_df['text'] = train_df.review.apply(clean_text)

In [60]:
train_df.head()

Unnamed: 0,id,sentiment,review,text
0,5814_8,1,With all this stuff going down at the moment w...,stuff going moment mj i've started listening m...
1,2381_9,1,"\The Classic War of the Worlds\"" by Timothy Hi...",classic war worlds timothy hines entertaini...
2,7759_3,0,The film starts with a manager (Nicholas Bell)...,film starts manager nicholas bell giving wel...
3,3630_4,0,It must be assumed that those who praised this...,must assumed praised film greatest filmed op...
4,9495_8,1,Superbly trashy and wondrously unpretentious 8...,superbly trashy wondrously unpretentious 's ...


In [61]:
tokenizer = Tokenizer(num_words = max_features)
tokenizer.fit_on_texts(train_df['text'])
list_tokenized_train = tokenizer.texts_to_sequences(train_df['text'])
list_tokenized_test = tokenizer.texts_to_sequences(test_df['text'])

train_x = list_tokenized_train
test_x = list_tokenized_test
train_y = np.array(train_df['sentiment'])
print(len(train_x),len(test_x),type(train_y))

25000 25000 <class 'numpy.ndarray'>


In [73]:
train_xy = [(train_x[i],train_y[i]) for i in range(len(train_x))]
import random
random.shuffle(train_xy)
train_x = [x[0] for x in train_xy]
train_y = [x[1] for x in train_xy]

In [74]:
net = textCNN([2,3,4],2)
criterion = nn.BCELoss()
optimizer = t.optim.Adam(net.parameters(),lr=0.003)
epoch = 2
batch_size = 50
display_iter = 10
num_threads = 8
model_path = os.path.join(data_path,r'textCNN.pkl')

train(net, train_x, train_y, criterion=criterion, optimizer=optimizer, save_model_path=model_path,
      epoch=epoch, display_iter=display_iter, num_threads=num_threads)
# result = predict(net, test_x).data.numpy()

[1,   10] loss 1.230
[1,   20] loss 1.106
[1,   30] loss 0.903
[1,   40] loss 0.785
[1,   50] loss 0.713
[1,   60] loss 0.700
[1,   70] loss 0.680
[1,   80] loss 0.678
[1,   90] loss 0.680
[1,  100] loss 0.667
[1,  110] loss 0.667
[1,  120] loss 0.656
[1,  130] loss 0.655
[1,  140] loss 0.646
[1,  150] loss 0.643
[1,  160] loss 0.627
[1,  170] loss 0.628
[1,  180] loss 0.629
[1,  190] loss 0.624
[1,  200] loss 0.613
[1,  210] loss 0.597
[1,  220] loss 0.594
[1,  230] loss 0.587
[1,  240] loss 0.591
[1,  250] loss 0.576
[1,  260] loss 0.558
[1,  270] loss 0.560
[1,  280] loss 0.553
[1,  290] loss 0.514
[1,  300] loss 0.520
[1,  310] loss 0.544
[1,  320] loss 0.525
[1,  330] loss 0.537
[1,  340] loss 0.504
[1,  350] loss 0.483
[1,  360] loss 0.507
[1,  370] loss 0.494
[1,  380] loss 0.466
[1,  390] loss 0.498
[1,  400] loss 0.529
[1,  410] loss 0.485
[1,  420] loss 0.471
[1,  430] loss 0.478
[1,  440] loss 0.447
[1,  450] loss 0.466
[1,  460] loss 0.454
[1,  470] loss 0.438
[1,  480] los

In [77]:
net.load_state_dict(t.load(model_path))
optimizer = t.optim.SGD(net.parameters(),lr=0.0005)
train(net, train_x, train_y, criterion=criterion, optimizer=optimizer, save_model_path=model_path,
      epoch=epoch, display_iter=display_iter, num_threads=num_threads)
result = predict(net, test_x).data.numpy()

[1,   10] loss 0.299
[1,   20] loss 0.333
[1,   30] loss 0.280
[1,   40] loss 0.286
[1,   50] loss 0.297
[1,   60] loss 0.267
[1,   70] loss 0.269
[1,   80] loss 0.246
[1,   90] loss 0.285
[1,  100] loss 0.255
[1,  110] loss 0.275
[1,  120] loss 0.241
[1,  130] loss 0.271
[1,  140] loss 0.261
[1,  150] loss 0.313
[1,  160] loss 0.234
[1,  170] loss 0.242
[1,  180] loss 0.266
[1,  190] loss 0.264
[1,  200] loss 0.231
[1,  210] loss 0.247
[1,  220] loss 0.242
[1,  230] loss 0.225
[1,  240] loss 0.261
[1,  250] loss 0.241
[1,  260] loss 0.253
[1,  270] loss 0.237
[1,  280] loss 0.226
[1,  290] loss 0.218
[1,  300] loss 0.212
[1,  310] loss 0.228
[1,  320] loss 0.220
[1,  330] loss 0.248
[1,  340] loss 0.208
[1,  350] loss 0.207
[1,  360] loss 0.230
[1,  370] loss 0.210
[1,  380] loss 0.204
[1,  390] loss 0.228
[1,  400] loss 0.237
[1,  410] loss 0.207
[1,  420] loss 0.188
[1,  430] loss 0.203
[1,  440] loss 0.178
[1,  450] loss 0.190
[1,  460] loss 0.201
[1,  470] loss 0.180
[1,  480] los





  0%|                                                                                          | 0/250 [00:00<?, ?it/s]



  0%|▎                                                                                 | 1/250 [00:00<01:06,  3.74it/s]



  1%|▋                                                                                 | 2/250 [00:00<01:05,  3.81it/s]



  1%|▉                                                                                 | 3/250 [00:00<00:58,  4.20it/s]



  2%|█▎                                                                                | 4/250 [00:00<00:55,  4.39it/s]



  2%|█▋                                                                                | 5/250 [00:01<00:53,  4.62it/s]



  2%|█▉                                                                                | 6/250 [00:01<00:54,  4.48it/s]



  3%|██▎                                                                               | 7/250 [00:01<00:53,  4.55it/s]



  3%|██▌    

 26%|█████████████████████▍                                                           | 66/250 [00:14<00:41,  4.47it/s]



 27%|█████████████████████▋                                                           | 67/250 [00:14<00:39,  4.62it/s]



 27%|██████████████████████                                                           | 68/250 [00:14<00:41,  4.39it/s]



 28%|██████████████████████▎                                                          | 69/250 [00:15<00:39,  4.57it/s]



 28%|██████████████████████▋                                                          | 70/250 [00:15<00:39,  4.53it/s]



 28%|███████████████████████                                                          | 71/250 [00:15<00:39,  4.55it/s]



 29%|███████████████████████▎                                                         | 72/250 [00:15<00:38,  4.58it/s]



 29%|███████████████████████▋                                                         | 73/250 [00:15<00:39,  4.45it/s]



 30%|███████████

 53%|██████████████████████████████████████████▏                                     | 132/250 [00:28<00:25,  4.61it/s]



 53%|██████████████████████████████████████████▌                                     | 133/250 [00:29<00:27,  4.28it/s]



 54%|██████████████████████████████████████████▉                                     | 134/250 [00:29<00:25,  4.47it/s]



 54%|███████████████████████████████████████████▏                                    | 135/250 [00:29<00:25,  4.48it/s]



 54%|███████████████████████████████████████████▌                                    | 136/250 [00:29<00:25,  4.48it/s]



 55%|███████████████████████████████████████████▊                                    | 137/250 [00:29<00:23,  4.81it/s]



 55%|████████████████████████████████████████████▏                                   | 138/250 [00:30<00:23,  4.82it/s]



 56%|████████████████████████████████████████████▍                                   | 139/250 [00:30<00:23,  4.82it/s]



 56%|███████████

 79%|███████████████████████████████████████████████████████████████▎                | 198/250 [00:42<00:11,  4.59it/s]



 80%|███████████████████████████████████████████████████████████████▋                | 199/250 [00:42<00:10,  4.78it/s]



 80%|████████████████████████████████████████████████████████████████                | 200/250 [00:43<00:10,  4.68it/s]



 80%|████████████████████████████████████████████████████████████████▎               | 201/250 [00:43<00:10,  4.60it/s]



 81%|████████████████████████████████████████████████████████████████▋               | 202/250 [00:43<00:10,  4.45it/s]



 81%|████████████████████████████████████████████████████████████████▉               | 203/250 [00:43<00:10,  4.61it/s]



 82%|█████████████████████████████████████████████████████████████████▎              | 204/250 [00:44<00:10,  4.35it/s]



 82%|█████████████████████████████████████████████████████████████████▌              | 205/250 [00:44<00:09,  4.53it/s]



 82%|███████████

In [78]:
result = np.array(result>0.5, dtype=np.int)
result_df = pd.DataFrame({'id':test_df['id'],'sentiment':result})
result_df.to_csv(os.path.join(data_path,'textCNN_result_1.csv'),index=False)