In [186]:
import torch
import numpy as np
import pandas as pd
import pytorch_lightning as pl
from pathlib import Path
from typing import Dict, List, Any
from collections import defaultdict

src_path = Path('.').absolute().parent
data_path = src_path / 'data'

print(torch.__version__)
print(torch.cuda.is_available())

1.11.0+cu113
True


# Data Loader

In [276]:
class MetaStockDataset(torch.utils.data.Dataset):
    def __init__(
            self, 
            meta_type: str ='train', 
            meta_train_stocks: List[str] | None =None,
            data_dir: Path | str ='', 
            dtype: str ='kdd17', 
            n_stocks: int =40, 
            n_window_tasks: int =5,
            n_lag: int =1, 
            n_tasks_per_window: int =100,
            show_y_index: bool=False
        ):
        """
        dataset ref: https://arxiv.org/abs/1810.09936
        In this meta learning setting, we have 3 meta-test and 1 meta-train
        vertical = stocks, horizontal = time
                train      |    test
           A               |
           B   meta-train  |   meta-test
           C               |      (1)
           ----------------|-------------
           D   meta-test   |   meta-test
           E     (2)       |      (3)

        meta-test (1) same stock, different time
        meta-test (2) different stock, same time
        meta-test (3) different stock, different time
        use `valid_date` to split the train / test set
        """
        super().__init__()
        # for debugging purpose
        self.show_y_index = show_y_index
        # data config
        self.data_dir = Path(data_dir)
        ds_info = {
            # train: (Jan-01-2007 to Jan-01-2015)
            # val: (Jan-01-2015 to Jan-01-2016)
            # test: (Jan-01-2016 to Jan-01-2017)
            'kdd17': {
                'path': self.data_dir / 'kdd17/price_long_50',
                'date': self.data_dir / 'kdd17/trading_dates.csv',
                'train_date': '2015-01-01', 
                'val_date': '2016-01-01', 
                'test_date': '2017-01-01',
            },
            # train: (Jan-01-2014 to Aug-01-2015)
            # vali: (Aug-01-2015 to Oct-01-2015)
            # test: (Oct-01-2015 to Jan-01-2016)
            'acl18': {
                'path': self.data_dir / 'stocknet-dataset/price/raw',
                'date': self.data_dir / 'stocknet-dataset/price/trading_dates.csv',
                'train_date': '2015-08-01', 
                'val_date': '2015-10-01', 
                'test_date': '2016-01-01',
            }
        }
        
        ds_config = ds_info[dtype]

        self.window_sizes = [5, 10, 15, 20]
        self.n_window_tasks = n_window_tasks
        self.n_lag = n_lag
        self.n_tasks_per_window = n_tasks_per_window

        # get data
        self.data = {}
        for i, p in enumerate((data_path / ds_config['path']).glob('*')):
            if meta_type == 'train' and (i == n_stocks):
                # stop when it reach `n_stocks`
                break
            
            stock_symbol = p.name.rstrip('.csv')
            df_single = self.load_single_stock(p)
            if meta_type == 'train':
                df_single = df_single.loc[df_single['date'] <= ds_config['val_date']]
            else:
                if meta_type == 'test1':
                    if stock_symbol in meta_train_stocks:
                        df_single = df_single.loc[df_single['date'] > ds_config['val_date']]
                    else:
                        continue
                elif meta_type == 'test2':
                    if stock_symbol not in meta_train_stocks:
                        df_single = df_single.loc[df_single['date'] <= ds_config['val_date']]
                    else:
                        continue
                elif meta_type == 'test3':
                    if stock_symbol not in meta_train_stocks:
                        df_single = df_single.loc[df_single['date'] > ds_config['val_date']]
                    else:
                        continue
                else:
                    raise KeyError('Error argument `meta_type`, should be in (train, test1, test2, test3)')

            self.data[stock_symbol] = df_single.reset_index(drop=True)

    def load_single_stock(self, p: Path | str):
        def longterm_trend(x: pd.Series, k:int):
            return (x.rolling(k).sum().div(k*x) - 1) * 100

        df = pd.read_csv(p)
        df['Date'] = pd.to_datetime(df['Date'])
        df = df.sort_values('Date').reset_index(drop=True)
        if 'Unnamed' in df.columns:
            df.drop(columns=df.columns[7], inplace=True)
        if 'Original_Open' in df.columns:
            df.rename(columns={'Original_Open': 'Open', 'Open': 'Adj Open'}, inplace=True)

        # Open, High, Low
        z1 = (df.loc[:, ['Open', 'High', 'Low']].div(df['Close'], axis=0) - 1).rename(
            columns={'Open': 'open', 'High': 'high', 'Low': 'low'}) * 100
        # Close
        z2 = df[['Close']].pct_change().rename(columns={'Close': 'close'}) * 100
        # Adj Close
        z3 = df[['Adj Close']].pct_change().rename(columns={'Adj Close': 'adj_close'}) * 100

        z4 = []
        for k in [5, 10, 15, 20, 25, 30]:
            z4.append(df[['Adj Close']].apply(longterm_trend, k=k).rename(columns={'Adj Close': f'zd{k}'}))

        df_pct = pd.concat([df['Date'], z1, z2, z3] + z4, axis=1).rename(columns={'Date': 'date'})
        cols_max = df_pct.columns[df_pct.isnull().sum() == df_pct.isnull().sum().max()]
        df_pct = df_pct.loc[~df_pct[cols_max].isnull().values, :]

        # from https://arxiv.org/abs/1810.09936
        # Examples with movement percent ≥ 0.55% and ≤ −0.5% are 
        # identified as positive and negative examples, respectively
        df_pct['label'] = 0
        df_pct.loc[(df_pct['close'] >= 0.55), 'label'] = 1
        df_pct.loc[(df_pct['close'] <= -0.5), 'label'] = -1
        return df_pct

    def symbols(self):
        return list(self.data.keys())

    def generate_tasks(self):
        all_tasks = defaultdict()
        for window_size in self.window_sizes:
            tasks = self.generate_tasks_per_window_size(window_size)
            all_tasks[window_size] = tasks
        return all_tasks

    def generate_tasks_per_window_size(self, window_size):
        tasks = defaultdict(list)
        for i in range(self.n_tasks_per_window):
            symbol = np.random.choice(self.symbols())
            data = self.generate_task_per_window_size_and_single_stock(symbol, window_size)
            for k, v in data.items():
                tasks[k].append(v)
                    
        for k, v in tasks.items():
            tasks[k] = np.concatenate(v, axis=0)
        return tasks

    def generate_task_per_window_size_and_single_stock(self, symnbol, window_size):
        df_stock = self.data[symnbol]
        labels_indices = df_stock.index[df_stock['label'].isin([-1, 1])].to_numpy()
        labels_candidates = labels_indices[window_size:-self.n_lag]
        y_s = np.array(sorted(np.random.choice(labels_candidates, size=(self.n_window_tasks,), replace=False)))
        support, support_labels = self.generate_data(df_stock, y_s, window_size)

        y_q = labels_indices[np.arange(len(labels_indices))[np.isin(labels_indices, y_s)] + self.n_lag]
        query, query_labels = self.generate_data(df_stock, y_q, window_size)
        return {
            'support': support, 'support_labels': support_labels,
            'query': query, 'query_labels': query_labels
        }

    def generate_data(self, df, y_index, window_size):
        # generate mini task
        inputs = []
        labels = []
        for i, j in zip(y_index-window_size-1, y_index):
            inputs.append(df.iloc[i:j-1, 1:-1].to_numpy())
            if self.show_y_index:
                labels.append(j)
            else:
                labels.append(df.iloc[j, -1])
        return np.stack(inputs), np.array(labels)

In [277]:
comm_kwargs = {
    'data_dir': data_path,
    'dtype': 'kdd17',
    'n_stocks': 40,
    'n_window_tasks': 5,
    'n_lag': 1,
    'n_tasks_per_window': 50,
    'show_y_index': True
}

meta_train = MetaStockDataset(meta_type='train', meta_train_stocks=None, **comm_kwargs)
meta_test1 = MetaStockDataset(meta_type='test1', meta_train_stocks=meta_train.symbols(), **comm_kwargs)
meta_test2 = MetaStockDataset(meta_type='test2', meta_train_stocks=meta_train.symbols(), **comm_kwargs)
meta_test3 = MetaStockDataset(meta_type='test3', meta_train_stocks=meta_train.symbols(), **comm_kwargs)

len(meta_train.symbols()), len(meta_test1.symbols()), len(meta_test2.symbols()), len(meta_test3.symbols())

(40, 40, 10, 10)

In [278]:
tasks = meta_train.generate_tasks()

In [279]:
for window_size, t in tasks.items():
    print(f'Window size = {window_size}')
    for k, v in t.items():
        print(f'  {k}: {v.shape}')
        if 'labels' in k:
            print('  y_index: ', v[:5])

Window size = 5
  support: (250, 5, 11)
  support_labels: (250,)
   [ 554  604  850 1021 1696]
  query: (250, 5, 11)
  query_labels: (250,)
   [ 555  606  851 1022 1697]
Window size = 10
  support: (250, 10, 11)
  support_labels: (250,)
   [ 110  452  486 1248 2057]
  query: (250, 10, 11)
  query_labels: (250,)
   [ 112  453  487 1249 2058]
Window size = 15
  support: (250, 15, 11)
  support_labels: (250,)
   [ 459  894 1209 1630 2050]
  query: (250, 15, 11)
  query_labels: (250,)
   [ 460  895 1210 1634 2054]
Window size = 20
  support: (250, 20, 11)
  support_labels: (250,)
   [ 490 1148 1582 1933 2217]
  query: (250, 20, 11)
  query_labels: (250,)
   [ 491 1149 1584 1935 2218]


---

# Model

In [None]:
import torch
import torch.nn as nn
import pytorch_lightning as pl

class TimeAxisAttention(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, num_layers: int):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=False)
        self.lnorm = nn.LayerNorm(hidden_size)

    def forward(self, x: torch.tensor, rt_attn=False):
        # x: (D, W, L)
        o, (h, _) = self.lstm(x) # o: (D, W, H) / h: (1, D, H)
        score = torch.bmm(o, h.permute(1, 2, 0)) # (D, W, H) x (D, H, 1)
        tx_attn = torch.softmax(score, 1).squeeze(-1)  # (D, W)
        context = torch.bmm(tx_attn.unsqueeze(1), o).squeeze(1)  # (D, 1, W) x (D, W, H)
        normed_context = self.lnorm(context)
        if rt_attn:
            return normed_context, tx_attn
        else:
            return normed_context, None
            
class DataAxisAttention(nn.Module):
    def __init__(self, hidden_size, n_heads, drop_rate=0.1):
        super().__init__()
        self.multi_attn = nn.MultiheadAttention(hidden_size, n_heads, batch_first=True)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, 4*hidden_size),
            nn.ReLU(),
            nn.Linear(4*hidden_size, hidden_size)
        )
        self.lnorm1 = nn.LayerNorm(hidden_size)
        self.lnorm2 = nn.LayerNorm(hidden_size)
        self.drop_out = nn.Dropout(drop_rate)

    def forward(self, hm: torch.tensor, rt_attn=False):
        # Forward Multi-head Attention
        residual = hm
        # hm_hat: (D, H), dx_attn: (D, D) 
        hm_hat, dx_attn = self.multi_attn(hm, hm, hm)
        hm_hat = self.lnorm1(residual + self.drop_out(hm_hat))

        # Forward FFN
        residual = hm_hat
        # hp: (D, H)
        hp = torch.tanh(hm + hm_hat + self.mlp(hm + hm_hat))
        hp = self.lnorm2(residual + self.drop_out(hp))

        if rt_attn:
            return hp, dx_attn
        else:
            return hp, None

class DTML(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, n_heads, beta=0.1, drop_rate=0.1):
        super.__init__()
        self.beta = beta
        self.txattention = TimeAxisAttention(input_size, hidden_size, num_layers)
        self.dxattention = DataAxisAttention(hidden_size, n_heads, drop_rate)
        self.linear = nn.Linear(hidden_size, 1)

    def forward(self, stocks, index, rt_attn=False):
        # stocks: (W, D, L) for a single time stamp
        # index: (W, 1, L) for a single time stamp
        # W: length of observations
        # D: number of stocks
        # L: number of features
        
        # Time-Axis Attention
        # c_stocks: (D, H) / tx_attn_stocks: (D, W)
        c_stocks, tx_attn_stocks = self.txattention(stocks.transpose(1, 0), rt_attn=rt_attn)
        # c_index: (1, H) / tx_attn_index: (1, W)
        c_index, tx_attn_index = self.txattention(index.transpose(1, 0), rt_attn=rt_attn)
        
        # Context Aggregation
        # Multi-level Context
        # hm: (D, H)
        hm = c_stocks + self.beta * c_index
        # The Effect of Global Contexts
        # effect: (D, D)
        effect = c_stocks.mm(c_stocks.transpose(0, 1)) + \
            self.beta * c_index.mm(c_stocks.transpose(1, 0)) + \
            self.beta**2 * torch.mm(c_index, c_index.transpose(0, 1)) 

        # Data-Axis Attention
        # hp: (D, H) / dx_attn: (D, D)
        hp, dx_attn_stocks = self.dxattention(hm, rt_attn=rt_attn)
        # output: (D, 1)
        output = self.linear(hp)

        return {
            'output': output,
            'tx_attn_stocks': tx_attn_stocks,
            'tx_attn_index': tx_attn_index,
            'dx_attn_stocks': dx_attn_stocks,
            'effect': effect
        }