-
Notifications
You must be signed in to change notification settings - Fork 41
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add related code for ai challenger 2018 sentiment analysis
- Loading branch information
panyang
committed
Oct 2, 2018
1 parent
702baab
commit a6504d5
Showing
7 changed files
with
239 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,26 @@ | ||
# fastText-for-AI-Challenger-Sentiment-Analysis | ||
AI Challenger 2018 Sentiment Analysis Baseline with fastText | ||
========================================= | ||
功能描述 | ||
--- | ||
本项目主要基于AI Challenger官方[baseline](https://github.com/AIChallenger/AI_Challenger_2018/tree/master/Baselines/sentiment_analysis2018_baseline)修改了一个基于fastText的baseline,方便参赛者快速上手比赛,主要功能涵盖完成比赛的全流程,如数据读取、分词、特征提取、模型定义以及封装、 | ||
模型训练、模型验证、模型存储以及模型预测等。baseline仅是一个简单的参考,希望参赛者能够充分发挥自己的想象,构建在该任务上更加强大的模型。 | ||
|
||
开发环境 | ||
--- | ||
* 主要依赖工具包以及版本,详情见requirements.txt | ||
|
||
项目结构 | ||
--- | ||
* src/config.py 项目配置信息模块,主要包括文件读取或存储路径信息 | ||
* src/util.py 数据处理模块,主要包括数据的读取以及处理等功能 | ||
* src/main_train.py 模型训练模块,模型训练流程包括 数据读取、分词、特征提取、模型训练、模型验证、模型存储等步骤 | ||
* src/main_predict.py 模型预测模块,模型预测流程包括 数据和模型的读取、分词、模型预测、预测结果存储等步骤 | ||
|
||
|
||
使用方法 | ||
--- | ||
* 准备 virtualenv -p python3 venv & source venv/bin/activate & pip install -r requirement.txt | ||
* 配置 在config.py中配置好文件存储路径 | ||
* 训练 运行 python main_train.py -mn your_model_name 训练模型并保存,同时通过日志可以得到验证集的F1_score指标 | ||
* 预测 运行 python main_predict.py -mn your_model_name 通过加载上一步的模型,在测试集上做预测 | ||
* 更多详情请参考我的博客文章:http://www.52nlp.cn/?p=10537 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
|
||
import os | ||
|
||
data_path = os.path.abspath('..') + "/data" | ||
model_path = data_path + "/model/" | ||
train_data_path = data_path + "/train/train.csv" | ||
validate_data_path = data_path + "/valid/valid.csv" | ||
test_data_path = data_path + "/test/testa.csv" | ||
test_data_predict_output_path = data_path + "/predict/testa_predict.csv" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
|
||
import argparse | ||
import config | ||
import logging | ||
|
||
import numpy as np | ||
|
||
from sklearn.externals import joblib | ||
from util import load_data_from_csv, seg_words | ||
|
||
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] <%(processName)s> (%(threadName)s) %(message)s') | ||
logger = logging.getLogger(__name__) | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('-mn', '--model_name', type=str, nargs='?', | ||
default='fasttext_model.pkl', | ||
help='the name of model') | ||
|
||
args = parser.parse_args() | ||
model_name = args.model_name | ||
|
||
# load data | ||
logger.info("start load load") | ||
test_data_df = load_data_from_csv(config.test_data_path) | ||
|
||
# load model | ||
logger.info("start load model") | ||
classifier_dict = joblib.load(config.model_path + model_name) | ||
|
||
content_test = test_data_df['content'] | ||
logger.info("start seg train data") | ||
content_test = seg_words(content_test) | ||
logger.info("complete seg train data") | ||
|
||
logger.info("prepare predict data format") | ||
test_data_format = np.asarray([content_test]).T | ||
logger.info("complete prepare predict formate data") | ||
|
||
columns = test_data_df.columns.values.tolist() | ||
|
||
# model predict | ||
logger.info("start predict test data") | ||
for column in columns[2:]: | ||
test_data_df[column] = classifier_dict[column].predict( | ||
test_data_format).astype(int) | ||
logger.info("complete %s predict" % column) | ||
|
||
test_data_df.to_csv(config.test_data_predict_output_path, | ||
encoding="utf-8", index=False) | ||
logger.info("complete predict test data") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
|
||
import argparse | ||
import config | ||
import logging | ||
import os | ||
|
||
import numpy as np | ||
|
||
from skift import FirstColFtClassifier | ||
from sklearn.externals import joblib | ||
from util import load_data_from_csv, seg_words, get_f1_score | ||
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] <%(processName)s> (%(threadName)s) %(message)s') | ||
logger = logging.getLogger(__name__) | ||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('-mn', '--model_name', type=str, nargs='?', | ||
default='fasttext_model.pkl', | ||
help='the name of model') | ||
parser.add_argument('-lr', '--learning_rate', type=float, nargs='?', | ||
default=1.0) | ||
parser.add_argument('-ep', '--epoch', type=int, nargs='?', | ||
default=10) | ||
parser.add_argument('-wn', '--word_ngrams', type=int, nargs='?', | ||
default=1) | ||
parser.add_argument('-mc', '--min_count', type=int, nargs='?', | ||
default=1) | ||
|
||
args = parser.parse_args() | ||
model_name = args.model_name | ||
learning_rate = args.learning_rate | ||
epoch = args.epoch | ||
word_ngrams = args.word_ngrams | ||
min_count = args.min_count | ||
|
||
# load train data | ||
logger.info("start load load") | ||
train_data_df = load_data_from_csv(config.train_data_path) | ||
validate_data_df = load_data_from_csv(config.validate_data_path) | ||
|
||
content_train = train_data_df.iloc[:, 1] | ||
|
||
logger.info("start seg train data") | ||
content_train = seg_words(content_train) | ||
logger.info("complete seg train data") | ||
|
||
logger.info("prepare train format") | ||
train_data_format = np.asarray([content_train]).T | ||
logger.info("complete formate train data") | ||
|
||
columns = train_data_df.columns.values.tolist() | ||
|
||
# model train | ||
logger.info("start train model") | ||
classifier_dict = dict() | ||
for column in columns[2:]: | ||
train_label = train_data_df[column] | ||
logger.info("start train %s model" % column) | ||
sk_clf = FirstColFtClassifier(lr=learning_rate, epoch=epoch, | ||
wordNgrams=word_ngrams, | ||
minCount=min_count, verbose=2) | ||
sk_clf.fit(train_data_format, train_label) | ||
logger.info("complete train %s model" % column) | ||
classifier_dict[column] = sk_clf | ||
|
||
logger.info("complete train model") | ||
logger.info("start save model") | ||
model_path = config.model_path | ||
if not os.path.exists(model_path): | ||
os.makedirs(model_path) | ||
joblib.dump(classifier_dict, model_path + model_name) | ||
logger.info("complete svae model") | ||
|
||
# validata model | ||
content_validata = validate_data_df.iloc[:, 1] | ||
|
||
logger.info("start seg validata data") | ||
content_validata = seg_words(content_validata) | ||
logger.info("complet seg validata data") | ||
|
||
logger.info("prepare valid format") | ||
validata_data_format = np.asarray([content_validata]).T | ||
logger.info("complete formate train data") | ||
|
||
logger.info("start compute f1 score for validata model") | ||
f1_score_dict = dict() | ||
for column in columns[2:]: | ||
true_label = np.asarray(validate_data_df[column]) | ||
classifier = classifier_dict[column] | ||
pred_label = classifier.predict(validata_data_format).astype(int) | ||
f1_score = get_f1_score(true_label, pred_label) | ||
f1_score_dict[column] = f1_score | ||
|
||
f1_score = np.mean(list(f1_score_dict.values())) | ||
str_score = "\n" | ||
for column in columns[2:]: | ||
str_score += column + ":" + str(f1_score_dict[column]) + "\n" | ||
|
||
logger.info("f1_scores: %s\n" % str_score) | ||
logger.info("f1_score: %s" % f1_score) | ||
logger.info("complete compute f1 score for validate model") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
pybind11>=2.2 | ||
Cython==0.28.5 | ||
future==0.16.0 | ||
jieba==0.39 | ||
numpy==1.15.1 | ||
pandas==0.23.4 | ||
python-dateutil==2.7.3 | ||
pytz==2018.5 | ||
scikit-learn==0.20rc1 | ||
scipy==1.1.0 | ||
six==1.11.0 | ||
skift==0.0.11 | ||
#fasttext==0.8.22 | ||
git+https://github.com/facebookresearch/fastText.git@ca8c5face7d5f3a64fff0e4dfaf58d60a691cb7c |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
|
||
import jieba | ||
import pandas as pd | ||
|
||
from sklearn.metrics import f1_score | ||
|
||
stop_words = [] | ||
|
||
|
||
def load_data_from_csv(file_name, header=0, encoding="utf-8"): | ||
data_df = pd.read_csv(file_name, header=header, encoding=encoding) | ||
return data_df | ||
|
||
|
||
def seg_words(contents): | ||
contents_segs = list() | ||
for content in contents: | ||
rcontent = content.replace("\r\n", " ").replace("\n", " ") | ||
segs = [word for word in jieba.cut(rcontent) if word not in stop_words] | ||
contents_segs.append(" ".join(segs)) | ||
return contents_segs | ||
|
||
|
||
def get_f1_score(y_true, y_pred): | ||
return f1_score(y_true, y_pred, labels=[1, 0, -1, -2], average='macro') |