In [None]:
import pandas as pd
import numpy as np
from typing import Tuple, List, Optional
import os
from concurrent.futures import ProcessPoolExecutor
import matplotlib.pyplot as plt

class ForexSentimentAnalyzer:
    def __init__(self, data_dir: str):
        self.data_dir = data_dir
        if not os.path.exists(data_dir):
            os.makedirs(data_dir)

    def load_df0(self, rt: str, keyword: str, window: int, model: str) -> Optional[pd.DataFrame]:
        file_path = os.path.join(self.data_dir, f"df0_{rt}_{keyword}_{window}_{model}.csv")
        if os.path.exists(file_path):
            return pd.read_csv(file_path, parse_dates=['time'])
        return None

    def save_df0(self, df0: pd.DataFrame, rt: str, keyword: str, window: int, model: str):
        file_path = os.path.join(self.data_dir, f"df0_{rt}_{keyword}_{window}_{model}.csv")
        df0.to_csv(file_path, index=False)

    def resample_df1(self, df1: pd.DataFrame, window: int, keyword: str) -> pd.DataFrame:
        df11 = df1[df1['keyword'] == keyword].copy()
        df11['time'] = pd.to_datetime(df11['time'])
        df11 = df11.set_index('time').resample(f'{window}D').mean().reset_index()
        return df11

    def resample_df2(self, df2: pd.DataFrame) -> pd.DataFrame:
        df21 = df2.copy()
        df21['time'] = pd.to_datetime(df21['time'])
        df21 = df21.set_index('time').resample('D').last().reset_index()
        return df21

    def merge_dataframes(self, df11: pd.DataFrame, df21: pd.DataFrame) -> pd.DataFrame:
        df0 = pd.merge(df11, df21, on='time', how='outer')
        df0.sort_values('time', inplace=True)
        df0.reset_index(drop=True, inplace=True)
        return df0

    def calculate_returns(self, df0: pd.DataFrame, t: int) -> pd.DataFrame:
        df0[f'returns_{t}'] = df0['price'].pct_change(t)
        return df0

    def process_params(self, params: Tuple[str, str, str, int, int], split_date: str, df1: pd.DataFrame, df2: pd.DataFrame) -> Tuple[dict, float, float]:
        rt, keyword, model, t, window = params
        df0 = self.load_df0(rt, keyword, window, model)
        
        if df0 is None:
            df11 = self.resample_df1(df1, window, keyword)
            df21 = self.resample_df2(df2)
            df0 = self.merge_dataframes(df11, df21)
            df0 = self.calculate_returns(df0, t)
            
            df0[f'{model}_score'] = df0['sentiment']
            df0[f'{model}_score_ema_{window}'] = df0[f'{model}_score'].ewm(span=window).mean()
            
            self.save_df0(df0, rt, keyword, window, model)
        
        df0['time'] = pd.to_datetime(df0['time'])
        split_datetime = pd.to_datetime(split_date)
        
        train_df = df0[df0['time'] < split_datetime]
        test_df = df0[df0['time'] >= split_datetime]
        
        train_corr = self.calculate_correlation(train_df, model, window, t)
        test_corr = self.calculate_correlation(test_df, model, window, t)
        
        return {'rt': rt, 'keyword': keyword, 'model': model, 't': t, 'window': window}, train_corr, test_corr

    def calculate_correlation(self, df: pd.DataFrame, model: str, window: int, t: int) -> float:
        column_name = f"{model}_score_ema_{window}"
        returns_column = f"returns_{t}"
        
        df_clean = df.dropna(subset=[column_name, returns_column])
        
        if df_clean.empty:
            return 0.0
        
        correlation = df_clean[column_name].corr(df_clean[returns_column])
        return correlation if not np.isnan(correlation) else 0.0

    def analyze(self, rt_list: List[str], keyword_list: List[str], model_list: List[str], 
                t_list: List[int], window_list: List[int], split_date: str, df1: pd.DataFrame, df2: pd.DataFrame, num_workers: int = 4) -> pd.DataFrame:
        params_list = [(rt, keyword, model, t, window) 
                       for rt in rt_list 
                       for keyword in keyword_list 
                       for model in model_list 
                       for t in t_list 
                       for window in window_list]

        with ProcessPoolExecutor(max_workers=num_workers) as executor:
            results = list(executor.map(lambda p: self.process_params(p, split_date, df1, df2), params_list))

        df_results = pd.DataFrame([
            {**params, 'train_corr': train_corr, 'test_corr': test_corr}
            for params, train_corr, test_corr in results
        ])

        df_results = df_results.sort_values(by='train_corr', key=abs, ascending=False)

        return df_results

    def plot_top_correlations(self, df_results: pd.DataFrame, top_n: int = 10):
        top_results = df_results.head(top_n)
        
        plt.figure(figsize=(12, 6))
        plt.bar(range(top_n), top_results['train_corr'], align='center', alpha=0.8, label='Train')
        plt.bar(range(top_n), top_results['test_corr'], align='center', alpha=0.6, label='Test')
        plt.xlabel('Parameters')
        plt.ylabel('Correlation')
        plt.title(f'Top {top_n} Correlations')
        plt.legend()
        plt.xticks(range(top_n), [f"{row['rt']}-{row['keyword']}-{row['model']}-{row['t']}-{row['window']}" 
                                  for _, row in top_results.iterrows()], rotation=90)
        plt.tight_layout()
        plt.show()

# 测试示例
if __name__ == "__main__":
    # 创建示例数据
    dates = pd.date_range(start='2023-01-01', end='2024-12-31', freq='D')
    
    df1 = pd.DataFrame({
        'time': dates.repeat(2),
        'keyword': ['economic', 'political'] * len(dates),
        'sentiment': np.random.randn(len(dates) * 2)
    })
    
    df2 = pd.DataFrame({
        'time': dates,
        'price': np.random.randn(len(dates)) * 100 + 1000
    })

    # 初始化分析器
    analyzer = ForexSentimentAnalyzer('data')

    # 设置参数
    rt_list = ['EURUSD']
    keyword_list = ['economic', 'political']
    model_list = ['sentiment']
    t_list = [1, 5]
    window_list = [5, 10]
    split_date = '2024-01-01'

    # 运行分析
    results = analyzer.analyze(rt_list, keyword_list, model_list, t_list, window_list, split_date, df1, df2)

    # 打印结果
    print(results)

    # 绘制图表
    analyzer.plot_top_correlations(results)