In [5]:
import baostock as bs
import akshare as ak
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime, timedelta

# 初始化baostock
bs.login()

# 定义初始资金
initial_funds = 1000000
funds = initial_funds

# 定义时间范围
start_date = '2010-01-01'
end_date = '2024-08-31'

# 获取每个月的第一个交易日和最后一个交易日
def get_monthly_trading_days(start_date, end_date):
    df = ak.stock_zh_a_daily(symbol="sz000001", start_date=start_date, end_date=end_date, adjust="qfq")
    df['date'] = pd.to_datetime(df['date'])
    df.set_index('date', inplace=True)
    
    monthly_trading_days = {}
    for year_month, group in df.groupby([df.index.year, df.index.month]):
        first_trading_day = group.index.min().strftime('%Y-%m-%d')
        last_trading_day = group.index.max().strftime('%Y-%m-%d')
        monthly_trading_days[first_trading_day] = (first_trading_day, last_trading_day)
    
    return monthly_trading_days

# 获取每个月初的沪深300成分股列表
def get_monthly_hs300_stocks(monthly_trading_days):
    monthly_hs300_stocks = {}
    
    for first_trading_day in monthly_trading_days.keys():
        if first_trading_day not in monthly_hs300_stocks:
            rs = bs.query_hs300_stocks(first_trading_day)
            hs300_stocks = []
            while (rs.error_code == '0') & rs.next():
                hs300_stocks.append(rs.get_row_data()[1])
            monthly_hs300_stocks[first_trading_day] = hs300_stocks
    
    return monthly_hs300_stocks

# 获取所有股票的总市值数据
def get_all_stock_market_values(stocks):
    market_value_data = {}
    for stock in stocks:
        df = ak.stock_a_indicator_lg(stock)
        market_value_data[stock] = df
    return market_value_data

# 获取所有股票的价格数据
def get_all_stock_prices(stocks, start_date, end_date):
    price_data = {}
    for stock in stocks:
        df = ak.stock_zh_a_daily(stock, start_date, end_date)
        price_data[stock] = df
    return price_data

# 获取股票总市值
def get_stock_market_value(stock_code, date, market_value_data):
    df = market_value_data[stock_code]
    df = df[df['date'] == date]
    if not df.empty:
        return df['total_mv'].values[0]
    return None

# 获取股票开盘价和收盘价
def get_stock_prices(stock_code, start_date, end_date, price_data):
    df = price_data[stock_code]
    df = df[(df['date'] >= start_date) & (df['date'] <= end_date)]
    if not df.empty:
        return df['open'].values[0], df['close'].values[-1]
    return None, None

print(get_monthly_hs300_stocks(get_monthly_trading_days(start_date, end_date)))

login success!
{'2010-01-04': ('2010-01-04', '2010-01-29'), '2010-02-01': ('2010-02-01', '2010-02-26'), '2010-03-02': ('2010-03-02', '2010-03-31'), '2010-04-01': ('2010-04-01', '2010-04-30'), '2010-05-04': ('2010-05-04', '2010-05-31'), '2010-06-01': ('2010-06-01', '2010-06-29'), '2010-09-02': ('2010-09-02', '2010-09-29'), '2010-10-08': ('2010-10-08', '2010-10-29'), '2010-11-01': ('2010-11-01', '2010-11-30'), '2010-12-01': ('2010-12-01', '2010-12-31'), '2011-01-04': ('2011-01-04', '2011-01-31'), '2011-02-01': ('2011-02-01', '2011-02-28'), '2011-03-01': ('2011-03-01', '2011-03-31'), '2011-04-01': ('2011-04-01', '2011-04-29'), '2011-05-03': ('2011-05-03', '2011-05-31'), '2011-06-01': ('2011-06-01', '2011-06-30'), '2011-07-01': ('2011-07-01', '2011-07-29'), '2011-08-01': ('2011-08-01', '2011-08-31'), '2011-09-01': ('2011-09-01', '2011-09-30'), '2011-10-10': ('2011-10-10', '2011-10-31'), '2011-11-01': ('2011-11-01', '2011-11-30'), '2011-12-01': ('2011-12-01', '2011-12-30'), '2012-01-04': ('