In [23]:
import nltk

# 导入pandas用于读取表格数据
import pandas as pd

# 导入BOW（词袋模型），可以选择将CountVectorizer替换为TfidfVectorizer（TF-IDF（词频-逆文档频率）），注意上下文要同时修改，亲测后者效果更佳
from sklearn.feature_extraction.text import TfidfVectorizer

# 导入LogisticRegression回归模型
#from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV

# 过滤警告消息
from warnings import simplefilter
from sklearn.exceptions import ConvergenceWarning

simplefilter("ignore", category=ConvergenceWarning)

# 读取数据集
train = pd.read_csv('./data/train.csv')
train['title'] = train['title'].fillna('')
train['abstract'] = train['abstract'].fillna('')

test = pd.read_csv('./data/testB.csv')
test['title'] = test['title'].fillna('')
test['abstract'] = test['abstract'].fillna('')

In [25]:
# 提取文本特征，生成训练集与测试集
train['text'] = train['title'].fillna('') + ' ' +  train['author'].fillna('') + ' ' + train['abstract'].fillna('')+ ' ' + train['Keywords'].fillna('')
test['text'] = test['title'].fillna('') + ' ' +  test['author'].fillna('') + ' ' + test['abstract'].fillna('')

vector = TfidfVectorizer().fit(train['text'])
train_vector = vector.transform(train['text'])
test_vector = vector.transform(test['text'])

# 引入模型
model = RandomForestClassifier()

# # 定义超参数的候选值
# param_grid = {
#     'n_estimators': [100, 200, 300],
#     'max_depth': [None, 10, 20],
#     'min_samples_split': [2, 5, 10],
#     'min_samples_leaf': [1, 2, 4]
# }

# # 使用GridSearchCV寻找最佳超参数
# grid_search = GridSearchCV(RandomForestClassifier(), param_grid, cv=5)
# grid_search.fit(train_vector, train['label'])

# # 输出最佳超参数组合
# print("Best Hyperparameters:", grid_search.best_params_)

# # 使用最佳超参数训练模型
# best_model = grid_search.best_estimator_
# best_model.fit(train_vector, train['label'])

# 开始训练，这里可以考虑修改默认的batch_size与epoch来取得更好的效果
model.fit(train_vector, train['label'])

# # 利用模型对测试集label标签进行预测
test['label'] = model.predict(test_vector)
test['Keywords'] = test['title'].fillna('')
test[['uuid','Keywords','label']].to_csv('submit_task1.csv', index=None)

In [26]:
pd.read_csv('./submit_task1.csv')

Unnamed: 0,uuid,Keywords,label
0,0,Tobacco Consumption and High-Sensitivity Cardi...,1
1,1,Approaching towards sustainable supply chain u...,1
2,2,Does globalization matter for ecological footp...,0
3,3,Myths and Misconceptions About University Stud...,1
4,4,Antioxidant Status of Rat Liver Mitochondria u...,1
...,...,...,...
1995,1995,The treatment of veterinary antibiotics in swi...,1
1996,1996,Socio-political efficacy explains increase in ...,1
1997,1997,Investigation of early puberty prevalence and ...,1
1998,1998,From 3D printing to 3D bioprinting: the materi...,1


In [27]:
test['Keywords']

0       Tobacco Consumption and High-Sensitivity Cardi...
1       Approaching towards sustainable supply chain u...
2       Does globalization matter for ecological footp...
3       Myths and Misconceptions About University Stud...
4       Antioxidant Status of Rat Liver Mitochondria u...
                              ...                        
1995    The treatment of veterinary antibiotics in swi...
1996    Socio-political efficacy explains increase in ...
1997    Investigation of early puberty prevalence and ...
1998    From 3D printing to 3D bioprinting: the materi...
1999    Effect of Processing on the Structure and Alle...
Name: Keywords, Length: 2000, dtype: object