<a href="https://colab.research.google.com/github/zyz314/100-Days-Of-ML-Code/blob/master/Project_8_StrategyLearner_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# 导入必要的库
import pandas as pd
import datetime as dt
import matplotlib.pyplot as plt
from qlearning import QLearning  # 假设有一个QLearning类用于强化学习

# 定义技术指标计算函数
def compute_sma(prices, window=20):
    """计算简单移动平均线（SMA）"""
    sma = prices.rolling(window=window).mean()
    return sma

def compute_bollinger_bands(prices, window=20):
    """计算布林带"""
    sma = compute_sma(prices, window)
    rolling_std = prices.rolling(window=window).std()
    upper_band = sma + (2 * rolling_std)
    lower_band = sma - (2 * rolling_std)
    bb_percent = (prices - lower_band) / (upper_band - lower_band)
    return upper_band, lower_band, bb_percent

def compute_momentum(prices, window=10):
    """计算动量"""
    momentum = prices / prices.shift(window) - 1
    return momentum

def get_data(symbols, dates):
    """读取给定符号的股票数据（调整后的收盘价）"""
    df = pd.DataFrame(index=dates)
    for symbol in symbols:
        df_temp = pd.read_csv(f"data/{symbol}.csv", index_col='Date', parse_dates=True,
                              usecols=['Date', 'Adj Close'], na_values=['nan'])
        df_temp = df_temp.rename(columns={'Adj Close': symbol})
        df = df.join(df_temp)
        if symbol == 'SPY':
            df = df.dropna(subset=["SPY"])
    return df

def plot_results(portvals, df_trades, symbol='JPM'):
    """绘制结果图表"""
    plt.figure(figsize=(10, 6))
    plt.plot(portvals, label='Strategy Learner', color='red')  # 绘制策略学习器的投资组合价值曲线
    plt.legend(loc='best')  # 显示图例
    plt.xlabel('Date')  # 设置X轴标签
    plt.ylabel('Normalized Portfolio Value')  # 设置Y轴标签
    plt.title('Strategy Learner vs Benchmark')  # 设置图表标题
    plt.show()  # 显示图表

# 定义策略学习器类
class StrategyLearner:
    def __init__(self, verbose=False, impact=0.0, commission=0.0):
        self.verbose = verbose  # 是否打印调试信息
        self.impact = impact  # 交易冲击
        self.commission = commission  # 佣金
        self.learner = QLearning()  # 使用Q-Learning作为策略学习器

    def author(self):
        return 'your_gt_username'  # 返回GT用户名

    def study_group(self):
        return 'your_study_group'  # 返回学习小组成员的GT用户名，多个用户名用逗号分隔

    def add_evidence(self, symbol='IBM', sd=dt.datetime(2008, 1, 1), ed=dt.datetime(2009, 1, 1), sv=100000):
        # 加载股票数据
        dates = pd.date_range(sd, ed)
        prices_all = get_data([symbol], dates)
        prices = prices_all[symbol]

        # 计算技术指标
        sma = compute_sma(prices, window=20)
        bb_upper, bb_lower, bb_percent = compute_bollinger_bands(prices, window=20)
        momentum = compute_momentum(prices, window=10)

        # 将数据合并为一个DataFrame
        indicators = pd.DataFrame(index=prices.index)
        indicators['SMA'] = sma
        indicators['BB'] = bb_percent
        indicators['Momentum'] = momentum
        indicators['Prices'] = prices

        # 训练Q-Learning模型
        self.learner.train(indicators, sv)

    def testPolicy(self, symbol='IBM', sd=dt.datetime(2009, 1, 1), ed=dt.datetime(2010, 1, 1), sv=100000):
        # 加载股票数据
        dates = pd.date_range(sd, ed)
        prices_all = get_data([symbol], dates)
        prices = prices_all[symbol]

        # 计算技术指标
        sma = compute_sma(prices, window=20)
        bb_upper, bb_lower, bb_percent = compute_bollinger_bands(prices, window=20)
        momentum = compute_momentum(prices, window=10)

        # 将数据合并为一个DataFrame
        indicators = pd.DataFrame(index=prices.index)
        indicators['SMA'] = sma
        indicators['BB'] = bb_percent
        indicators['Momentum'] = momentum
        indicators['Prices'] = prices

        # 使用训练好的模型生成交易信号
        trades = self.learner.test(indicators, sv)

        return trades  # 返回交易信号DataFrame

# 示例用法
if __name__ == "__main__":
    sl = StrategyLearner(verbose=True, impact=0.005, commission=9.95)  # 创建策略学习器实例
    sl.add_evidence(symbol='JPM', sd=dt.datetime(2008, 1, 1), ed=dt.datetime(2009, 1, 1), sv=100000)  # 训练模型
    df_trades = sl.testPolicy(symbol='JPM', sd=dt.datetime(2009, 1, 1), ed=dt.datetime(2010, 1, 1), sv=100000)  # 测试模型生成交易信号
    portvals = compute_portvals(df_trades, start_val=100000, commission=9.95, impact=0.005)  # 计算投资组合价值

    # 归一化投资组合价值到1.0
    portvals = portvals / portvals.iloc[0]

    # 绘制结果图表
    plot_results(portvals, df_trades, symbol='JPM')  # 调用绘图函数绘制结果图表
