# Multi-Head Diversity: H Scaling & FT-Transformer Baseline

Experiments:
1. **H scaling** (clean only): Multi-Head Diversity for H ∈ {1, 2, 4, 8, 16}
2. **FT-Transformer baseline**: Modern tabular transformer for comparison

Uses subset of data when `USE_SMALL_DATA=True` for faster runs.

In [None]:
USE_SMALL_DATA = False  # True = subset for quick runs; False = full data (matches paper)
N_SAMPLES = 5000       # Max observations when USE_SMALL_DATA
N_EPOCHS = 100         # Epochs when USE_SMALL_DATA=False (use 30 for small-data runs)

In [None]:
# Mount Google Drive first (required when running in Colab)
try:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=False)
    print('Google Drive mounted.')
except Exception:
    pass

In [None]:
import sys
import os
from pathlib import Path
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import random
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, r2_score
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

def _find_repo_root():
    cwd = Path.cwd().resolve()
    # Colab: explicit paths (run Drive mount cell first)
    candidates = [
        Path('/content/drive/MyDrive/multihead-attention-robustness'),
        Path('/content/drive/My Drive/multihead-attention-robustness'),
        Path('/content/repo_run'),
    ]
    for p in candidates:
        if p.exists() and (p / 'src').exists():
            return p
    # Colab: search under Drive for any folder containing 'multihead-attention' with src
    drive_root = Path('/content/drive')
    if drive_root.exists():
        for base in [drive_root / 'MyDrive', drive_root / 'My Drive', drive_root]:
            if base.exists():
                for sub in base.iterdir():
                    if sub.is_dir() and 'multihead-attention' in sub.name.lower() and (sub / 'src').exists():
                        return sub
    # Local: walk up from cwd
    p = cwd
    for _ in range(10):
        if (p / 'src').exists():
            return p
        if p.parent == p:
            break
        p = p.parent
    # Fallback: if we're in notebooks/, parent is repo root
    if cwd.name == 'notebooks' and (cwd.parent / 'src').exists():
        return cwd.parent
    return cwd.parent if cwd.name == 'notebooks' else cwd

repo_root = _find_repo_root()
if not (repo_root / 'src').exists():
    raise FileNotFoundError(
        f"Could not find repo root with 'src' folder. repo_root={repo_root}\n"
        "In Colab: run the Drive mount cell above first, then ensure your project folder "
        "(multihead-attention-robustness) is in My Drive."
    )
sys.path.insert(0, str(repo_root))
os.chdir(repo_root)  # ensure cwd is repo root for relative paths
from src.models.feature_token_transformer import FeatureTokenTransformer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}, Repo: {repo_root}')

## 1. Load Data

In [None]:
data_path = repo_root / 'data' / 'cross_sectional' / 'master_table.csv'
if not data_path.exists():
    data_path = repo_root / 'data' / 'master_table.csv'
df = pd.read_csv(data_path)
if 'date' in df.columns:
    df['date'] = pd.to_datetime(df['date'])
    df = df.set_index('date')

class CrossSectionalDataSplitter:
    def __init__(self, train_start='2005-01-01', train_end='2017-12-31', val_start='2018-01-01', val_end='2019-12-31'):
        self.train_start, self.train_end = train_start, train_end
        self.val_start, self.val_end = val_start, val_end
    def split(self, master_table):
        master_table = master_table.copy()
        master_table.index = pd.to_datetime(master_table.index)
        return {'train': master_table.loc[self.train_start:self.train_end], 'val': master_table.loc[self.val_start:self.val_end]}
    def prepare_features_labels(self, data):
        if data.empty:
            return pd.DataFrame(), pd.Series()
        numeric_data = data.select_dtypes(include=[np.number])
        if numeric_data.empty:
            return pd.DataFrame(), pd.Series()
        exclude_cols = ['mktcap', 'market_cap', 'date', 'year', 'month', 'ticker', 'permno', 'gvkey']
        target_cols = ['return', 'returns', 'ret', 'target', 'y', 'next_return', 'forward_return', 'ret_1', 'ret_1m', 'ret_12m', 'future_return', 'returns_1d']
        target_col = None
        for tc in target_cols:
            for col in numeric_data.columns:
                if tc.lower() in col.lower() and col.lower() not in [ec.lower() for ec in exclude_cols]:
                    target_col = col
                    break
            if target_col:
                break
        if target_col is None:
            potential = [c for c in numeric_data.columns if c.lower() not in [ec.lower() for ec in exclude_cols]]
            target_col = potential[-2] if len(potential) > 1 else (potential[-1] if potential else numeric_data.columns[-1])
        feature_cols = [c for c in numeric_data.columns if c != target_col and c.lower() not in [ec.lower() for ec in exclude_cols]]
        if not feature_cols:
            feature_cols = [c for c in numeric_data.columns if c != target_col]
        if not feature_cols:
            feature_cols = numeric_data.columns[:-1].tolist()
            target_col = numeric_data.columns[-1]
        return numeric_data[feature_cols], numeric_data[target_col]

splitter = CrossSectionalDataSplitter()
data_splits = splitter.split(df)
train_df, val_df = data_splits['train'], data_splits['val']
X_train_df, y_train = splitter.prepare_features_labels(train_df)
X_val_df, y_val = splitter.prepare_features_labels(val_df)
X_train = X_train_df.fillna(0).values.astype(np.float32)
y_train = y_train.fillna(0).values.astype(np.float32)
X_val = X_val_df.fillna(0).values.astype(np.float32)
y_val = y_val.fillna(0).values.astype(np.float32)

if USE_SMALL_DATA and N_SAMPLES < len(X_train):
    idx = np.random.RandomState(RANDOM_SEED).choice(len(X_train), N_SAMPLES, replace=False)
    X_train, y_train = X_train[idx], y_train[idx]
    print(f'Using subset: {N_SAMPLES} train samples')

scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_val_scaled = scaler.transform(X_val)
print(f'Data: train {X_train_scaled.shape[0]}, val {X_val_scaled.shape[0]}, features {X_train_scaled.shape[1]}')

## 2. Training Function (with Diversity Loss)

In [None]:
def train_mhd_transformer(model, model_name, X_train, y_train, X_val, y_val, config, device='cpu'):
    """Train Multi-Head Diversity transformer (includes diversity loss)."""
    model = model.to(device)
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate'])
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
    
    X_train_t = torch.FloatTensor(X_train).to(device)
    y_train_t = torch.FloatTensor(y_train).to(device)
    X_val_t = torch.FloatTensor(X_val).to(device)
    y_val_t = torch.FloatTensor(y_val).to(device)
    
    num_features = model.num_features
    if X_train.shape[1] != num_features:
        X_train_t = torch.FloatTensor(X_train[:, :num_features]).to(device)
        X_val_t = torch.FloatTensor(X_val[:, :num_features]).to(device)
    
    best_val_loss = float('inf')
    patience_counter = 0
    train_losses, val_losses = [], []
    batch_size = config['batch_size']
    n_batches = (len(X_train_t) + batch_size - 1) // batch_size
    epochs = config['epochs']
    
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0
        for i in range(0, len(X_train_t), batch_size):
            batch_X = X_train_t[i:i+batch_size]
            batch_y = y_train_t[i:i+batch_size]
            optimizer.zero_grad()
            pred, attn_dict = model(batch_X)
            mse_loss = criterion(pred.squeeze(), batch_y)
            loss = mse_loss
            if model.use_head_diversity and attn_dict:
                attn_list = [attn_dict[f'layer_{j}'] for j in range(len(attn_dict))]
                div_loss = model.compute_diversity_loss(attn_list)
                loss = mse_loss + div_loss
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            epoch_loss += loss.item()
        epoch_loss /= n_batches
        
        model.eval()
        with torch.no_grad():
            val_pred, _ = model(X_val_t)
            val_loss = criterion(val_pred.squeeze(), y_val_t)
        
        train_losses.append(epoch_loss)
        val_losses.append(val_loss.item())
        scheduler.step(val_loss)
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= config['patience']:
                print(f'  {model_name}: Early stop at epoch {epoch+1}')
                break
        if (epoch + 1) % 10 == 0:
            print(f'  {model_name} Epoch {epoch+1}/{epochs}: train={epoch_loss:.6f}, val={val_loss.item():.6f}')
    
    model.eval()
    with torch.no_grad():
        final_pred, _ = model(X_val_t)
        final_pred = final_pred.squeeze().cpu().numpy()
    return model, final_pred, train_losses, val_losses

## 3. H Scaling Experiment (H ∈ {1, 2, 4, 8, 16})

In [None]:
H_VALUES = [1, 2, 4, 8, 12]  # 16 omitted: 72 not divisible by 16
D_MODEL = 72  # Must be divisible by all H (matches TRAINING_CONFIG)
assert all(D_MODEL % h == 0 for h in H_VALUES), 'd_model must be divisible by all H'

tr_config = {
    'd_model': D_MODEL,
    'num_layers': 2,
    'd_ff': 512,
    'dropout': 0.1,
    'learning_rate': 0.0001,
    'batch_size': 32,
    'epochs': N_EPOCHS if USE_SMALL_DATA else 100,
    'patience': 20
}

h_scaling_results = []
for h in H_VALUES:
    print(f'\n{"="*60}\nTraining Multi-Head Diversity with H={h}\n{"="*60}')
    model = FeatureTokenTransformer(
        num_features=X_train_scaled.shape[1],
        d_model=D_MODEL,
        num_heads=h,
        num_layers=tr_config['num_layers'],
        d_ff=tr_config['d_ff'],
        dropout=tr_config['dropout'],
        use_head_diversity=True,
        diversity_weight=0.01
    )
    model, pred, tl, vl = train_mhd_transformer(
        model, f'MHD-H{h}', X_train_scaled, y_train, X_val_scaled, y_val, tr_config, device
    )
    rmse = np.sqrt(mean_squared_error(y_val, pred))
    r2 = r2_score(y_val, pred)
    h_scaling_results.append({'H': h, 'RMSE': rmse, 'R2': r2})
    print(f'  H={h}: RMSE={rmse:.6f}, R²={r2:.6f}')

h_df = pd.DataFrame(h_scaling_results)
print('\nH Scaling Summary:')
print(h_df.to_string(index=False))

## 4. FT-Transformer Baseline

In [None]:
try:
    from pytorch_tabular import TabularModel
    from pytorch_tabular.models import FTTransformerConfig
    from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig
    HAS_PYTORCH_TABULAR = True
except ImportError:
    HAS_PYTORCH_TABULAR = False
    print('pytorch-tabular not installed. Run: pip install pytorch-tabular')

if HAS_PYTORCH_TABULAR:
    feature_names = [f'f{i}' for i in range(X_train_scaled.shape[1])]
    train_df_pt = pd.DataFrame(X_train_scaled, columns=feature_names)
    train_df_pt['target'] = y_train
    val_df_pt = pd.DataFrame(X_val_scaled, columns=feature_names)
    val_df_pt['target'] = y_val
    
    data_config = DataConfig(
        target=['target'],
        continuous_cols=feature_names,
        categorical_cols=[],
        normalize_continuous_features=False,  # already scaled
    )
    trainer_config = TrainerConfig(
        batch_size=32,
        max_epochs=N_EPOCHS if USE_SMALL_DATA else 100,
        early_stopping='valid_loss',
        early_stopping_patience=20,
        accelerator='auto',
        devices=1 if torch.cuda.is_available() else 0,
    )
    optimizer_config = OptimizerConfig()
    
    model_config = FTTransformerConfig(
        task='regression',
        learning_rate=1e-4,
        num_heads=8,
        num_attn_blocks=2,
        transformer_activation='GEGLU',  # GEGLU is default; 'gelu' causes KeyError in pytorch-tabular
        embedding_dropout=0.1,
        attn_dropout=0.1,
    )
    
    ft_model = TabularModel(
        data_config=data_config,
        model_config=model_config,
        optimizer_config=optimizer_config,
        trainer_config=trainer_config,
    )
    
    print('Training FT-Transformer...')
    ft_model.fit(train=train_df_pt, validation=val_df_pt)
    pred_ft = ft_model.predict(val_df_pt)
    pred_ft = pred_ft['prediction'].values if 'prediction' in pred_ft.columns else pred_ft.iloc[:, -1].values
    ft_rmse = np.sqrt(mean_squared_error(y_val, pred_ft))
    ft_r2 = r2_score(y_val, pred_ft)
    print(f'FT-Transformer: RMSE={ft_rmse:.6f}, R²={ft_r2:.6f}')
else:
    ft_rmse, ft_r2 = np.nan, np.nan

## 5. Summary & Plots

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(h_df['H'], h_df['R2'], 'o-', label='MHD R²')
axes[0].set_xlabel('Number of Heads (H)')
axes[0].set_ylabel('R²')
axes[0].set_title('H Scaling: Multi-Head Diversity (Clean)')
axes[0].set_xticks(H_VALUES)
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].plot(h_df['H'], h_df['RMSE'], 's-', color='orange', label='MHD RMSE')
axes[1].set_xlabel('Number of Heads (H)')
axes[1].set_ylabel('RMSE')
axes[1].set_title('H Scaling: RMSE')
axes[1].set_xticks(H_VALUES)
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print('\n' + '='*60)
print('FINAL SUMMARY')
print('='*60)
print(h_df.to_string(index=False))
if HAS_PYTORCH_TABULAR:
    print(f'\nFT-Transformer: RMSE={ft_rmse:.6f}, R²={ft_r2:.6f}')