In [6]:
import baostock as bs
import akshare as ak
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime, timedelta

# 初始化baostock
bs.login()

# 定义初始资金
initial_funds = 1000000
funds = initial_funds

# 定义时间范围
start_date = '2010-01-01'
end_date = '2024-08-31'

# 获取每个月的第一个交易日和最后一个交易日
def get_monthly_trading_days(start_date, end_date):
    df = ak.stock_zh_a_daily(symbol="sz000001", start_date=start_date, end_date=end_date, adjust="qfq")
    df['date'] = pd.to_datetime(df['date'])
    df.set_index('date', inplace=True)
    
    monthly_trading_days = {}
    for year_month, group in df.groupby([df.index.year, df.index.month]):
        first_trading_day = group.index.min().strftime('%Y-%m-%d')
        last_trading_day = group.index.max().strftime('%Y-%m-%d')
        monthly_trading_days[first_trading_day] = (first_trading_day, last_trading_day)
    
    return monthly_trading_days

# 获取每个月初的沪深300成分股列表
def get_monthly_hs300_stocks(monthly_trading_days):
    monthly_hs300_stocks = {}
    
    for first_trading_day in monthly_trading_days.keys():
        if first_trading_day not in monthly_hs300_stocks:
            rs = bs.query_hs300_stocks(first_trading_day)
            hs300_stocks = []
            while (rs.error_code == '0') & rs.next():
                hs300_stocks.append(rs.get_row_data()[1])
            monthly_hs300_stocks[first_trading_day] = hs300_stocks
    
    return monthly_hs300_stocks

# 获取所有股票的总市值数据
def get_all_stock_market_values(stocks):
    market_value_data = {}
    for stock in stocks:
        df = ak.stock_a_indicator_lg(stock)
        market_value_data[stock] = df
    return market_value_data

# 获取所有股票的价格数据
def get_all_stock_prices(stocks, start_date, end_date):
    price_data = {}
    for stock in stocks:
        df = ak.stock_zh_a_daily(stock, start_date, end_date)
        price_data[stock] = df
    return price_data

# 获取股票总市值
def get_stock_market_value(stock_code, date, market_value_data):
    df = market_value_data[stock_code]
    df = df[df['date'] == date]
    if not df.empty:
        return df['total_mv'].values[0]
    return None

# 获取股票开盘价和收盘价
def get_stock_prices(stock_code, start_date, end_date, price_data):
    df = price_data[stock_code]
    df = df[(df['date'] >= start_date) & (df['date'] <= end_date)]
    if not df.empty:
        return df['open'].values[0], df['close'].values[-1]
    return None, None

print(get_monthly_hs300_stocks(get_monthly_trading_days(start_date, end_date)))

login success!
{'2010-01-04': ['sh.600000', 'sh.600004', 'sh.600005', 'sh.600006', 'sh.600008', 'sh.600009', 'sh.600010', 'sh.600011', 'sh.600015', 'sh.600016', 'sh.600017', 'sh.600018', 'sh.600019', 'sh.600022', 'sh.600026', 'sh.600027', 'sh.600028', 'sh.600029', 'sh.600030', 'sh.600031', 'sh.600033', 'sh.600036', 'sh.600037', 'sh.600048', 'sh.600050', 'sh.600058', 'sh.600062', 'sh.600066', 'sh.600068', 'sh.600085', 'sh.600087', 'sh.600089', 'sh.600096', 'sh.600100', 'sh.600102', 'sh.600104', 'sh.600108', 'sh.600109', 'sh.600111', 'sh.600118', 'sh.600123', 'sh.600125', 'sh.600132', 'sh.600143', 'sh.600150', 'sh.600151', 'sh.600153', 'sh.600158', 'sh.600161', 'sh.600166', 'sh.600169', 'sh.600170', 'sh.600176', 'sh.600177', 'sh.600183', 'sh.600188', 'sh.600196', 'sh.600208', 'sh.600210', 'sh.600216', 'sh.600219', 'sh.600220', 'sh.600221', 'sh.600236', 'sh.600239', 'sh.600246', 'sh.600251', 'sh.600256', 'sh.600266', 'sh.600269', 'sh.600270', 'sh.600271', 'sh.600276', 'sh.600282', 'sh.600