In [2]:
# data_manager.py
import tushare as ts
import pandas as pd
import os
from datetime import datetime, timedelta

class TushareDataManager:
    def __init__(self, token, cache_dir='./tushare_cache'):
        self.pro = ts.pro_api(token)
        self.cache_dir = cache_dir
        os.makedirs(cache_dir, exist_ok=True)
        print(f"TushareDataManager initialized. Cache directory: {self.cache_dir}")

    def _get_cache_path(self, data_type, ts_code=None, start_date=None, end_date=None):
        if data_type == 'index_weekly':
            return os.path.join(self.cache_dir, f'index_weekly_{ts_code}.parquet')
        elif data_type == 'index_constituents':
            return os.path.join(self.cache_dir, f'index_constituents_{ts_code}.parquet')
        elif data_type == 'daily_basic':
            return os.path.join(self.cache_dir, f'daily_basic_{ts_code}.parquet')
        elif data_type == 'index_daily': # For CSI300 benchmark
            return os.path.join(self.cache_dir, f'index_daily_{ts_code}.parquet')
        return None

    def _load_from_cache(self, file_path, start_date_str, end_date_str):
        if os.path.exists(file_path):
            try:
                df = pd.read_parquet(file_path)
                df['trade_date'] = pd.to_datetime(df['trade_date'])
                
                # Check if the cached data covers the requested range
                if not df.empty and \
                   df['trade_date'].min() <= pd.to_datetime(start_date_str) and \
                   df['trade_date'].max() >= pd.to_datetime(end_date_str):
                    print(f"Loaded from cache: {file_path}")
                    return df
            except Exception as e:
                print(f"Error loading from cache {file_path}: {e}. Re-fetching.")
        return None

    def _save_to_cache(self, df, file_path):
        os.makedirs(os.path.dirname(file_path), exist_ok=True)
        df['trade_date'] = df['trade_date'].dt.strftime('%Y%m%d') # Store as string
        df.to_parquet(file_path, index=False)
        print(f"Saved to cache: {file_path}")

    def get_index_weekly_data(self, ts_code, start_date, end_date):
        start_date_str = start_date.strftime('%Y%m%d')
        end_date_str = end_date.strftime('%Y%m%d')
        file_path = self._get_cache_path('index_weekly', ts_code)

        cached_df = self._load_from_cache(file_path, start_date_str, end_date_str)
        if cached_df is not None:
            return cached_df[(cached_df['trade_date'] >= start_date) & (cached_df['trade_date'] <= end_date)]

        print(f"Fetching weekly data for {ts_code} from Tushare...")
        df = self.pro.pro_bar(ts_code=ts_code, asset='I', freq='W',
                              start_date=start_date_str, end_date=end_date_str,
                              fields='trade_date,close,open,high,low,vol,amount')
        if df is None or df.empty:
            print(f"No data fetched for {ts_code} in range {start_date_str} to {end_date_str}")
            return pd.DataFrame()

        df['trade_date'] = pd.to_datetime(df['trade_date'])
        df = df.sort_values('trade_date').reset_index(drop=True)
        self._save_to_cache(df, file_path)
        return df

    def get_index_daily_data(self, ts_code, start_date, end_date):
        start_date_str = start_date.strftime('%Y%m%d')
        end_date_str = end_date.strftime('%Y%m%d')
        file_path = self._get_cache_path('index_daily', ts_code)

        cached_df = self._load_from_cache(file_path, start_date_str, end_date_str)
        if cached_df is not None:
            return cached_df[(cached_df['trade_date'] >= start_date) & (cached_df['trade_date'] <= end_date)]

        print(f"Fetching daily data for {ts_code} from Tushare...")
        df = self.pro.index_daily(ts_code=ts_code,
                                  start_date=start_date_str, end_date=end_date_str,
                                  fields='trade_date,close,open,high,low,vol,amount')
        if df is None or df.empty:
            print(f"No data fetched for {ts_code} in range {start_date_str} to {end_date_str}")
            return pd.DataFrame()

        df['trade_date'] = pd.to_datetime(df['trade_date'])
        df = df.sort_values('trade_date').reset_index(drop=True)
        self._save_to_cache(df, file_path)
        return df

    def get_index_constituents(self, index_code, trade_date):
        # index_weight is monthly, so we need to find the closest month-end or month-start
        # For simplicity, let's try to get data for the first day of the month
        # or the last day of the previous month if current month's data is not available.
        # Tushare's index_weight 'trade_date' is usually the last trading day of the month.
        
        # Try current month's last day
        current_month_end = trade_date.replace(day=1) + timedelta(days=31)
        current_month_end = current_month_end.replace(day=1) - timedelta(days=1)
        
        # Try previous month's last day
        prev_month_end = trade_date.replace(day=1) - timedelta(days=1)

        dates_to_try = [current_month_end, prev_month_end]
        
        for d in dates_to_try:
            date_str = d.strftime('%Y%m%d')
            file_path = self._get_cache_path('index_constituents', index_code)
            
            # Load all cached constituents for this index and filter by date
            cached_df = self._load_from_cache(file_path, '19900101', '20991231') # Load all to check specific date
            if cached_df is not None and not cached_df.empty:
                specific_date_df = cached_df[cached_df['trade_date'] == d]
                if not specific_date_df.empty:
                    print(f"Loaded constituents for {index_code} on {date_str} from cache.")
                    return specific_date_df
            
            print(f"Fetching constituents for {index_code} on {date_str} from Tushare...")
            df = self.pro.index_weight(index_code=index_code, trade_date=date_str)
            if df is not None and not df.empty:
                df['trade_date'] = pd.to_datetime(df['trade_date'])
                # Append to existing cache or create new
                if cached_df is not None and not cached_df.empty:
                    combined_df = pd.concat([cached_df, df]).drop_duplicates(subset=['index_code', 'con_code', 'trade_date']).reset_index(drop=True)
                    self._save_to_cache(combined_df, file_path)
                else:
                    self._save_to_cache(df, file_path)
                return df
        
        print(f"No constituents data found for {index_code} around {trade_date.strftime('%Y%m%d')}")
        return pd.DataFrame()

    def get_daily_basic(self, ts_code, trade_date):
        date_str = trade_date.strftime('%Y%m%d')
        file_path = self._get_cache_path('daily_basic', ts_code)
        
        # Load all cached daily basic for this stock and filter by date
        cached_df = self._load_from_cache(file_path, '19900101', '20991231') # Load all to check specific date
        if cached_df is not None and not cached_df.empty:
            specific_date_df = cached_df[cached_df['trade_date'] == trade_date]
            if not specific_date_df.empty:
                # print(f"Loaded daily basic for {ts_code} on {date_str} from cache.")
                return specific_date_df
        
        # print(f"Fetching daily basic for {ts_code} on {date_str} from Tushare...")
        df = self.pro.daily_basic(ts_code=ts_code, trade_date=date_str, fields='ts_code,trade_date,pe_ttm,pb')
        if df is None or df.empty:
            # print(f"No daily basic data fetched for {ts_code} on {date_str}")
            return pd.DataFrame()

        df['trade_date'] = pd.to_datetime(df['trade_date'])
        # Append to existing cache or create new
        if cached_df is not None and not cached_df.empty:
            combined_df = pd.concat([cached_df, df]).drop_duplicates(subset=['ts_code', 'trade_date']).reset_index(drop=True)
            self._save_to_cache(combined_df, file_path)
        else:
            self._save_to_cache(df, file_path)
        return df

In [4]:
# factor_calculator.py
import pandas as pd
import numpy as np
from data_manager import TushareDataManager # Assuming data_manager.py is in the same directory

class FactorCalculator:
    def __init__(self, dm: TushareDataManager):
        self.dm = dm
        self.csi100_code = '000903.SH'
        self.csi500_code = '000905.SH'

    def calculate_returns(self, df, period):
        """计算指定周期内的收益率"""
        return df['close'].pct_change(periods=period)

    def calculate_rsi(self, df, period):
        """计算RSI"""
        delta = df['close'].diff()
        gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
        loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()

        rs = gain / loss
        rsi = 100 - (100 / (1 + rs))
        return rsi

    def calculate_psy(self, df, period):
        """计算PSY (心理线)"""
        # Count up weeks (close > prev_close)
        up_weeks = (df['close'].diff() > 0).astype(int).rolling(window=period).sum()
        psy = (up_weeks / period) * 100
        return psy

    def calculate_volume_ratio(self, df_csi500, df_csi100, period):
        """计算成交量比值"""
        vol_csi500 = df_csi500['vol'].rolling(window=period).sum()
        vol_csi100 = df_csi100['vol'].rolling(window=period).sum()
        return vol_csi500 / vol_csi100

    def calculate_median_pe_pb(self, index_code, trade_date):
        """
        计算指定指数在特定日期的成分股PE/PB中值。
        忽略PE为负或缺失的值。
        """
        constituents_df = self.dm.get_index_constituents(index_code, trade_date)
        if constituents_df.empty:
            return np.nan, np.nan

        pe_values =
        pb_values =

        # Filter constituents for the specific trade_date if multiple dates are returned
        constituents_on_date = constituents_df[constituents_df['trade_date'] == trade_date]
        if constituents_on_date.empty: # Fallback to latest available if exact date not found
            latest_date = constituents_df['trade_date'].max()
            constituents_on_date = constituents_df[constituents_df['trade_date'] == latest_date]
            if constituents_on_date.empty:
                return np.nan, np.nan

        for _, row in constituents_on_date.iterrows():
            stock_code = row['con_code']
            daily_basic_df = self.dm.get_daily_basic(stock_code, trade_date)
            if not daily_basic_df.empty:
                pe = daily_basic_df['pe_ttm'].iloc
                pb = daily_basic_df['pb'].iloc
                
                if pd.notna(pe) and pe > 0: # Ignore negative or NaN PE
                    pe_values.append(pe)
                if pd.notna(pb): # PB can be negative, but usually positive. Let's include all non-NaN.
                    pb_values.append(pb)
        
        median_pe = np.median(pe_values) if pe_values else np.nan
        median_pb = np.median(pb_values) if pb_values else np.nan
        
        return median_pe, median_pb

    def calculate_all_factors(self, start_date, end_date):
        """
        整合所有因子计算，生成包含所有因子值的DataFrame。
        报告中“月”周期转换为周数：1月=4周, 3月=12周, 6月=24周, 12月=48周。
        """
        print("Calculating all factors...")
        df_csi100 = self.dm.get_index_weekly_data(self.csi100_code, start_date, end_date)
        df_csi500 = self.dm.get_index_weekly_data(self.csi500_code, start_date, end_date)

        if df_csi100.empty or df_csi500.empty:
            print("Error: Could not retrieve index weekly data.")
            return pd.DataFrame()

        # Merge dataframes on trade_date
        df_merged = pd.merge(df_csi100[['trade_date', 'close', 'vol', 'amount']],
                             df_csi500[['trade_date', 'close', 'vol', 'amount']],
                             on='trade_date', suffixes=('_csi100', '_csi500'))
        
        df_merged = df_merged.sort_values('trade_date').reset_index(drop=True)

        # Calculate factors
        # Returns difference (涨幅差值)
        df_merged['ret_diff_4w'] = self.calculate_returns(df_merged, 4) - self.calculate_returns(df_merged, 4).shift(periods=4) # This is not correct based on report.
        # Corrected: CSI500 return - CSI100 return
        df_merged['ret_csi500_4w'] = self.calculate_returns(df_merged.rename(columns={'close_csi500': 'close'}), 4)
        df_merged['ret_csi100_4w'] = self.calculate_returns(df_merged.rename(columns={'close_csi100': 'close'}), 4)
        df_merged['ret_diff_4w'] = df_merged['ret_csi500_4w'] - df_merged['ret_csi100_4w']

        df_merged['ret_csi500_12w'] = self.calculate_returns(df_merged.rename(columns={'close_csi500': 'close'}), 12)
        df_merged['ret_csi100_12w'] = self.calculate_returns(df_merged.rename(columns={'close_csi100': 'close'}), 12)
        df_merged['ret_diff_12w'] = df_merged['ret_csi500_12w'] - df_merged['ret_csi100_12w']

        df_merged['ret_csi500_24w'] = self.calculate_returns(df_merged.rename(columns={'close_csi500': 'close'}), 24)
        df_merged['ret_csi100_24w'] = self.calculate_returns(df_merged.rename(columns={'close_csi100': 'close'}), 24)
        df_merged['ret_diff_24w'] = df_merged['ret_csi500_24w'] - df_merged['ret_csi100_24w']

        df_merged['ret_csi500_48w'] = self.calculate_returns(df_merged.rename(columns={'close_csi500': 'close'}), 48)
        df_merged['ret_csi100_48w'] = self.calculate_returns(df_merged.rename(columns={'close_csi100': 'close'}), 48)
        df_merged['ret_diff_48w'] = df_merged['ret_csi500_48w'] - df_merged['ret_csi100_48w']

        # RSI difference (RSI差值)
        df_merged['rsi_csi500_4w'] = self.calculate_rsi(df_merged.rename(columns={'close_csi500': 'close'}), 4)
        df_merged['rsi_csi100_4w'] = self.calculate_rsi(df_merged.rename(columns={'close_csi100': 'close'}), 4)
        df_merged['rsi_diff_4w'] = df_merged['rsi_csi500_4w'] - df_merged['rsi_csi100_4w']

        df_merged['rsi_csi500_12w'] = self.calculate_rsi(df_merged.rename(columns={'close_csi500': 'close'}), 12)
        df_merged['rsi_csi100_12w'] = self.calculate_rsi(df_merged.rename(columns={'close_csi100': 'close'}), 12)
        df_merged['rsi_diff_12w'] = df_merged['rsi_csi500_12w'] - df_merged['rsi_csi100_12w']

        df_merged['rsi_csi500_24w'] = self.calculate_rsi(df_merged.rename(columns={'close_csi500': 'close'}), 24)
        df_merged['rsi_csi100_24w'] = self.calculate_rsi(df_merged.rename(columns={'close_csi100': 'close'}), 24)
        df_merged['rsi_diff_24w'] = df_merged['rsi_csi500_24w'] - df_merged['rsi_csi100_24w']

        df_merged['rsi_csi500_48w'] = self.calculate_rsi(df_merged.rename(columns={'close_csi500': 'close'}), 48)
        df_merged['rsi_csi100_48w'] = self.calculate_rsi(df_merged.rename(columns={'close_csi100': 'close'}), 48)
        df_merged['rsi_diff_48w'] = df_merged['rsi_csi500_48w'] - df_merged['rsi_csi100_48w']

        # PSY difference (PSY差值)
        df_merged['psy_csi500_4w'] = self.calculate_psy(df_merged.rename(columns={'close_csi500': 'close'}), 4)
        df_merged['psy_csi100_4w'] = self.calculate_psy(df_merged.rename(columns={'close_csi100': 'close'}), 4)
        df_merged['psy_diff_4w'] = df_merged['psy_csi500_4w'] - df_merged['psy_csi100_4w']

        df_merged['psy_csi500_12w'] = self.calculate_psy(df_merged.rename(columns={'close_csi500': 'close'}), 12)
        df_merged['psy_csi100_12w'] = self.calculate_psy(df_merged.rename(columns={'close_csi100': 'close'}), 12)
        df_merged['psy_diff_12w'] = df_merged['psy_csi500_12w'] - df_merged['psy_csi100_12w']

        df_merged['psy_csi500_24w'] = self.calculate_psy(df_merged.rename(columns={'close_csi500': 'close'}), 24)
        df_merged['psy_csi100_24w'] = self.calculate_psy(df_merged.rename(columns={'close_csi100': 'close'}), 24)
        df_merged['psy_diff_24w'] = df_merged['psy_csi500_24w'] - df_merged['psy_csi100_24w']

        df_merged['psy_csi500_48w'] = self.calculate_psy(df_merged.rename(columns={'close_csi500': 'close'}), 48)
        df_merged['psy_csi100_48w'] = self.calculate_psy(df_merged.rename(columns={'close_csi100': 'close'}), 48)
        df_merged['psy_diff_48w'] = df_merged['psy_csi500_48w'] - df_merged['psy_csi100_48w']

        # Volume ratio (成交量比值)
        df_merged['vol_ratio_4w'] = self.calculate_volume_ratio(
            df_merged.rename(columns={'vol_csi500': 'vol'}),
            df_merged.rename(columns={'vol_csi100': 'vol'}),
            4
        )
        df_merged['vol_ratio_12w'] = self.calculate_volume_ratio(
            df_merged.rename(columns={'vol_csi500': 'vol'}),
            df_merged.rename(columns={'vol_csi100': 'vol'}),
            12
        )

        # PE/PB difference (PE/PB差值) - This needs to be calculated per week
        # This part is computationally intensive due to individual stock data fetching
        pe_diffs =
        pb_diffs =
        for i, row in df_merged.iterrows():
            trade_date = row['trade_date']
            median_pe_csi500, median_pb_csi500 = self.calculate_median_pe_pb(self.csi500_code, trade_date)
            median_pe_csi100, median_pb_csi100 = self.calculate_median_pe_pb(self.csi100_code, trade_date)
            
            pe_diffs.append(median_pe_csi500 - median_pe_csi100 if pd.notna(median_pe_csi500) and pd.notna(median_pe_csi100) else np.nan)
            pb_diffs.append(median_pb_csi500 - median_pb_csi100 if pd.notna(median_pb_csi500) and pd.notna(median_pb_csi100) else np.nan)
        
        df_merged['pe_diff'] = pe_diffs
        df_merged['pb_diff'] = pb_diffs

        # Calculate next week's relative return for labeling
        df_merged['next_week_csi500_return'] = df_merged['close_csi500'].pct_change(periods=-1).shift(1) # Shift to align with current week's factors
        df_merged['next_week_csi100_return'] = df_merged['close_csi100'].pct_change(periods=-1).shift(1)
        df_merged['next_week_relative_return'] = df_merged['next_week_csi500_return'] - df_merged['next_week_csi100_return']
        
        # Label: 1 if CSI500 outperforms CSI100, 0 otherwise
        df_merged['label'] = (df_merged['next_week_relative_return'] > 0).astype(int)

        # Drop rows with NaN in factors (due to rolling window or initial PE/PB calculation)
        # The report uses '3月RSI&6月PSY&12月PSY&1月成交' for Model II
        # So, we need 'rsi_diff_12w', 'psy_diff_24w', 'psy_diff_48w', 'vol_ratio_4w'
        # And 'pe_diff' if it's part of the model (Model II does not explicitly list PE, but it's a key factor in the report)
        # Let's include all factors for now and select later in strategy_logic
        factor_columns = [
            'rsi_diff_12w', 'psy_diff_24w', 'psy_diff_48w', 'vol_ratio_4w', # Model II factors
            'ret_diff_12w', 'ret_diff_24w', # Other important return diffs
            'pe_diff', 'pb_diff', # Valuation factors
            'rsi_diff_4w', 'rsi_diff_24w', 'rsi_diff_48w',
            'psy_diff_4w', 'psy_diff_12w',
            'vol_ratio_12w'
        ]
        
        # Ensure all required columns exist before dropping NaNs
        df_merged = df_merged.dropna(subset=factor_columns + ['label']).reset_index(drop=True)
        
        print(f"Factor calculation complete. Total rows with factors: {len(df_merged)}")
        return df_merged

SyntaxError: invalid syntax (2051489982.py, line 48)