# Exploratory Data Analysis

Analyze the processed features and labels:
- Feature distributions
- Correlation matrix
- Return distributions by ticker
- Temporal patterns

In [None]:
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

## 1. Load Data

In [None]:
with open('../data/processed/processed_data.pkl', 'rb') as f:
    data = pickle.load(f)

splits = data['splits']
feature_cols = data['feature_cols']
prices = data['prices']
config = data['config']

# Combine all splits for overview
all_data = pd.concat([splits.train, splits.val, splits.test])

print(f"Total samples: {len(all_data)}")
print(f"Features: {feature_cols}")
print(f"Tickers: {config['tickers']['etfs']}")
print(f"Date range: {all_data['date'].min()} to {all_data['date'].max()}")

## 2. Feature Distributions

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
axes = axes.flatten()

for idx, col in enumerate(feature_cols):
    ax = axes[idx]
    ax.hist(all_data[col], bins=50, alpha=0.7, edgecolor='black')
    ax.axvline(x=all_data[col].mean(), color='red', linestyle='--', label=f'Mean: {all_data[col].mean():.4f}')
    ax.set_xlabel(col)
    ax.set_ylabel('Frequency')
    ax.set_title(f'{col}\nStd: {all_data[col].std():.4f}')
    ax.legend()

plt.tight_layout()
plt.savefig('../outputs/backtest/feature_distributions.png', dpi=150)
plt.show()

## 3. Feature Correlation Matrix

In [None]:
# Correlation matrix including target
corr_cols = feature_cols + ['excess_return']
corr_matrix = all_data[corr_cols].corr()

fig, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(corr_matrix, annot=True, fmt='.2f', cmap='RdBu_r', center=0,
            square=True, linewidths=0.5, ax=ax)
ax.set_title('Feature Correlation Matrix (including target)')

plt.tight_layout()
plt.savefig('../outputs/backtest/correlation_matrix.png', dpi=150)
plt.show()

# Print correlation with target
print("\nCorrelation with excess_return:")
print(corr_matrix['excess_return'].sort_values(ascending=False))

## 4. Return Distributions by Ticker

In [None]:
# Box plot of excess returns by ticker
fig, ax = plt.subplots(figsize=(14, 6))

# Order by median return
ticker_medians = all_data.groupby('ticker')['excess_return'].median().sort_values()
order = ticker_medians.index.tolist()

sns.boxplot(data=all_data, x='ticker', y='excess_return', order=order, ax=ax)
ax.axhline(y=0, color='red', linestyle='--', alpha=0.5)
ax.set_xlabel('Ticker')
ax.set_ylabel('Excess Return vs SPY')
ax.set_title('Distribution of Weekly Excess Returns by Ticker')
plt.xticks(rotation=45, ha='right')

plt.tight_layout()
plt.savefig('../outputs/backtest/returns_by_ticker.png', dpi=150)
plt.show()

## 5. Feature Statistics by Split

In [None]:
# Compare feature distributions across splits
splits_summary = []

for split_name, split_data in [('Train', splits.train), ('Val', splits.val), ('Test', splits.test)]:
    row = {'Split': split_name, 'N': len(split_data)}
    for col in feature_cols + ['excess_return']:
        row[f'{col}_mean'] = split_data[col].mean()
        row[f'{col}_std'] = split_data[col].std()
    splits_summary.append(row)

splits_df = pd.DataFrame(splits_summary).set_index('Split')

# Display means
mean_cols = [c for c in splits_df.columns if '_mean' in c]
print("Feature Means by Split:")
display(splits_df[['N'] + mean_cols].round(4))

## 6. Temporal Analysis

In [None]:
# Rolling average of excess returns over time
fig, axes = plt.subplots(2, 1, figsize=(14, 8))

# Group by date
daily_stats = all_data.groupby('date').agg({
    'excess_return': ['mean', 'std']
}).droplevel(0, axis=1)

# Plot mean excess return
ax1 = axes[0]
ax1.plot(daily_stats.index, daily_stats['mean'] * 100, alpha=0.5, label='Weekly mean')
rolling_mean = daily_stats['mean'].rolling(20).mean() * 100
ax1.plot(daily_stats.index, rolling_mean, color='red', linewidth=2, label='20-week MA')
ax1.axhline(y=0, color='gray', linestyle='--', alpha=0.5)

# Mark split boundaries
train_end = splits.train['date'].max()
val_end = splits.val['date'].max()
ax1.axvline(x=train_end, color='green', linestyle='--', alpha=0.7, label='Train/Val split')
ax1.axvline(x=val_end, color='blue', linestyle='--', alpha=0.7, label='Val/Test split')

ax1.set_xlabel('Date')
ax1.set_ylabel('Mean Excess Return (%)')
ax1.set_title('Cross-Sectional Mean Excess Return Over Time')
ax1.legend(loc='upper left')

# Plot volatility
ax2 = axes[1]
ax2.plot(daily_stats.index, daily_stats['std'] * 100, alpha=0.5, label='Weekly std')
rolling_std = daily_stats['std'].rolling(20).mean() * 100
ax2.plot(daily_stats.index, rolling_std, color='red', linewidth=2, label='20-week MA')
ax2.axvline(x=train_end, color='green', linestyle='--', alpha=0.7)
ax2.axvline(x=val_end, color='blue', linestyle='--', alpha=0.7)

ax2.set_xlabel('Date')
ax2.set_ylabel('Std of Excess Return (%)')
ax2.set_title('Cross-Sectional Return Dispersion Over Time')
ax2.legend(loc='upper left')

plt.tight_layout()
plt.savefig('../outputs/backtest/temporal_analysis.png', dpi=150)
plt.show()

## 7. Price Performance Overview

In [None]:
# Normalized price performance
fig, ax = plt.subplots(figsize=(14, 8))

# Normalize all prices to start at 1
normalized = prices / prices.iloc[0]

# Plot SPY prominently
ax.plot(normalized.index, normalized['SPY'], color='black', linewidth=2, label='SPY', zorder=10)

# Plot others with transparency
for ticker in normalized.columns:
    if ticker != 'SPY':
        ax.plot(normalized.index, normalized[ticker], alpha=0.4, linewidth=1)

ax.axhline(y=1, color='gray', linestyle='--', alpha=0.5)
ax.set_xlabel('Date')
ax.set_ylabel('Normalized Price (start = 1)')
ax.set_title('ETF Price Performance (2012-2025)')
ax.legend()

plt.tight_layout()
plt.savefig('../outputs/backtest/price_performance.png', dpi=150)
plt.show()

## 8. Summary Statistics

In [None]:
# Comprehensive summary
print("="*60)
print("DATA SUMMARY")
print("="*60)

print(f"\nDate Range: {all_data['date'].min()} to {all_data['date'].max()}")
print(f"Total Samples: {len(all_data):,}")
print(f"Unique Weeks: {all_data['date'].nunique()}")
print(f"Tickers: {all_data['ticker'].nunique()} ETFs")

print(f"\nSplit Sizes:")
print(f"  Train: {len(splits.train):,} ({len(splits.train)/len(all_data)*100:.1f}%)")
print(f"  Val:   {len(splits.val):,} ({len(splits.val)/len(all_data)*100:.1f}%)")
print(f"  Test:  {len(splits.test):,} ({len(splits.test)/len(all_data)*100:.1f}%)")

print(f"\nTarget (excess_return) Statistics:")
print(f"  Mean: {all_data['excess_return'].mean()*100:.4f}%")
print(f"  Std:  {all_data['excess_return'].std()*100:.4f}%")
print(f"  Min:  {all_data['excess_return'].min()*100:.4f}%")
print(f"  Max:  {all_data['excess_return'].max()*100:.4f}%")

print(f"\nFeature Descriptions:")
for col in feature_cols:
    print(f"  {col}: mean={all_data[col].mean():.4f}, std={all_data[col].std():.4f}")