# 机器学习纳米学位
## 毕业项目: 自然语言处理 文档归类——第一部分：TFIDF 实现

## 描述

[自然语言处理](https://en.wikipedia.org/wiki/Natural_language_processing)（后面简称NLP）是机器学习技术重要应用范畴之一，从手机上的智能语音助理如Siri，到移动、联通的自动语音服务，再到具有理解、推理能力的[IBM Waston](http://www.ibm.com/watson/)，最近亚马逊也相应推出了可提供高级语音识别及自然语言理解功能的[Lex](https://aws.amazon.com/cn/lex/)，这些都是自然语言处理技术应用前沿产品实例。 试想，如果有朝一日人类完全解决自然语言处理瓶颈，实现计算机对自然语言完全理解、分析，那么出现在科幻片如《机械姬》、《西部世界》里面的机器人与人类无障碍沟通及情感交流的情景很可能出现。

​但是现实中自然语言处理技术层面还面临诸多挑战，其中之一就是词、语句以及文章的表达。在日常生活中，最常见的词语表述方式比如”cat“、”dog“，这些都是利用符号表示意思。统计语言处理里面，比较容易利用符号来描述概率模型，比如[ngram模型](http://blog.csdn.net/ahmanz/article/details/51273500) ，计算两个单词或者多个单词同时出现的概率，但是这些符号难以直接表示词与词之间的关联，也难以直接作为机器学习模型输入向量。对句子或者文章的表示，可以采用[词袋子模型](http://www.cnblogs.com/platero/archive/2012/12/03/2800251.html)，即将段落或文章表示成一组单词，例如两个句子：”She loves cats.“、”He loves cats too.“ 我们可以构建一个词频字典：{"She": 1, "He": 1, "loves": 2 "cats": 2, "too": 1}。根据这个字典, 我们能将上述两句话重新表达为下述两个向量: [1, 0, 1, 1, 0]和[0, 1, 1, 1, 1]，每1维代表对应单词的频率。

​近几年来，借助深度学习概念和性能强劲的硬件平台，Geofrey Hinton, Tomas Mikolov, Richard Socher等学者深入开展了针对词向量的研究，进行了大量鼓舞人心的实验，将自然语言处理推向了新的高度。以词向量为基础，可以方便引入机器学习模型对文本进行分类、情感分析、预测、自动翻译等。最简单的词向量就是独热编码(one-hot encoder)，比如有三个单词“man"、”husband“、”dog“，将之分别表示为[0,0,1]，[0,1,0]，[1,0,0]，这些词向量可以作为机器学习模型的输入数值向量，但是它们依然难以表达关联性，而且当词库单词量庞大时，独热编码的维度也会十分巨大，给计算和存储带来不少问题。Mikolov、Socher等人提出了[Word2Vec](http://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf)、[GloVec](http://nlp.stanford.edu/pubs/glove.pdf)等词向量模型，能够比较好的解决这个问题，即用维数较少的向量表达词以及词之间的关联性。关于这些词向量模型的具体原理，可以阅读他们所发表的论文，主要是英文，中文网站上也出现了不少精彩的翻译和解读，可以参考某些关于自然语言处理的[中文博客](http://www.52nlp.cn/%e6%96%af%e5%9d%a6%e7%a6%8f%e5%a4%a7%e5%ad%a6%e6%b7%b1%e5%ba%a6%e5%ad%a6%e4%b9%a0%e4%b8%8e%e8%87%aa%e7%84%b6%e8%af%ad%e8%a8%80%e5%a4%84%e7%90%86%e7%ac%ac%e4%ba%8c%e8%ae%b2%e8%af%8d%e5%90%91%e9%87%8f)。

​类似的，句子、段落以及文章也可以引入向量的概念进行表达，称之为Doc2Vec，有兴趣的可以拜读Mikolov的论文[《Distributed Representations of Sentences and Documents》](https://arxiv.org/pdf/1405.4053v2.pdf)。

​本项目目的就是利用上述自然语言处理技术结合所学机器学习知识对文档进行准确分类。

## 数据

​分类文本数据可以使用经典的20类新闻包，里面大约有20000条新闻，比较均衡地分成了20类，是比较常用的文本数据之一。既可以从[官方网站](http://www.qwone.com/~jason/20Newsgroups/)下载，也可利用*sklearn*工具包下载，具体请参见[说明文档](http://scikit-learn.org/stable/datasets/twenty_newsgroups.html)。

​此外，词向量的训练也需要大量数据，如果感觉20类新闻数据样本量不足以训练出较好的词向量模型，可以采用Mikolov曾经使用过的[text8](http://mattmahoney.net/dc/text8.zip)数据包进行训练。

## 任务：

- （1）探索文本表示的方式

  - 使用词袋子模型来表示每篇文档，常见的一种思路是首先将文本进行分词，也就是将一个文本文件分成单词的集合，建立词典，每篇文档表示成特征词的频率向量或者加权词频[TF-IDF](http://baike.baidu.com/link?url=toXJqDyZ1smDK2HpzusBzUnWX6YlKffU9bigEa5DHEOHmF0pL6XsDlhbzF10sijRGPeeml5Ze3cOtGAIHLXT0_)向量，这样可以得到熟悉的特征表。接下来，就可以方便利用机器学习分类模型进行训练。如下面所示意：

    ```
    	                   She	He	loves cats dogs	 too
    "She loves cats."	    1	 0	  1	   1	0	 0
    "He loves cats too."	0	 1	  1	   1	0	 1
    "She loves dogs."	    1	 0	  1	   0	1	 0

    ```
  - 利用Word2Vec方式即词向量模型表示每篇文档，这里面包含两部分主要工作：

     - 利用文本数据对词向量进行训练，将每个单词表示成向量形式。词向量训练后需要进行简单评测，比如检验一些单词之间相似性是否符合逻辑，下图是我在text8数据上训练的词向量模型，其中"school"、“university"、”college“三个意义比较类似的词基本上聚在一起。![词向量示意图](https://raw.githubusercontent.com/nd009/capstone/master/document_classification/w2v.png)

     - 探讨怎样用文档中每个词的向量来表达整个文档。学有余力的同学还可以尝试Dov2Vec的模型来直接训练表示每篇文档。

- （2）分别在词袋子、词向量表达基础上采用你认为适当的模型对文本分类，优化模型并分析其稳健性。



## 模型

​文本表示模型在前面已经谈到，这里不作赘述。 注意文本预处理方式可能对最终结果有影响，对于某些语言比如英文，是否要考虑区分单词大小写、是否要对同一词不同形式（如单复数）进行统一、是否要保留标点符号？这些讨论都应该反映在报告中。

​下列分类模型可供参考：

- 决策树模型
- 支持矢量机(SVM)模型
- 朴素贝叶斯模型
- 神经网络模型

## 要求

具体项目报告要求请参见优达学城毕业项目[模版](https://github.com/nd009/capstone/blob/master/capstone_report_template.md)和[要求](https://review.udacity.com/#!/rubrics/273/view)。

## 工具

建议使用的工具包：

- [gensim](http://radimrehurek.com/gensim/)，可以方便快捷地训练Word2Vec词向量。
- [GloVec](https://github.com/maciejkula/glove-python)，可以用来训练GloVec词向量。
- [sklearn](http://scikit-learn.org/)，功能强大的机器学习包，包含有常用的分类工具。
- [tensorflow](http://www.tensorfly.cn/)，可以逐步定义词向量训练过程，也可以建立深度学习建模。

# 参考文献

1. 维基百科，自然语言处理，https://en.wikipedia.org/wiki/Natural_language_processing
2. 我爱自然语言处理，中英文维基百科语料上的Word2Vec实验，http://www.52nlp.cn/tag/word2vec
3. 优达学城，机器学习毕业项目说明，https://github.com/nd009/machine-learning
4. 牛津大学，自然语言处理和深度学习课程， https://github.com/oxford-cs-deepnlp-2017/lectures
5. 斯坦福大学，自然语言处理和深度学习课程，http://cs224d.stanford.edu/
6. 哥伦比亚大学，自然语言处理课程，http://www.cs.columbia.edu/~mcollins/


## ⚠️注意事项：
- 运行环境为 Python-3
- 文本预处理部分如需手动导入 NLTK 路径，需要将 nltk.data.path.append('nltk_data') 路径修改为本地路径，大小 10 MB，下载地址：[百度云](https://pan.baidu.com/s/1Tp-NsX9vWDgBVp14P3jNxw)

---

# 一、数据

## 1.1 导入 20 Newsgroups 数据

In [1]:
import time
total_cost_time_start = time.time()

import ssl
ssl._create_default_https_context = ssl._create_unverified_context

from sklearn.datasets import fetch_20newsgroups
newsgroups = fetch_20newsgroups(subset='all', shuffle=True, random_state=233)

from pprint import pprint
pprint(list(newsgroups.target_names))

['alt.atheism',
 'comp.graphics',
 'comp.os.ms-windows.misc',
 'comp.sys.ibm.pc.hardware',
 'comp.sys.mac.hardware',
 'comp.windows.x',
 'misc.forsale',
 'rec.autos',
 'rec.motorcycles',
 'rec.sport.baseball',
 'rec.sport.hockey',
 'sci.crypt',
 'sci.electronics',
 'sci.med',
 'sci.space',
 'soc.religion.christian',
 'talk.politics.guns',
 'talk.politics.mideast',
 'talk.politics.misc',
 'talk.religion.misc']


In [2]:
import plotly.offline as of
from plotly.graph_objs import *
of.offline.init_notebook_mode(connected=True)

categories=['alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc', 'comp.sys.ibm.pc.hardware',
            'comp.sys.mac.hardware', 'comp.windows.x', 'misc.forsale', 'rec.autos', 'rec.motorcycles',
            'rec.sport.baseball', 'rec.sport.hockey', 'sci.crypt', 'sci.electronics', 'sci.med',
            'sci.space', 'soc.religion.christian', 'talk.politics.guns', 'talk.politics.mideast',
            'talk.politics.misc', 'talk.religion.misc']

category_numbers=[len(fetch_20newsgroups(subset='all',categories=['alt.atheism']).data),
                len(fetch_20newsgroups(subset='all',categories=['comp.graphics']).data),
                len(fetch_20newsgroups(subset='all',categories=['comp.os.ms-windows.misc']).data),
                len(fetch_20newsgroups(subset='all',categories=['comp.sys.ibm.pc.hardware']).data),
                len(fetch_20newsgroups(subset='all',categories=['comp.sys.mac.hardware']).data),
                len(fetch_20newsgroups(subset='all',categories=['comp.windows.x']).data),
                len(fetch_20newsgroups(subset='all',categories=['misc.forsale']).data),
                len(fetch_20newsgroups(subset='all',categories=['rec.autos']).data),
                len(fetch_20newsgroups(subset='all',categories=['rec.motorcycles']).data),
                len(fetch_20newsgroups(subset='all',categories=['rec.sport.baseball']).data),
                len(fetch_20newsgroups(subset='all',categories=['rec.sport.hockey']).data),
                len(fetch_20newsgroups(subset='all',categories=['sci.crypt']).data),
                len(fetch_20newsgroups(subset='all',categories=['sci.electronics']).data),
                len(fetch_20newsgroups(subset='all',categories=['sci.med']).data),
                len(fetch_20newsgroups(subset='all',categories=['sci.space']).data),
                len(fetch_20newsgroups(subset='all',categories=['soc.religion.christian']).data),
                len(fetch_20newsgroups(subset='all',categories=['talk.politics.guns']).data),
                len(fetch_20newsgroups(subset='all',categories=['talk.politics.mideast']).data),
                len(fetch_20newsgroups(subset='all',categories=['talk.politics.misc']).data),
                len(fetch_20newsgroups(subset='all',categories=['talk.religion.misc']).data)]

data = [Bar(x=categories,
            y=category_numbers)]

of.iplot(data)

In [3]:
# 计算每个类别的平均行数
def cal_lines_numbers_per_news(category):
    lines_number = 0
    i = 0
    length_of_cat = len(category)
    while i < length_of_cat:
        for word in category[i].split("\n"):
            lines_number += 1
        i += 1
    result = lines_number / length_of_cat
    return result

lines_numbers_per_news=[cal_lines_numbers_per_news(fetch_20newsgroups(subset='all',categories=['alt.atheism']).data),
                cal_lines_numbers_per_news(fetch_20newsgroups(subset='all',categories=['comp.graphics']).data),
                cal_lines_numbers_per_news(fetch_20newsgroups(subset='all',categories=['comp.os.ms-windows.misc']).data),
                cal_lines_numbers_per_news(fetch_20newsgroups(subset='all',categories=['comp.sys.ibm.pc.hardware']).data),
                cal_lines_numbers_per_news(fetch_20newsgroups(subset='all',categories=['comp.sys.mac.hardware']).data),
                cal_lines_numbers_per_news(fetch_20newsgroups(subset='all',categories=['comp.windows.x']).data),
                cal_lines_numbers_per_news(fetch_20newsgroups(subset='all',categories=['misc.forsale']).data),
                cal_lines_numbers_per_news(fetch_20newsgroups(subset='all',categories=['rec.autos']).data),
                cal_lines_numbers_per_news(fetch_20newsgroups(subset='all',categories=['rec.motorcycles']).data),
                cal_lines_numbers_per_news(fetch_20newsgroups(subset='all',categories=['rec.sport.baseball']).data),
                cal_lines_numbers_per_news(fetch_20newsgroups(subset='all',categories=['rec.sport.hockey']).data),
                cal_lines_numbers_per_news(fetch_20newsgroups(subset='all',categories=['sci.crypt']).data),
                cal_lines_numbers_per_news(fetch_20newsgroups(subset='all',categories=['sci.electronics']).data),
                cal_lines_numbers_per_news(fetch_20newsgroups(subset='all',categories=['sci.med']).data),
                cal_lines_numbers_per_news(fetch_20newsgroups(subset='all',categories=['sci.space']).data),
                cal_lines_numbers_per_news(fetch_20newsgroups(subset='all',categories=['soc.religion.christian']).data),
                cal_lines_numbers_per_news(fetch_20newsgroups(subset='all',categories=['talk.politics.guns']).data),
                cal_lines_numbers_per_news(fetch_20newsgroups(subset='all',categories=['talk.politics.mideast']).data),
                cal_lines_numbers_per_news(fetch_20newsgroups(subset='all',categories=['talk.politics.misc']).data),
                cal_lines_numbers_per_news(fetch_20newsgroups(subset='all',categories=['talk.religion.misc']).data)]

data = [Bar(x=categories,
            y=lines_numbers_per_news)]

of.iplot(data)

In [4]:
# 计算每个类别的平均单词数
def cal_word_numbers_per_news(category):
    word_number = 0
    i = 0
    length_of_cat = len(category)
    while i < length_of_cat:
        for word in category[i].split(" "):
            word_number += 1
        i += 1
    result = word_number / length_of_cat
    return result

word_numbers_per_news=[cal_word_numbers_per_news(fetch_20newsgroups(subset='all',categories=['alt.atheism']).data),
                cal_word_numbers_per_news(fetch_20newsgroups(subset='all',categories=['comp.graphics']).data),
                cal_word_numbers_per_news(fetch_20newsgroups(subset='all',categories=['comp.os.ms-windows.misc']).data),
                cal_word_numbers_per_news(fetch_20newsgroups(subset='all',categories=['comp.sys.ibm.pc.hardware']).data),
                cal_word_numbers_per_news(fetch_20newsgroups(subset='all',categories=['comp.sys.mac.hardware']).data),
                cal_word_numbers_per_news(fetch_20newsgroups(subset='all',categories=['comp.windows.x']).data),
                cal_word_numbers_per_news(fetch_20newsgroups(subset='all',categories=['misc.forsale']).data),
                cal_word_numbers_per_news(fetch_20newsgroups(subset='all',categories=['rec.autos']).data),
                cal_word_numbers_per_news(fetch_20newsgroups(subset='all',categories=['rec.motorcycles']).data),
                cal_word_numbers_per_news(fetch_20newsgroups(subset='all',categories=['rec.sport.baseball']).data),
                cal_word_numbers_per_news(fetch_20newsgroups(subset='all',categories=['rec.sport.hockey']).data),
                cal_word_numbers_per_news(fetch_20newsgroups(subset='all',categories=['sci.crypt']).data),
                cal_word_numbers_per_news(fetch_20newsgroups(subset='all',categories=['sci.electronics']).data),
                cal_word_numbers_per_news(fetch_20newsgroups(subset='all',categories=['sci.med']).data),
                cal_word_numbers_per_news(fetch_20newsgroups(subset='all',categories=['sci.space']).data),
                cal_word_numbers_per_news(fetch_20newsgroups(subset='all',categories=['soc.religion.christian']).data),
                cal_word_numbers_per_news(fetch_20newsgroups(subset='all',categories=['talk.politics.guns']).data),
                cal_word_numbers_per_news(fetch_20newsgroups(subset='all',categories=['talk.politics.mideast']).data),
                cal_word_numbers_per_news(fetch_20newsgroups(subset='all',categories=['talk.politics.misc']).data),
                cal_word_numbers_per_news(fetch_20newsgroups(subset='all',categories=['talk.religion.misc']).data)]

data = [Bar(x=categories,
            y=word_numbers_per_news)]

of.iplot(data)

## 1.2 处理数据

In [5]:
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.naive_bayes import MultinomialNB
from sklearn.linear_model import SGDClassifier
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.pipeline import Pipeline
import numpy as np
import regex as re

import nltk
from nltk.stem.porter import PorterStemmer
from nltk.stem import WordNetLemmatizer
nltk.data.path.append('/Users/jian/Desktop/nltk_data')

### 1.2.1 普通清理（处理大小写、数字、标点、空格）

In [6]:
def clean_text(text):
    text = text.lower() # 大小写转换
    text = re.sub("\d+", " ", text) # 去除数字
    text = re.sub("\p{P}+", " ", text) # 去除标点符号
    text = re.sub("<", " ", text)
    text = re.sub(">", " ", text)
    text = re.sub(r'\s+', " ", text) # 多个空格合并一个空格   
    return text

### 1.2.2 词干提取

In [7]:
def stemmed_word(word):
    porter_stemmer = PorterStemmer()
    return porter_stemmer.stem(word)

def stemmed_text(text):
    text = [stemmed_word(word) for word in text.split(" ")]
    text = ' '.join(text)
    return text

### 1.2.3 词性还原

In [8]:
def lemmatizer_word(word):
    lemmatizer = WordNetLemmatizer()
    return lemmatizer.lemmatize(word)

def lemmatizer_text(text):
    text = [lemmatizer_word(word) for word in text.split(" ")]
    text = ' '.join(text)
    return text

In [9]:
stemmed_word('strawberries')

'strawberri'

In [10]:
lemmatizer_word('strawberries')

'strawberry'

In [11]:
clean_text_4_test = clean_text(newsgroups.data[0])
stemmed_text(clean_text_4_test)

'from psyrobtw ubvmsd cc buffalo edu robert weiss subject apr god s promis in philippian organ univers at buffalo line news softwar vax vm vnew nntp post host ubvmsd cc buffalo edu those thing which ye have both learn and receiv and heard and seen in me do and the god of peac shall be with you philippian '

In [12]:
lemmatizer_text(clean_text_4_test)

'from psyrobtw ubvmsd cc buffalo edu robert wei subject apr god s promise in philippian organization university at buffalo line news software vax vms vnews nntp posting host ubvmsd cc buffalo edu those thing which ye have both learned and received and heard and seen in me do and the god of peace shall be with you philippian '

## 1.3 生成原始数据与清理后的数据

### 1.3.1 原始数据

In [13]:
newsgroups_original = fetch_20newsgroups(subset='all', shuffle=True, random_state=233)
VALIDATION_SPLIT = 0.2
num_validation_samples = int(VALIDATION_SPLIT * len(newsgroups_original.data))

x_train_o = newsgroups_original.data[:-num_validation_samples]
y_train_o = newsgroups_original.target[:-num_validation_samples]
x_val_o = newsgroups_original.data[-num_validation_samples:]
y_val_o = newsgroups_original.target[-num_validation_samples:]

### 1.3.2 普通清理后的数据

In [14]:
newsgroups_clean = fetch_20newsgroups(subset='all', shuffle=True, random_state=233)

start_time = time.time()

i = 0
length = len(newsgroups_clean.data)

while i < length:
    newsgroups_clean.data[i] = clean_text(newsgroups_clean.data[i])
    i += 1

x_train_c = newsgroups_clean.data[:-num_validation_samples]
y_train_c = newsgroups_clean.target[:-num_validation_samples]
x_val_c = newsgroups_clean.data[-num_validation_samples:]
y_val_c = newsgroups_clean.target[-num_validation_samples:]

end_time = time.time()
print("\nTime cost: {:.2f} seconds".format(end_time - start_time))


Time cost: 5.99 seconds


### 1.3.3 词干提取后的数据

In [15]:
newsgroups_stemmed = fetch_20newsgroups(subset='all', shuffle=True, random_state=233)

start_time = time.time()

i = 0
length = len(newsgroups_stemmed.data)

while i < length:
    newsgroups_stemmed.data[i] = stemmed_text(newsgroups_stemmed.data[i])
    i += 1

x_train_s = newsgroups_stemmed.data[:-num_validation_samples]
y_train_s = newsgroups_stemmed.target[:-num_validation_samples]
x_val_s = newsgroups_stemmed.data[-num_validation_samples:]
y_val_s = newsgroups_stemmed.target[-num_validation_samples:]

end_time = time.time()
print("\nTime cost: {:.2f} seconds".format(end_time - start_time))


Time cost: 160.74 seconds


### 1.3.4 词性还原后的数据

In [16]:
newsgroups_lemmatizer = fetch_20newsgroups(subset='all', shuffle=True, random_state=233)

start_time = time.time()

i = 0
length = len(newsgroups_stemmed.data)

while i < length:
    newsgroups_lemmatizer.data[i] = lemmatizer_text(newsgroups_lemmatizer.data[i])
    i += 1

x_train_l = newsgroups_lemmatizer.data[:-num_validation_samples]
y_train_l = newsgroups_lemmatizer.target[:-num_validation_samples]
x_val_l = newsgroups_lemmatizer.data[-num_validation_samples:]
y_val_l = newsgroups_lemmatizer.target[-num_validation_samples:]

end_time = time.time()
print("\nTime cost: {:.2f} seconds".format(end_time - start_time))


Time cost: 37.54 seconds


### 1.3.5 普通清理&词干提取后的数据

In [17]:
newsgroups_clean_stemmed = fetch_20newsgroups(subset='all', shuffle=True, random_state=233)

start_time = time.time()

i = 0
length = len(newsgroups_stemmed.data)

while i < length:
    newsgroups_clean_stemmed.data[i] = stemmed_text(clean_text(newsgroups_clean_stemmed.data[i]))
    i += 1

x_train_cs = newsgroups_clean_stemmed.data[:-num_validation_samples]
y_train_cs = newsgroups_clean_stemmed.target[:-num_validation_samples]
x_val_cs = newsgroups_clean_stemmed.data[-num_validation_samples:]
y_val_cs = newsgroups_clean_stemmed.target[-num_validation_samples:]

end_time = time.time()
print("\nTime cost: {:.2f} seconds".format(end_time - start_time))


Time cost: 176.25 seconds


### 1.3.6 普通清理&词性还原后的数据

In [18]:
newsgroups_clean_lemmatizer = fetch_20newsgroups(subset='all', shuffle=True, random_state=233)

start_time = time.time()

i = 0
length = len(newsgroups_clean_lemmatizer.data)

while i < length:
    newsgroups_clean_lemmatizer.data[i] = lemmatizer_text(clean_text(newsgroups_clean_lemmatizer.data[i]))
    i += 1

x_train_cl = newsgroups_clean_lemmatizer.data[:-num_validation_samples]
y_train_cl = newsgroups_clean_lemmatizer.target[:-num_validation_samples]
x_val_cl = newsgroups_clean_lemmatizer.data[-num_validation_samples:]
y_val_cl = newsgroups_clean_lemmatizer.target[-num_validation_samples:]

end_time = time.time()
print("\nTime cost: {:.2f} seconds".format(end_time - start_time))


Time cost: 45.18 seconds


# 二、模型与数据对比

## 2.1 定义模型

### 2.1.1 贝叶斯

In [19]:
def NB(X_train, X_test, y_train, y_test):
    
    start_time = time.time()

    # pipeline (tokenizer => transformer => MultinomialNB classifier)
    text_clf = Pipeline([('vect', CountVectorizer()),('tfidf', TfidfTransformer()),('clf', MultinomialNB())])
    text_clf = text_clf.fit(X_train, X_test)

    # evaluate on test set
    predicted = text_clf.predict(y_train)
    acc = np.mean(predicted == y_test) * 100
    
    end_time = time.time()
    time_cost = end_time - start_time
    return acc, time_cost

### 2.1.2 SVM

In [20]:
def SVM(X_train, X_test, y_train, y_test):
    
    start_time = time.time()
    
    text_clf = Pipeline([('vect', CountVectorizer()),('tfidf', TfidfTransformer()),('clf', SVC(kernel="linear"))])
    text_clf = text_clf.fit(X_train, X_test)

    predicted = text_clf.predict(y_train)
    acc = np.mean(predicted == y_test) * 100
    
    end_time = time.time()
    time_cost = end_time - start_time
    return acc, time_cost

### 2.1.3 决策树

In [21]:
def DT(X_train, X_test, y_train, y_test):
    
    start_time = time.time()
    
    text_clf = Pipeline([('vect', CountVectorizer()),('tfidf', TfidfTransformer()),('clf', DecisionTreeClassifier())])
    text_clf = text_clf.fit(X_train, X_test)

    predicted = text_clf.predict(y_train)
    acc = np.mean(predicted == y_test) * 100
    
    end_time = time.time()
    time_cost = end_time - start_time
    return acc, time_cost

## 2.2 对比模型

In [22]:
model_data_name = ['NB_o', 'NB_c', 'NB_s', 'NB_l', 'NB_cs', 'NB_cl',
                   'SVM_o', 'SVM_c', 'SVM_s', 'SVM_l', 'SVM_cs', 'SVM_cl',
                   'DT_o', 'DT_c', 'DT_s', 'DT_l', 'DT_cs', 'DT_cl',]

NB_o_acc, NB_o_time = NB(x_train_o, y_train_o, x_val_o, y_val_o)
NB_c_acc, NB_c_time = NB(x_train_c, y_train_c, x_val_c, y_val_c)
NB_s_acc, NB_s_time = NB(x_train_s, y_train_s, x_val_s, y_val_s)
NB_l_acc, NB_l_time = NB(x_train_l, y_train_l, x_val_l, y_val_l)
NB_cs_acc, NB_cs_time = NB(x_train_cs, y_train_cs, x_val_cs, y_val_cs)
NB_cl_acc, NB_cl_time = NB(x_train_cl, y_train_cl, x_val_cl, y_val_cl)

SVM_o_acc, SVM_o_time = SVM(x_train_o, y_train_o, x_val_o, y_val_o)
SVM_c_acc, SVM_c_time = SVM(x_train_c, y_train_c, x_val_c, y_val_c)
SVM_s_acc, SVM_s_time = SVM(x_train_s, y_train_s, x_val_s, y_val_s)
SVM_l_acc, SVM_l_time = SVM(x_train_l, y_train_l, x_val_l, y_val_l)
SVM_cs_acc, SVM_cs_time = SVM(x_train_cs, y_train_cs, x_val_cs, y_val_cs)
SVM_cl_acc, SVM_cl_time = SVM(x_train_cl, y_train_cl, x_val_cl, y_val_cl)

DT_o_acc, DT_o_time = DT(x_train_o, y_train_o, x_val_o, y_val_o)
DT_c_acc, DT_c_time = DT(x_train_c, y_train_c, x_val_c, y_val_c)
DT_s_acc, DT_s_time = DT(x_train_s, y_train_s, x_val_s, y_val_s)
DT_l_acc, DT_l_time = DT(x_train_l, y_train_l, x_val_l, y_val_l)
DT_cs_acc, DT_cs_time = DT(x_train_cs, y_train_cs, x_val_cs, y_val_cs)
DT_cl_acc, DT_cl_time = DT(x_train_cl, y_train_cl, x_val_cl, y_val_cl)

acc_numbers = [NB_o_acc, NB_c_acc, NB_s_acc, NB_l_acc, NB_cs_acc, NB_cl_acc,
               SVM_o_acc, SVM_c_acc, SVM_s_acc, SVM_l_acc, SVM_cs_acc, SVM_cl_acc,
               DT_o_acc, DT_c_acc, DT_s_acc, DT_l_acc, DT_cs_acc, DT_cl_acc]

time_numbers = [NB_o_time, NB_c_time, NB_s_time, NB_l_time, NB_cs_time, NB_cl_time,
               SVM_o_time, SVM_c_time, SVM_s_time, SVM_l_time, SVM_cs_time, SVM_cl_time,
               DT_o_time, DT_c_time, DT_s_time, DT_l_time, DT_cs_time, DT_cl_time]

print("*** Model Accuracy (%) ***")
data = [Bar(x=model_data_name,
            y=acc_numbers)]
of.iplot(data)

print("\n*** Model Time Cost (seconds) ***")
data = [Bar(x=model_data_name,
            y=time_numbers)]
of.iplot(data)

*** Model Accuracy (%) ***



*** Model Time Cost (seconds) ***


In [23]:
total_cost_time_end = time.time()
total_cost_time = total_cost_time_end - total_cost_time_start
print("\nTotal time cost: {:.2f} minutes".format(total_cost_time / 60))


Total time cost: 42.17 minutes
