In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import timedelta
from scipy import stats

# 假设 df 和 dfm 已经加载好

# 1. 时间对齐
df['time'] = pd.to_datetime(df['createDate']).dt.floor('T')
dfm['asoftime'] = pd.to_datetime(dfm['asoftime'])

# 找到最近的交易时间
df['time'] = df['time'].apply(lambda x: dfm['asoftime'][dfm['asoftime'] >= x].min())

# 合并价格数据
df = pd.merge(df, dfm[['asoftime', 'close']], left_on='time', right_on='asoftime', how='left')
df = df.drop('asoftime', axis=1)

# 2. 回测逻辑
def backtest(df, dfm, thres1, thres2, T, r_pt, r_sl):
    trades = []
    position = 0
    entry_price = 0
    entry_time = None
    
    for i, row in df.iterrows():
        if pd.isnull(row['close']):
            continue
        
        # 检查是否需要平仓
        if position != 0:
            # Time exit
            if row['time'] >= entry_time + timedelta(minutes=T):
                exit_price = row['close']
                trades.append({
                    'entry_time': entry_time,
                    'exit_time': row['time'],
                    'entry_price': entry_price,
                    'exit_price': exit_price,
                    'position': position,
                    'exit_reason': 'Time exit'
                })
                position = 0
            
            # Profit exit
            elif (position == 1 and row['close'] / entry_price - 1 >= r_pt) or \
                 (position == -1 and 1 - row['close'] / entry_price >= r_pt):
                exit_price = row['close']
                trades.append({
                    'entry_time': entry_time,
                    'exit_time': row['time'],
                    'entry_price': entry_price,
                    'exit_price': exit_price,
                    'position': position,
                    'exit_reason': 'Profit exit'
                })
                position = 0
            
            # Stop-loss exit
            elif (position == 1 and row['close'] / entry_price - 1 <= r_sl) or \
                 (position == -1 and 1 - row['close'] / entry_price <= r_sl):
                exit_price = row['close']
                trades.append({
                    'entry_time': entry_time,
                    'exit_time': row['time'],
                    'entry_price': entry_price,
                    'exit_price': exit_price,
                    'position': position,
                    'exit_reason': 'Stop-loss exit'
                })
                position = 0
            
            # Reversal exit
            elif (position == 1 and row['FinBERT_sentiment_title'] <= thres2) or \
                 (position == -1 and row['FinBERT_sentiment_title'] >= thres1):
                exit_price = row['close']
                trades.append({
                    'entry_time': entry_time,
                    'exit_time': row['time'],
                    'entry_price': entry_price,
                    'exit_price': exit_price,
                    'position': position,
                    'exit_reason': 'Reversal exit'
                })
                position = 0
                
                # 反向开仓
                if row['FinBERT_sentiment_title'] > thres1:
                    position = 1
                    entry_price = row['close']
                    entry_time = row['time']
                elif row['FinBERT_sentiment_title'] < thres2:
                    position = -1
                    entry_price = row['close']
                    entry_time = row['time']
        
        # 开仓
        elif row['FinBERT_sentiment_title'] > thres1:
            position = 1
            entry_price = row['close']
            entry_time = row['time']
        elif row['FinBERT_sentiment_title'] < thres2:
            position = -1
            entry_price = row['close']
            entry_time = row['time']
    
    # 最后一个交易日平仓
    if position != 0:
        exit_price = dfm['close'].iloc[-1]
        trades.append({
            'entry_time': entry_time,
            'exit_time': dfm['asoftime'].iloc[-1],
            'entry_price': entry_price,
            'exit_price': exit_price,
            'position': position,
            'exit_reason': 'End of backtest'
        })
    
    return pd.DataFrame(trades)

# 运行回测
trades = backtest(df, dfm, thres1=0.5, thres2=-0.5, T=60, r_pt=0.02, r_sl=-0.01)

# 3. 计算回测指标
def calculate_metrics(trades):
    trades['return'] = np.where(trades['position'] == 1, 
                                trades['exit_price'] / trades['entry_price'] - 1,
                                1 - trades['exit_price'] / trades['entry_price'])
    
    total_return = (1 + trades['return']).prod() - 1
    sharpe_ratio = np.sqrt(252) * trades['return'].mean() / trades['return'].std()
    
    cumulative_returns = (1 + trades['return']).cumprod()
    max_drawdown = (cumulative_returns.cummax() - cumulative_returns).max()
    
    win_rate = (trades['return'] > 0).mean()
    profit_factor = trades[trades['return'] > 0]['return'].sum() / abs(trades[trades['return'] < 0]['return'].sum())
    
    return {
        'Total Return': total_return,
        'Sharpe Ratio': sharpe_ratio,
        'Max Drawdown': max_drawdown,
        'Win Rate': win_rate,
        'Profit Factor': profit_factor
    }

metrics = calculate_metrics(trades)
print(metrics)

# 4. 分析long和short操作
long_trades = trades[trades['position'] == 1]
short_trades = trades[trades['position'] == -1]

long_metrics = calculate_metrics(long_trades)
short_metrics = calculate_metrics(short_trades)

print("Long trades metrics:", long_metrics)
print("Short trades metrics:", short_metrics)

# 5. 绘制累积收益率曲线
def plot_cumulative_returns(trades):
    trades['cumulative_return'] = (1 + trades['return']).cumprod()
    
    plt.figure(figsize=(12, 6))
    sns.lineplot(x='exit_time', y='cumulative_return', data=trades)
    plt.title('Cumulative Returns')
    plt.xlabel('Date')
    plt.ylabel('Cumulative Return')
    
    for i, trade in trades.iterrows():
        color = 'g' if trade['position'] == 1 else 'r'
        plt.scatter(trade['entry_time'], trade['cumulative_return'], color=color, marker='^')
        plt.scatter(trade['exit_time'], trade['cumulative_return'], color=color, marker='v')
    
    plt.tight_layout()
    plt.savefig('cumulative_returns.png')
    plt.close()

plot_cumulative_returns(trades)

# 6. 保存结果
trades.to_csv('trades.csv', index=False)
pd.DataFrame(metrics, index=[0]).to_csv('metrics.csv', index=False)
pd.DataFrame(long_metrics, index=[0]).to_csv('long_metrics.csv', index=False)
pd.DataFrame(short_metrics, index=[0]).to_csv('short_metrics.csv', index=False)

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

def plot_interactive_cumulative_returns(trades):
    trades['cumulative_return'] = (1 + trades['return']).cumprod()
    
    fig = make_subplots(specs=[[{"secondary_y": True}]])
    
    fig.add_trace(
        go.Scatter(x=trades['exit_time'], y=trades['cumulative_return'], name="Cumulative Return"),
        secondary_y=False,
    )
    
    for i, trade in trades.iterrows():
        color = 'green' if trade['position'] == 1 else 'red'
        symbol = 'triangle-up' if trade['position'] == 1 else 'triangle-down'
        
        fig.add_trace(
            go.Scatter(
                x=[trade['entry_time'], trade['exit_time']],
                y=[trade['cumulative_return'], trade['cumulative_return']],
                mode='markers',
                marker=dict(color=color, symbol=symbol, size=10),
                name=f"Trade {i+1}",
                text=[f"Entry: {trade['entry_time']}<br>Price: {trade['entry_price']:.2f}",
                      f"Exit: {trade['exit_time']}<br>Price: {trade['exit_price']:.2f}<br>Return: {trade['return']:.2%}<br>Reason: {trade['exit_reason']}"],
                hoverinfo='text'
            ),
            secondary_y=False,
        )
    
    fig.update_layout(
        title_text="Cumulative Returns with Trade Points",
        xaxis_title="Date",
        yaxis_title="Cumulative Return",
        hovermode="closest"
    )
    
    fig.write_html("interactive_cumulative_returns.html")

plot_interactive_cumulative_returns(trades)

In [None]:
class SentimentBacktest:
    def __init__(self, df, dfm, thres1, thres2, T, r_pt, r_sl):
        self.df = df
        self.dfm = dfm
        self.thres1 = thres1
        self.thres2 = thres2
        self.T = T
        self.r_pt = r_pt
        self.r_sl = r_sl
        
        self._align_time()
        self.trades = self._backtest()
        self.metrics = self._calculate_metrics(self.trades)
        self.long_metrics = self._calculate_metrics(self.trades[self.trades['position'] == 1])
        self.short_metrics = self._calculate_metrics(self.trades[self.trades['position'] == -1])
    
    def _align_time(self):
        self.df['time'] = pd.to_datetime(self.df['createDate']).dt.floor('T')
        self.dfm['asoftime'] = pd.to_datetime(self.dfm['asoftime'])
        self.df['time'] = self.df['time'].apply(lambda x: self.dfm['asoftime'][self.dfm['asoftime'] >= x].min())
        self.df = pd.merge(self.df, self.dfm[['asoftime', 'close']], left_on='time', right_on='asoftime', how='left')
        self.df = self.df.drop('asoftime', axis=1)
    
    def _backtest(self):
        # 实现回测逻辑，与之前的 backtest 函数相同
        pass
    
    def _calculate_metrics(self, trades):
        # 实现指标计算，与之前的 calculate_metrics 函数相同
        pass
    
    def plot_cumulative_returns(self):
        # 实现绘图功能，与之前的 plot_cumulative_returns 函数相同
        pass
    
    def plot_interactive_cumulative_returns(self):
        # 实现交互式绘图功能，与之前的 plot_interactive_cumulative_returns 函数相同
        pass
    
    def save_results(self):
        self.trades.to_csv('trades.csv', index=False)
        pd.DataFrame(self.metrics, index=[0]).to_csv('metrics.csv', index=False)
        pd.DataFrame(self.long_metrics, index=[0]).to_csv('long_metrics.csv', index=False)
        pd.DataFrame(self.short_metrics, index=[0]).to_csv('short_metrics.csv', index=False)
        self.plot_cumulative_returns()
        self.plot_interactive_cumulative_returns()

# 使用示例
backtest = SentimentBacktest(df, dfm, thres1=0.5, thres2=-0.5, T=60, r_pt=0.02, r_sl=-0.01)
backtest.save_results()
print(backtest.metrics)
print("Long trades metrics:", backtest.long_metrics)
print("Short trades metrics:", backtest.short_metrics)