### Dev's code

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import logging
import wrds
import math
import gym
from gym import spaces
from torch.optim.optimizer import Optimizer

# To this:
import torch.optim as optim

logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

db = wrds.Connection()

class AlphaPortfolioData(Dataset):
    def __init__(self, start_year=2014, end_year=2020, final_year=2016, lookback=12, G=2):
        super().__init__()
        self.lookback = lookback
        self.G = G
        self.merged, self.final_data = self._load_wrds_data(start_year, end_year, final_year)
        self.unique_permnos = sorted(self.final_data['permno'].unique())
        self.global_max_assets = len(self.unique_permnos)
        self.permno_to_idx = {permno: idx for idx, permno in enumerate(self.unique_permnos)}
        self.sequences, self.future_returns, self.masks = self._create_sequences()

    def _load_wrds_data(self, start_year, end_year, final_year):

        permno_list = []
        combined_data = pd.DataFrame()

        for year in range(start_year, end_year+1):
            
            start_date = f'{year}-01-01'
            end_date = f'{year}-12-31'
            
            crsp_query = f"""
                SELECT a.permno, a.date, a.ret, a.prc, a.shrout, 
                    a.vol, a.cfacshr, a.altprc, a.retx
                FROM crsp.msf AS a
                WHERE a.date BETWEEN '{start_date}' AND '{end_date}'
                AND a.permno IN (
                    SELECT permno FROM crsp.msenames 
                    WHERE exchcd BETWEEN 1 AND 3  
                        AND shrcd IN (10, 11)       
                    )
                """
            crsp_data = db.raw_sql(crsp_query)

            query_ticker = """
                SELECT permno, namedt, nameenddt, ticker
                FROM crsp.stocknames
            """
            
            stocknames = db.raw_sql(query_ticker)
            crsp_data = crsp_data.merge(stocknames.drop_duplicates(subset=['permno']), on='permno', how='left')
            crsp_data = crsp_data.dropna(subset=['ticker'])

            crsp_data['mktcap'] = (crsp_data['prc'].abs() * crsp_data['shrout'] * 1000) / 1e6  # In millions
            crsp_data['year'] = pd.to_datetime(crsp_data['date']).dt.year
            crsp_data = crsp_data.dropna(subset=['mktcap'])
            
            top_50_permnos_by_year = crsp_data.groupby('permno')['mktcap'].agg(['max']).reset_index().sort_values(by='max', ascending=False).head(50)['permno'].unique()
            permno_list.extend(top_50_permnos_by_year)
            
            combined_data = pd.concat([combined_data, crsp_data[crsp_data['permno'].isin(permno_list)]], axis=0)

        combined_data = combined_data[['permno', 'ticker', 'date', 'ret', 'prc', 'shrout', 'vol', 'mktcap', 'year']]
        combined_data['date'] = pd.to_datetime(combined_data['date'])

        start_date = f'{start_year}-01-01'
        end_date = f'{end_year}-12-31'

        # Query Compustat quarterly data with release dates (rdq)
        fund_query = f"""
            SELECT gvkey, datadate, rdq, saleq
            FROM comp.fundq
            WHERE indfmt = 'INDL' AND datafmt = 'STD' AND popsrc = 'D' AND consol = 'C'
            AND datadate BETWEEN '{start_date}' AND '{end_date}'
            AND rdq IS NOT NULL
        """
        fund = db.raw_sql(fund_query)
        fund['rdq'] = pd.to_datetime(fund['rdq'])
        fund['datadate'] = pd.to_datetime(fund['datadate'])

        # Link Compustat GVKEY to CRSP PERMNO
        link_query = """
            SELECT lpermno AS permno, gvkey, linkdt, linkenddt
            FROM crsp.ccmxpf_linktable
            WHERE linktype IN ('LU', 'LC') AND linkprim IN ('P', 'C')
        """
        link = db.raw_sql(link_query)
        fund = pd.merge(fund, link, on='gvkey', how='left')
        fund = fund.dropna(subset=['permno'])

        # Sort both datasets by date
        combined_data_sorted = combined_data.sort_values('date')
        fund_sorted = fund.sort_values('rdq')
        fund_sorted['permno'] = fund_sorted['permno'].astype(int)
        combined_data_sorted['permno'] = combined_data_sorted['permno'].astype(int)

        merged = pd.merge_asof(
            combined_data_sorted,
            fund_sorted,
            left_on='date',
            right_on='rdq',
            by='permno',
            direction='backward'
        )
        # merged = merged.dropna(subset=['rdq', 'ticker'])
        merged = merged.sort_values(by='date')
        merged = merged[['permno', 'ticker', 'date', 'ret', 'prc','vol', 'mktcap', 'gvkey', 'rdq', 'saleq']]
        merged = merged.ffill()

        unique_dates = merged['date'].unique()
        date_mapping = {date: i for i, date in enumerate(sorted(unique_dates))}
        merged['date_mapped'] = merged['date'].map(date_mapping)

        merged['year'] = pd.to_datetime(merged['date']).dt.year
        final_data = merged[merged['year'] >= final_year]

        
        return merged, final_data


    def _create_sequences(self):
        data = self.final_data
        lookback = self.lookback
        unique_dates = pd.to_datetime(data['date'].unique())
        unique_dates_sorted = np.sort(unique_dates)
        num_features = 6  # Based on []'permno', 'ret', 'prc', 'vol', 'mktcap', 'saleq']

        sequences = []
        future_returns = []
        masks = []

        for start_idx in tqdm(range(len(unique_dates_sorted) - 2 * lookback+1)):
            hist_start = unique_dates_sorted[start_idx]
            hist_end = unique_dates_sorted[start_idx + lookback - 1]
            future_start = unique_dates_sorted[start_idx + lookback]
            future_end = unique_dates_sorted[start_idx + 2 * lookback-1]

            print(f'Hist start: {hist_start}, Hist end: {hist_end}, Future start: {future_start}, Future end: {future_end}')

            # Initialize batch arrays with zeros
            batch_features = np.zeros((self.global_max_assets, lookback, num_features))
            batch_returns = np.zeros((self.global_max_assets, lookback))
            batch_mask = np.zeros(self.global_max_assets, dtype=bool)

            for permno in self.unique_permnos:
                idx = self.permno_to_idx[permno]

                # Historical data for the current window
                hist_data = data[
                    (data['permno'] == permno) &
                    (data['date'] >= hist_start) &
                    (data['date'] <= hist_end)
                ].sort_values('date')

                # Future returns for the next window
                future_data = data[
                    (data['permno'] == permno) &
                    (data['date'] >= future_start) &
                    (data['date'] <= future_end)
                ]['ret'].values

                # Check if both periods have complete data
                if len(hist_data) == lookback and len(future_data) == lookback:
                    features = hist_data[['permno', 'ret', 'prc', 'vol', 'mktcap', 'saleq']].values
                    batch_features[idx] = features
                    batch_returns[idx] = future_data
                    batch_mask[idx] = True

            sequences.append(batch_features)
            future_returns.append(batch_returns)
            masks.append(batch_mask)

        # Convert to tensors
        sequences_tensor = torch.tensor(np.array(sequences), dtype=torch.float32)

        # NEW
        sequences_tensor = sequences_tensor.view(-1, self.lookback, num_features)
        future_returns_tensor = torch.tensor(np.array(future_returns), dtype=torch.float32)
        masks_tensor = torch.tensor(np.array(masks), dtype=torch.bool)

        return sequences_tensor, future_returns_tensor, masks_tensor

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        return self.sequences[idx], self.future_returns[idx], self.masks[idx]

WRDS recommends setting up a .pgpass file.
Created .pgpass file successfully.
You can create this file yourself at any time with the create_pgpass_file() function.
Loading library list...
Done


### My Code

In [2]:
# Define SREM Module (Transformer Encoder)
class SREM(nn.Module):
    # def __init__(self, input_dim, embed_dim, num_heads, num_layers):
    #     super(SREM, self).__init__()
    #     encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads)
    #     self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
    #     self.fc = nn.Linear(embed_dim, embed_dim)

    # def forward(self, x):
    #     x = self.transformer(x)
    #     x = self.fc(x.mean(dim=1))
    #     return x
    def __init__(self, input_dim, embed_dim, num_heads, num_layers, dropout=0.1):
        super(SREM, self).__init__()
        self.embedding = nn.Linear(input_dim, embed_dim)
        encoder_layers = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dropout=dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_layers)
        self.output_layer = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        x = self.embedding(x)  # Linear transformation to embedding space
        x = self.transformer_encoder(x)  # Apply Transformer Encoder
        x = self.output_layer(x[:, -1, :])  # Extract last time-step representation
        return x

# Define CAAN Module (Cross-Asset Attention Network)
class CAAN(nn.Module):
    def __init__(self, embed_dim):
        super(CAAN, self).__init__()
        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)
        self.fc = nn.Linear(embed_dim, 1)

    def forward(self, x):
        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)
        attn = torch.softmax(Q @ K.transpose(-2, -1) / np.sqrt(Q.size(-1)), dim=-1)
        attention_output = attn @ V
        winner_scores = torch.tanh(self.fc(attention_output).squeeze(-1))
        return winner_scores

# Portfolio Generator
class PortfolioGenerator:
    def __init__(self, long_short_ratio=0.1):
        self.long_short_ratio = long_short_ratio

    def generate_portfolio(self, scores):
        num_assets = len(scores)
        num_long = num_short = int(self.long_short_ratio * num_assets)
        
        ranked = torch.argsort(scores, descending=True)
        long_assets = ranked[:num_long]
        short_assets = ranked[-num_short:]
        
        long_weights = torch.softmax(scores[long_assets], dim=0)
        short_weights = torch.softmax(-scores[short_assets], dim=0)
        
        portfolio = torch.zeros(num_assets)
        portfolio[long_assets] = long_weights
        portfolio[short_assets] = -short_weights
        
        return portfolio


def train_model(data, model_srem, model_caan, portfolio_generator, optimizer, criterion, num_epochs=10):
    model_srem.train()
    model_caan.train()

    for epoch in range(num_epochs):
        total_loss = 0

        for sequences, future_returns, masks in DataLoader(data, batch_size=32, shuffle=True):
            optimizer.zero_grad()


            # Pass sequences through the SREM (to capture asset trends)
            srem_output = model_srem(sequences)

            # Pass through CAAN (to model interrelationships)
            winner_scores = model_caan(srem_output)


            # Generate portfolio weights from the winner scores
            portfolio = portfolio_generator.generate_portfolio(winner_scores)

            # Compute loss (e.g., portfolio returns)
            portfolio_returns = torch.sum(portfolio * future_returns, dim=1)
            loss = -portfolio_returns.mean()  # Maximize returns (minimize negative returns)
            
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss:.4f}")

# Hyperparameters and Model Initialization
input_dim = 6  # Features: permno, ret, prc, vol, mktcap, saleq
embed_dim = 64
num_heads = 4
num_layers = 3

model_srem = SREM(input_dim=input_dim, embed_dim=embed_dim, num_heads=num_heads, num_layers=num_layers)
model_caan = CAAN(embed_dim=embed_dim)
portfolio_generator = PortfolioGenerator(long_short_ratio=0.1)

optimizer = optim.Adam(list(model_srem.parameters()) + list(model_caan.parameters()), lr=1e-4)
criterion = nn.MSELoss()  # Placeholder; could be customized based on your goal

# Get Data
data = AlphaPortfolioData(start_year=2014, end_year=2020, final_year=2016, lookback=12, G=2)

# Train the model
train_model(data, model_srem, model_caan, portfolio_generator, optimizer, criterion)

  3%|▎         | 1/37 [00:00<00:03,  9.23it/s]

Hist start: 2016-01-29T00:00:00.000000000, Hist end: 2016-12-30T00:00:00.000000000, Future start: 2017-01-31T00:00:00.000000000, Future end: 2017-12-29T00:00:00.000000000
Hist start: 2016-02-29T00:00:00.000000000, Hist end: 2017-01-31T00:00:00.000000000, Future start: 2017-02-28T00:00:00.000000000, Future end: 2018-01-31T00:00:00.000000000
Hist start: 2016-03-31T00:00:00.000000000, Hist end: 2017-02-28T00:00:00.000000000, Future start: 2017-03-31T00:00:00.000000000, Future end: 2018-02-28T00:00:00.000000000


 14%|█▎        | 5/37 [00:00<00:03,  9.75it/s]

Hist start: 2016-04-29T00:00:00.000000000, Hist end: 2017-03-31T00:00:00.000000000, Future start: 2017-04-28T00:00:00.000000000, Future end: 2018-03-29T00:00:00.000000000
Hist start: 2016-05-31T00:00:00.000000000, Hist end: 2017-04-28T00:00:00.000000000, Future start: 2017-05-31T00:00:00.000000000, Future end: 2018-04-30T00:00:00.000000000
Hist start: 2016-06-30T00:00:00.000000000, Hist end: 2017-05-31T00:00:00.000000000, Future start: 2017-06-30T00:00:00.000000000, Future end: 2018-05-31T00:00:00.000000000


 22%|██▏       | 8/37 [00:00<00:02,  9.94it/s]

Hist start: 2016-07-29T00:00:00.000000000, Hist end: 2017-06-30T00:00:00.000000000, Future start: 2017-07-31T00:00:00.000000000, Future end: 2018-06-29T00:00:00.000000000
Hist start: 2016-08-31T00:00:00.000000000, Hist end: 2017-07-31T00:00:00.000000000, Future start: 2017-08-31T00:00:00.000000000, Future end: 2018-07-31T00:00:00.000000000
Hist start: 2016-09-30T00:00:00.000000000, Hist end: 2017-08-31T00:00:00.000000000, Future start: 2017-09-29T00:00:00.000000000, Future end: 2018-08-31T00:00:00.000000000


 30%|██▉       | 11/37 [00:01<00:02,  9.90it/s]

Hist start: 2016-10-31T00:00:00.000000000, Hist end: 2017-09-29T00:00:00.000000000, Future start: 2017-10-31T00:00:00.000000000, Future end: 2018-09-28T00:00:00.000000000
Hist start: 2016-11-30T00:00:00.000000000, Hist end: 2017-10-31T00:00:00.000000000, Future start: 2017-11-30T00:00:00.000000000, Future end: 2018-10-31T00:00:00.000000000
Hist start: 2016-12-30T00:00:00.000000000, Hist end: 2017-11-30T00:00:00.000000000, Future start: 2017-12-29T00:00:00.000000000, Future end: 2018-11-30T00:00:00.000000000


 35%|███▌      | 13/37 [00:01<00:02,  9.91it/s]

Hist start: 2017-01-31T00:00:00.000000000, Hist end: 2017-12-29T00:00:00.000000000, Future start: 2018-01-31T00:00:00.000000000, Future end: 2018-12-31T00:00:00.000000000
Hist start: 2017-02-28T00:00:00.000000000, Hist end: 2018-01-31T00:00:00.000000000, Future start: 2018-02-28T00:00:00.000000000, Future end: 2019-01-31T00:00:00.000000000
Hist start: 2017-03-31T00:00:00.000000000, Hist end: 2018-02-28T00:00:00.000000000, Future start: 2018-03-29T00:00:00.000000000, Future end: 2019-02-28T00:00:00.000000000


 43%|████▎     | 16/37 [00:01<00:02,  9.94it/s]

Hist start: 2017-04-28T00:00:00.000000000, Hist end: 2018-03-29T00:00:00.000000000, Future start: 2018-04-30T00:00:00.000000000, Future end: 2019-03-29T00:00:00.000000000
Hist start: 2017-05-31T00:00:00.000000000, Hist end: 2018-04-30T00:00:00.000000000, Future start: 2018-05-31T00:00:00.000000000, Future end: 2019-04-30T00:00:00.000000000
Hist start: 2017-06-30T00:00:00.000000000, Hist end: 2018-05-31T00:00:00.000000000, Future start: 2018-06-29T00:00:00.000000000, Future end: 2019-05-31T00:00:00.000000000


 51%|█████▏    | 19/37 [00:01<00:01,  9.54it/s]

Hist start: 2017-07-31T00:00:00.000000000, Hist end: 2018-06-29T00:00:00.000000000, Future start: 2018-07-31T00:00:00.000000000, Future end: 2019-06-28T00:00:00.000000000
Hist start: 2017-08-31T00:00:00.000000000, Hist end: 2018-07-31T00:00:00.000000000, Future start: 2018-08-31T00:00:00.000000000, Future end: 2019-07-31T00:00:00.000000000
Hist start: 2017-09-29T00:00:00.000000000, Hist end: 2018-08-31T00:00:00.000000000, Future start: 2018-09-28T00:00:00.000000000, Future end: 2019-08-30T00:00:00.000000000


 62%|██████▏   | 23/37 [00:02<00:01,  9.86it/s]

Hist start: 2017-10-31T00:00:00.000000000, Hist end: 2018-09-28T00:00:00.000000000, Future start: 2018-10-31T00:00:00.000000000, Future end: 2019-09-30T00:00:00.000000000
Hist start: 2017-11-30T00:00:00.000000000, Hist end: 2018-10-31T00:00:00.000000000, Future start: 2018-11-30T00:00:00.000000000, Future end: 2019-10-31T00:00:00.000000000
Hist start: 2017-12-29T00:00:00.000000000, Hist end: 2018-11-30T00:00:00.000000000, Future start: 2018-12-31T00:00:00.000000000, Future end: 2019-11-29T00:00:00.000000000


 68%|██████▊   | 25/37 [00:02<00:01,  9.82it/s]

Hist start: 2018-01-31T00:00:00.000000000, Hist end: 2018-12-31T00:00:00.000000000, Future start: 2019-01-31T00:00:00.000000000, Future end: 2019-12-31T00:00:00.000000000
Hist start: 2018-02-28T00:00:00.000000000, Hist end: 2019-01-31T00:00:00.000000000, Future start: 2019-02-28T00:00:00.000000000, Future end: 2020-01-31T00:00:00.000000000
Hist start: 2018-03-29T00:00:00.000000000, Hist end: 2019-02-28T00:00:00.000000000, Future start: 2019-03-29T00:00:00.000000000, Future end: 2020-02-28T00:00:00.000000000


 76%|███████▌  | 28/37 [00:02<00:00,  9.86it/s]

Hist start: 2018-04-30T00:00:00.000000000, Hist end: 2019-03-29T00:00:00.000000000, Future start: 2019-04-30T00:00:00.000000000, Future end: 2020-03-31T00:00:00.000000000
Hist start: 2018-05-31T00:00:00.000000000, Hist end: 2019-04-30T00:00:00.000000000, Future start: 2019-05-31T00:00:00.000000000, Future end: 2020-04-30T00:00:00.000000000
Hist start: 2018-06-29T00:00:00.000000000, Hist end: 2019-05-31T00:00:00.000000000, Future start: 2019-06-28T00:00:00.000000000, Future end: 2020-05-29T00:00:00.000000000


 84%|████████▍ | 31/37 [00:03<00:00,  9.85it/s]

Hist start: 2018-07-31T00:00:00.000000000, Hist end: 2019-06-28T00:00:00.000000000, Future start: 2019-07-31T00:00:00.000000000, Future end: 2020-06-30T00:00:00.000000000
Hist start: 2018-08-31T00:00:00.000000000, Hist end: 2019-07-31T00:00:00.000000000, Future start: 2019-08-30T00:00:00.000000000, Future end: 2020-07-31T00:00:00.000000000
Hist start: 2018-09-28T00:00:00.000000000, Hist end: 2019-08-30T00:00:00.000000000, Future start: 2019-09-30T00:00:00.000000000, Future end: 2020-08-31T00:00:00.000000000


 92%|█████████▏| 34/37 [00:03<00:00,  9.84it/s]

Hist start: 2018-10-31T00:00:00.000000000, Hist end: 2019-09-30T00:00:00.000000000, Future start: 2019-10-31T00:00:00.000000000, Future end: 2020-09-30T00:00:00.000000000
Hist start: 2018-11-30T00:00:00.000000000, Hist end: 2019-10-31T00:00:00.000000000, Future start: 2019-11-29T00:00:00.000000000, Future end: 2020-10-30T00:00:00.000000000
Hist start: 2018-12-31T00:00:00.000000000, Hist end: 2019-11-29T00:00:00.000000000, Future start: 2019-12-31T00:00:00.000000000, Future end: 2020-11-30T00:00:00.000000000


100%|██████████| 37/37 [00:03<00:00,  9.87it/s]


Hist start: 2019-01-31T00:00:00.000000000, Hist end: 2019-12-31T00:00:00.000000000, Future start: 2020-01-31T00:00:00.000000000, Future end: 2020-12-31T00:00:00.000000000


IndexError: index 672 is out of bounds for dimension 0 with size 37