# 教程 3: 模型训练与评估 / Tutorial 3: Model Training and Evaluation

本教程将教你如何训练机器学习模型来预测加密货币价格涨跌。

This tutorial will teach you how to train a machine learning model to predict cryptocurrency price movements.

## 1. 导入库 / Import Libraries

In [None]:
import os
import sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split, cross_val_score, learning_curve
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    confusion_matrix, classification_report, roc_curve, auc
)
import joblib

# 设置 matplotlib 支持中文显示 / Set matplotlib to support Chinese display
plt.rcParams['font.sans-serif'] = ['SimHei', 'Arial Unicode MS']
plt.rcParams['axes.unicode_minus'] = False

# 添加项目路径 / Add project path
sys.path.append('..')

from utils.binance_client import BinanceUtility
from utils.data_processor import DataProcessor

print("库导入成功！/ Libraries imported successfully!")

## 2. 准备数据 / Prepare Data

In [None]:
# 尝试加载保存的特征数据
# Try to load saved feature data
if os.path.exists('../data/X_features.npy') and os.path.exists('../data/y_labels.npy'):
    X = np.load('../data/X_features.npy')
    y = np.load('../data/y_labels.npy')
    print(f"从文件加载特征数据 / Loaded features from file: X.shape = {X.shape}")
else:
    print("从币安获取数据并创建特征... / Fetching data and creating features...")
    client = BinanceUtility()
    df = client.fetch_historical_data('BTCUSDT', '1h', '6 months ago UTC')
    processor = DataProcessor()
    df_features = processor.add_technical_indicators(df)
    X, y = processor.prepare_features_labels(df_features)
    print(f"创建特征数据 / Created features: X.shape = {X.shape}")

print(f"\n标签分布 / Label distribution:")
print(f"上涨 (UP/1): {np.sum(y)} / {len(y)} ({np.sum(y)/len(y)*100:.1f}%)")
print(f"下跌 (DOWN/0): {len(y)-np.sum(y)} / {len(y)} ({(len(y)-np.sum(y))/len(y)*100:.1f}%)")

## 3. 理解随机森林算法 / Understanding Random Forest

**随机森林 (Random Forest)** 是一个集成学习算法，它由多个决策树组成。

**Random Forest** is an ensemble learning algorithm composed of multiple decision trees.

### 工作原理 / How it works:

1. **构建决策树 (Build Decision Trees)**: 创建多个决策树，每棵树基于数据的随机子集
   - Create multiple decision trees, each based on a random subset of data

2. **投票机制 (Voting Mechanism)**: 对新数据进行预测时，所有树进行投票
   - When predicting new data, all trees vote

3. **最终预测 (Final Prediction)**: 得票最多的类别作为最终结果
   - The category with most votes is the final result

### 为什么选择随机森林？/ Why Random Forest?
- ✅ 不容易过拟合 / Less prone to overfitting
- ✅ 可以处理非线性关系 / Can handle nonlinear relationships
- ✅ 提供特征重要性 / Provides feature importance
- ✅ 对异常值鲁棒 / Robust to outliers
- ✅ 不需要大量数据调参 / Doesn't require extensive parameter tuning

## 4. 划分训练集和测试集 / Train-Test Split

In [None]:
# 划分数据
# 重要：对于时间序列数据，我们不应该随机打乱！
# Split data
# Important: For time series data, we should NOT shuffle randomly!
X_train, X_test, y_train, y_test = train_test_split(
    X, y, 
    test_size=0.2,  # 20% 用于测试 / 20% for testing
    shuffle=False   # 保持时间顺序 / Keep time order
)

print(f"训练集大小 / Training set size: {X_train.shape[0]}")
print(f"测试集大小 / Test set size: {X_test.shape[0]}")
print(f"\n训练集标签分布 / Training label distribution:")
print(f"  上涨 (UP/1): {np.sum(y_train)} / {len(y_train)} ({np.sum(y_train)/len(y_train)*100:.1f}%)")
print(f"  下跌 (DOWN/0): {len(y_train)-np.sum(y_train)} / {len(y_train)}")
print(f"\n测试集标签分布 / Test label distribution:")
print(f"  上涨 (UP/1): {np.sum(y_test)} / {len(y_test)} ({np.sum(y_test)/len(y_test)*100:.1f}%)")
print(f"  下跌 (DOWN/0): {len(y_test)-np.sum(y_test)} / {len(y_test)}")

## 5. 训练随机森林模型 / Train Random Forest Model

In [None]:
# 创建模型 / Create model
# n_estimators: 树的数量 / Number of trees
# random_state: 随机种子，保证结果可复现 / Random seed for reproducibility
model = RandomForestClassifier(
    n_estimators=100,
    random_state=42,
    n_jobs=-1,  # 使用所有 CPU 核心 / Use all CPU cores
    max_depth=10  # 限制树的深度，防止过拟合 / Limit tree depth to prevent overfitting
)

print("开始训练模型... / Starting model training...")
model.fit(X_train, y_train)
print("模型训练完成！/ Model training completed!")

## 6. 模型评估 - 基础指标 / Model Evaluation - Basic Metrics

In [None]:
# 在测试集上进行预测 / Make predictions on test set
y_pred = model.predict(X_test)
y_pred_proba = model.predict_proba(X_test)

# 计算评估指标 / Calculate evaluation metrics
accuracy = accuracy_score(y_test, y_pred)
precision = precision_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)

print("=" * 50)
print("模型评估指标 / Model Evaluation Metrics")
print("=" * 50)
print(f"准确率 / Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
print(f"精确率 / Precision: {precision:.4f}")
print(f"召回率 / Recall: {recall:.4f}")
print(f"F1 分数 / F1 Score: {f1:.4f}")
print()

# 指标说明 / Metrics explanation
print("指标说明 / Metrics Explanation:")
print("- 准确率 / Accuracy: 所有预测中正确的比例 / Proportion of correct predictions among all predictions")
print("- 精确率 / Precision: 预测为上涨中真正上涨的比例 / Proportion of truly up movements among predicted up movements")
print("- 召回率 / Recall: 实际上涨中被预测为上涨的比例 / Proportion of predicted up movements among actual up movements")
print("- F1 分数 / F1 Score: 精确率和召回率的调和平均数 / Harmonic mean of precision and recall")

## 7. 混淆矩阵 / Confusion Matrix

In [None]:
# 计算混淆矩阵 / Calculate confusion matrix
cm = confusion_matrix(y_test, y_pred)

# 绘制混淆矩阵 / Plot confusion matrix
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=['下跌 / DOWN', '上涨 / UP'],
            yticklabels=['下跌 / DOWN', '上涨 / UP'],
            cbar_kws={'label': '样本数 / Count'})
plt.title('混淆矩阵 / Confusion Matrix', fontsize=14, fontweight='bold')
plt.ylabel('真实标签 / True Label', fontsize=12)
plt.xlabel('预测标签 / Predicted Label', fontsize=12)
plt.tight_layout()
plt.show()

print("\n混淆矩阵解释 / Confusion Matrix Explanation:")
print(f"真阴性 (TN): 正确预测下跌 / Correctly predicted DOWN: {cm[0,0]}")
print(f"假阳性 (FP): 错误预测为上涨 / Incorrectly predicted UP: {cm[0,1]}")
print(f"假阴性 (FN): 错误预测为下跌 / Incorrectly predicted DOWN: {cm[1,0]}")
print(f"真阳性 (TP): 正确预测上涨 / Correctly predicted UP: {cm[1,1]}")

## 8. ROC 曲线和 AUC 值 / ROC Curve and AUC Score

In [None]:
# 计算 ROC 曲线 / Calculate ROC curve
fpr, tpr, thresholds = roc_curve(y_test, y_pred_proba[:, 1])
roc_auc = auc(fpr, tpr)

# 绘制 ROC 曲线 / Plot ROC curve
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC 曲线 / Curve (AUC = {roc_auc:.4f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='随机猜测 / Random Guess')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('假阳性率 / False Positive Rate', fontsize=12)
plt.ylabel('真阳性率 / True Positive Rate', fontsize=12)
plt.title('ROC 曲线 / ROC Curve', fontsize=14, fontweight='bold')
plt.legend(loc="lower right")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print(f"AUC 值: {roc_auc:.4f}")
print("AUC 说明 / AUC Explanation:")
print("- AUC = 0.5: 模型性能等同随机猜测 / Model performs same as random guessing")
print("- AUC = 1.0: 完美的分类器 / Perfect classifier")
print("- AUC > 0.7: 一般认为模型有较好表现 / Generally considered good performance")

## 9. 交叉验证 / Cross-Validation

In [None]:
# 进行 5 折交叉验证 / Perform 5-fold cross-validation
cv_scores = cross_val_score(model, X, y, cv=5)

print("=" * 50)
print("5 折交叉验证结果 / 5-Fold Cross-Validation Results")
print("=" * 50)
for i, score in enumerate(cv_scores, 1):
    print(f"第 {i} 折 / Fold {i}: {score:.4f} ({score*100:.2f}%)")
print(f"\n平均准确率 / Mean Accuracy: {cv_scores.mean():.4f} ({cv_scores.mean()*100:.2f}%)")
print(f"标准差 / Standard Deviation: {cv_scores.std():.4f}")

print("\n交叉验证说明 / Cross-Validation Explanation:")
print("- 交叉验证可以更好地评估模型的泛化能力")
print("- Cross-validation better assesses model's generalization ability")
print("- 标准差越小，模型越稳定 / Smaller std means more stable model")

## 10. 特征重要性分析 / Feature Importance Analysis

In [None]:
# 获取特征重要性 / Get feature importance
feature_names = ['open', 'high', 'low', 'close', 'volume', 'sma_7', 'sma_25', 'rsi_14', 'roc', 'volatility']
importances = model.feature_importances_

# 创建 DataFrame / Create DataFrame
importance_df = pd.DataFrame({
    'Feature / 特征': feature_names,
    'Importance / 重要性': importances
}).sort_values('Importance / 重要性', ascending=False)

# 绘制特征重要性 / Plot feature importance
plt.figure(figsize=(10, 6))
colors = plt.cm.viridis(np.linspace(0, 0.8, len(importance_df)))
plt.barh(importance_df['Feature / 特征'], importance_df['Importance / 重要性'], color=colors)
plt.xlabel('重要性 / Importance', fontsize=12)
plt.ylabel('特征 / Feature', fontsize=12)
plt.title('特征重要性分析 / Feature Importance Analysis', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("\n特征重要性排名 / Feature Importance Ranking:")
print(importance_df.to_string(index=False))

## 11. 学习曲线 / Learning Curve

In [None]:
# 计算学习曲线 / Calculate learning curve
train_sizes, train_scores, test_scores = learning_curve(
    model, X, y, cv=5, n_jobs=-1,
    train_sizes=np.linspace(0.1, 1.0, 10),
    scoring='accuracy'
)

train_mean = np.mean(train_scores, axis=1)
train_std = np.std(train_scores, axis=1)
test_mean = np.mean(test_scores, axis=1)
test_std = np.std(test_scores, axis=1)

# 绘制学习曲线 / Plot learning curve
plt.figure(figsize=(10, 6))
plt.plot(train_sizes, train_mean, 'o-', color='r', label='训练集准确率 / Training Accuracy')
plt.fill_between(train_sizes, train_mean - train_std, train_mean + train_std, alpha=0.1, color='r')
plt.plot(train_sizes, test_mean, 'o-', color='g', label='验证集准确率 / Validation Accuracy')
plt.fill_between(train_sizes, test_mean - test_std, test_mean + test_std, alpha=0.1, color='g')

plt.xlabel('训练样本数 / Training Size', fontsize=12)
plt.ylabel('准确率 / Accuracy', fontsize=12)
plt.title('学习曲线 / Learning Curve', fontsize=14, fontweight='bold')
plt.legend(loc='best')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print("\n学习曲线说明 / Learning Curve Explanation:")
print("- 训练线和验证线差距大: 可能过拟合 / Large gap between lines: possible overfitting")
print("- 两条线都低: 可能欠拟合 / Both lines low: possible underfitting")
print("- 两条线接近且都高: 模型表现良好 / Lines close and both high: good model performance")

## 12. 可视化预测结果 / Visualize Predictions

In [None]:
# 获取最近的数据用于可视化 / Get recent data for visualization
client = BinanceUtility()
df = client.fetch_historical_data('BTCUSDT', '1h', '2 weeks ago UTC')
processor = DataProcessor()
df_features = processor.add_technical_indicators(df)
X_recent, _ = processor.prepare_features_labels(df_features)

# 预测 / Predict
predictions = model.predict(X_recent)
probabilities = model.predict_proba(X_recent)[:, 1]

# 准备可视化数据 / Prepare visualization data
viz_df = df_features.iloc[-len(predictions):].copy()
viz_df['prediction'] = predictions
viz_df['probability'] = probabilities
viz_df['timestamp'] = pd.to_datetime(viz_df['timestamp'])
viz_df = viz_df.set_index('timestamp')

# 绘制 / Plot
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 10), sharex=True)

# 价格和预测 / Price and predictions
ax1.plot(viz_df.index, viz_df['close'], label='收盘价 / Close', linewidth=1.5, alpha=0.8)
scatter_up = ax1.scatter(
    viz_df[viz_df['prediction'] == 1].index,
    viz_df[viz_df['prediction'] == 1]['close'],
    color='green', marker='^', s=100, label='预测上涨 / Predict UP', zorder=5
)
scatter_down = ax1.scatter(
    viz_df[viz_df['prediction'] == 0].index,
    viz_df[viz_df['prediction'] == 0]['close'],
    color='red', marker='v', s=100, label='预测下跌 / Predict DOWN', zorder=5
)
ax1.set_ylabel('价格 / Price', fontsize=12)
ax1.legend()
ax1.grid(True, alpha=0.3)
ax1.set_title('价格走势与预测 / Price Trend and Predictions', fontsize=14)

# 预测概率 / Prediction probabilities
colors = ['red' if p < 0.5 else 'green' for p in viz_df['probability']]
ax2.bar(viz_df.index, viz_df['probability'], color=colors, alpha=0.7)
ax2.axhline(y=0.5, color='black', linestyle='--', linewidth=1.5, label='决策边界 / Decision Boundary')
ax2.set_ylabel('上涨概率 / UP Probability', fontsize=12)
ax2.set_xlabel('时间 / Time', fontsize=12)
ax2.set_ylim([0, 1])
ax2.legend()
ax2.grid(True, alpha=0.3)
ax2.set_title('预测概率分布 / Prediction Probability Distribution', fontsize=14)

plt.tight_layout()
plt.show()

## 13. 保存模型 / Save Model

In [None]:
# 创建 models 目录 / Create models directory
if not os.path.exists('../models'):
    os.makedirs('../models')

# 保存模型 / Save model
model_path = '../models/BTCUSDT_price_model.pkl'
joblib.dump(model, model_path)

print(f"模型已保存至 / Model saved to: {model_path}")

# 保存模型信息 / Save model info
model_info = {
    'model_type': 'RandomForestClassifier',
    'n_estimators': 100,
    'max_depth': 10,
    'accuracy': accuracy,
    'precision': precision,
    'recall': recall,
    'f1': f1,
    'cv_mean': cv_scores.mean(),
    'cv_std': cv_scores.std()
}

info_path = '../models/BTCUSDT_model_info.txt'
with open(info_path, 'w') as f:
    f.write("模型信息 / Model Information:\n")
    f.write("=" * 40 + "\n")
    for key, value in model_info.items():
        f.write(f"{key}: {value}\n")

print(f"模型信息已保存至 / Model info saved to: {info_path}")

## 14. 模型使用示例 / Model Usage Example

In [None]:
# 加载模型 / Load model
loaded_model = joblib.load(model_path)

# 获取最新数据 / Get latest data
df_latest = client.fetch_historical_data('BTCUSDT', '1h', '1 day ago UTC')
df_latest_features = processor.add_technical_indicators(df_latest)
X_latest, _ = processor.prepare_features_labels(df_latest_features)

# 预测最新的价格走势 / Predict latest price movement
latest_features = X_latest.tail(1)
prediction = loaded_model.predict(latest_features)[0]
probability = loaded_model.predict_proba(latest_features)[0]

print("=" * 50)
print("实时预测示例 / Real-time Prediction Example")
print("=" * 50)
print(f"当前价格 / Current Price: {df_latest['close'].iloc[-1]:.2f} USDT")
print(f"预测趋势 / Prediction: {'上涨 / UP' if prediction == 1 else '下跌 / DOWN'}")
print(f"上涨概率 / UP Probability: {probability[1]*100:.2f}%")
print(f"下跌概率 / DOWN Probability: {probability[0]*100:.2f}%")
print(f"预测时间 / Prediction Time: {df_latest['timestamp'].iloc[-1]}")

## 总结 / Summary

在本教程中，我们完成了：

1. **数据准备**: 加载特征数据，划分训练集和测试集
2. **模型理解**: 学习了随机森林算法的原理和优势
3. **模型训练**: 训练了随机森林分类模型
4. **模型评估**: 使用多种指标评估模型性能
   - 准确率、精确率、召回率、F1 分数
   - 混淆矩阵
   - ROC 曲线和 AUC 值
5. **交叉验证**: 通过交叉验证评估模型泛化能力
6. **特征分析**: 分析了各特征的重要性
7. **学习曲线**: 检查模型是否存在过拟合或欠拟合
8. **结果可视化**: 可视化预测结果
9. **模型保存**: 保存模型供后续使用
10. **实际应用**: 演示如何使用模型进行实时预测

In this tutorial, we completed:

1. **Data Preparation**: Loaded feature data, split train and test sets
2. **Model Understanding**: Learned Random Forest principles and advantages
3. **Model Training**: Trained Random Forest classifier
4. **Model Evaluation**: Evaluated model performance with multiple metrics
   - Accuracy, Precision, Recall, F1 Score
   - Confusion Matrix
   - ROC Curve and AUC Score
5. **Cross-Validation**: Evaluated model generalization ability
6. **Feature Analysis**: Analyzed importance of each feature
7. **Learning Curve**: Checked for overfitting or underfitting
8. **Result Visualization**: Visualized prediction results
9. **Model Saving**: Saved model for future use
10. **Real Application**: Demonstrated real-time prediction usage

下一步：学习回测和策略评估！/ Next: Learn backtesting and strategy evaluation!