In [8]:
import json
import numpy as np
import pickle
import xgboost as xgb
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.model_selection import GridSearchCV

### model training

In [2]:
train_data = []

with open('../data/domain1_train.json', 'r') as file:
    for line in file:
        data = json.loads(line)
        train_data.append(data)

In [3]:
X = [item['text'] for item in train_data]
y = [item['label'] for item in train_data]

In [4]:
vectorizer = CountVectorizer(max_features=5000)
X = vectorizer.fit_transform([" ".join(map(str, text)) for text in X])

In [5]:
param_grid = {
    'n_estimators': [100, 200, 300],
    'max_depth': [3, 4, 5],
    'learning_rate': [0.1, 0.01, 0.001]
}

In [6]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [7]:
xgboost_model = xgb.XGBClassifier()

In [9]:
grid_search = GridSearchCV(estimator=xgboost_model, param_grid=param_grid, scoring='accuracy', cv=3)
grid_search.fit(X_train, y_train)

GridSearchCV(cv=3,
             estimator=XGBClassifier(base_score=None, booster=None,
                                     callbacks=None, colsample_bylevel=None,
                                     colsample_bynode=None,
                                     colsample_bytree=None,
                                     early_stopping_rounds=None,
                                     enable_categorical=False, eval_metric=None,
                                     feature_types=None, gamma=None,
                                     gpu_id=None, grow_policy=None,
                                     importance_type=None,
                                     interaction_constraints=None,
                                     learning_rate=None,...
                                     max_cat_threshold=None,
                                     max_cat_to_onehot=None,
                                     max_delta_step=None, max_depth=None,
                                     max_leaves=Non

In [11]:
best_params = grid_search.best_params_
best_score = grid_search.best_score_

print(f'best para: {best_params}')
print(f'best score: {best_score}')

best para: {'learning_rate': 0.1, 'max_depth': 5, 'n_estimators': 300}
best score: 0.8944230769230769


In [12]:
#xgboost_model.fit(X_train, y_train)
best_model = grid_search.best_estimator_
best_model.fit(X_train, y_train)

XGBClassifier(base_score=None, booster=None, callbacks=None,
              colsample_bylevel=None, colsample_bynode=None,
              colsample_bytree=None, early_stopping_rounds=None,
              enable_categorical=False, eval_metric=None, feature_types=None,
              gamma=None, gpu_id=None, grow_policy=None, importance_type=None,
              interaction_constraints=None, learning_rate=0.1, max_bin=None,
              max_cat_threshold=None, max_cat_to_onehot=None,
              max_delta_step=None, max_depth=5, max_leaves=None,
              min_child_weight=None, missing=nan, monotone_constraints=None,
              n_estimators=300, n_jobs=None, num_parallel_tree=None,
              predictor=None, random_state=None, ...)

In [13]:
y_pred = best_model.predict(X_test)

In [16]:
accuracy = accuracy_score(y_test, y_pred)
print(f'Acc:{accuracy}')

Acc:0.9007692307692308


In [17]:
with open('../models/xgboost_model(grid_search).pkl', 'wb') as model_file:
    pickle.dump(xgboost_model, model_file)

In [19]:
with open('../data/Xgboost_output(grid_search).csv', 'w') as output_file:
    output_file.write('id,class\n')  # 写入CSV文件的标题行

    # 读取测试集
    with open('../data/test_set.json', 'r') as file:
        for line in file:
            entry = json.loads(line)
            text = entry["text"]

            # 将文本转化为词袋模型特征
            X_test = vectorizer.transform([" ".join(map(str, text))])

            # 使用XGBoost模型进行预测
            prediction = best_model.predict(X_test)

            # 写入CSV文件
            output_file.write(f"{entry['id']},{prediction[0]}\n")