# 模型预测&训练
这一部分是模型的训练和预测，我们使用到的模型可以分为两类

## 单模型
### 深度学习模型
我们联合了新闻文本和图片上的文本信息来构建模型，基本框架如下图所示
![img](../img/example.png)

### 机器学习模型

* 输入特征为TFIDF+SVD、Basic Features等
* 这里OCR出来的结果和新闻文本是做简单拼接的方式
* 模型有：xgboost、catboost、lightGBM、DNN


| 模型或方法          | 得分F1-measure                                   |
| ----------- | ---------------------------------------- 
| catboost     | 0.611                                  |
| xgboost         | 0.621                                   |
| lightgbm      | 0.625                                   |
| dnn         | 0.621|
| textCNN  |0.617|
| capsule      |0.625|
| covlstm    |0.630|    
| dpcnn    |0.626|    
| lstm+gru    |0.635|    
| lstm+gru+attention    |0.640| 


In [1]:
# 导入必要的包
import pickle
import numpy as np
import os
from config import Config
from sklearn.preprocessing import MinMaxScaler
from keras.preprocessing import sequence
from keras.utils import np_utils



  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
config = Config()
# 读取词对应的id
with open(config.ITEM_TO_ID, 'rb') as f:
    item_to_id = pickle.load(f)
# 取最后一位作为UNK表识
UNK = len(item_to_id)


#### 让我们来导入一下传统模型【xgboost、catboost、lightgbm等】的训练数据

In [7]:
def static_data_prepare():
    """

    :return:
    """
    train_y = []

    with open(config.TRAIN_X, 'r', encoding='utf-8') as f:
        lines = f.readlines()
        for line in lines:
            line = line.strip()
            line = line.split('\t')
            label = int(line[2])
            train_y.append(label)


    with open(config.FEATURES_test_FILE, 'rb') as f:
        test_features = pickle.load(f)

    with open(config.FEATURES_FILE, 'rb') as f:
        features = pickle.load(f)
        
    with open(config.OCR_FEATURES_test_FILE, 'rb') as f:
        ocr_test_features = pickle.load(f)
    with open(config.OCR_FEATURES_FILE, 'rb') as f:
        ocr_features = pickle.load(f)

    # TFIDF + SVD 降维后的特征
    with open('../data/train_x_250.pkl', 'rb') as f:
        train_x = pickle.load(f)
    with open('../data/test_x_250.pkl', 'rb') as f:
        test_x = pickle.load(f)

    # 联合OCR提取的特征
    with open('../data/ocr_train_x_250.pkl', 'rb') as f:
        train_ocr_x = pickle.load(f)

    with open('../data/ocr_test_x_250.pkl', 'rb') as f:
        test_ocr_x = pickle.load(f)
        
    scaler = MinMaxScaler()
    all_feature = np.concatenate([features, test_features, ocr_test_features, ocr_features], axis=0)
    scaler.fit(all_feature)
    features = scaler.transform(features)
    test_features = scaler.transform(test_features)
    ocr_features = scaler.transform(ocr_features)
    ocr_test_features = scaler.transform(ocr_test_features)
    
    train_x = np.concatenate((train_x, train_ocr_x, features, ocr_features), axis=-1)
    test_x = np.concatenate((test_x, test_ocr_x, test_features, ocr_test_features), axis=-1)
    train_y = np.array(train_y)
    return train_x, train_y, test_x

In [8]:
data_x, data_y, test = static_data_prepare()



In [9]:
# 这里先导入传统机器学习模型
model = config.model['model' + str(8)]()
model.train_predict(data_x[:10], data_y[:10], test[:10])

0:	learn: -0.6834329	total: 61ms	remaining: 1m
1:	learn: -0.6788535	total: 63.2ms	remaining: 31.5s
2:	learn: -0.6703978	total: 65.4ms	remaining: 21.7s
3:	learn: -0.6586300	total: 67ms	remaining: 16.7s
4:	learn: -0.6488000	total: 68.7ms	remaining: 13.7s
5:	learn: -0.6398159	total: 70.5ms	remaining: 11.7s
6:	learn: -0.6320691	total: 72.9ms	remaining: 10.3s
7:	learn: -0.6261224	total: 76ms	remaining: 9.42s
8:	learn: -0.6213667	total: 79.7ms	remaining: 8.77s
9:	learn: -0.6119698	total: 82.6ms	remaining: 8.18s
10:	learn: -0.6028043	total: 84.8ms	remaining: 7.63s
11:	learn: -0.5972627	total: 86.9ms	remaining: 7.15s
12:	learn: -0.5913013	total: 88.9ms	remaining: 6.75s
13:	learn: -0.5864425	total: 91.8ms	remaining: 6.47s
14:	learn: -0.5821057	total: 95.7ms	remaining: 6.28s
15:	learn: -0.5763753	total: 98ms	remaining: 6.03s
16:	learn: -0.5679472	total: 100ms	remaining: 5.79s
17:	learn: -0.5632925	total: 102ms	remaining: 5.57s
18:	learn: -0.5587668	total: 106ms	remaining: 5.47s
19:	learn: -0.553

228:	learn: -0.1085297	total: 632ms	remaining: 2.13s
229:	learn: -0.1078424	total: 634ms	remaining: 2.12s
230:	learn: -0.1071629	total: 636ms	remaining: 2.12s
231:	learn: -0.1065991	total: 638ms	remaining: 2.11s
232:	learn: -0.1063273	total: 643ms	remaining: 2.12s
233:	learn: -0.1060569	total: 646ms	remaining: 2.12s
234:	learn: -0.1053979	total: 648ms	remaining: 2.11s
235:	learn: -0.1047464	total: 651ms	remaining: 2.11s
236:	learn: -0.1043895	total: 654ms	remaining: 2.1s
237:	learn: -0.1037494	total: 656ms	remaining: 2.1s
238:	learn: -0.1031165	total: 658ms	remaining: 2.1s
239:	learn: -0.1028611	total: 662ms	remaining: 2.1s
240:	learn: -0.1026070	total: 665ms	remaining: 2.09s
241:	learn: -0.1020865	total: 667ms	remaining: 2.09s
242:	learn: -0.1018358	total: 671ms	remaining: 2.09s
243:	learn: -0.1015389	total: 673ms	remaining: 2.08s
244:	learn: -0.1009305	total: 674ms	remaining: 2.08s
245:	learn: -0.1005976	total: 677ms	remaining: 2.07s
246:	learn: -0.1000897	total: 679ms	remaining: 2.0

386:	learn: -0.0609581	total: 1.01s	remaining: 1.6s
387:	learn: -0.0607235	total: 1.01s	remaining: 1.6s
388:	learn: -0.0604906	total: 1.02s	remaining: 1.6s
389:	learn: -0.0602593	total: 1.02s	remaining: 1.6s
390:	learn: -0.0601345	total: 1.02s	remaining: 1.6s
391:	learn: -0.0600101	total: 1.03s	remaining: 1.59s
392:	learn: -0.0598220	total: 1.03s	remaining: 1.59s
393:	learn: -0.0595955	total: 1.03s	remaining: 1.59s
394:	learn: -0.0594903	total: 1.04s	remaining: 1.59s
395:	learn: -0.0593995	total: 1.04s	remaining: 1.59s
396:	learn: -0.0592923	total: 1.05s	remaining: 1.59s
397:	learn: -0.0591713	total: 1.05s	remaining: 1.59s
398:	learn: -0.0590262	total: 1.05s	remaining: 1.59s
399:	learn: -0.0588055	total: 1.06s	remaining: 1.59s
400:	learn: -0.0586194	total: 1.06s	remaining: 1.58s
401:	learn: -0.0584394	total: 1.06s	remaining: 1.58s
402:	learn: -0.0583517	total: 1.07s	remaining: 1.58s
403:	learn: -0.0582643	total: 1.07s	remaining: 1.58s
404:	learn: -0.0581772	total: 1.07s	remaining: 1.58

637:	learn: -0.0346788	total: 1.57s	remaining: 890ms
638:	learn: -0.0346362	total: 1.57s	remaining: 888ms
639:	learn: -0.0345710	total: 1.57s	remaining: 885ms
640:	learn: -0.0345144	total: 1.57s	remaining: 882ms
641:	learn: -0.0344827	total: 1.58s	remaining: 881ms
642:	learn: -0.0344511	total: 1.58s	remaining: 879ms
643:	learn: -0.0344195	total: 1.58s	remaining: 877ms
644:	learn: -0.0343823	total: 1.59s	remaining: 874ms
645:	learn: -0.0343041	total: 1.59s	remaining: 872ms
646:	learn: -0.0342401	total: 1.59s	remaining: 869ms
647:	learn: -0.0341875	total: 1.59s	remaining: 866ms
648:	learn: -0.0341450	total: 1.6s	remaining: 864ms
649:	learn: -0.0340816	total: 1.6s	remaining: 861ms
650:	learn: -0.0340047	total: 1.6s	remaining: 858ms
651:	learn: -0.0339418	total: 1.6s	remaining: 854ms
652:	learn: -0.0338792	total: 1.6s	remaining: 851ms
653:	learn: -0.0338032	total: 1.6s	remaining: 848ms
654:	learn: -0.0337382	total: 1.6s	remaining: 846ms
655:	learn: -0.0336832	total: 1.61s	remaining: 843ms


862:	learn: -0.0248551	total: 1.94s	remaining: 308ms
863:	learn: -0.0248210	total: 1.94s	remaining: 305ms
864:	learn: -0.0247968	total: 1.94s	remaining: 303ms
865:	learn: -0.0247613	total: 1.94s	remaining: 301ms
866:	learn: -0.0247274	total: 1.95s	remaining: 299ms
867:	learn: -0.0246937	total: 1.95s	remaining: 296ms
868:	learn: -0.0246747	total: 1.95s	remaining: 294ms
869:	learn: -0.0246411	total: 1.95s	remaining: 292ms
870:	learn: -0.0246002	total: 1.95s	remaining: 290ms
871:	learn: -0.0245838	total: 1.96s	remaining: 287ms
872:	learn: -0.0245431	total: 1.96s	remaining: 285ms
873:	learn: -0.0245194	total: 1.96s	remaining: 283ms
874:	learn: -0.0245002	total: 1.96s	remaining: 281ms
875:	learn: -0.0244671	total: 1.97s	remaining: 278ms
876:	learn: -0.0244456	total: 1.97s	remaining: 276ms
877:	learn: -0.0244126	total: 1.97s	remaining: 274ms
878:	learn: -0.0243724	total: 1.97s	remaining: 271ms
879:	learn: -0.0243453	total: 1.97s	remaining: 269ms
880:	learn: -0.0243183	total: 1.97s	remaining:

65:	learn: -0.3271757	total: 122ms	remaining: 1.73s
66:	learn: -0.3225274	total: 124ms	remaining: 1.73s
67:	learn: -0.3185296	total: 127ms	remaining: 1.73s
68:	learn: -0.3147271	total: 129ms	remaining: 1.74s
69:	learn: -0.3125678	total: 132ms	remaining: 1.75s
70:	learn: -0.3087763	total: 133ms	remaining: 1.75s
71:	learn: -0.3066866	total: 137ms	remaining: 1.76s
72:	learn: -0.3030180	total: 139ms	remaining: 1.77s
73:	learn: -0.2994275	total: 142ms	remaining: 1.77s
74:	learn: -0.2959934	total: 144ms	remaining: 1.77s
75:	learn: -0.2920512	total: 146ms	remaining: 1.77s
76:	learn: -0.2886792	total: 147ms	remaining: 1.76s
77:	learn: -0.2854368	total: 149ms	remaining: 1.76s
78:	learn: -0.2838916	total: 151ms	remaining: 1.76s
79:	learn: -0.2807598	total: 152ms	remaining: 1.75s
80:	learn: -0.2789587	total: 154ms	remaining: 1.75s
81:	learn: -0.2753834	total: 156ms	remaining: 1.74s
82:	learn: -0.2723289	total: 157ms	remaining: 1.74s
83:	learn: -0.2688948	total: 158ms	remaining: 1.73s
84:	learn: -

295:	learn: -0.0771869	total: 483ms	remaining: 1.15s
296:	learn: -0.0768216	total: 485ms	remaining: 1.15s
297:	learn: -0.0764596	total: 486ms	remaining: 1.15s
298:	learn: -0.0762661	total: 489ms	remaining: 1.15s
299:	learn: -0.0759090	total: 491ms	remaining: 1.14s
300:	learn: -0.0757417	total: 494ms	remaining: 1.15s
301:	learn: -0.0754529	total: 496ms	remaining: 1.15s
302:	learn: -0.0751028	total: 498ms	remaining: 1.14s
303:	learn: -0.0749388	total: 502ms	remaining: 1.15s
304:	learn: -0.0745932	total: 504ms	remaining: 1.15s
305:	learn: -0.0744298	total: 507ms	remaining: 1.15s
306:	learn: -0.0740886	total: 509ms	remaining: 1.15s
307:	learn: -0.0737502	total: 510ms	remaining: 1.15s
308:	learn: -0.0735696	total: 512ms	remaining: 1.15s
309:	learn: -0.0732962	total: 514ms	remaining: 1.14s
310:	learn: -0.0731601	total: 516ms	remaining: 1.14s
311:	learn: -0.0730027	total: 518ms	remaining: 1.14s
312:	learn: -0.0727333	total: 520ms	remaining: 1.14s
313:	learn: -0.0725115	total: 521ms	remaining:

533:	learn: -0.0400723	total: 845ms	remaining: 737ms
534:	learn: -0.0400005	total: 847ms	remaining: 737ms
535:	learn: -0.0399515	total: 850ms	remaining: 736ms
536:	learn: -0.0398960	total: 853ms	remaining: 735ms
537:	learn: -0.0398054	total: 855ms	remaining: 734ms
538:	learn: -0.0397152	total: 857ms	remaining: 733ms
539:	learn: -0.0396669	total: 860ms	remaining: 733ms
540:	learn: -0.0395637	total: 862ms	remaining: 731ms
541:	learn: -0.0395092	total: 865ms	remaining: 731ms
542:	learn: -0.0394213	total: 867ms	remaining: 729ms
543:	learn: -0.0393193	total: 869ms	remaining: 728ms
544:	learn: -0.0392780	total: 871ms	remaining: 727ms
545:	learn: -0.0391768	total: 872ms	remaining: 725ms
546:	learn: -0.0391358	total: 875ms	remaining: 724ms
547:	learn: -0.0390352	total: 876ms	remaining: 722ms
548:	learn: -0.0389540	total: 877ms	remaining: 721ms
549:	learn: -0.0389012	total: 879ms	remaining: 719ms
550:	learn: -0.0388018	total: 880ms	remaining: 717ms
551:	learn: -0.0387029	total: 881ms	remaining:

756:	learn: -0.0273902	total: 1.21s	remaining: 389ms
757:	learn: -0.0273495	total: 1.21s	remaining: 387ms
758:	learn: -0.0273261	total: 1.22s	remaining: 386ms
759:	learn: -0.0272856	total: 1.22s	remaining: 385ms
760:	learn: -0.0272452	total: 1.22s	remaining: 383ms
761:	learn: -0.0272050	total: 1.22s	remaining: 381ms
762:	learn: -0.0271684	total: 1.22s	remaining: 380ms
763:	learn: -0.0271453	total: 1.23s	remaining: 379ms
764:	learn: -0.0271089	total: 1.23s	remaining: 378ms
765:	learn: -0.0270859	total: 1.23s	remaining: 376ms
766:	learn: -0.0270461	total: 1.23s	remaining: 375ms
767:	learn: -0.0270064	total: 1.24s	remaining: 373ms
768:	learn: -0.0269669	total: 1.24s	remaining: 371ms
769:	learn: -0.0269274	total: 1.24s	remaining: 370ms
770:	learn: -0.0268787	total: 1.24s	remaining: 368ms
771:	learn: -0.0268395	total: 1.24s	remaining: 366ms
772:	learn: -0.0267911	total: 1.24s	remaining: 365ms
773:	learn: -0.0267555	total: 1.24s	remaining: 363ms
774:	learn: -0.0267201	total: 1.24s	remaining:

990:	learn: -0.0206252	total: 1.58s	remaining: 14.3ms
991:	learn: -0.0206055	total: 1.58s	remaining: 12.8ms
992:	learn: -0.0205904	total: 1.58s	remaining: 11.2ms
993:	learn: -0.0205654	total: 1.59s	remaining: 9.58ms
994:	learn: -0.0205422	total: 1.59s	remaining: 7.99ms
995:	learn: -0.0205135	total: 1.59s	remaining: 6.39ms
996:	learn: -0.0204904	total: 1.59s	remaining: 4.8ms
997:	learn: -0.0204657	total: 1.6s	remaining: 3.2ms
998:	learn: -0.0204427	total: 1.6s	remaining: 1.6ms
999:	learn: -0.0204295	total: 1.6s	remaining: 0us
Test error using softmax = 1.0
0:	learn: -0.6798083	total: 1.13ms	remaining: 1.13s
1:	learn: -0.6713676	total: 2.72ms	remaining: 1.35s
2:	learn: -0.6637866	total: 4.37ms	remaining: 1.45s
3:	learn: -0.6556042	total: 6.62ms	remaining: 1.65s
4:	learn: -0.6487367	total: 8.16ms	remaining: 1.62s
5:	learn: -0.6420018	total: 11ms	remaining: 1.81s
6:	learn: -0.6321792	total: 12.5ms	remaining: 1.78s
7:	learn: -0.6268317	total: 14.1ms	remaining: 1.75s
8:	learn: -0.6190593	tot

168:	learn: -0.1434474	total: 323ms	remaining: 1.59s
169:	learn: -0.1428231	total: 326ms	remaining: 1.59s
170:	learn: -0.1416364	total: 328ms	remaining: 1.59s
171:	learn: -0.1410265	total: 331ms	remaining: 1.59s
172:	learn: -0.1398670	total: 333ms	remaining: 1.59s
173:	learn: -0.1394069	total: 337ms	remaining: 1.6s
174:	learn: -0.1386514	total: 340ms	remaining: 1.6s
175:	learn: -0.1380648	total: 344ms	remaining: 1.61s
176:	learn: -0.1369488	total: 346ms	remaining: 1.61s
177:	learn: -0.1365058	total: 348ms	remaining: 1.61s
178:	learn: -0.1359824	total: 350ms	remaining: 1.6s
179:	learn: -0.1354161	total: 353ms	remaining: 1.61s
180:	learn: -0.1343383	total: 354ms	remaining: 1.6s
181:	learn: -0.1338425	total: 356ms	remaining: 1.6s
182:	learn: -0.1327880	total: 357ms	remaining: 1.59s
183:	learn: -0.1317491	total: 358ms	remaining: 1.59s
184:	learn: -0.1309924	total: 360ms	remaining: 1.58s
185:	learn: -0.1301396	total: 361ms	remaining: 1.58s
186:	learn: -0.1291389	total: 362ms	remaining: 1.57

374:	learn: -0.0621550	total: 689ms	remaining: 1.15s
375:	learn: -0.0620726	total: 691ms	remaining: 1.15s
376:	learn: -0.0618222	total: 693ms	remaining: 1.15s
377:	learn: -0.0615738	total: 695ms	remaining: 1.14s
378:	learn: -0.0614566	total: 698ms	remaining: 1.14s
379:	learn: -0.0612547	total: 700ms	remaining: 1.14s
380:	learn: -0.0610466	total: 702ms	remaining: 1.14s
381:	learn: -0.0609228	total: 707ms	remaining: 1.14s
382:	learn: -0.0607745	total: 709ms	remaining: 1.14s
383:	learn: -0.0606357	total: 712ms	remaining: 1.14s
384:	learn: -0.0605526	total: 715ms	remaining: 1.14s
385:	learn: -0.0604514	total: 719ms	remaining: 1.14s
386:	learn: -0.0603053	total: 721ms	remaining: 1.14s
387:	learn: -0.0601332	total: 723ms	remaining: 1.14s
388:	learn: -0.0599395	total: 724ms	remaining: 1.14s
389:	learn: -0.0598161	total: 727ms	remaining: 1.14s
390:	learn: -0.0596244	total: 728ms	remaining: 1.13s
391:	learn: -0.0595482	total: 730ms	remaining: 1.13s
392:	learn: -0.0593505	total: 732ms	remaining:

573:	learn: -0.0391636	total: 1.06s	remaining: 785ms
574:	learn: -0.0390600	total: 1.06s	remaining: 783ms
575:	learn: -0.0389847	total: 1.06s	remaining: 782ms
576:	learn: -0.0388976	total: 1.06s	remaining: 780ms
577:	learn: -0.0387954	total: 1.07s	remaining: 778ms
578:	learn: -0.0387211	total: 1.07s	remaining: 778ms
579:	learn: -0.0386199	total: 1.07s	remaining: 776ms
580:	learn: -0.0385768	total: 1.08s	remaining: 776ms
581:	learn: -0.0384914	total: 1.08s	remaining: 774ms
582:	learn: -0.0384089	total: 1.08s	remaining: 772ms
583:	learn: -0.0383466	total: 1.08s	remaining: 771ms
584:	learn: -0.0382472	total: 1.08s	remaining: 769ms
585:	learn: -0.0381750	total: 1.08s	remaining: 767ms
586:	learn: -0.0380765	total: 1.09s	remaining: 765ms
587:	learn: -0.0379930	total: 1.09s	remaining: 763ms
588:	learn: -0.0378954	total: 1.09s	remaining: 760ms
589:	learn: -0.0378347	total: 1.09s	remaining: 759ms
590:	learn: -0.0377526	total: 1.09s	remaining: 757ms
591:	learn: -0.0376708	total: 1.09s	remaining:

768:	learn: -0.0279377	total: 1.43s	remaining: 430ms
769:	learn: -0.0279146	total: 1.44s	remaining: 429ms
770:	learn: -0.0278706	total: 1.44s	remaining: 428ms
771:	learn: -0.0278173	total: 1.44s	remaining: 426ms
772:	learn: -0.0277736	total: 1.44s	remaining: 424ms
773:	learn: -0.0277300	total: 1.45s	remaining: 422ms
774:	learn: -0.0276849	total: 1.45s	remaining: 420ms
775:	learn: -0.0276323	total: 1.45s	remaining: 418ms
776:	learn: -0.0275877	total: 1.45s	remaining: 416ms
777:	learn: -0.0275581	total: 1.45s	remaining: 415ms
778:	learn: -0.0275152	total: 1.46s	remaining: 413ms
779:	learn: -0.0274927	total: 1.46s	remaining: 412ms
780:	learn: -0.0274548	total: 1.46s	remaining: 410ms
781:	learn: -0.0274105	total: 1.46s	remaining: 408ms
782:	learn: -0.0273681	total: 1.47s	remaining: 406ms
783:	learn: -0.0273258	total: 1.47s	remaining: 405ms
784:	learn: -0.0272993	total: 1.47s	remaining: 403ms
785:	learn: -0.0272481	total: 1.47s	remaining: 401ms
786:	learn: -0.0271971	total: 1.47s	remaining:

967:	learn: -0.0216640	total: 1.81s	remaining: 59.7ms
968:	learn: -0.0216483	total: 1.81s	remaining: 57.9ms
969:	learn: -0.0216217	total: 1.81s	remaining: 56.1ms
970:	learn: -0.0216021	total: 1.82s	remaining: 54.3ms
971:	learn: -0.0215837	total: 1.82s	remaining: 52.4ms
972:	learn: -0.0215560	total: 1.82s	remaining: 50.6ms
973:	learn: -0.0215359	total: 1.83s	remaining: 48.8ms
974:	learn: -0.0215037	total: 1.83s	remaining: 46.9ms
975:	learn: -0.0214716	total: 1.83s	remaining: 45ms
976:	learn: -0.0214454	total: 1.83s	remaining: 43.2ms
977:	learn: -0.0214273	total: 1.83s	remaining: 41.3ms
978:	learn: -0.0214075	total: 1.84s	remaining: 39.4ms
979:	learn: -0.0213804	total: 1.84s	remaining: 37.5ms
980:	learn: -0.0213544	total: 1.84s	remaining: 35.7ms
981:	learn: -0.0213228	total: 1.84s	remaining: 33.8ms
982:	learn: -0.0213091	total: 1.85s	remaining: 32ms
983:	learn: -0.0212822	total: 1.85s	remaining: 30.1ms
984:	learn: -0.0212508	total: 1.85s	remaining: 28.2ms
985:	learn: -0.0212330	total: 1.

200:	learn: -0.1099258	total: 464ms	remaining: 1.84s
201:	learn: -0.1092235	total: 467ms	remaining: 1.84s
202:	learn: -0.1086632	total: 469ms	remaining: 1.84s
203:	learn: -0.1081084	total: 472ms	remaining: 1.84s
204:	learn: -0.1075591	total: 474ms	remaining: 1.84s
205:	learn: -0.1070152	total: 478ms	remaining: 1.84s
206:	learn: -0.1064767	total: 480ms	remaining: 1.84s
207:	learn: -0.1059434	total: 482ms	remaining: 1.83s
208:	learn: -0.1054154	total: 485ms	remaining: 1.83s
209:	learn: -0.1048924	total: 488ms	remaining: 1.83s
210:	learn: -0.1042471	total: 490ms	remaining: 1.83s
211:	learn: -0.1038183	total: 493ms	remaining: 1.83s
212:	learn: -0.1031849	total: 495ms	remaining: 1.83s
213:	learn: -0.1025587	total: 497ms	remaining: 1.83s
214:	learn: -0.1020616	total: 499ms	remaining: 1.82s
215:	learn: -0.1014477	total: 501ms	remaining: 1.82s
216:	learn: -0.1009151	total: 503ms	remaining: 1.81s
217:	learn: -0.1004326	total: 506ms	remaining: 1.81s
218:	learn: -0.0999546	total: 508ms	remaining:

428:	learn: -0.0482992	total: 821ms	remaining: 1.09s
429:	learn: -0.0481469	total: 823ms	remaining: 1.09s
430:	learn: -0.0480295	total: 825ms	remaining: 1.09s
431:	learn: -0.0478966	total: 827ms	remaining: 1.09s
432:	learn: -0.0477819	total: 830ms	remaining: 1.09s
433:	learn: -0.0476505	total: 832ms	remaining: 1.08s
434:	learn: -0.0475353	total: 833ms	remaining: 1.08s
435:	learn: -0.0474207	total: 835ms	remaining: 1.08s
436:	learn: -0.0473082	total: 837ms	remaining: 1.08s
437:	learn: -0.0471618	total: 840ms	remaining: 1.08s
438:	learn: -0.0470163	total: 841ms	remaining: 1.07s
439:	learn: -0.0469040	total: 844ms	remaining: 1.07s
440:	learn: -0.0467923	total: 846ms	remaining: 1.07s
441:	learn: -0.0466810	total: 848ms	remaining: 1.07s
442:	learn: -0.0465703	total: 849ms	remaining: 1.07s
443:	learn: -0.0464283	total: 851ms	remaining: 1.06s
444:	learn: -0.0463187	total: 852ms	remaining: 1.06s
445:	learn: -0.0462097	total: 853ms	remaining: 1.06s
446:	learn: -0.0461012	total: 855ms	remaining:

636:	learn: -0.0303141	total: 1.18s	remaining: 675ms
637:	learn: -0.0302524	total: 1.19s	remaining: 674ms
638:	learn: -0.0301909	total: 1.19s	remaining: 671ms
639:	learn: -0.0301297	total: 1.19s	remaining: 670ms
640:	learn: -0.0300761	total: 1.19s	remaining: 667ms
641:	learn: -0.0300153	total: 1.19s	remaining: 665ms
642:	learn: -0.0299548	total: 1.2s	remaining: 664ms
643:	learn: -0.0298945	total: 1.2s	remaining: 662ms
644:	learn: -0.0298344	total: 1.2s	remaining: 660ms
645:	learn: -0.0297746	total: 1.2s	remaining: 658ms
646:	learn: -0.0297151	total: 1.2s	remaining: 656ms
647:	learn: -0.0296557	total: 1.2s	remaining: 654ms
648:	learn: -0.0295966	total: 1.21s	remaining: 653ms
649:	learn: -0.0295377	total: 1.21s	remaining: 651ms
650:	learn: -0.0294868	total: 1.21s	remaining: 650ms
651:	learn: -0.0294283	total: 1.21s	remaining: 648ms
652:	learn: -0.0293848	total: 1.22s	remaining: 646ms
653:	learn: -0.0293268	total: 1.22s	remaining: 644ms
654:	learn: -0.0292689	total: 1.22s	remaining: 641ms

885:	learn: -0.0204256	total: 1.54s	remaining: 199ms
886:	learn: -0.0203972	total: 1.55s	remaining: 197ms
887:	learn: -0.0203689	total: 1.55s	remaining: 195ms
888:	learn: -0.0203406	total: 1.55s	remaining: 194ms
889:	learn: -0.0203125	total: 1.55s	remaining: 192ms
890:	learn: -0.0202844	total: 1.55s	remaining: 190ms
891:	learn: -0.0202564	total: 1.56s	remaining: 189ms
892:	learn: -0.0202285	total: 1.56s	remaining: 187ms
893:	learn: -0.0202006	total: 1.56s	remaining: 185ms
894:	learn: -0.0201728	total: 1.56s	remaining: 183ms
895:	learn: -0.0201451	total: 1.56s	remaining: 182ms
896:	learn: -0.0201175	total: 1.57s	remaining: 180ms
897:	learn: -0.0200899	total: 1.57s	remaining: 178ms
898:	learn: -0.0200625	total: 1.57s	remaining: 176ms
899:	learn: -0.0200351	total: 1.57s	remaining: 175ms
900:	learn: -0.0200077	total: 1.57s	remaining: 173ms
901:	learn: -0.0199805	total: 1.57s	remaining: 171ms
902:	learn: -0.0199606	total: 1.57s	remaining: 169ms
903:	learn: -0.0199409	total: 1.58s	remaining:

80:	learn: -0.2488456	total: 186ms	remaining: 2.11s
81:	learn: -0.2456656	total: 188ms	remaining: 2.1s
82:	learn: -0.2432789	total: 190ms	remaining: 2.1s
83:	learn: -0.2402220	total: 191ms	remaining: 2.08s
84:	learn: -0.2386570	total: 193ms	remaining: 2.07s
85:	learn: -0.2357011	total: 194ms	remaining: 2.06s
86:	learn: -0.2328100	total: 197ms	remaining: 2.07s
87:	learn: -0.2306747	total: 200ms	remaining: 2.08s
88:	learn: -0.2278930	total: 202ms	remaining: 2.07s
89:	learn: -0.2251710	total: 205ms	remaining: 2.07s
90:	learn: -0.2241248	total: 208ms	remaining: 2.07s
91:	learn: -0.2214806	total: 209ms	remaining: 2.06s
92:	learn: -0.2188920	total: 212ms	remaining: 2.07s
93:	learn: -0.2163574	total: 214ms	remaining: 2.06s
94:	learn: -0.2138754	total: 216ms	remaining: 2.06s
95:	learn: -0.2114442	total: 217ms	remaining: 2.04s
96:	learn: -0.2090627	total: 219ms	remaining: 2.04s
97:	learn: -0.2067292	total: 220ms	remaining: 2.02s
98:	learn: -0.2044425	total: 221ms	remaining: 2.01s
99:	learn: -0.

290:	learn: -0.0624186	total: 548ms	remaining: 1.33s
291:	learn: -0.0621739	total: 550ms	remaining: 1.33s
292:	learn: -0.0619310	total: 552ms	remaining: 1.33s
293:	learn: -0.0616899	total: 555ms	remaining: 1.33s
294:	learn: -0.0614505	total: 557ms	remaining: 1.33s
295:	learn: -0.0612130	total: 560ms	remaining: 1.33s
296:	learn: -0.0609772	total: 562ms	remaining: 1.33s
297:	learn: -0.0607432	total: 563ms	remaining: 1.33s
298:	learn: -0.0605108	total: 565ms	remaining: 1.32s
299:	learn: -0.0603142	total: 567ms	remaining: 1.32s
300:	learn: -0.0600850	total: 569ms	remaining: 1.32s
301:	learn: -0.0598575	total: 574ms	remaining: 1.33s
302:	learn: -0.0596316	total: 576ms	remaining: 1.32s
303:	learn: -0.0594073	total: 578ms	remaining: 1.32s
304:	learn: -0.0591846	total: 580ms	remaining: 1.32s
305:	learn: -0.0589635	total: 582ms	remaining: 1.32s
306:	learn: -0.0587440	total: 584ms	remaining: 1.32s
307:	learn: -0.0585261	total: 587ms	remaining: 1.32s
308:	learn: -0.0583097	total: 590ms	remaining:

533:	learn: -0.0322499	total: 1.1s	remaining: 963ms
534:	learn: -0.0321811	total: 1.1s	remaining: 961ms
535:	learn: -0.0321127	total: 1.11s	remaining: 958ms
536:	learn: -0.0320445	total: 1.11s	remaining: 956ms
537:	learn: -0.0319867	total: 1.11s	remaining: 954ms
538:	learn: -0.0319190	total: 1.11s	remaining: 952ms
539:	learn: -0.0318516	total: 1.11s	remaining: 950ms
540:	learn: -0.0317845	total: 1.12s	remaining: 948ms
541:	learn: -0.0317277	total: 1.12s	remaining: 945ms
542:	learn: -0.0316611	total: 1.12s	remaining: 943ms
543:	learn: -0.0315947	total: 1.12s	remaining: 943ms
544:	learn: -0.0315286	total: 1.13s	remaining: 940ms
545:	learn: -0.0314628	total: 1.13s	remaining: 938ms
546:	learn: -0.0313973	total: 1.13s	remaining: 936ms
547:	learn: -0.0313320	total: 1.13s	remaining: 933ms
548:	learn: -0.0312767	total: 1.13s	remaining: 931ms
549:	learn: -0.0312119	total: 1.13s	remaining: 928ms
550:	learn: -0.0311474	total: 1.14s	remaining: 925ms
551:	learn: -0.0310832	total: 1.14s	remaining: 9

743:	learn: -0.0225603	total: 1.46s	remaining: 504ms
744:	learn: -0.0225260	total: 1.47s	remaining: 502ms
745:	learn: -0.0224918	total: 1.47s	remaining: 500ms
746:	learn: -0.0224578	total: 1.47s	remaining: 498ms
747:	learn: -0.0224238	total: 1.47s	remaining: 496ms
748:	learn: -0.0223899	total: 1.47s	remaining: 494ms
749:	learn: -0.0223561	total: 1.48s	remaining: 492ms
750:	learn: -0.0223225	total: 1.48s	remaining: 490ms
751:	learn: -0.0222889	total: 1.48s	remaining: 488ms
752:	learn: -0.0222554	total: 1.48s	remaining: 486ms
753:	learn: -0.0222370	total: 1.49s	remaining: 485ms
754:	learn: -0.0222037	total: 1.49s	remaining: 483ms
755:	learn: -0.0221704	total: 1.49s	remaining: 480ms
756:	learn: -0.0221422	total: 1.49s	remaining: 478ms
757:	learn: -0.0221092	total: 1.49s	remaining: 476ms
758:	learn: -0.0220762	total: 1.49s	remaining: 474ms
759:	learn: -0.0220434	total: 1.49s	remaining: 472ms
760:	learn: -0.0220106	total: 1.5s	remaining: 470ms
761:	learn: -0.0219779	total: 1.5s	remaining: 4

951:	learn: -0.0174391	total: 1.83s	remaining: 92.3ms
952:	learn: -0.0174183	total: 1.83s	remaining: 90.4ms
953:	learn: -0.0174050	total: 1.83s	remaining: 88.5ms
954:	learn: -0.0173875	total: 1.84s	remaining: 86.5ms
955:	learn: -0.0173669	total: 1.84s	remaining: 84.6ms
956:	learn: -0.0173494	total: 1.84s	remaining: 82.6ms
957:	learn: -0.0173319	total: 1.84s	remaining: 80.7ms
958:	learn: -0.0173115	total: 1.84s	remaining: 78.8ms
959:	learn: -0.0172911	total: 1.84s	remaining: 76.9ms
960:	learn: -0.0172737	total: 1.85s	remaining: 75ms
961:	learn: -0.0172591	total: 1.85s	remaining: 73.1ms
962:	learn: -0.0172418	total: 1.85s	remaining: 71.2ms
963:	learn: -0.0172216	total: 1.85s	remaining: 69.2ms
964:	learn: -0.0172043	total: 1.86s	remaining: 67.4ms
965:	learn: -0.0171872	total: 1.86s	remaining: 65.5ms
966:	learn: -0.0171700	total: 1.86s	remaining: 63.6ms
967:	learn: -0.0171530	total: 1.86s	remaining: 61.6ms
968:	learn: -0.0171329	total: 1.87s	remaining: 59.7ms
969:	learn: -0.0171159	total: 

#### 让我们来导入一下深度学习模型的训练数据

In [3]:
def deep_data_prepare():
    """

    :return:
    """
    train_x = []
    train_y = []
    ocr_train = []
    ocr_test = []
    test_x = []

    STOP_WORD = set()


    # 加载停用词
    with open(config.STOP_WORD_FILE, 'r', encoding='utf-8') as f:
        lines = f.readlines()
        for line in lines:
            line = line.strip()
            STOP_WORD.add(line)


    test_ids = []
    with open(config.TEST_FILE, 'r', encoding='utf-8') as f:
        lines = f.readlines()
        for line in lines:
            line = line.strip().split('\t')
            test_ids.append(line[0])
            line = line[1].split(' ')
            tmp = []
            for l in line:
                if l == '' \
                        or l == ' ' \
                        or l in STOP_WORD\
                        or len(l) < 2:
                    continue


                if l in item_to_id:
                    id = item_to_id[l]
                    tmp.append(id)

            if len(tmp) == 0:
                tmp = [UNK]

            # 对于超过最大长度的我们只取前面的
            if len(tmp) > config.MAX_LEN:
                # print('Length Exceed!')
                tmp = tmp[:config.MAX_LEN]

            test_x.append(tmp)


    with open(config.TRAIN_X, 'r', encoding='utf-8') as f:
        lines = f.readlines()
        for line in lines:
            line = line.strip()
            line = line.split('\t')
            label = line[2]
            line = line[1].split(' ')
            tmp = []
            for l in line:
                if l == '' \
                        or l == ' ' \
                        or l in STOP_WORD\
                        or len(l) < 2:
                    continue


                if l in item_to_id:
                    id = item_to_id[l]
                    tmp.append(id)

            if len(tmp) == 0:
                tmp = [UNK]

            # 对于超过最大长度的我们只取前面的
            if len(tmp) > config.MAX_LEN:
                # print('Length Exceed!')
                tmp = tmp[:config.MAX_LEN]
            train_x.append(tmp)
            train_y.append(label)

            
    # 加载ocr模型数据
    
    with open(config.OCR_TRAIN_X, 'r', encoding='utf-8') as f:
        lines = f.readlines()
        for line in lines:
            line = line.strip()
            line = line.split('\t')
            line = line[1].split(' ')
            tmp = []
            for l in line:
                if l == '' \
                        or l == ' ' \
                        or l in STOP_WORD\
                        or len(l) < 2:
                    continue


                if l in item_to_id:
                    id = item_to_id[l]
                    tmp.append(id)

            if len(tmp) == 0:
                tmp = [UNK]

            # 对于超过最大长度的我们只取前面的
            if len(tmp) > config.MAX_LEN:
                # print('Length Exceed!')
                tmp = tmp[:config.MAX_LEN]
            ocr_train.append(tmp)
            
    with open(config.OCR_TEST_X, 'r', encoding='utf-8') as f:
        lines = f.readlines()
        for line in lines:
            line = line.strip().split('\t')
            test_ids.append(line[0])
            line = line[1].split(' ')
            tmp = []
            for l in line:
                if l == '' \
                        or l == ' ' \
                        or l in STOP_WORD\
                        or len(l) < 2:
                    continue


                if l in item_to_id:
                    id = item_to_id[l]
                    tmp.append(id)

            if len(tmp) == 0:
                tmp = [UNK]

            # 对于超过最大长度的我们只取前面的
            if len(tmp) > config.MAX_LEN:
                # print('Length Exceed!')
                tmp = tmp[:config.MAX_LEN]

            ocr_test.append(tmp)


    train_x = sequence.pad_sequences(train_x, maxlen=config.MAX_LEN, padding='post',
                                                         truncating='post', value=UNK)

    ocr_train = sequence.pad_sequences(ocr_train, maxlen=config.OCR_LEN, padding='post',
                                                         truncating='post', value=UNK)
    
    test_x = sequence.pad_sequences(test_x, maxlen=config.MAX_LEN, padding='post',
                                                         truncating='post', value=UNK)

    ocr_test = sequence.pad_sequences(ocr_test, maxlen=config.OCR_LEN, padding='post',
                                                         truncating='post', value=UNK)





    train_y = np_utils.to_categorical(train_y)
    print('train_x shape is: ', train_x.shape)
    print('train_y shape is: ', train_y.shape)

    test = {}
    test['news'] = test_x
    test['ocr'] = ocr_test

    train = {}
    train['news'] = train_x
    train['ocr'] = ocr_train
    return train, train_y, test

def init_embedding():
    print('vocabulary size : ', len(item_to_id) + 1)
    print('begin to load word2vec file')
    def get_coefs(word, *arr): return word, np.asarray(arr, dtype='float32')
    embeddings_index = dict(get_coefs(*o.rstrip().rsplit(' ')) for o in open(config.EMBEDDING_FILE))

    all_embs = np.stack(embeddings_index.values())
    emb_mean, emb_std = all_embs.mean(), all_embs.std()

    print('create embedding matrix')
    embedding_matrix = np.random.normal(emb_mean, emb_std, (len(item_to_id) + 1, config.EMBED_SIZES))
    for word, i in item_to_id.items():
        embedding_vector = embeddings_index.get(word)
        if embedding_vector is not None: embedding_matrix[i] = embedding_vector
    print('end load word2vec')
    return embedding_matrix

In [4]:
data_x, data_y, test = deep_data_prepare()
init_emb = init_embedding()


train_x shape is:  (48480, 1000)
train_y shape is:  (48480, 3)
vocabulary size :  153325
begin to load word2vec file
create embedding matrix
end load word2vec


#### 这里是导入深度学习模型，估计使用的模型可以参考config.py文件

In [5]:
model = config.model['model' + str(1)](config.MAX_LEN, config.OCR_LEN, (len(item_to_id) + 1), init_emb)

In [6]:
model.train_predict(data_x, data_y, test)


__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
news (InputLayer)               (None, 1000)         0                                            
__________________________________________________________________________________________________
ocr (InputLayer)                (None, 400)          0                                            
__________________________________________________________________________________________________
embedding (Embedding)           multiple             45997500    news[0][0]                       
                                                                 ocr[0][0]                        
__________________________________________________________________________________________________
spatial_dropout1d_1 (SpatialDro (None, 1000, 300)    0           embedding[0][0]                  
__________

KeyboardInterrupt: 