In [9]:
import json
import importlib
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import mlflow
import mlflow.sklearn
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
from math import sqrt
import xgboost as xgb

mlflow.sklearn.autolog()

# 加载数据
data = pd.read_csv('./DATA/exported_data7.csv')
X = data.drop('取引価格（総額）', axis=1)
y = data['取引価格（総額）']

# 标准化数据
X_scaler = StandardScaler()
X_scaled = X_scaler.fit_transform(X)

y_scaler = StandardScaler()
y_scaled = y_scaler.fit_transform(np.array(y).reshape(-1, 1))

X_train, X_test, y_train, y_test = train_test_split(X_scaled, y_scaled, test_size=0.2, random_state=42)

params = {
    'objective': 'reg:squarederror',  # 指定模型目标为平方误差回归
    'n_estimators': 1000,  # 树的数量
    'learning_rate': 0.05,  # 学习率
    'max_depth': 6,  # 树的最大深度
    'subsample': 0.8,  # 样本采样比例
    'colsample_bytree': 0.8,  # 特征采样比例
    'reg_alpha': 0.1,  # L1 正则化参数
    'reg_lambda': 0.1,  # L2 正则化参数
    'verbosity': 1  # 显示训练日志的详细程度
}

with mlflow.start_run():
    xgb_model = xgb.XGBRegressor(**params)
    xgb_model.fit(X_train, y_train, 
                  eval_set=[(X_test, y_test)], 
                  verbose=True)
    
    y_pred = xgb_model.predict(X_test)

    mse = mean_squared_error(y_test, y_pred)
    rmse = sqrt(mse)
    r2 = r2_score(y_test, y_pred)
    mae = mean_absolute_error(y_test, y_pred)
    
    mlflow.log_metric("mse", mse)
    mlflow.log_metric("rmse", rmse)
    mlflow.log_metric("r2", r2)
    mlflow.log_metric("mae", mae)
    

    mlflow.sklearn.log_model(xgb_model, "model")
    mlflow.log_params(params)

print(f"Test MSE: {mse}")
print(f"Test RMSE: {rmse}")
print(f"Test R²: {r2}")
print(f"Test MAE: {mae}")


[0]	validation_0-rmse:0.99797
[1]	validation_0-rmse:0.97850
[2]	validation_0-rmse:0.96053
[3]	validation_0-rmse:0.95067
[4]	validation_0-rmse:0.93435
[5]	validation_0-rmse:0.92316
[6]	validation_0-rmse:0.90905
[7]	validation_0-rmse:0.89948
[8]	validation_0-rmse:0.88682
[9]	validation_0-rmse:0.87849
[10]	validation_0-rmse:0.87091
[11]	validation_0-rmse:0.85993
[12]	validation_0-rmse:0.85036
[13]	validation_0-rmse:0.84186
[14]	validation_0-rmse:0.83596
[15]	validation_0-rmse:0.83376
[16]	validation_0-rmse:0.82555
[17]	validation_0-rmse:0.81760
[18]	validation_0-rmse:0.81022
[19]	validation_0-rmse:0.80390
[20]	validation_0-rmse:0.79907
[21]	validation_0-rmse:0.79325
[22]	validation_0-rmse:0.78805
[23]	validation_0-rmse:0.78337
[24]	validation_0-rmse:0.77895
[25]	validation_0-rmse:0.77479
[26]	validation_0-rmse:0.77049
[27]	validation_0-rmse:0.76684
[28]	validation_0-rmse:0.76340
[29]	validation_0-rmse:0.75973
[30]	validation_0-rmse:0.75677
[31]	validation_0-rmse:0.75403
[32]	validation_0-