In [60]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import logging
import wrds
import math
import gym
from gym import spaces
from torch.optim.optimizer import Optimizer

# To this:
import torch.optim as optim

# Configure logging
logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# db = wrds.Connection()

In [95]:
# --------------------
# 1. Dataset Pipeline
# --------------------
class AlphaPortfolioData(Dataset):
    
    def __init__(self, start_date='2010-01-01', end_date='2019-12-31', lookback=12, G=2):
        super().__init__()
        self.lookback = lookback
        self.G = G  # Number of assets to long/short
        self.data = self._load_wrds_data(start_date, end_date)
        self.sequences, self.future_returns, self.masks = self._create_sequences()
        self._validate_data_shapes()

    def _load_wrds_data(self, start_date, end_date):
        
        # CRSP data
        crsp_query = f"""
        SELECT a.permno, a.date, a.ret, a.prc, a.shrout, 
            a.vol, a.cfacshr, a.altprc, a.retx
        FROM crsp.msf AS a
        WHERE a.date BETWEEN '{start_date}' AND '{end_date}'
        AND a.permno IN (
            SELECT permno FROM crsp.msenames 
            WHERE exchcd BETWEEN 1 AND 3  
                AND shrcd IN (10, 11)       
            )
        """
        crsp_data = db.raw_sql(crsp_query)

        query_ticker = """
            SELECT permno, namedt, nameenddt, ticker
            FROM crsp.stocknames
        """
        stocknames = db.raw_sql(query_ticker)
        crsp_data = crsp_data.merge(stocknames.drop_duplicates(subset=['permno']), on='permno', how='left')
        crsp_data = crsp_data.dropna(subset=['ticker'])

        crsp_data['mktcap'] = (crsp_data['prc'].abs() * crsp_data['shrout'] * 1000) / 1e6  # In millions
        crsp_data['year'] = pd.to_datetime(crsp_data['date']).dt.year
        crsp_data = crsp_data.dropna(subset=['mktcap'])
        mean_mktcap = crsp_data.groupby(['year', 'permno'])['mktcap'].mean().reset_index()
        sorted_mean_mktcap = mean_mktcap.sort_values(by=['year', 'mktcap'], ascending=[True, False])
        top_20_each_year = sorted_mean_mktcap.groupby('year').head(50)

        crsp_data = top_20_each_year.merge(crsp_data, on='permno', how='left')
        crsp_data = crsp_data.drop_duplicates(subset=['mktcap_y'])

        crsp_data = crsp_data[['permno', 'ticker', 'date', 'ret', 'prc', 'shrout', 'vol', 'mktcap_y', 'year_y']]
        crsp_data['date'] = pd.to_datetime(crsp_data['date'])
        crsp_data.sort_values(['permno', 'date'], inplace=True)

        # Query Compustat quarterly data with release dates (rdq)
        fund_query = f"""
            SELECT gvkey, datadate, rdq, saleq
            FROM comp.fundq
            WHERE indfmt = 'INDL' AND datafmt = 'STD' AND popsrc = 'D' AND consol = 'C'
            AND datadate BETWEEN '{start_date}' AND '{end_date}'
            AND rdq IS NOT NULL
        """
        fund = db.raw_sql(fund_query)
        fund['rdq'] = pd.to_datetime(fund['rdq'])
        fund['datadate'] = pd.to_datetime(fund['datadate'])

        # Link Compustat GVKEY to CRSP PERMNO
        link_query = """
            SELECT lpermno AS permno, gvkey, linkdt, linkenddt
            FROM crsp.ccmxpf_linktable
            WHERE linktype IN ('LU', 'LC') AND linkprim IN ('P', 'C')
        """
        link = db.raw_sql(link_query)
        fund = pd.merge(fund, link, on='gvkey', how='left')
        fund = fund.dropna(subset=['permno'])

        # Sort both datasets by date
        crsp_sorted = crsp_data.sort_values('date')
        fund_sorted = fund.sort_values('rdq')
        fund_sorted['permno'] = fund_sorted['permno'].astype(int)
        # Merge fundamentals to CRSP using rdq
        merged = pd.merge_asof(
            crsp_sorted,
            fund_sorted,
            left_on='date',
            right_on='rdq',
            by='permno',
            direction='backward'  # Take the first CRSP date >= rdq
        )
        merged = merged.dropna(subset=['rdq', 'ticker'])
        merged = merged.sort_values(by='date')
        merged = merged[['permno', 'ticker', 'date', 'ret', 'prc','vol', 'mktcap_y', 'gvkey', 'rdq', 'saleq']]
        merged = merged.ffill()
        
        unique_dates = merged['date'].unique()
        date_mapping = {date: i for i, date in enumerate(sorted(unique_dates))}
        merged['date_mapped'] = merged['date'].map(date_mapping)
        
        return merged

    def _create_sequences(self):
        data = self.data
        lookback = self.lookback
        unique_dates = pd.to_datetime(data['date'].unique())
        unique_assets = data['permno'].unique()
        
        sequences = []
        future_returns = []
        masks = []
        min_assets = 2 * self.G
        batch_info = []

        # First pass: collect valid batches
        for date_idx in range(len(unique_dates) - 2 * lookback):
            hist_start = unique_dates[date_idx]
            hist_end = unique_dates[date_idx + lookback - 1]
            future_start = unique_dates[date_idx + lookback]
            future_end = unique_dates[date_idx + 2 * lookback - 1]

            batch_assets = []
            hist_features = []
            fwd_returns = []
            
            for asset in unique_assets:
                asset_hist = data[
                    (data['permno'] == asset) & 
                    (data['date'].between(hist_start, hist_end))
                ].sort_values('date')
                
                asset_future = data[
                    (data['permno'] == asset) & 
                    (data['date'].between(future_start, future_end))
                ]['ret'].values
                
                if len(asset_hist) == lookback and len(asset_future) == lookback:
                    features = asset_hist[['ret', 'prc', 'vol', 'mktcap_y', 'saleq']].values
                    hist_features.append(features)
                    fwd_returns.append(asset_future)
                    batch_assets.append(asset)

            if len(hist_features) >= min_assets:
                batch_info.append({
                    'features': np.stack(hist_features),
                    'returns': np.stack(fwd_returns),
                    'num_assets': len(hist_features)
                })

        # Find global max assets across all valid batches
        if not batch_info:
            return torch.empty(0), torch.empty(0), torch.empty(0)
        
        global_max_assets = max(b['num_assets'] for b in batch_info)
        features_dim = batch_info[0]['features'].shape[-1]

        # Second pass: pad to global max
        for batch in batch_info:
            num_assets = batch['num_assets']
            
            # Features: (assets, lookback, features)
            padded_features = np.zeros((global_max_assets, lookback, features_dim))
            padded_features[:num_assets] = batch['features']
            
            # Returns: (assets, lookback)
            padded_returns = np.zeros((global_max_assets, lookback))  # Fix 1: 2D padding
            padded_returns[:num_assets] = batch['returns']
            
            # Mask: (assets,)
            mask = np.zeros(global_max_assets, dtype=bool)
            mask[:num_assets] = True

            sequences.append(padded_features)
            future_returns.append(padded_returns)
            masks.append(mask)

        return (
            torch.as_tensor(np.array(sequences), dtype=torch.float32),  # (time, assets, lookback, features)
            torch.as_tensor(np.array(future_returns), dtype=torch.float32),  # (time, assets, lookback)
            torch.as_tensor(np.array(masks), dtype=torch.bool)  # (time, assets)
        )

    def _validate_data_shapes(self):
        assert self.sequences.dim() == 4, \
            f"Sequences should be 4D (time, assets, lookback, features). Got {self.sequences.shape}"
        assert self.future_returns.dim() == 3, \
            f"Future returns should be 3D (time, assets, lookback). Got {self.future_returns.shape}"
        assert self.masks.dim() == 2, \
            f"Masks should be 2D (time, assets). Got {self.masks.shape}"

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        return (
        self.sequences[idx],  # (assets, lookback, features)
        self.future_returns[idx],  # (assets, lookback)
        self.masks[idx]         # (assets,)
    )

class AssetTransformer(nn.Module):
    def __init__(self, input_dim, d_model=64, nhead=4, num_layers=2):
        super().__init__()
        self.embedding = nn.Linear(input_dim, d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward=4*d_model)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)

    def forward(self, x):
        batch_size, num_assets, lookback, features = x.shape
        x = x.view(-1, lookback, features)  # Reshape to (batch*assets, lookback, features)
        x = self.embedding(x)  # Project to d_model
        x = x.permute(1, 0, 2)  # (lookback, batch*assets, d_model)
        x = self.transformer(x)  # Apply transformer
        x = x.mean(dim=0).view(batch_size, num_assets, -1)  # Global average pooling
        return x

class CrossAssetAttention(nn.Module):
    def __init__(self, d_model=64, d_k=32, d_v=32):
        super().__init__()
        self.WQ = nn.Linear(d_model, d_k)
        self.WK = nn.Linear(d_model, d_k)
        self.WV = nn.Linear(d_model, d_v)
        self.scale = 1 / math.sqrt(d_k)
        self.score_layer = nn.Sequential(
            nn.Linear(d_v, 1),
            nn.Tanh()
        )

    def forward(self, x, mask=None):
        # Print input shapes for debugging
        # print("CrossAssetAttention Input x shape:", x.shape)
        # print("CrossAssetAttention Input mask shape:", mask.shape)

        # Ensure x is 3D (batch, assets, d_model)
        if x.dim() == 4:
            x = x.squeeze(0)  # Remove extra batch dimension if present
        
        if x.dim() != 3:
            raise ValueError(f"Expected 3D input, got {x.dim()}D tensor. Shape: {x.shape}")

        # Compute query, key, value
        Q = self.WQ(x)  # (batch, assets, d_k)
        K = self.WK(x)  # (batch, assets, d_k)
        V = self.WV(x)  # (batch, assets, d_v)
        
        # Attention computation
        attn = torch.matmul(Q, K.transpose(-2, -1)) * self.scale  # (batch, assets, assets)
        
        # Apply mask if provided
        if mask is not None:
            # Ensure mask is 2D
            if mask.dim() == 3:
                mask = mask.squeeze(0)
            
            if mask.dim() != 2:
                raise ValueError(f"Mask must be 2D, got {mask.dim()}D tensor")
            
            attn = attn.masked_fill(~mask.unsqueeze(1), float('-inf'))
        
        # Softmax and value aggregation
        attn = F.softmax(attn, dim=-1)
        attn_out = torch.matmul(attn, V)  # (batch, assets, d_v)
        
        # Score computation
        scores = self.score_layer(attn_out).squeeze(-1)  # (batch, assets)
        
        # print("CrossAssetAttention Output scores shape:", scores.shape)
        
        return scores

class PortfolioGenerator(nn.Module):
    def __init__(self, G=5):
        super().__init__()
        self.G = G
        self.temperature = nn.Parameter(torch.tensor(1.0))
        
    def forward(self, scores, mask):
        # print("PortfolioGenerator Input scores shape:", scores.shape)
        # print("PortfolioGenerator Input mask shape:", mask.shape)

        # Ensure scores is 2D
        if scores.dim() > 2:
            scores = scores.squeeze()
        
        if scores.dim() != 2:
            raise ValueError(f"Scores must be 2D, got {scores.dim()}D tensor. Shape: {scores.shape}")
        
        # Ensure mask is 2D
        if mask.dim() > 2:
            mask = mask.squeeze()
        
        if mask.dim() != 2:
            raise ValueError(f"Mask must be 2D, got {mask.dim()}D tensor")
        
        # Validate shapes
        batch_size, num_assets = scores.shape
        weights = torch.zeros_like(scores)
        
        for i in range(batch_size):
            # More robust validity check
            valid_assets = mask[i].sum().item()
            
            if valid_assets < 2*self.G:
                print(f"Warning: Batch {i} has insufficient valid assets. Skipping.")
                continue
            
            # Identify top and bottom G assets
            top_indices = torch.topk(scores[i], min(self.G, valid_assets)).indices
            bottom_indices = torch.topk(-scores[i], min(self.G, valid_assets)).indices
            
            # Compute weights
            top_scores = scores[i, top_indices]
            bottom_scores = scores[i, bottom_indices]
            
            top_weights = F.softmax(top_scores / self.temperature.clamp(min=0.1), dim=-1)
            bottom_weights = F.softmax(-bottom_scores / self.temperature.clamp(min=0.1), dim=-1)
            
            weights[i, top_indices] = top_weights
            weights[i, bottom_indices] = -bottom_weights
        
        # print("PortfolioGenerator Output weights shape:", weights.shape)
        return weights.nan_to_num(0.0)

class PolicyNetwork(nn.Module):
    def __init__(self, transformer, attention, portfolio_gen, d_model=64):
        super().__init__()
        self.transformer = transformer
        self.attention = attention
        self.portfolio_gen = portfolio_gen
        self.log_std = nn.Parameter(torch.zeros(1))
        
        self.value_head = nn.Sequential(
            nn.Linear(d_model, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, x, mask):
        # Print input shapes for debugging
        # print("PolicyNetwork Input x shape:", x.shape)
        # print("PolicyNetwork Input mask shape:", mask.shape)

        # Ensure correct input dimensions
        if x.dim() == 5:  # Extra batch dimension
            x = x.squeeze(0)
        
        # if x.dim() != 4:
            # raise ValueError(f"Expected 4D input tensor, got {x.dim()}D. Shape: {x.shape}")
        
        # Process through transformer
        x = self.transformer(x)  # (batch, assets, d_model)
        
        # Ensure mask is consistent
        if mask.dim() == 3:
            mask = mask.squeeze(0)
        
        # if mask.dim() != 2:
            # raise ValueError(f"Expected 2D mask, got {mask.dim()}D. Shape: {mask.shape}")
        
        # Cross-asset attention scoring
        scores = self.attention(x, mask)  # (batch, assets)
        
        # Portfolio generation
        weights = self.portfolio_gen(scores, mask)  # (batch, assets)
        
        # Policy distribution
        std = torch.exp(self.log_std).expand_as(weights)
        dist = torch.distributions.Normal(weights, std)
        
        # Value estimation
        values = self.value_head(x.mean(dim=1))
        
        return dist, values.squeeze(-1)

class PortfolioEnv(gym.Env):
    def __init__(self, sequences, future_returns, masks, G=5):
        super().__init__()
        self.sequences = sequences
        self.future_returns = future_returns
        self.masks = masks
        self.G = G
        
        self.observation_space = gym.spaces.Box(-np.inf, np.inf, shape=sequences[0].shape)
        self.action_space = gym.spaces.Box(-1, 1, shape=(sequences.shape[1],))

    def step(self, action):
        if np.random.random() < 0.01:  # Random sampling mechanism
            idx = np.random.randint(len(self.sequences))
            
            # Get future returns and mask
            future_returns = self.future_returns[idx]
            mask = self.masks[idx]
            
            # Ensure action is a tensor
            action = torch.as_tensor(action, dtype=torch.float32)
            
            # Filter valid assets
            valid = mask & (action != 0)
            if valid.sum().item() < 2*self.G:
                return None, -1.0, True, {}
            
            # Selected weights and returns
            selected_weights = action[valid]
            selected_returns = future_returns[valid]
            
            # Portfolio return calculation
            portfolio_returns = (selected_weights.unsqueeze(-1) * selected_returns).sum(dim=0)
            
            # Sharpe Ratio calculation
            sharpe = portfolio_returns.mean() / (portfolio_returns.std() + 1e-6)
            
            print(sharpe)
            
            return None, sharpe.item(), True, {}
        
        return None, 0.0, True, {}

class PPOTrainer:
    def __init__(self, policy, env, lr=1e-4, gamma=0.99, clip=0.2):
        self.policy = policy
        self.env = env
        self.optimizer = optim.Adam(policy.parameters(), lr=lr)
        self.gamma = gamma
        self.clip = clip
    
    def collect_trajectory(self):
        states, masks, actions, log_probs, rewards = [], [], [], [], []
        
        # Determine the batch size dynamically
        batch_size = len(self.env.sequences)
        
        for _ in range(batch_size):
            idx = np.random.randint(len(self.env.sequences))
            state = self.env.sequences[idx]
            mask = self.env.masks[idx]
            
            # Ensure consistent dimensions
            if state.dim() == 3:
                state = state.unsqueeze(0)  # Add batch dimension
            elif state.dim() == 4:
                state = state.squeeze(0)  # Remove extra dimension if present
            
            if mask.dim() == 1:
                mask = mask.unsqueeze(0)  # Add batch dimension
            elif mask.dim() == 2:
                mask = mask.squeeze(0)  # Remove extra dimension if present
            
            with torch.no_grad():
                try:
                    # Ensure state and mask have consistent batch dimension
                    dist, values = self.policy(state, mask)
                    action = dist.sample()
                    log_prob = dist.log_prob(action)
                    
                    _, reward, _, _ = self.env.step(action.squeeze())
                    
                    states.append(state.squeeze(0))
                    masks.append(mask.squeeze(0))
                    actions.append(action.squeeze(0))
                    log_probs.append(log_prob.squeeze(0))
                    rewards.append(reward)
                
                except Exception as e:
                    print(f"Error processing trajectory: {e}")
                    # print("State shape:", state.shape)
                    # print("Mask shape:", mask.shape)
                    raise
        
        # Ensure all tensors have consistent shapes
        states = torch.stack(states)
        masks = torch.stack(masks)
        actions = torch.stack(actions)
        log_probs = torch.stack(log_probs)
        rewards = torch.tensor(rewards)
        
        return states, masks, actions, log_probs, rewards

    def train_epoch(self):
        states, masks, actions, old_log_probs, rewards = self.collect_trajectory()
        
        # Ensure rewards have the correct shape for broadcasting
        rewards = rewards.view(-1, 1)
        
        # Normalize rewards
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-6)
        
        for _ in range(3):  # Multiple PPO update steps
            self.optimizer.zero_grad()
            
            # Get new policy outputs
            dist, values = self.policy(states, masks)
            new_log_probs = dist.log_prob(actions)
            entropy = dist.entropy().mean()
            
            # Ensure consistent shapes
            values = values.view(-1, 1)
            
            # Policy loss computation
            ratio = torch.exp(new_log_probs - old_log_probs.detach())
            surr1 = ratio * rewards
            surr2 = torch.clamp(ratio, 1-self.clip, 1+self.clip) * rewards
            policy_loss = -torch.min(surr1, surr2).mean()
            
            # Value loss
            value_loss = F.mse_loss(values, rewards)
            
            # Total loss
            loss = policy_loss + 0.5*value_loss - 0.01*entropy
            print(loss)
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 1.0)
            self.optimizer.step()
        
        return loss.item()
    

In [None]:
dataset = AlphaPortfolioData()

In [91]:
# dataset = AlphaPortfolioData()
time_steps, num_assets, lookback, features = dataset.sequences.shape

In [97]:
d_model = 128
G = 4  # Top and bottom K stocks
epochs = 30

# Load your preprocessed data from AlphaPortfolioData


# Network Components
transformer = AssetTransformer(input_dim=features, d_model=d_model)
attention = CrossAssetAttention(d_model=d_model)
portfolio_gen = PortfolioGenerator(G=G)

# Policy Network
policy = PolicyNetwork(transformer, attention, portfolio_gen, d_model=d_model)

# Portfolio Environment
env = PortfolioEnv(
    dataset.sequences, 
    dataset.future_returns, 
    dataset.masks, 
    G=G
)

# PPO Trainer
trainer = PPOTrainer(policy, env)

# Training Loop
for epoch in range(epochs):
    loss = trainer.train_epoch()
    print(f"Epoch {epoch+1}/{epochs} | Loss: {loss:.4f}")

tensor(0.0819)
tensor(0.5421, grad_fn=<SubBackward0>)
tensor(0.5010, grad_fn=<SubBackward0>)
tensor(0.4824, grad_fn=<SubBackward0>)
Epoch 1/30 | Loss: 0.4824
tensor(-0.0379)
tensor(0.4852, grad_fn=<SubBackward0>)
tensor(0.4921, grad_fn=<SubBackward0>)
tensor(0.4914, grad_fn=<SubBackward0>)
Epoch 2/30 | Loss: 0.4914
tensor(-0.1996)
tensor(0.4877, grad_fn=<SubBackward0>)
tensor(0.4818, grad_fn=<SubBackward0>)
tensor(0.4805, grad_fn=<SubBackward0>)
Epoch 3/30 | Loss: 0.4805
tensor(-0.5251)
tensor(0.4823, grad_fn=<SubBackward0>)
tensor(0.4834, grad_fn=<SubBackward0>)
tensor(0.4848, grad_fn=<SubBackward0>)
Epoch 4/30 | Loss: 0.4848
tensor(0.0603)
tensor(0.4852, grad_fn=<SubBackward0>)
tensor(0.4838, grad_fn=<SubBackward0>)
tensor(0.4820, grad_fn=<SubBackward0>)
Epoch 5/30 | Loss: 0.4820
tensor(-0.6513)
tensor(-0.0492)
tensor(0.2825)
tensor(0.4805, grad_fn=<SubBackward0>)
tensor(0.4818, grad_fn=<SubBackward0>)
tensor(0.4842, grad_fn=<SubBackward0>)
Epoch 6/30 | Loss: 0.4842
tensor(-0.0106, g

New Modification

In [100]:
# --------------------
# 1. Dataset Pipeline
# --------------------
class AlphaPortfolioData(Dataset):
    
    def __init__(self, start_date='2010-01-01', end_date='2019-12-31', lookback=12, G=2):
        super().__init__()
        self.lookback = lookback
        self.G = G  # Number of assets to long/short
        self.data = self._load_wrds_data(start_date, end_date)
        self.sequences, self.future_returns, self.masks = self._create_sequences()
        self._validate_data_shapes()

    def _load_wrds_data(self, start_date, end_date):
        
        # CRSP data
        crsp_query = f"""
        SELECT a.permno, a.date, a.ret, a.prc, a.shrout, 
            a.vol, a.cfacshr, a.altprc, a.retx
        FROM crsp.msf AS a
        WHERE a.date BETWEEN '{start_date}' AND '{end_date}'
        AND a.permno IN (
            SELECT permno FROM crsp.msenames 
            WHERE exchcd BETWEEN 1 AND 3  
                AND shrcd IN (10, 11)       
            )
        """
        crsp_data = db.raw_sql(crsp_query)

        query_ticker = """
            SELECT permno, namedt, nameenddt, ticker
            FROM crsp.stocknames
        """
        stocknames = db.raw_sql(query_ticker)
        crsp_data = crsp_data.merge(stocknames.drop_duplicates(subset=['permno']), on='permno', how='left')
        crsp_data = crsp_data.dropna(subset=['ticker'])

        crsp_data['mktcap'] = (crsp_data['prc'].abs() * crsp_data['shrout'] * 1000) / 1e6  # In millions
        crsp_data['year'] = pd.to_datetime(crsp_data['date']).dt.year
        crsp_data = crsp_data.dropna(subset=['mktcap'])
        mean_mktcap = crsp_data.groupby(['year', 'permno'])['mktcap'].mean().reset_index()
        sorted_mean_mktcap = mean_mktcap.sort_values(by=['year', 'mktcap'], ascending=[True, False])
        top_20_each_year = sorted_mean_mktcap.groupby('year').head(50)

        crsp_data = top_20_each_year.merge(crsp_data, on='permno', how='left')
        crsp_data = crsp_data.drop_duplicates(subset=['mktcap_y'])

        crsp_data = crsp_data[['permno', 'ticker', 'date', 'ret', 'prc', 'shrout', 'vol', 'mktcap_y', 'year_y']]
        crsp_data['date'] = pd.to_datetime(crsp_data['date'])
        crsp_data.sort_values(['permno', 'date'], inplace=True)

        # Query Compustat quarterly data with release dates (rdq)
        fund_query = f"""
            SELECT gvkey, datadate, rdq, saleq
            FROM comp.fundq
            WHERE indfmt = 'INDL' AND datafmt = 'STD' AND popsrc = 'D' AND consol = 'C'
            AND datadate BETWEEN '{start_date}' AND '{end_date}'
            AND rdq IS NOT NULL
        """
        fund = db.raw_sql(fund_query)
        fund['rdq'] = pd.to_datetime(fund['rdq'])
        fund['datadate'] = pd.to_datetime(fund['datadate'])

        # Link Compustat GVKEY to CRSP PERMNO
        link_query = """
            SELECT lpermno AS permno, gvkey, linkdt, linkenddt
            FROM crsp.ccmxpf_linktable
            WHERE linktype IN ('LU', 'LC') AND linkprim IN ('P', 'C')
        """
        link = db.raw_sql(link_query)
        fund = pd.merge(fund, link, on='gvkey', how='left')
        fund = fund.dropna(subset=['permno'])

        # Sort both datasets by date
        crsp_sorted = crsp_data.sort_values('date')
        fund_sorted = fund.sort_values('rdq')
        fund_sorted['permno'] = fund_sorted['permno'].astype(int)
        # Merge fundamentals to CRSP using rdq
        merged = pd.merge_asof(
            crsp_sorted,
            fund_sorted,
            left_on='date',
            right_on='rdq',
            by='permno',
            direction='backward'  # Take the first CRSP date >= rdq
        )
        merged = merged.dropna(subset=['rdq', 'ticker'])
        merged = merged.sort_values(by='date')
        merged = merged[['permno', 'ticker', 'date', 'ret', 'prc','vol', 'mktcap_y', 'gvkey', 'rdq', 'saleq']]
        merged = merged.ffill()
        
        unique_dates = merged['date'].unique()
        date_mapping = {date: i for i, date in enumerate(sorted(unique_dates))}
        merged['date_mapped'] = merged['date'].map(date_mapping)
        
        return merged

    def _create_sequences(self):
        data = self.data
        lookback = self.lookback
        unique_dates = pd.to_datetime(data['date'].unique())
        unique_assets = data['permno'].unique()
        
        sequences = []
        future_returns = []
        masks = []
        min_assets = 2 * self.G
        batch_info = []

        # First pass: collect valid batches
        for date_idx in range(len(unique_dates) - 2 * lookback):
            hist_start = unique_dates[date_idx]
            hist_end = unique_dates[date_idx + lookback - 1]
            future_start = unique_dates[date_idx + lookback]
            future_end = unique_dates[date_idx + 2 * lookback - 1]

            batch_assets = []
            hist_features = []
            fwd_returns = []
            
            for asset in unique_assets:
                asset_hist = data[
                    (data['permno'] == asset) & 
                    (data['date'].between(hist_start, hist_end))
                ].sort_values('date')
                
                asset_future = data[
                    (data['permno'] == asset) & 
                    (data['date'].between(future_start, future_end))
                ]['ret'].values
                
                if len(asset_hist) == lookback and len(asset_future) == lookback:
                    features = asset_hist[['ret', 'prc', 'vol', 'mktcap_y', 'saleq']].values
                    hist_features.append(features)
                    fwd_returns.append(asset_future)
                    batch_assets.append(asset)

            if len(hist_features) >= min_assets:
                batch_info.append({
                    'features': np.stack(hist_features),
                    'returns': np.stack(fwd_returns),
                    'num_assets': len(hist_features)
                })

        # Find global max assets across all valid batches
        if not batch_info:
            return torch.empty(0), torch.empty(0), torch.empty(0)
        
        global_max_assets = max(b['num_assets'] for b in batch_info)
        features_dim = batch_info[0]['features'].shape[-1]

        # Second pass: pad to global max
        for batch in batch_info:
            num_assets = batch['num_assets']
            
            # Features: (assets, lookback, features)
            padded_features = np.zeros((global_max_assets, lookback, features_dim))
            padded_features[:num_assets] = batch['features']
            
            # Returns: (assets, lookback)
            padded_returns = np.zeros((global_max_assets, lookback))  # Fix 1: 2D padding
            padded_returns[:num_assets] = batch['returns']
            
            # Mask: (assets,)
            mask = np.zeros(global_max_assets, dtype=bool)
            mask[:num_assets] = True

            sequences.append(padded_features)
            future_returns.append(padded_returns)
            masks.append(mask)

        return (
            torch.as_tensor(np.array(sequences), dtype=torch.float32),  # (time, assets, lookback, features)
            torch.as_tensor(np.array(future_returns), dtype=torch.float32),  # (time, assets, lookback)
            torch.as_tensor(np.array(masks), dtype=torch.bool)  # (time, assets)
        )

    def _validate_data_shapes(self):
        assert self.sequences.dim() == 4, \
            f"Sequences should be 4D (time, assets, lookback, features). Got {self.sequences.shape}"
        assert self.future_returns.dim() == 3, \
            f"Future returns should be 3D (time, assets, lookback). Got {self.future_returns.shape}"
        assert self.masks.dim() == 2, \
            f"Masks should be 2D (time, assets). Got {self.masks.shape}"

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        return (
        self.sequences[idx],  # (assets, lookback, features)
        self.future_returns[idx],  # (assets, lookback)
        self.masks[idx]         # (assets,)
    )

class AssetTransformer(nn.Module):
    def __init__(self, input_dim, d_model=64, nhead=4, num_layers=2):
        super().__init__()
        self.embedding = nn.Linear(input_dim, d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward=4*d_model)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)

    def forward(self, x):
        batch_size, num_assets, lookback, features = x.shape
        x = x.view(-1, lookback, features)  # Reshape to (batch*assets, lookback, features)
        x = self.embedding(x)  # Project to d_model
        x = x.permute(1, 0, 2)  # (lookback, batch*assets, d_model)
        x = self.transformer(x)  # Apply transformer
        x = x.mean(dim=0).view(batch_size, num_assets, -1)  # Global average pooling
        return x

class CrossAssetAttention(nn.Module):
    def __init__(self, d_model=64, d_k=32, d_v=32):
        super().__init__()
        self.WQ = nn.Linear(d_model, d_k)
        self.WK = nn.Linear(d_model, d_k)
        self.WV = nn.Linear(d_model, d_v)
        self.scale = 1 / math.sqrt(d_k)
        self.score_layer = nn.Sequential(
            nn.Linear(d_v, 1),
            nn.Tanh()
        )

    def forward(self, x, mask=None):
        # Print input shapes for debugging
        # print("CrossAssetAttention Input x shape:", x.shape)
        # print("CrossAssetAttention Input mask shape:", mask.shape)

        # Ensure x is 3D (batch, assets, d_model)
        if x.dim() == 4:
            x = x.squeeze(0)  # Remove extra batch dimension if present
        
        if x.dim() != 3:
            raise ValueError(f"Expected 3D input, got {x.dim()}D tensor. Shape: {x.shape}")

        # Compute query, key, value
        Q = self.WQ(x)  # (batch, assets, d_k)
        K = self.WK(x)  # (batch, assets, d_k)
        V = self.WV(x)  # (batch, assets, d_v)
        
        # Attention computation
        attn = torch.matmul(Q, K.transpose(-2, -1)) * self.scale  # (batch, assets, assets)
        
        # Apply mask if provided
        if mask is not None:
            # Ensure mask is 2D
            if mask.dim() == 3:
                mask = mask.squeeze(0)
            
            if mask.dim() != 2:
                raise ValueError(f"Mask must be 2D, got {mask.dim()}D tensor")
            
            attn = attn.masked_fill(~mask.unsqueeze(1), float('-inf'))
        
        # Softmax and value aggregation
        attn = F.softmax(attn, dim=-1)
        attn_out = torch.matmul(attn, V)  # (batch, assets, d_v)
        
        # Score computation
        scores = self.score_layer(attn_out).squeeze(-1)  # (batch, assets)
        
        # print("CrossAssetAttention Output scores shape:", scores.shape)
        
        return scores

class PortfolioGenerator(nn.Module):
    def __init__(self, G=5):
        super().__init__()
        self.G = G
        self.temperature = nn.Parameter(torch.tensor(1.0))
        
    def forward(self, scores, mask):
        # print("PortfolioGenerator Input scores shape:", scores.shape)
        # print("PortfolioGenerator Input mask shape:", mask.shape)

        # Ensure scores is 2D
        if scores.dim() > 2:
            scores = scores.squeeze()
        
        if scores.dim() != 2:
            raise ValueError(f"Scores must be 2D, got {scores.dim()}D tensor. Shape: {scores.shape}")
        
        # Ensure mask is 2D
        if mask.dim() > 2:
            mask = mask.squeeze()
        
        if mask.dim() != 2:
            raise ValueError(f"Mask must be 2D, got {mask.dim()}D tensor")
        
        # Validate shapes
        batch_size, num_assets = scores.shape
        weights = torch.zeros_like(scores)
        
        for i in range(batch_size):
            # More robust validity check
            valid_assets = mask[i].sum().item()
            
            if valid_assets < 2*self.G:
                print(f"Warning: Batch {i} has insufficient valid assets. Skipping.")
                continue
            
            # Identify top and bottom G assets
            top_indices = torch.topk(scores[i], min(self.G, valid_assets)).indices
            bottom_indices = torch.topk(-scores[i], min(self.G, valid_assets)).indices
            
            # Compute weights
            top_scores = scores[i, top_indices]
            bottom_scores = scores[i, bottom_indices]
            
            top_weights = F.softmax(top_scores / self.temperature.clamp(min=0.1), dim=-1)
            bottom_weights = F.softmax(-bottom_scores / self.temperature.clamp(min=0.1), dim=-1)
            
            weights[i, top_indices] = top_weights
            weights[i, bottom_indices] = -bottom_weights
        
        # print("PortfolioGenerator Output weights shape:", weights.shape)
        return weights.nan_to_num(0.0)

class PolicyNetwork(nn.Module):
    def __init__(self, transformer, attention, portfolio_gen, d_model=64):
        super().__init__()
        self.transformer = transformer
        self.attention = attention
        self.portfolio_gen = portfolio_gen
        self.log_std = nn.Parameter(torch.zeros(1))
        
        self.value_head = nn.Sequential(
            nn.Linear(d_model, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, x, mask):
        # Print input shapes for debugging
        # print("PolicyNetwork Input x shape:", x.shape)
        # print("PolicyNetwork Input mask shape:", mask.shape)

        # Ensure correct input dimensions
        if x.dim() == 5:  # Extra batch dimension
            x = x.squeeze(0)
        
        # if x.dim() != 4:
            # raise ValueError(f"Expected 4D input tensor, got {x.dim()}D. Shape: {x.shape}")
        
        # Process through transformer
        x = self.transformer(x)  # (batch, assets, d_model)
        
        # Ensure mask is consistent
        if mask.dim() == 3:
            mask = mask.squeeze(0)
        
        # if mask.dim() != 2:
            # raise ValueError(f"Expected 2D mask, got {mask.dim()}D. Shape: {mask.shape}")
        
        # Cross-asset attention scoring
        scores = self.attention(x, mask)  # (batch, assets)
        
        # Portfolio generation
        weights = self.portfolio_gen(scores, mask)  # (batch, assets)
        
        # Policy distribution
        std = torch.exp(self.log_std).expand_as(weights)
        dist = torch.distributions.Normal(weights, std)
        
        # Value estimation
        values = self.value_head(x.mean(dim=1))
        
        return dist, values.squeeze(-1)

class PortfolioEnv(gym.Env):
    def __init__(self, sequences, future_returns, masks, trainer, G=5):
        super().__init__()
        self.sequences = sequences
        self.future_returns = future_returns
        self.masks = masks
        self.G = G
        self.trainer = trainer  # Add this line
        
        self.observation_space = gym.spaces.Box(-np.inf, np.inf, shape=sequences[0].shape)
        self.action_space = gym.spaces.Box(-1, 1, shape=(sequences.shape[1],))

    def step(self, action):
    # Get a random batch index
        idx = np.random.randint(len(self.sequences))
        
        # Get future returns and mask
        future_returns = self.future_returns[idx]
        mask = self.masks[idx]
        
        # Ensure action is a tensor
        action = torch.as_tensor(action, dtype=torch.float32)
        
        # Filter valid assets
        valid = mask & (action != 0)
        if valid.sum().item() < 2*self.G:
            return None, -1.0, True, {}
        
        # Selected weights and returns
        selected_weights = action[valid]
        selected_returns = future_returns[valid]
        
        # Portfolio return calculation
        portfolio_returns = (selected_weights.unsqueeze(-1) * selected_returns)
        portfolio_cumulative_return = portfolio_returns.sum(dim=0)
        
        # More robust Sharpe Ratio calculation
        mean_return = portfolio_cumulative_return.mean()
        std_return = portfolio_cumulative_return.std()
        sharpe_ratio = mean_return / (std_return + 1e-8)
        
        return None, sharpe_ratio.item(), True, {}

class PPOTrainer:
    def __init__(self, policy, env, lr=1e-4, gamma=0.99, clip=0.2):
        self.policy = policy
        self.env = env
        self.optimizer = optim.Adam(policy.parameters(), lr=lr)
        self.gamma = gamma
        self.clip = clip
    
    def collect_trajectory(self):
        states, masks, actions, log_probs, rewards = [], [], [], [], []
        
        # Determine the batch size dynamically
        batch_size = len(self.env.sequences)
        
        for _ in range(batch_size):
            idx = np.random.randint(len(self.env.sequences))
            state = self.env.sequences[idx]
            mask = self.env.masks[idx]
            
            # Ensure consistent dimensions
            if state.dim() == 3:
                state = state.unsqueeze(0)  # Add batch dimension
            elif state.dim() == 4:
                state = state.squeeze(0)  # Remove extra dimension if present
            
            if mask.dim() == 1:
                mask = mask.unsqueeze(0)  # Add batch dimension
            elif mask.dim() == 2:
                mask = mask.squeeze(0)  # Remove extra dimension if present
            
            with torch.no_grad():
                try:
                    # Ensure state and mask have consistent batch dimension
                    dist, values = self.policy(state, mask)
                    action = dist.sample()
                    log_prob = dist.log_prob(action)
                    
                    _, reward, _, _ = self.env.step(action.squeeze())
                    
                    states.append(state.squeeze(0))
                    masks.append(mask.squeeze(0))
                    actions.append(action.squeeze(0))
                    log_probs.append(log_prob.squeeze(0))
                    rewards.append(reward)
                
                except Exception as e:
                    print(f"Error processing trajectory: {e}")
                    # print("State shape:", state.shape)
                    # print("Mask shape:", mask.shape)
                    raise
        
        # Ensure all tensors have consistent shapes
        states = torch.stack(states)
        masks = torch.stack(masks)
        actions = torch.stack(actions)
        log_probs = torch.stack(log_probs)
        rewards = torch.tensor(rewards)
        
        return states, masks, actions, log_probs, rewards

    def train_epoch(self):
        states, masks, actions, old_log_probs, rewards = self.collect_trajectory()
        
        # Ensure rewards have the correct shape for broadcasting
        rewards = rewards.view(-1, 1)
        
        # Normalize rewards
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-6)
        
        for _ in range(3):  # Multiple PPO update steps
            self.optimizer.zero_grad()
            
            # Get new policy outputs
            dist, values = self.policy(states, masks)
            new_log_probs = dist.log_prob(actions)
            entropy = dist.entropy().mean()
            
            # Ensure consistent shapes
            values = values.view(-1, 1)
            
            # Policy loss computation
            ratio = torch.exp(new_log_probs - old_log_probs.detach())
            surr1 = ratio * rewards
            surr2 = torch.clamp(ratio, 1-self.clip, 1+self.clip) * rewards
            policy_loss = -torch.min(surr1, surr2).mean()
            
            # Value loss
            value_loss = F.mse_loss(values, rewards)
            
            # Total loss
            loss = policy_loss + 0.5*value_loss - 0.01*entropy
            print(loss)
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 1.0)
            self.optimizer.step()
        
        return loss.item()
    
import numpy as np
import matplotlib.pyplot as plt
import torch

def compute_average_sharpe(env, num_samples=100):
    """
    Compute average Sharpe ratio across multiple random samples from the environment.
    
    Args:
        env (PortfolioEnv): The portfolio environment
        num_samples (int): Number of samples to compute Sharpe ratio

    Returns:
        float: Average Sharpe ratio
    """
    sharpe_ratios = []
    
    for _ in range(num_samples):
        # Randomly select a time period
        idx = np.random.randint(len(env.sequences))
        
        # Get state, mask, and future returns
        state = env.sequences[idx]
        mask = env.masks[idx]
        future_returns = env.future_returns[idx]
        
        # Ensure state and mask are in correct dimensions
        if state.dim() == 3:
            state = state.unsqueeze(0)
        if mask.dim() == 1:
            mask = mask.unsqueeze(0)
        
        try:
            # Sample a portfolio weight using the policy in the environment's trainer
            with torch.no_grad():
                dist, _ = env.trainer.policy(state, mask)
                action = dist.sample().squeeze()
            
            # Compute Sharpe ratio
            valid = mask.squeeze() & (action != 0)
            
            if valid.sum().item() < 2*env.G:
                continue
            
            selected_weights = action[valid]
            selected_returns = future_returns[valid]
            
            # Portfolio return calculation
            portfolio_returns = (selected_weights.unsqueeze(-1) * selected_returns)
            portfolio_cumulative_return = portfolio_returns.sum(dim=0)
            
            # Sharpe Ratio calculation
            mean_return = portfolio_cumulative_return.mean()
            std_return = portfolio_cumulative_return.std()
            sharpe_ratio = mean_return / (std_return + 1e-8)
            
            sharpe_ratios.append(sharpe_ratio.item())
        
        except Exception as e:
            print(f"Error computing Sharpe ratio: {e}")
    
    # Return average Sharpe ratio
    return np.mean(sharpe_ratios) if sharpe_ratios else -np.inf

def plot_learning_curves(losses, sharpe_ratios, save_path='learning_curves.png'):
    """
    Plot learning curves for losses and Sharpe ratios.
    
    Args:
        losses (list): List of losses from training
        sharpe_ratios (list): List of Sharpe ratios from evaluation
        save_path (str): Path to save the plot
    """
    plt.figure(figsize=(12, 5))
    
    # Loss subplot
    plt.subplot(1, 2, 1)
    plt.plot(losses, label='Loss', color='blue')
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    # Sharpe Ratio subplot
    plt.subplot(1, 2, 2)
    plt.plot(sharpe_ratios, label='Sharpe Ratio', color='green')
    plt.title('Portfolio Sharpe Ratio')
    plt.xlabel('Epoch')
    plt.ylabel('Sharpe Ratio')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()
# Optional: Enhanced logging decorator
def log_training_metrics(func):
    def wrapper(*args, **kwargs):
        # Pre-training logging
        print("Starting training epoch...")
        result = func(*args, **kwargs)
        
        # Post-training logging (placeholder for custom metrics)
        print(f"Epoch completed. Additional metrics will be logged here.")
        
        return result
    return wrapper



In [None]:
scheduler = torch.optim.lr_scheduler.StepLR(trainer.optimizer, step_size=5, gamma=0.9)

# Arrays to track metrics
sharpe_ratios = []
losses = []

d_model = 128
G = 4  # Top and bottom K stocks
epochs = 30

# Load your preprocessed data from AlphaPortfolioData


# Network Components
transformer = AssetTransformer(input_dim=features, d_model=d_model)
attention = CrossAssetAttention(d_model=d_model)
portfolio_gen = PortfolioGenerator(G=G)

# Policy Network
policy = PolicyNetwork(transformer, attention, portfolio_gen, d_model=d_model)

# Portfolio Environment
env = PortfolioEnv(
    dataset.sequences,
    dataset.future_returns,
    dataset.masks,
    trainer,  # Pass the trainer
    G=G
)

# PPO Trainer
trainer = PPOTrainer(policy, env)

sharpe_ratios = []
losses = []

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.StepLR(trainer.optimizer, step_size=5, gamma=0.9)

for epoch in range(epochs):
    # Train for an epoch
    loss = trainer.train_epoch()
    
    # Compute average Sharpe ratio
    avg_sharpe = compute_average_sharpe(env)
    
    # Track metrics
    sharpe_ratios.append(avg_sharpe)
    losses.append(loss)
    
    # Learning rate decay
    scheduler.step()
    
    print(f"Epoch {epoch+1}/{epochs} | Loss: {loss:.4f} | Avg Sharpe: {avg_sharpe:.4f}")

# Plot learning curves
plot_learning_curves(losses, sharpe_ratios)

tensor(0.5213, grad_fn=<SubBackward0>)
tensor(0.4946, grad_fn=<SubBackward0>)
tensor(0.4841, grad_fn=<SubBackward0>)
Epoch 1/30 | Loss: 0.4841 | Avg Sharpe: -0.0194
tensor(0.4836, grad_fn=<SubBackward0>)
tensor(0.4867, grad_fn=<SubBackward0>)
tensor(0.4874, grad_fn=<SubBackward0>)
Epoch 2/30 | Loss: 0.4874 | Avg Sharpe: -0.0115
tensor(0.4877, grad_fn=<SubBackward0>)
tensor(0.4851, grad_fn=<SubBackward0>)
tensor(0.4835, grad_fn=<SubBackward0>)
Epoch 3/30 | Loss: 0.4835 | Avg Sharpe: 0.0208
tensor(0.4846, grad_fn=<SubBackward0>)
tensor(0.4845, grad_fn=<SubBackward0>)
tensor(0.4860, grad_fn=<SubBackward0>)
Epoch 4/30 | Loss: 0.4860 | Avg Sharpe: -0.0394
tensor(0.4838, grad_fn=<SubBackward0>)
tensor(0.4823, grad_fn=<SubBackward0>)
tensor(0.4814, grad_fn=<SubBackward0>)
Epoch 5/30 | Loss: 0.4814 | Avg Sharpe: 0.0483
tensor(0.4852, grad_fn=<SubBackward0>)
tensor(0.4851, grad_fn=<SubBackward0>)
tensor(0.4847, grad_fn=<SubBackward0>)
Epoch 6/30 | Loss: 0.4847 | Avg Sharpe: 0.0282
tensor(0.4834