In [None]:
import pandas as pd
import numpy as np

def calculate_stats(returns, positions, initial_capital=1000):
    trades = returns * positions
    
    # Calculate portfolio returns and cumulative returns
    portfolio_returns = trades.mean(axis=1)
    cumulative_returns = (1 + portfolio_returns).cumprod()
    
    # Calculate the capital at each point
    capital = initial_capital * cumulative_returns
    
    total_trades = ((positions != 0) & positions.notna()).sum().sum()
    profitable_trades = (trades > 0).sum().sum()
    losing_trades = (trades < 0).sum().sum()
    even_trades = (trades == 0).sum().sum()
    
    total_profit = capital.iloc[-1] - initial_capital
    gross_profit = (trades[trades > 0] * capital.shift(1).values.reshape(-1, 1)).sum().sum()
    gross_loss = abs((trades[trades < 0] * capital.shift(1).values.reshape(-1, 1)).sum().sum())
    
    avg_trade_net_profit = total_profit / total_trades if total_trades > 0 else 0
    avg_winning_trade = gross_profit / profitable_trades if profitable_trades > 0 else 0
    avg_losing_trade = -gross_loss / losing_trades if losing_trades > 0 else 0
    
    largest_winning_trade = (trades * capital.shift(1).values.reshape(-1, 1)).max().max()
    largest_losing_trade = (trades * capital.shift(1).values.reshape(-1, 1)).min().min()

    final_capital = capital.iloc[-1]
    total_return = (final_capital - initial_capital) / initial_capital

    profit_factor = gross_profit / gross_loss if gross_loss != 0 else np.inf
    ratio_avg_win_loss = avg_winning_trade / abs(avg_losing_trade) if avg_losing_trade != 0 else np.inf

    # Long and short specific metrics
    long_trades = (positions > 0)
    short_trades = (positions < 0)
    
    long_profit = (trades * long_trades * capital.shift(1).values.reshape(-1, 1)).sum().sum()
    short_profit = (trades * short_trades * capital.shift(1).values.reshape(-1, 1)).sum().sum()
    long_trades_count = long_trades.sum().sum()
    short_trades_count = short_trades.sum().sum()
    
    long_winning_trades = ((trades > 0) & long_trades).sum().sum()
    long_losing_trades = ((trades < 0) & long_trades).sum().sum()
    short_winning_trades = ((trades > 0) & short_trades).sum().sum()
    short_losing_trades = ((trades < 0) & short_trades).sum().sum()

    return {
        'Total number of round_trips': int(total_trades),
        'Percent profitable': profitable_trades / total_trades if total_trades > 0 else 0,
        'Winning round_trips': int(profitable_trades),
        'Losing round_trips': int(losing_trades),
        'Even round_trips': int(even_trades),
        'Total profit': total_profit,
        'Gross profit': gross_profit,
        'Gross loss': gross_loss,
        'Profit factor': profit_factor,
        'Avg. trade net profit': avg_trade_net_profit,
        'Avg. winning trade': avg_winning_trade,
        'Avg. losing trade': avg_losing_trade,
        'Ratio Avg. Win:Avg. Loss': ratio_avg_win_loss,
        'Largest winning trade': largest_winning_trade,
        'Largest losing trade': largest_losing_trade,
        'Initial capital': initial_capital,
        'Final capital': final_capital,
        'Total return': total_return,
        'Long trades profit': long_profit,
        'Short trades profit': short_profit,
        'Number of long trades': int(long_trades_count),
        'Number of short trades': int(short_trades_count),
        'Long winning trades': int(long_winning_trades),
        'Long losing trades': int(long_losing_trades),
        'Short winning trades': int(short_winning_trades),
        'Short losing trades': int(short_losing_trades)
    }

# The rest of the code remains the same
price_changes, df_rebalanced = resampled_returns, df_balance[rebalance_mask]

# Calculate stats for all trades
all_trades_stats = calculate_stats(price_changes, df_rebalanced, initial_capital=1000)

# Create summary stats table
summary_stats = pd.DataFrame({
    'Strategy': all_trades_stats
}, index=[
    'Total number of round_trips',
    'Percent profitable',
    'Winning round_trips',
    'Losing round_trips',
    'Even round_trips',
    'Initial capital',
    'Final capital',
    'Total return',
    'Number of long trades',
    'Number of short trades',
    'Long trades profit',
    'Short trades profit'
])

# Create PnL stats table
pnl_stats = pd.DataFrame({
    'Strategy': all_trades_stats
}, index=[
    'Total profit',
    'Gross profit',
    'Gross loss',
    'Profit factor',
    'Avg. trade net profit',
    'Avg. winning trade',
    'Avg. losing trade',
    'Ratio Avg. Win:Avg. Loss',
    'Largest winning trade',
    'Largest losing trade'
])

# Create long/short breakdown table
long_short_breakdown = pd.DataFrame({
    'Long trades': {
        'Number of trades': all_trades_stats['Number of long trades'],
        'Winning trades': all_trades_stats['Long winning trades'],
        'Losing trades': all_trades_stats['Long losing trades'],
        'Total profit': all_trades_stats['Long trades profit']
    },
    'Short trades': {
        'Number of trades': all_trades_stats['Number of short trades'],
        'Winning trades': all_trades_stats['Short winning trades'],
        'Losing trades': all_trades_stats['Short losing trades'],
        'Total profit': all_trades_stats['Short trades profit']
    }
})

# Format the tables
summary_stats = summary_stats.round(2)
summary_stats.loc['Percent profitable'] = summary_stats.loc['Percent profitable'].map('{:.2%}'.format)
summary_stats.loc['Initial capital':'Final capital'] = summary_stats.loc['Initial capital':'Final capital'].applymap('${:,.2f}'.format)
summary_stats.loc['Total return'] = summary_stats.loc['Total return'].map('{:.2%}'.format)
summary_stats.loc['Long trades profit':'Short trades profit'] = summary_stats.loc['Long trades profit':'Short trades profit'].applymap('${:,.2f}'.format)

pnl_stats = pnl_stats.round(2)
pnl_stats.loc['Total profit':'Gross loss'] = pnl_stats.loc['Total profit':'Gross loss'].applymap('${:,.2f}'.format)
pnl_stats.loc['Profit factor'] = pnl_stats.loc['Profit factor'].map('{:.2f}'.format)
pnl_stats.loc['Avg. trade net profit':'Largest losing trade'] = pnl_stats.loc['Avg. trade net profit':'Largest losing trade'].applymap('${:,.2f}'.format)

long_short_breakdown = long_short_breakdown.round(2)
long_short_breakdown.loc['Total profit'] = long_short_breakdown.loc['Total profit'].map('${:,.2f}'.format)

print("Summary stats")
print(summary_stats)
print("\nPnL stats")
print(pnl_stats)
print("\nLong/Short Breakdown")
print(long_short_breakdown)