In [1]:
import torch
import numpy as np
import pandas as pd
import pytorch_lightning as pl
from pathlib import Path
from typing import Dict, List, Any, Generator

from collections import defaultdict

src_path = Path('.').absolute().parent
data_path = src_path / 'data'

print(torch.__version__)
print(torch.cuda.is_available())

1.11.0+cu113
True


# Task

There are many way to construct the tasks.

Here we define a task is sample from a distribution $\tau_t \sim P(\tau \vert s_i, w_j)$ where $s_i$ is a fixed single stock and $w_j$ is a fixed window size. So, task $\tau_{t} = \{ (\mathbf{X}_{(t-w):t}, y_{t}) \}$, where $t > w$.

Our meta-dataset is $\mathcal{D}^{tr}_{\tau_t} = \{ X_{(t-w):t}, y_{t} \}$,  $\mathcal{D}^{val}_{\tau_t} = \{ X_{(t+k-w):t+k}, y_{t+k} \}$ where $t+k$ is the next rise/fall timestep after $t$.

Example inputs $X$ for $\mathcal{D}^{tr}_{\tau_t}$ with 5 samples, 4 fixed window sizes and 3 stocks.

| stock \ window size | $w_1=5$ | $w_2=10$ | $w_3=15$ | $w_4=20$ |
|---|---|---|---|---|
| $s_1$ | $X_{(t_1-w_1):t_1}^{s_1}$ |  $X_{(t_1-w_2):t_1}^{s_1}$ |  $X_{(t_1-w_3):t_1}^{s_1}$ |  $X_{(t_1-w_4):t_1}^{s_1}$ |
|       | $X_{(t_2-w_1):t_2}^{s_1}$ |  $X_{(t_2-w_1):t_2}^{s_1}$ |  $X_{(t_2-w_1):t_2}^{s_1}$ |  $X_{(t_2-w_1):t_2}^{s_1}$ |
|       | $X_{(t_3-w_1):t_3}^{s_1}$ |  $X_{(t_3-w_1):t_3}^{s_1}$ |  $X_{(t_3-w_1):t_3}^{s_1}$ |  $X_{(t_3-w_1):t_3}^{s_1}$ |
|       | $X_{(t_4-w_1):t_4}^{s_1}$ |  $X_{(t_4-w_1):t_4}^{s_1}$ |  $X_{(t_4-w_1):t_4}^{s_1}$ |  $X_{(t_4-w_1):t_4}^{s_1}$ |
|       | $X_{(t_5-w_1):t_5}^{s_1}$ |  $X_{(t_5-w_1):t_5}^{s_1}$ |  $X_{(t_5-w_1):t_5}^{s_1}$ |  $X_{(t_5-w_1):t_5}^{s_1}$ |
| $s_2$ | $X_{(t_1-w_1):t_1}^{s_2}$ |  $X_{(t_1-w_1):t_1}^{s_2}$ |  $X_{(t_1-w_1):t_1}^{s_2}$ |  $X_{(t_1-w_1):t_1}^{s_2}$ |
|       | $X_{(t_2-w_1):t_2}^{s_2}$ |  $X_{(t_2-w_1):t_2}^{s_2}$ |  $X_{(t_2-w_1):t_2}^{s_2}$ |  $X_{(t_2-w_1):t_2}^{s_2}$ |
|       | $X_{(t_3-w_1):t_3}^{s_2}$ |  $X_{(t_3-w_1):t_3}^{s_2}$ |  $X_{(t_3-w_1):t_3}^{s_2}$ |  $X_{(t_3-w_1):t_3}^{s_2}$ |
|       | $X_{(t_4-w_1):t_4}^{s_2}$ |  $X_{(t_4-w_1):t_4}^{s_2}$ |  $X_{(t_4-w_1):t_4}^{s_2}$ |  $X_{(t_4-w_1):t_4}^{s_2}$ |
|       | $X_{(t_5-w_1):t_5}^{s_2}$ |  $X_{(t_5-w_1):t_5}^{s_2}$ |  $X_{(t_5-w_1):t_5}^{s_2}$ |  $X_{(t_5-w_1):t_5}^{s_2}$ |
| $s_3$ | $X_{(t_1-w_1):t_1}^{s_3}$ |  $X_{(t_1-w_1):t_1}^{s_3}$ |  $X_{(t_1-w_1):t_1}^{s_3}$ |  $X_{(t_1-w_1):t_1}^{s_3}$ |
|       | $X_{(t_2-w_1):t_2}^{s_3}$ |  $X_{(t_2-w_1):t_2}^{s_3}$ |  $X_{(t_2-w_1):t_2}^{s_3}$ |  $X_{(t_2-w_1):t_2}^{s_3}$ |
|       | $X_{(t_3-w_1):t_3}^{s_3}$ |  $X_{(t_3-w_1):t_3}^{s_3}$ |  $X_{(t_3-w_1):t_3}^{s_3}$ |  $X_{(t_3-w_1):t_3}^{s_3}$ |
|       | $X_{(t_4-w_1):t_4}^{s_3}$ |  $X_{(t_4-w_1):t_4}^{s_3}$ |  $X_{(t_4-w_1):t_4}^{s_3}$ |  $X_{(t_4-w_1):t_4}^{s_3}$ |
|       | $X_{(t_5-w_1):t_5}^{s_3}$ |  $X_{(t_5-w_1):t_5}^{s_3}$ |  $X_{(t_5-w_1):t_5}^{s_3}$ |  $X_{(t_5-w_1):t_5}^{s_3}$ |
 
However there is another way to think about tha task. Since in some stock market theories, each time stamp for a stock has a different distribution. So based on former definition, we also treat each timestamp to a task $\tau \sim P(\tau \vert s_i, w_j, t_k)$ 

this will make a little difference when we calculate the $\mathbf{z}$
1. each stock has different distribution, but each timestamp has same distribution for the same stock. $\mathbf{z}_s \sim \mathcal{N}(\mu_s, \sigma_s)$
2. each timestamp of stock has different distribution: $\mathbf{z}_{(s, t)} \sim \mathcal{N}(\mu_{(s, t)}, \sigma_{(s, t)})$


# Scenario

pop trading

# Data Loader

Need Python 3.10

In [2]:
def flatten(li: List[Any]) -> Generator:
    """flatten nested list
    ```python
    x = [[[1], 2], [[[[3]], 4, 5], 6], 7, [[8]], [9], 10]
    print(type(flatten(x)))
    # <generator object flatten at 0x00000212BF603CC8>
    print(list(flatten(x)))
    # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    ```
    Args:
        li (List[Any]): any kinds of list
    Yields:
        Generator: flattened list generator
    """
    for ele in li:
        if isinstance(ele, list) or isinstance(ele, tuple):
            yield from flatten(ele)
        else:
            yield ele

In [3]:
import yaml
import inspect

class ARGProcessor():
    def __init__(self, setting_file):
        self.setting_file = setting_file
        self.load()

    def load(self):
        with open(self.setting_file) as file:
            self.kwargs = yaml.load(file, Loader=yaml.FullLoader)
    
    def get_args(self, cls):
        cls_kwargs = {
            k: self.kwargs.get(k) for k in inspect.signature(cls.__init__).parameters.keys() 
            if self.kwargs.get(k) is not None
        }
        return cls_kwargs

class MetaStockDataset(torch.utils.data.Dataset):
    def __init__(
            self, 
            meta_type: str ='train', 
            meta_train_stocks: List[str] | None =None,
            data_dir: Path | str ='', 
            dtype: str ='kdd17', 
            n_train_stock: int =40, 
            n_sample: int =5,
            n_lag: int =1, 
            n_stock: int =5,
            keep_support_history: bool=False,
            show_y_index: bool=False
        ):
        """
        dataset ref: https://arxiv.org/abs/1810.09936
        In this meta learning setting, we have 3 meta-test and 1 meta-train
        vertical = stocks, horizontal = time
                train      |    test
           A               |
           B   meta-train  |   meta-test
           C               |      (1)
           ----------------|-------------
           D   meta-test   |   meta-test
           E     (2)       |      (3)

        meta-test (1) same stock, different time
        meta-test (2) different stock, same time
        meta-test (3) different stock, different time
        use `valid_date` to split the train / test set
        """
        super().__init__()
        # for debugging purpose
        self.labels_dict = {
            'fall': 0, 'rise': 1, 'unchange': 2 
        }
        self.keep_support_history = keep_support_history
        self.show_y_index = show_y_index
        # data config
        self.data_dir = Path(data_dir)
        ds_info = {
            # train: (Jan-01-2007 to Jan-01-2015)
            # val: (Jan-01-2015 to Jan-01-2016)
            # test: (Jan-01-2016 to Jan-01-2017)
            'kdd17': {
                'path': self.data_dir / 'kdd17/price_long_50',
                'date': self.data_dir / 'kdd17/trading_dates.csv',
                'train_date': '2015-01-01', 
                'val_date': '2016-01-01', 
                'test_date': '2017-01-01',
            },
            # train: (Jan-01-2014 to Aug-01-2015)
            # vali: (Aug-01-2015 to Oct-01-2015)
            # test: (Oct-01-2015 to Jan-01-2016)
            'acl18': {
                'path': self.data_dir / 'stocknet-dataset/price/raw',
                'date': self.data_dir / 'stocknet-dataset/price/trading_dates.csv',
                'train_date': '2015-08-01', 
                'val_date': '2015-10-01', 
                'test_date': '2016-01-01',
            }
        }
        
        ds_config = ds_info[dtype]

        self.window_sizes = [5, 10, 15, 20]
        self.n_sample = n_sample
        self.n_lag = n_lag
        self.n_stock = n_stock

        # get data
        self.data = {}
        for i, p in enumerate((data_path / ds_config['path']).glob('*')):
            if meta_type == 'train' and (i == n_train_stock):
                # stop when it reach `n_train_stock`
                break
            
            stock_symbol = p.name.rstrip('.csv')
            df_single = self.load_single_stock(p)
            if meta_type == 'train':
                df_single = df_single.loc[df_single['date'] <= ds_config['val_date']]
            else:
                if meta_type == 'test1':
                    if stock_symbol in meta_train_stocks:
                        df_single = df_single.loc[df_single['date'] > ds_config['val_date']]
                    else:
                        continue
                elif meta_type == 'test2':
                    if stock_symbol not in meta_train_stocks:
                        df_single = df_single.loc[df_single['date'] <= ds_config['val_date']]
                    else:
                        continue
                elif meta_type == 'test3':
                    if stock_symbol not in meta_train_stocks:
                        df_single = df_single.loc[df_single['date'] > ds_config['val_date']]
                    else:
                        continue
                else:
                    raise KeyError('Error argument `meta_type`, should be in (train, test1, test2, test3)')

            self.data[stock_symbol] = df_single.reset_index(drop=True)

    def load_single_stock(self, p: Path | str):
        def longterm_trend(x: pd.Series, k:int):
            return (x.rolling(k).sum().div(k*x) - 1) * 100

        df = pd.read_csv(p)
        df['Date'] = pd.to_datetime(df['Date'])
        df = df.sort_values('Date').reset_index(drop=True)
        if 'Unnamed' in df.columns:
            df.drop(columns=df.columns[7], inplace=True)
        if 'Original_Open' in df.columns:
            df.rename(columns={'Original_Open': 'Open', 'Open': 'Adj Open'}, inplace=True)

        # Open, High, Low
        z1 = (df.loc[:, ['Open', 'High', 'Low']].div(df['Close'], axis=0) - 1).rename(
            columns={'Open': 'open', 'High': 'high', 'Low': 'low'}) * 100
        # Close
        z2 = df[['Close']].pct_change().rename(columns={'Close': 'close'}) * 100
        # Adj Close
        z3 = df[['Adj Close']].pct_change().rename(columns={'Adj Close': 'adj_close'}) * 100

        z4 = []
        for k in [5, 10, 15, 20, 25, 30]:
            z4.append(df[['Adj Close']].apply(longterm_trend, k=k).rename(columns={'Adj Close': f'zd{k}'}))

        df_pct = pd.concat([df['Date'], z1, z2, z3] + z4, axis=1).rename(columns={'Date': 'date'})
        cols_max = df_pct.columns[df_pct.isnull().sum() == df_pct.isnull().sum().max()]
        df_pct = df_pct.loc[~df_pct[cols_max].isnull().values, :]

        # from https://arxiv.org/abs/1810.09936
        # Examples with movement percent ≥ 0.55% and ≤ −0.5% are 
        # identified as positive and negative examples, respectively
        df_pct['label'] = self.labels_dict['unchange']
        df_pct.loc[(df_pct['close'] >= 0.55), 'label'] = self.labels_dict['rise']
        df_pct.loc[(df_pct['close'] <= -0.5), 'label'] = self.labels_dict['fall']
        
        return df_pct

    def check_func(self, x):
        checks = [self.labels_dict['fall'], self.labels_dict['rise']]
        return np.isin(x.values[0], checks) and np.isin(x.values[1], checks)

    @property
    def symbols(self):
        return list(self.data.keys())

    def generate_tasks(self):
        all_tasks = defaultdict()
        for window_size in self.window_sizes:
            tasks = self.generate_tasks_per_window_size(window_size)
            all_tasks[window_size] = tasks
        return all_tasks

    def generate_tasks_per_window_size(self, window_size):
        # tasks: {X: (n_stock, n_sample, window_size, n_in), y: (n_stock, n_sample)}
        tasks = defaultdict(list)
        for i in range(self.n_stock):
            symbol = np.random.choice(self.symbols)
            # data: {X: (n_sample, n_in), y: (n_sample,)}
            data = self.generate_task_per_window_size_and_single_stock(symbol, window_size)
            for k, v in data.items():
                tasks[k].append(v)

        for k, v in tasks.items():
            tasks[k] = list(flatten(v))

        return tasks

    def generate_task_per_window_size_and_single_stock(self, symnbol, window_size):
        df_stock = self.data[symnbol]
        # condition: only continious rise or fall
        condition = df_stock['label'].rolling(2).apply(self.check_func).shift(-meta_train.n_lag).fillna(0.0).astype(bool)
        labels_indices = df_stock.index[condition].to_numpy()
        # code for jumpped tags like [1(support), 0, 0, 1(query)]
        # labels_indices = df_stock.index[df_stock['label'].isin([self.labels_dict['fall'], self.labels_dict['rise']])].to_numpy()
        labels_candidates = labels_indices[labels_indices >= window_size]
        y_s = np.array(sorted(np.random.choice(labels_candidates, size=(self.n_sample,), replace=False)))
        y_ss = y_s-window_size
        support, support_labels = self.generate_data(df_stock, y_start=y_ss, y_end=y_s)
        
        # code for jumpped tags like [1(support), 0, 0, 1(query)]
        # y_q = labels_indices[np.arange(len(labels_indices))[np.isin(labels_indices, y_s)] + self.n_lag]
        y_q = y_s + self.n_lag
        y_qs = y_s - window_size if self.keep_support_history else y_q - window_size
        query, query_labels = self.generate_data(df_stock, y_start=y_qs, y_end=y_q)
        
        return {
            'support': support, 'support_labels': support_labels,
            'query': query, 'query_labels': query_labels
        }

    def generate_data(self, df, y_start, y_end):
        # generate mini task
        inputs = []
        labels = []
        for i, j in zip(y_start, y_end):
            inputs.append(df.loc[i:j-1].to_numpy()[:, 1:-1].astype(np.float64))
            if self.show_y_index:
                labels.append(j)
            else:
                labels.append(df.loc[j].iloc[-1].astype(np.uint8))

        # inputs: (n_sample, y_end-y_start, n_in), labels: (n_sample,)
        return inputs, labels

    def map_to_tensor(self, tasks, device: None | str=None):
        if device is None:
            device = torch.device('cpu')
        else:
            device = torch.device(device)
        tensor_tasks = {}
        for k, v in tasks.items():
            tensor = torch.LongTensor if 'labels' in k else torch.FloatTensor
            tensor_tasks[k] = tensor(np.array(v)).to(device)
        return tensor_tasks

In [4]:
args = ARGProcessor(setting_file='settings.yml')
data_kwargs = args.get_args(cls=MetaStockDataset)
meta_train = MetaStockDataset(meta_type='train', meta_train_stocks=None, **data_kwargs)
tasks = meta_train.generate_tasks()

In [5]:
for window_size, t in tasks.items():
    print(f'Window size = {window_size}')
    for k, v in t.items():
        if 'labels' in k:
            print('  y_index: ', v[:10])
            print('  --------')
        else:
            print(f'  {k}: {np.array(v).shape}')

Window size = 5
  support: (15, 5, 11)
  y_index:  [1, 0, 0, 0, 1, 0, 1, 0, 1, 0]
  --------
  query: (15, 6, 11)
  y_index:  [0, 1, 0, 1, 1, 0, 0, 1, 1, 0]
  --------
Window size = 10
  support: (15, 10, 11)
  y_index:  [0, 0, 1, 1, 1, 0, 1, 0, 1, 0]
  --------
  query: (15, 11, 11)
  y_index:  [0, 1, 0, 1, 0, 0, 1, 1, 1, 0]
  --------
Window size = 15
  support: (15, 15, 11)
  y_index:  [0, 1, 0, 1, 1, 1, 0, 1, 0, 0]
  --------
  query: (15, 16, 11)
  y_index:  [1, 1, 1, 0, 1, 1, 0, 1, 0, 1]
  --------
Window size = 20
  support: (15, 20, 11)
  y_index:  [0, 0, 0, 1, 0, 1, 0, 1, 1, 0]
  --------
  query: (15, 21, 11)
  y_index:  [0, 1, 1, 0, 1, 0, 1, 0, 0, 1]
  --------


In [4]:
# meta_train = MetaStockDataset(meta_type='train', meta_train_stocks=None, **comm_kwargs)
# meta_test1 = MetaStockDataset(meta_type='test1', meta_train_stocks=meta_train.symbols(), **comm_kwargs)
# meta_test2 = MetaStockDataset(meta_type='test2', meta_train_stocks=meta_train.symbols(), **comm_kwargs)
# meta_test3 = MetaStockDataset(meta_type='test3', meta_train_stocks=meta_train.symbols(), **comm_kwargs)

# len(meta_train.symbols()), len(meta_test1.symbols()), len(meta_test2.symbols()), len(meta_test3.symbols())

(40, 40, 10, 10)

---

# Model

In [6]:
import torch
import torch.nn as nn
import pytorch_lightning as pl

class LSTM(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, num_layers: int):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=False)
        self.lnorm = nn.LayerNorm(hidden_size)
    
    def forward(self, x: torch.tensor):
        # x: (B, T, I)
        o, (h, _) = self.lstm(x) # o: (B, T, H) / h: (1, B, H)
        normed_context = self.lnorm(h)
        return normed_context

class LSTMAttention(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, num_layers: int):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=False)
        self.lnorm = nn.LayerNorm(hidden_size)

    def forward(self, x: torch.tensor, rt_attn: bool=False):
        # x: (B, T, I)
        o, (h, _) = self.lstm(x) # o: (B, T, H) / h: (1, B, H)
        score = torch.bmm(o, h.permute(1, 2, 0)) # (B, T, H) x (B, H, 1)
        attn = torch.softmax(score, 1).squeeze(-1)  # (B, T)
        context = torch.bmm(attn.unsqueeze(1), o).squeeze(1)  # (B, 1, T) x (B, T, H)
        normed_context = self.lnorm(context)  # (B, H)
        if rt_attn:
            return normed_context, attn
        else:
            return normed_context, None

class MappingNet(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.rn = nn.Sequential(
            nn.Linear(hidden_size, 2*hidden_size, bias=False),
            nn.ReLU(),
            nn.Linear(2*hidden_size, 2*hidden_size, bias=False),
        )

    def forward(self, x: torch.tensor):
        # x: (B, H)
        outputs = self.rn(x)
        return outputs

class Model(nn.Module):
    def __init__(
            self, 
            feature_size: int, 
            hidden_size: int, 
            output_size: int, 
            num_layers: int, 
            drop_rate: float, 
            n_sample: int,
            inner_lr_init: float,
            finetuning_lr_init: float
        ):
        super().__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.parameter_size = hidden_size*output_size
        self.n_sample = n_sample
        
        self.inner_lr = nn.Parameter(torch.FloatTensor([inner_lr_init]))
        self.finetuning_lr = nn.Parameter(torch.FloatTensor([finetuning_lr_init]))

        self.dropout = nn.Dropout(drop_rate)
        self.feature_transform = nn.Linear(feature_size, hidden_size)
        self.lstm = LSTMAttention(hidden_size, hidden_size, num_layers)  # encode
        self.mapping_net = MappingNet(hidden_size)  # to generate z(latent)
        self.decoder = nn.Linear(hidden_size, 2*self.parameter_size, bias=False)
        self.prob_layer = nn.LogSoftmax(dim=1) if output_size >=2 else nn.LogSigmoid()

    def encode(self, inputs, rt_attn: bool=False):
        # inputs: (B, T, I)
        inputs = self.dropout(inputs)
        inputs = self.feature_transform(inputs)
        encoded, attn = self.lstm(inputs, rt_attn)  # B, H
        return encoded, attn

    def get_z(self, inputs, rt_attn: bool=False):
        # inputs: (B, T, I)
        encoded, attn = self.encode(inputs, rt_attn=rt_attn)
        hs = self.mapping_net(encoded)

        z, dist = self.sample(hs, size=self.hidden_size)
        return encoded, z, dist, attn

    def sample(self, distribution_params, size):
        mean, log_std = distribution_params[:, :size], distribution_params[:, size:]
        std = torch.exp(log_std)
        dist = torch.distributions.Normal(torch.zeros_like(mean), torch.ones_like(std))
        z = dist.rsample()
        return mean + std*z, dist

    def decode(self, z):
        param_hs = self.decoder(z)
        parameters, _ = self.sample(param_hs, size=self.parameter_size)
        return parameters

    def predict(self, encoded, parameters):
        theta = parameters.view(-1, self.hidden_size, self.output_size)
        scores = encoded.unsqueeze(1).bmm(theta).squeeze()
        probs = self.prob_layer(scores)
        return probs
        
    def cal_accuracy(self, log_probs, target):
        if self.output_size >= 2:
            pred = log_probs.argmax(1)
        else:
            pred = (torch.exp(log_probs) >= 0.5).long()
        correct = pred.eq(target).sum()
        acc = correct / len(target)
        return acc
        

    def inner_loop(self, data, n_inner_step: int=5):
        support_X, support_y = data['support'], data['support_labels']
        support_encoded, support_z, support_dist, support_attn = self.get_z(support_X, rt_attn=False)
        kld_loss = self.cal_kl_div(support_dist, support_z)
        
        z_init = z.clone().detach()

        # inner adaptation to z
        for i in range(n_inner_step):
            z.retain_grad()
            parameters = self.decode(z)
            log_probs = self.predict(support_encoded, parameters)
            train_loss = self.loss_fn(log_probs, support_y)
            train_loss.backward(retain_graph=True)

            z = z - self.inner_lr * z.grad.data
            
        z_penalty = torch.mean((z_init - z)**2)
        return support_encoded, support_z, kld_loss, z_penalty, support_attn

    def outer_loop(self, data, support_z, support_encoded, n_finetuning_step: int=0, rt_attn:bool =False):
        # finetuning inner + validation
        records = {'Training Loss': 0.0, 'Training Accuracy': 0.0, 'Inner LR': 0.0, 'Finetuning LR': 0.0}
        # inner loop prediction
        support_y, query_X, query_y = data['support_labels'], data['query'], data['query_labels']
        parameters = self.decode(support_z)
        parameters.retain_grad()
        support_log_probs = self.predict(support_encoded, parameters)
        train_loss = self.loss_fn(support_log_probs, support_y)
        train_acc = self.cal_accuracy(support_log_probs, support_y)

        # logging
        records['Training Loss'] = train_loss.item()
        records['Training Accuracy'] = train_acc.item()
        records['Inner LR'] = float(self.inner_l_rate)
        records['Finetuning LR'] = float(self.inner_l_rate)

        # finetuning adaptation to parameters
        for i in range(n_finetuning_step):
            train_loss.backward(retain_graph=True)
            parameters = parameters - self.finetuning_lr * parameters.grad
            parameters.retain_grad()
            support_log_probs = self.predict(support_encoded, parameters)
            train_loss = self.loss_fn(support_log_probs, support_y)
            
        # meta validation     
        query_encoded, *_, query_attn = self.encode(query_X, rt_attn=rt_attn)
        query_log_probs = self.predict(query_encoded, parameters)
        query_loss = self.loss_fn(query_log_probs, query_y)
        query_acc = self.cal_accuracy(query_log_probs, query_y)

        return query_loss, query_acc, query_attn, records

    def cal_kl_div(self, dist, z):
        normal = torch.distributions.Normal(torch.zeros_like(z), torch.ones_like(z))
        return torch.mean(dist.log_prob(z) - normal.log_prob(z))

    def cal_total_loss(self, query_loss, kld_loss, z_penalty, beta, gamma, lambda2):
        orthogonality_penalty = self.orthgonality_constraint(list(self.decoder.parameters())[0])
        total_loss = query_loss + beta*kld_loss + gamma*z_penalty + lambda2*orthogonality_penalty
        return total_loss

    def orthgonality_constraint(self, params):
        # purpose: encourages the dimensions of the latend code as well as the decoder network to be maximally expressive
        # number of class x hidden_size x 2(mean, std)
        p_dot = params.mm(params.transpose(0, 1))
        p_norm = torch.norm(params, dim=1, keepdim=True) + 1e-15
        corr = p_dot / p_norm.mm(p_norm.transpose(0, 1))
        corr.masked_fill_(corr>1.0, 1.0)
        corr.masked_fill_(corr<-1.0, -1.0)
        I = torch.eye(corr.size(0)).to(corr.device)
        orthogonality_penalty = torch.mean((corr - I)**2)
        return orthogonality_penalty

    def forward(
            self, data, 
            n_inner_step: int=5, 
            n_finetuning_step:int =5, 
            rt_attn: bool=False
        ):
        support_encoded, support_z, kld_loss, z_penalty, support_attn = self.inner_loop(data, n_inner_step, rt_attn)
        query_loss, query_acc, query_attn, records = self.outer_loop(data, support_z, support_encoded, n_finetuning_step, rt_attn)
        
        return query_loss, query_acc, kld_loss, z_penalty, support_attn, query_attn, records

    def meta_run(self, data, 
            beta: float=0.001, 
            gamma: float=1e-9, 
            lambda2: float=0.1,
            n_inner_step: int=5, 
            n_finetuning_step:int =5,
            rt_attn: bool=False
        ):
        query_loss, query_acc, kld_loss, z_penalty, support_attn, query_attn, records = self(data, n_inner_step, n_finetuning_step, rt_attn)
        total_loss = self.cal_total_loss(query_loss, kld_loss, z_penalty, beta, gamma, lambda2)
        # logging
        records['Valid Loss'] = total_loss.item()
        records['Valid Accuracy'] = query_acc.item()
        records['Query Attn'] = None if query_attn is None else query_attn.detach().numpy()
        records['Support Attn'] = None if support_attn is None else support_attn.detach().numpy()

        return total_loss, records


* LEO-Deepmind: https://github.com/deepmind/leo/blob/de9a0c2a77dd7a42c1986b1eef18d184a86e294a/model.py#L256
* LEO-pytorch: https://github.com/timchen0618/pytorch-leo/blob/master/model.py

In [250]:
n_sample = 5
n_stock = 3
window_size = 5
feature_size = 11
output_size = 1
hidden_size, num_layers, drop_rate = 20, 1, 0.2
parameter_size = output_size*hidden_size

lstm = LSTMAttention(feature_size, hidden_size, num_layers)  # encode
mapping_net = MappingNet(hidden_size)  # generate z
decoder = nn.Linear(hidden_size, 2*parameter_size)

sample data

In [333]:
x = torch.randn(n_stock*n_sample, window_size, feature_size)
y = torch.ones(n_stock*n_sample)

print('--- Encoder ---')
# lstm
encoded, attn = lstm(x, rt_attn=False)
# mapping
hs = mapping_net(encoded)
# encoded, hs, attn = encoder(x, rt_attn=False)
print(f'encoder: encoded - {encoded.size()} | mapping_net - {hs.size()}')

# sample
e_mean, e_std = hs[:, :hidden_size], hs[:, hidden_size:]
e_std = torch.exp(e_std)
print(f'mean: {e_mean.size()} | std: {e_std.size()}')
e_dist = torch.distributions.Normal(torch.zeros_like(e_mean), torch.ones_like(e_std))
z = e_dist.rsample()
print(f'z: {z.size()}')
latents = e_mean + e_std*z
print(f'latent: {latents.size()}')

# decoder
print('--- Decoder ---')
param_hs = decoder(latents)
# sample
print(f'param_hs: {param_hs.size()}')
d_mean, d_std = param_hs[:, :parameter_size], param_hs[:, parameter_size:]
d_std = torch.exp(d_std)
print(f'mean: {d_mean.size()} | std: {d_std.size()}')
d_dist = torch.distributions.Normal(torch.zeros_like(d_mean), torch.ones_like(d_std))
z = d_dist.rsample()
print(f'z: {z.size()}')
parameters = d_mean + d_std*z
print(f'parameters: {parameters.size()}')

# score
theta = parameters.view(-1, hidden_size, output_size)
scores = encoded.unsqueeze(1).bmm(theta).squeeze(1)
print(f'scores: {scores.size()}')

--- Encoder ---
encoder: encoded - torch.Size([15, 20]) | mapping_net - torch.Size([15, 40])
mean: torch.Size([15, 20]) | std: torch.Size([15, 20])
z: torch.Size([15, 20])
latent: torch.Size([15, 20])
--- Decoder ---
param_hs: torch.Size([15, 40])
mean: torch.Size([15, 20]) | std: torch.Size([15, 20])
z: torch.Size([15, 20])
parameters: torch.Size([15, 20])
scores: torch.Size([15, 1])


In [345]:
feature_size = 11
hidden_size = 20 
output_size = 1  # 1 == 2 setting
num_layers = 1
drop_rate = 0.2
n_sample = 5

x = torch.randn(n_stock*n_sample, window_size, feature_size)
y = torch.ones(n_stock*n_sample, dtype=torch.long)

loss_fn = nn.NLLLoss()
model = Model(feature_size, hidden_size, output_size, num_layers, drop_rate, n_sample)

# forward?
encoded, z, dist, attn = model.get_z(x, rt_attn=False)
parameters = model.decode(z)
probs = model.predict(encoded, parameters)
print(scores.shape)
kl_loss = model.cal_kl_div(dist, z)
loss = loss_fn(probs, y)
print(f'Loss: {loss:.4f}, KLD: {kl_loss:.4f}')

---

# Trainer

In [14]:
args = ARGProcessor(setting_file='settings.yml')
data_kwargs = args.get_args(cls=MetaStockDataset)
meta_train = MetaStockDataset(meta_type='train', meta_train_stocks=None, **data_kwargs)
meta_test1 = MetaStockDataset(meta_type='test1', meta_train_stocks=meta_train.symbols, **data_kwargs)
meta_test2 = MetaStockDataset(meta_type='test2', meta_train_stocks=meta_train.symbols, **data_kwargs)
meta_test3 = MetaStockDataset(meta_type='test3', meta_train_stocks=meta_train.symbols, **data_kwargs)

print(f'Meta Train: {len(meta_train.symbols)}, Meta Test1: {len(meta_test1.symbols)}, Meta Test2: {len(meta_test2.symbols)}, Meta Test3: {len(meta_test3.symbols)}')
tasks = meta_train.generate_tasks()

Meta Train: 40, Meta Test1: 40, Meta Test2: 10, Meta Test3: 10


In [11]:
tasks = meta_train.generate_tasks()
for window_size, t in tasks.items():
    t_tensor = meta_train.map_to_tensor(t, device='cpu')
    break

In [12]:
model_kwargs = args.get_args(cls=Model)
model = Model(**model_kwargs)

In [13]:
train_X, train_y = t_tensor['support'], t_tensor['support_labels']
valid_X, valid_y = t_tensor['query'], t_tensor['query_labels']

In [143]:
from torch.utils.tensorboard import SummaryWriter

class Trainer():
    def __init__(
            self, 
            log_dir, 
            total_steps,
            n_inner_step, 
            n_finetuning_step, 
            beta,
            gamma,
            lambda1,
            lambda2,
            outer_lr,
            clip_value,
            device: str='cpu',
            print_step: int=5
        ):
        self.device = device
        self.print_step = print_step
        self.total_steps = total_steps
        self.n_inner_step = n_inner_step
        self.n_finetuning_step = n_finetuning_step
        
        self.beta = beta
        self.gamma = gamma
        self.lambda1 = lambda1  # penalty on model(encoder, mapping_net, decoder) parameters
        self.lambda2 = lambda2  # penalty on decoder
        self.outer_lr = outer_lr
        self.clip_value = clip_value
        
        self.loss_fn = nn.NLLLoss()
        self.writer = SummaryWriter(log_dir)

    def map_to_tensor(self, tasks, device: None | str=None):
        if device is None:
            device = torch.device('cpu')
        else:
            device = torch.device(device)
        tensor_tasks = {}
        for k, v in tasks.items():
            tensor = torch.LongTensor if 'labels' in k else torch.FloatTensor
            tensor_tasks[k] = tensor(np.array(v)).to(device)
        return tensor_tasks

    def inner_loop(self, model, data):
        train_X, train_y = data['support'], data['support_labels']
        encoded, z, dist, _ = model.get_z(train_X, rt_attn=False)
        kld_loss = model.cal_kl_div(dist, z)
        
        z_init = z.clone().detach()

        # inner adaptation to z
        for i in range(self.n_inner_step):
            z.retain_grad()
            parameters = model.decode(z)
            log_probs = model.predict(encoded, parameters)
            train_loss = self.loss_fn(log_probs, train_y)
            train_loss.backward(retain_graph=True)

            z = z - model.inner_lr * z.grad.data
            
        z_penalty = torch.mean((z_init - z)**2)
        return encoded, z, kld_loss, z_penalty
    
    def outer_loop(self, model, z, encoded, data, step, window_size):
        # finetuning inner + validation

        # inner loop prediction
        train_y, valid_X, valid_y = data['support_labels'], data['query'], data['query_labels']
        parameters = model.decode(z)
        parameters.retain_grad()
        log_probs = model.predict(encoded, parameters)
        train_loss = self.loss_fn(log_probs, train_y)
        train_acc = model.cal_accuracy(log_probs, train_y)

        # logging
        self.writer.add_scalar(f'{window_size}-Training Loss', train_loss.item(), step=step)
        self.writer.add_scalar(f'{window_size}-Training Accuracy', train_acc.item(), step=step)
        self.writer.add_scalar(f'{window_size}-Inner LR', float(model.inner_l_rate), step=step)
        self.writer.add_scalar(f'{window_size}-Finetuning LR', float(model.inner_l_rate), step=step)

        if (step % self.print_step == 0) or (step == self.total_steps-1):
            print(f'[Meta Train] ({step}/{self.total_steps})')
            print(f'  WinSize={window_size} Loss={train_loss.item():.4f} Accuracy={train_acc.item():.4f} Inner LR={float(model.inner_l_rate):.4f} Finetuning LR={float(model.inner_l_rate):.4f}')
        
        # finetuning adaptation to parameters
        for i in range(self.n_finetuning_step):
            train_loss.backward(retain_graph=True)
            parameters = parameters - model.finetuning_lr * parameters.grad
            parameters.retain_grad()
            log_probs = model.predict(encoded, parameters)
            train_loss = self.loss_fn(log_probs, train_y)
            
        # meta validation     
        valid_encoded, *_ = model.encode(valid_X, rt_attn=False)
        valid_log_probs = model.predict(valid_encoded, parameters)
        valid_loss = self.loss_fn(valid_log_probs, valid_y)
        valid_acc = model.cal_accuracy(valid_log_probs, valid_y)

        return valid_loss, valid_acc

    def orthgonality_constraint(self, params):
        # purpose: encourages the dimensions of the latend code as well as the decoder network to be maximally expressive
        # number of class x hidden_size x 2(mean, std)
        p_dot = params.mm(params.transpose(0, 1))
        p_norm = torch.norm(params, dim=1, keepdim=True) + 1e-15
        corr = p_dot / p_norm.mm(p_norm.transpose(0, 1))
        corr.masked_fill_(corr>1.0, 1.0)
        corr.masked_fill_(corr<-1.0, -1.0)
        I = torch.eye(corr.size(0)).to(corr.device)
        orthogonality_penalty = torch.mean((corr - I)**2)
        return orthogonality_penalty

    def step_batch(self, model, batch_data, step, window_size):
        encoded, z, kld_loss, z_penalty = self.inner_loop(model, batch_data)
        valid_loss, valid_acc = self.outer_loop(model, z, encoded, batch_data, step, window_size)
        orthogonality_penalty = self.orthgonality_constraint(list(model.decoder.parameters())[0])

        total_loss = valid_loss + self.beta * kld_loss + self.gamma * z_penalty + self.lambda2 * orthogonality_penalty
        return {
            'Valid Loss': total_loss, 
            'Valid Accuracy': valid_acc, 
            'KLD Loss': kld_loss, 
            'Z Penalty': z_penalty, 
            'Orthogonality Penalty': orthogonality_penalty
        }

    def main(self, model, meta_train, meta_test1, meta_test2, meta_test3):
        lr_list = ['inner_lr', 'finetuning_lr']
        params = [x[1] for x in list(filter(lambda k: k[0] not in lr_list, model.named_parameters()))]
        lr_params = [x[1] for x in list(filter(lambda k: k[0] in lr_list, model.named_parameters()))]
        optim = torch.optim.Adam(params, lr=self.outer_lr, weight_decay=self.lambda1)
        optim_lr = torch.optim.Adam(lr_params, lr=self.outer_lr)

        for step in range(self.total_steps):
            # Meta Train
            optim.zero_grad()
            optim_lr.zero_grad()
            all_tasks = meta_train.generate_tasks()
            records = {'Valid Loss': [], 'Valid Accuracy': [], 'KLD Loss': [], 'Z Penalty': [], 'Orthogonality Penalty': []}
            for window_size, tasks in all_tasks.items():
                batch_data = self.map_to_tensor(tasks, device=self.device)
                output_dict = self.step_batch(model, batch_data, step, window_size)
                for k, v in output_dict.items():
                    if 'Valid' in k:
                        records[k].append(v.item())
                    else:
                        records[k].append(v)

                output_dict['Valid Loss'].backward()

                nn.utils.clip_grad_value_(model.parameters(), self.clip_value)
                nn.utils.clip_grad_norm_(model.parameters(), self.clip_value)
                optim.step()
                optim_lr.step()
            
            # record summary
            for k, v in records.items():
                # logging
                self.writer.add_scalar(k, np.mean(v), step=step)
                
            if (step % self.print_step == 0) or (step == self.total_steps-1):
                print(f'[Meta Valid]({step}/{self.total_steps})')
                print(f'  Loss={records["Valid Loss"]:.4f} Accuracy={records["Valid Accuracy"]:.4f}')
                print(f'  KLD Loss={records["KLD Loss"]:.4f} Z Penalty={records["Z Penalty"]:.4f} Orthogonality Penalty={records["Orthogonality Penalty"]:.4f}')

            # Meta Test
            test1_tasks = meta_test1.generate_tasks()

                
        return 


In [None]:
print('(Meta-Valid) [Step: %d/%d] Total Loss: %4.4f Valid Accuracy: %4.4f'%(step, self.config['total_steps'], val_loss.item(), val_acc.item()))
print('(Meta-Valid) [Step: %d/%d] KL: %4.4f Encoder Penalty: %4.4f Orthogonality Penalty: %4.4f'%(step, self.config['total_steps'], kl_div, encoder_penalty, orthogonality_penalty))


In [None]:

n_inner_step = 5
n_finetuning_step = 5

In [131]:
loss_fn = nn.NLLLoss()

# inner loop
encoded, z, dist, _ = model.get_z(train_X, rt_attn=False)
z_init = z.clone().detach()
kld_loss = model.cal_kl_div(dist, z)

print('inner adaptation to z')
for i in range(n_inner_step):
    z.retain_grad()
    parameters = model.decode(z)
    log_probs = model.predict(encoded, parameters)
    train_loss = loss_fn(log_probs, train_y)
    train_loss.backward(retain_graph=True)
    
    z = z - model.inner_lr * z.grad.data
    
z_penalty = torch.mean((z_init - z)**2)
print(f'z\': {z_init.sum().item():.4f}, z: {z.sum().item():.4f}')
print(f'penalty: {z_penalty.item():.4f}')

inner adaptation to z
z': 33.8720, z: 33.9034
penalty: 0.0001


In [146]:
# outer loop
# inputs: encoded, z

parameters = model.decode(z)
parameters.retain_grad()
log_probs = model.predict(encoded, parameters)
train_loss = loss_fn(log_probs, train_y)

print('fintuning adaptation to parameters')
for i in range(n_finetuning_step):
    train_loss.backward(retain_graph=True)
    parameters = parameters - model.finetuning_lr * parameters.grad
    parameters.retain_grad()
    log_probs = model.predict(encoded, parameters)
    acc = model.cal_accuracy(log_probs, train_y)
    train_loss = loss_fn(log_probs, train_y)
    print(f'loss: {train_loss.item():.4f}, params sum: {parameters.sum():.4f}, acc: {acc:.4f}')
    
valid_encoded, _ = model.encode(valid_X, rt_attn=False)
valid_log_probs = model.predict(valid_encoded, parameters)
valid_loss = loss_fn(valid_log_probs, valid_y)
valid_acc = model.cal_accuracy(log_probs, valid_y)
print(f'Validation Loss: {valid_loss:.4f}, Validation Acc: {valid_acc:.4f}')

fintuning adaptation to parameters
loss: 0.0005, params sum: -28.0169, acc: 0.4667
loss: 0.0005, params sum: -28.0169, acc: 0.4667
loss: 0.0005, params sum: -28.0169, acc: 0.4667
loss: 0.0005, params sum: -28.0169, acc: 0.4667
loss: 0.0005, params sum: -28.0169, acc: 0.4667
Validation Loss: 0.0076, Validation Acc: 0.3333


$$cos\theta = \dfrac{\mathbf{x_1} \mathbf{x_2}}{\Vert\mathbf{x_1}\Vert \cdot \Vert\mathbf{x_1}\Vert}$$


In [173]:
params = list(model.decoder.parameters())[0]  # number of class x hidden_size x 2(mean, std)
p_dot = params.mm(params.transpose(0, 1))
p_norm = torch.norm(params, dim=1, keepdim=True) + 1e-15
corr = p_dot / p_norm.mm(p_norm.transpose(0, 1))
corr.masked_fill_(corr>1.0, 1.0)
corr.masked_fill_(corr<-1.0, -1.0)
print(f'{corr.min()}, {corr.max()}')

-0.6364535689353943, 1.0


In [170]:
corr.masked_fill_(corr>1.0, 1.0)

tensor([[ 1.0000,  0.6774,  0.3638,  ..., -0.1992, -0.0358, -0.2015],
        [ 0.6774,  1.0000,  0.0734,  ..., -0.2679, -0.2289, -0.1591],
        [ 0.3638,  0.0734,  1.0000,  ..., -0.1065, -0.1207, -0.2172],
        ...,
        [-0.1992, -0.2679, -0.1065,  ...,  1.0000, -0.1412, -0.0251],
        [-0.0358, -0.2289, -0.1207,  ..., -0.1412,  1.0000, -0.1652],
        [-0.2015, -0.1591, -0.2172,  ..., -0.0251, -0.1652,  1.0000]],
       grad_fn=<MaskedFillBackward0>)

In [172]:
corr[corr > 1.0]

tensor([], grad_fn=<IndexBackward0>)

In [None]:
torch._masked_fill()