用FastText的監督式學習方法進行文本分類

In [1]:
import re,jieba,random
import numpy as np 
import pandas as pd 
import fasttext

In [6]:
cate_dic = {'Car News':1, 'Entertainment News':2, 'International News':3, 'Technology News':4, 'Society News':5, 'Sports News':6, 'Finance News':7}

In [4]:
#讀取分類的檔案
car_news = pd.read_csv('class_data/car_news.csv',encoding='utf-8')
car_news = car_news.dropna()

entertainment_news = pd.read_csv('class_data/entertainment_news.csv',encoding='utf-8')
entertainment_news = entertainment_news.dropna()

international_news = pd.read_csv('class_data/international_news.csv',encoding='utf-8')
international_news = international_news.dropna()

technology_news = pd.read_csv('class_data/technology_news.csv',encoding='utf-8')
technology_news = technology_news.dropna()

society_news = pd.read_csv('class_data/society_news.csv',encoding='utf-8')
society_news = society_news.dropna()

sports_news = pd.read_csv('class_data/sports_news.csv',encoding='utf-8')
sports_news = sports_news.dropna()

finance_news = pd.read_csv('class_data/finance_news.csv',encoding='utf-8')
finance_news = finance_news.dropna()

print('Car News:{}\nEntertainment News:{}\nInternational News:{}\nTechnology News:{}\nSociety News:{}\nSports News:{}\nFinance News:{}\n'.format(len(car_news),len(entertainment_news),len(international_news),len(technology_news),len(society_news),len(sports_news),len(finance_news)))

#每個新聞取出8000筆
car_news = car_news[:10000]
entertainment_news = entertainment_news[:10000]
entertainment_news = entertainment_news[:10000]
technology_news = technology_news[:10000]
society_news = society_news[:10000]
sports_news = sports_news[:10000]
finance_news = finance_news[:10000]

Car News:11740
Entertainment News:39264
International News:89338
Technology News:25057
Society News:268829
Sports News:32728
Finance News:143141



In [5]:
stop_list=[]
with open('data/stopwords.txt','r',encoding='utf-8') as f:
    for line in f.readlines():
        stop_list.append(line.strip())

In [7]:
def preprocess(data,all_data,category):
    for line in data:
        line = re.sub(r'[^\w]','',line)
        line = re.sub(r'[A-Za-z0-9]','',line)
        line = re.sub(u'[\uFF01-\uFF5A]','',line)
        segment_list = jieba.lcut(line)
        segment_list = filter(lambda x: len(x)>1,segment_list)
        segment_list = filter(lambda x: x not in stop_list,segment_list)
        all_data.append( "__label__"+str(cate_dic[category])+" , "+" ".join(segment_list) )

all_data = []
preprocess(car_news.content.values.tolist(),all_data,'Car News')
preprocess(technology_news.content.values.tolist(),all_data,'Entertainment News')
preprocess(technology_news.content.values.tolist(),all_data,'International News')
preprocess(technology_news.content.values.tolist(),all_data,'Technology News')
preprocess(society_news.content.values.tolist(),all_data,'Society News')
preprocess(sports_news.content.values.tolist(),all_data,'Sports News')
preprocess(finance_news.content.values.tolist(),all_data,'Finance News')

Building prefix dict from the default dictionary ...
Loading model from cache /var/folders/vk/4wfw6yvn67b0gpkhvj8wy0cw0000gn/T/jieba.cache
Loading model cost 1.559 seconds.
Prefix dict has been built succesfully.


In [31]:
# 按照7:3切分訓練與測試資料集

random.shuffle(all_data)
train_data = all_data[:int(len(all_data)*0.8)]
test_data = all_data[int(len(all_data)*0.8):]

In [32]:
# 儲存訓練與測試資料集檔案

print("Writing data to fasttext format...")

with open('data/fasttext_train_data.txt', 'w',encoding='utf-8') as f:
    for data in train_data:
        f.write(data+"\n")
f.close()

with open('data/fasttext_test_data.txt', 'w',encoding='utf-8') as f2:
    for data in test_data:
        f2.write(data+"\n")
f2.close()

print("done!")

Writing data to fasttext format...
done!


In [23]:
help(fasttext.train_supervised)

Help on function train_supervised in module fasttext.FastText:

train_supervised(*kargs, **kwargs)
    Train a supervised model and return a model object.
    
    input must be a filepath. The input text does not need to be tokenized
    as per the tokenize function, but it must be preprocessed and encoded
    as UTF-8. You might want to consult standard preprocessing scripts such
    as tokenizer.perl mentioned here: http://www.statmt.org/wmt07/baseline.html
    
    The input file must must contain at least one label per line. For an
    example consult the example datasets which are part of the fastText
    repository such as the dataset pulled by classification-example.sh.



In [34]:
# 訓練FastText模型
classifier = fasttext.train_supervised('data/fasttext_train_data.txt')
result = classifier.test('data/fasttext_test_data.txt')
print('Precision:{}'.format(result[1]))
print('Recall:{}'.format(result[2]))
print('Number of examples:{}'.format(result[0]))

Precision:0.6879107142857143
Recall:0.6879107142857143
Number of examples:56000


In [42]:
recate_dic = {'1':'Car News', '2':'Entertainment News', '3':'International News', '4':'Technology News', '5':'Society News', '6':'Sports News', '7':'Finance News'}
result = classifier.predict(['新車 好看 美 等等 一次 付清'],k=3)
for n in range(3):
    print('{}\t{}'.format( recate_dic[ result[0][0][n].split('__label__')[1] ],result[1][0][n] ))

Car News	1.0000096559524536
Finance News	1.0137071512872353e-05
Society News	1.0123488209501375e-05
