In [1]:
import pandas as pd
import joblib
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm

# データを読み込む
train_df = pd.read_csv('train.txt', sep='\t', header=None, names=['Category', 'Title'])
valid_df = pd.read_csv('valid.txt', sep='\t', header=None, names=['Category', 'Title'])

# Vectorizerを初期化して訓練データに適合する
vectorizer = CountVectorizer()
X_train = vectorizer.fit_transform(train_df['Title'])
X_valid = vectorizer.transform(valid_df['Title'])

# ラベルを数値に変換する
label_encoder = LabelEncoder()
y_train = label_encoder.fit_transform(train_df['Category'])
y_valid = label_encoder.transform(valid_df['Category'])

# ハイパーパラメータの候補
param_grid_logistic = {
    'C': [0.01, 0.1, 1, 10, 100]
}
param_grid_svm = {
    'C': [0.01, 0.1, 1, 10, 100],
    'kernel': ['linear', 'rbf']
}
param_grid_rf = {
    'n_estimators': [10, 50, 100],
    'max_depth': [None, 10, 20, 30]
}

# モデルのリスト
models = [
    (LogisticRegression(max_iter=1000, random_state=42), param_grid_logistic),
    (SVC(probability=True, random_state=42), param_grid_svm),
    (RandomForestClassifier(random_state=42), param_grid_rf)
]

# 最適なモデルとパラメータを見つける
best_model = None
best_score = 0
best_params = None

# 進捗バーの設定
for model, param_grid in tqdm(models, desc="Hyperparameter Search"):
    grid_search = GridSearchCV(estimator=model, param_grid=param_grid, cv=5, n_jobs=-1)
    grid_search.fit(X_train, y_train)
    if grid_search.best_score_ > best_score:
        best_score = grid_search.best_score_
        best_model = grid_search.best_estimator_
        best_params = grid_search.best_params_

# 検証データでの正解率
valid_accuracy = best_model.score(X_valid, y_valid)
print(f'Best Model: {best_model}')
print(f'Validation Accuracy: {valid_accuracy:.4f}')
print(f'Best Parameters: {best_params}')

# モデルを保存する
joblib.dump(best_model, 'best_model.pkl')
joblib.dump(vectorizer, 'vectorizer.pkl')
joblib.dump(label_encoder, 'label_encoder.pkl')


Hyperparameter Search:  33%|██████▋             | 1/3 [17:07<34:15, 1027.59s/it]


KeyboardInterrupt: 