In [1]:
import torch
import numpy as np
import pandas as pd
import pytorch_lightning as pl
from pathlib import Path
from typing import Dict, List, Any
from collections import defaultdict

src_path = Path('.').absolute().parent
data_path = src_path / 'data'

print(torch.__version__)
print(torch.cuda.is_available())

1.11.0+cu113
True


# Task

There are many way to construct the tasks.

Here we define a task is sample from a distribution $\tau_t \sim P(\tau \vert s_i, w_j)$ where $s_i$ is a fixed single stock and $w_j$ is a fixed window size. So, task $\tau_{t} = \{ (\mathbf{X}_{(t-w):t}, y_{t}) \}$, where $t > w$.

Our meta-dataset is $\mathcal{D}^{tr}_{\tau_t} = \{ X_{(t-w):t}, y_{t} \}$,  $\mathcal{D}^{val}_{\tau_t} = \{ X_{(t+k-w):t+k}, y_{t+k} \}$ where $t+k$ is the next rise/fall timestep after $t$.

Example inputs $X$ for $\mathcal{D}^{tr}_{\tau_t}$ with 5 samples, 4 fixed window sizes and 3 stocks.

| stock \ window size | $w_1=5$ | $w_2=10$ | $w_3=15$ | $w_4=20$ |
|---|---|---|---|---|
| $s_1$ | $X_{(t_1-w_1):t_1}^{s_1}$ |  $X_{(t_1-w_2):t_1}^{s_1}$ |  $X_{(t_1-w_3):t_1}^{s_1}$ |  $X_{(t_1-w_4):t_1}^{s_1}$ |
|       | $X_{(t_2-w_1):t_2}^{s_1}$ |  $X_{(t_2-w_1):t_2}^{s_1}$ |  $X_{(t_2-w_1):t_2}^{s_1}$ |  $X_{(t_2-w_1):t_2}^{s_1}$ |
|       | $X_{(t_3-w_1):t_3}^{s_1}$ |  $X_{(t_3-w_1):t_3}^{s_1}$ |  $X_{(t_3-w_1):t_3}^{s_1}$ |  $X_{(t_3-w_1):t_3}^{s_1}$ |
|       | $X_{(t_4-w_1):t_4}^{s_1}$ |  $X_{(t_4-w_1):t_4}^{s_1}$ |  $X_{(t_4-w_1):t_4}^{s_1}$ |  $X_{(t_4-w_1):t_4}^{s_1}$ |
|       | $X_{(t_5-w_1):t_5}^{s_1}$ |  $X_{(t_5-w_1):t_5}^{s_1}$ |  $X_{(t_5-w_1):t_5}^{s_1}$ |  $X_{(t_5-w_1):t_5}^{s_1}$ |
| $s_2$ | $X_{(t_1-w_1):t_1}^{s_2}$ |  $X_{(t_1-w_1):t_1}^{s_2}$ |  $X_{(t_1-w_1):t_1}^{s_2}$ |  $X_{(t_1-w_1):t_1}^{s_2}$ |
|       | $X_{(t_2-w_1):t_2}^{s_2}$ |  $X_{(t_2-w_1):t_2}^{s_2}$ |  $X_{(t_2-w_1):t_2}^{s_2}$ |  $X_{(t_2-w_1):t_2}^{s_2}$ |
|       | $X_{(t_3-w_1):t_3}^{s_2}$ |  $X_{(t_3-w_1):t_3}^{s_2}$ |  $X_{(t_3-w_1):t_3}^{s_2}$ |  $X_{(t_3-w_1):t_3}^{s_2}$ |
|       | $X_{(t_4-w_1):t_4}^{s_2}$ |  $X_{(t_4-w_1):t_4}^{s_2}$ |  $X_{(t_4-w_1):t_4}^{s_2}$ |  $X_{(t_4-w_1):t_4}^{s_2}$ |
|       | $X_{(t_5-w_1):t_5}^{s_2}$ |  $X_{(t_5-w_1):t_5}^{s_2}$ |  $X_{(t_5-w_1):t_5}^{s_2}$ |  $X_{(t_5-w_1):t_5}^{s_2}$ |
| $s_3$ | $X_{(t_1-w_1):t_1}^{s_3}$ |  $X_{(t_1-w_1):t_1}^{s_3}$ |  $X_{(t_1-w_1):t_1}^{s_3}$ |  $X_{(t_1-w_1):t_1}^{s_3}$ |
|       | $X_{(t_2-w_1):t_2}^{s_3}$ |  $X_{(t_2-w_1):t_2}^{s_3}$ |  $X_{(t_2-w_1):t_2}^{s_3}$ |  $X_{(t_2-w_1):t_2}^{s_3}$ |
|       | $X_{(t_3-w_1):t_3}^{s_3}$ |  $X_{(t_3-w_1):t_3}^{s_3}$ |  $X_{(t_3-w_1):t_3}^{s_3}$ |  $X_{(t_3-w_1):t_3}^{s_3}$ |
|       | $X_{(t_4-w_1):t_4}^{s_3}$ |  $X_{(t_4-w_1):t_4}^{s_3}$ |  $X_{(t_4-w_1):t_4}^{s_3}$ |  $X_{(t_4-w_1):t_4}^{s_3}$ |
|       | $X_{(t_5-w_1):t_5}^{s_3}$ |  $X_{(t_5-w_1):t_5}^{s_3}$ |  $X_{(t_5-w_1):t_5}^{s_3}$ |  $X_{(t_5-w_1):t_5}^{s_3}$ |
 
However there is another way to think about tha task. Since in some stock market theories, each time stamp for a stock has a different ditribution. So based on former definition, we also treat each timestamp to a task $\tau \sim P(\tau \vert s_i, w_j, t_k)$ 

this will make a little difference when we calculate the $\mathbf{z}$
1. each stock has different distribution, but each timestamp has same distribution for the same stock. $\mathbf{z}_s \sim \mathcal{N}(\mu_s, \sigma_s)$
2. each timestamp of stock has different distribution: $\mathbf{z}_{(s, t)} \sim \mathcal{N}(\mu_{(s, t)}, \sigma_{(s, t)})$


# Data Loader

Need Python 3.10

In [422]:
class MetaStockDataset(torch.utils.data.Dataset):
    def __init__(
            self, 
            meta_type: str ='train', 
            meta_train_stocks: List[str] | None =None,
            data_dir: Path | str ='', 
            dtype: str ='kdd17', 
            n_train_stock: int =40, 
            n_sample: int =5,
            n_lag: int =1, 
            n_stock: int =5,
            show_y_index: bool=False
        ):
        """
        dataset ref: https://arxiv.org/abs/1810.09936
        In this meta learning setting, we have 3 meta-test and 1 meta-train
        vertical = stocks, horizontal = time
                train      |    test
           A               |
           B   meta-train  |   meta-test
           C               |      (1)
           ----------------|-------------
           D   meta-test   |   meta-test
           E     (2)       |      (3)

        meta-test (1) same stock, different time
        meta-test (2) different stock, same time
        meta-test (3) different stock, different time
        use `valid_date` to split the train / test set
        """
        super().__init__()
        # for debugging purpose
        self.labels_dict = {
            'fall': 0, 'rise': 1, 'nothing': 2 
        }
        self.show_y_index = show_y_index
        # data config
        self.data_dir = Path(data_dir)
        ds_info = {
            # train: (Jan-01-2007 to Jan-01-2015)
            # val: (Jan-01-2015 to Jan-01-2016)
            # test: (Jan-01-2016 to Jan-01-2017)
            'kdd17': {
                'path': self.data_dir / 'kdd17/price_long_50',
                'date': self.data_dir / 'kdd17/trading_dates.csv',
                'train_date': '2015-01-01', 
                'val_date': '2016-01-01', 
                'test_date': '2017-01-01',
            },
            # train: (Jan-01-2014 to Aug-01-2015)
            # vali: (Aug-01-2015 to Oct-01-2015)
            # test: (Oct-01-2015 to Jan-01-2016)
            'acl18': {
                'path': self.data_dir / 'stocknet-dataset/price/raw',
                'date': self.data_dir / 'stocknet-dataset/price/trading_dates.csv',
                'train_date': '2015-08-01', 
                'val_date': '2015-10-01', 
                'test_date': '2016-01-01',
            }
        }
        
        ds_config = ds_info[dtype]

        self.window_sizes = [5, 10, 15, 20]
        self.n_sample = n_sample
        self.n_lag = n_lag
        self.n_stock = n_stock

        # get data
        self.data = {}
        for i, p in enumerate((data_path / ds_config['path']).glob('*')):
            if meta_type == 'train' and (i == n_train_stock):
                # stop when it reach `n_train_stock`
                break
            
            stock_symbol = p.name.rstrip('.csv')
            df_single = self.load_single_stock(p)
            if meta_type == 'train':
                df_single = df_single.loc[df_single['date'] <= ds_config['val_date']]
            else:
                if meta_type == 'test1':
                    if stock_symbol in meta_train_stocks:
                        df_single = df_single.loc[df_single['date'] > ds_config['val_date']]
                    else:
                        continue
                elif meta_type == 'test2':
                    if stock_symbol not in meta_train_stocks:
                        df_single = df_single.loc[df_single['date'] <= ds_config['val_date']]
                    else:
                        continue
                elif meta_type == 'test3':
                    if stock_symbol not in meta_train_stocks:
                        df_single = df_single.loc[df_single['date'] > ds_config['val_date']]
                    else:
                        continue
                else:
                    raise KeyError('Error argument `meta_type`, should be in (train, test1, test2, test3)')

            self.data[stock_symbol] = df_single.reset_index(drop=True)

    def load_single_stock(self, p: Path | str):
        def longterm_trend(x: pd.Series, k:int):
            return (x.rolling(k).sum().div(k*x) - 1) * 100

        df = pd.read_csv(p)
        df['Date'] = pd.to_datetime(df['Date'])
        df = df.sort_values('Date').reset_index(drop=True)
        if 'Unnamed' in df.columns:
            df.drop(columns=df.columns[7], inplace=True)
        if 'Original_Open' in df.columns:
            df.rename(columns={'Original_Open': 'Open', 'Open': 'Adj Open'}, inplace=True)

        # Open, High, Low
        z1 = (df.loc[:, ['Open', 'High', 'Low']].div(df['Close'], axis=0) - 1).rename(
            columns={'Open': 'open', 'High': 'high', 'Low': 'low'}) * 100
        # Close
        z2 = df[['Close']].pct_change().rename(columns={'Close': 'close'}) * 100
        # Adj Close
        z3 = df[['Adj Close']].pct_change().rename(columns={'Adj Close': 'adj_close'}) * 100

        z4 = []
        for k in [5, 10, 15, 20, 25, 30]:
            z4.append(df[['Adj Close']].apply(longterm_trend, k=k).rename(columns={'Adj Close': f'zd{k}'}))

        df_pct = pd.concat([df['Date'], z1, z2, z3] + z4, axis=1).rename(columns={'Date': 'date'})
        cols_max = df_pct.columns[df_pct.isnull().sum() == df_pct.isnull().sum().max()]
        df_pct = df_pct.loc[~df_pct[cols_max].isnull().values, :]

        # from https://arxiv.org/abs/1810.09936
        # Examples with movement percent ≥ 0.55% and ≤ −0.5% are 
        # identified as positive and negative examples, respectively
        df_pct['label'] = self.labels_dict['nothing']
        df_pct.loc[(df_pct['close'] >= 0.55), 'label'] = self.labels_dict['rise']
        df_pct.loc[(df_pct['close'] <= -0.5), 'label'] = self.labels_dict['fall']
        return df_pct

    def symbols(self):
        return list(self.data.keys())

    def generate_tasks(self):
        all_tasks = defaultdict()
        for window_size in self.window_sizes:
            tasks = self.generate_tasks_per_window_size(window_size)
            all_tasks[window_size] = tasks
        return all_tasks

    def generate_tasks_per_window_size(self, window_size):
        # tasks: {X: (n_stock, n_sample, window_size, n_in), y: (n_stock, n_sample)}
        tasks = defaultdict(list)
        for i in range(self.n_stock):
            symbol = np.random.choice(self.symbols())
            # data: {X: (n_sample, n_in), y: (n_sample,)}
            data = self.generate_task_per_window_size_and_single_stock(symbol, window_size)
            for k, v in data.items():
                tasks[k].append(v)

        for k, v in tasks.items():
            tasks[k] = np.concatenate(v, axis=0)
        return tasks

    def generate_task_per_window_size_and_single_stock(self, symnbol, window_size):
        df_stock = self.data[symnbol]
        labels_indices = df_stock.index[df_stock['label'].isin(
            [self.labels_dict['fall'], self.labels_dict['rise']])].to_numpy()
        labels_candidates = labels_indices[window_size:-self.n_lag]
        y_s = np.array(sorted(np.random.choice(labels_candidates, size=(self.n_sample,), replace=False)))
        support, support_labels = self.generate_data(df_stock, y_s, window_size)

        y_q = labels_indices[np.arange(len(labels_indices))[np.isin(labels_indices, y_s)] + self.n_lag]
        query, query_labels = self.generate_data(df_stock, y_q, window_size)
        
        return {
            'support': support, 'support_labels': support_labels,
            'query': query, 'query_labels': query_labels
        }

    def generate_data(self, df, y_index, window_size):
        # generate mini task
        inputs = []
        labels = []
        for i, j in zip(y_index-window_size-1, y_index):
            inputs.append(df.iloc[i:j-1, 1:-1].to_numpy())
            if self.show_y_index:
                labels.append(j)
            else:
                labels.append(df.iloc[j, -1])

        # inputs: (n_sample, window_size, n_in), labels: (n_sample,)
        return np.stack(inputs), np.array(labels)

    def map_to_tensor(self, tasks, device: None | str=None):
        if device is None:
            device = torch.device('cpu')
        else:
            device = torch.device(device)
        tensor_tasks = {}
        for k, v in tasks.items():
            if 'labels' in k:
                tensor_tasks[k] = torch.LongTensor(v).to(device)
            else:
                tensor_tasks[k] = torch.FloatTensor(v).to(device)
        return tensor_tasks

In [394]:
comm_kwargs = {
    'data_dir': data_path,
    'dtype': 'kdd17',
    'n_train_stock': 40,  # number of training stocks to construct the universe
    'n_sample': 5,  # number of samples per window size and per single stock
    'n_lag': 1,
    'n_stock': 3,  # number of iteration(stock) to sample per window size, total will be `n_stocks` * `n_sample`
    'show_y_index': True  # for debug purpose
}

meta_train = MetaStockDataset(meta_type='train', meta_train_stocks=None, **comm_kwargs)
meta_test1 = MetaStockDataset(meta_type='test1', meta_train_stocks=meta_train.symbols(), **comm_kwargs)
meta_test2 = MetaStockDataset(meta_type='test2', meta_train_stocks=meta_train.symbols(), **comm_kwargs)
meta_test3 = MetaStockDataset(meta_type='test3', meta_train_stocks=meta_train.symbols(), **comm_kwargs)

len(meta_train.symbols()), len(meta_test1.symbols()), len(meta_test2.symbols()), len(meta_test3.symbols())

(40, 40, 10, 10)

In [395]:
tasks = meta_train.generate_tasks()

In [396]:
for window_size, t in tasks.items():
    print(f'Window size = {window_size}')
    for k, v in t.items():
        print(f'  {k}: {v.shape}')
        if 'labels' in k:
            print('  y_index: ', v[:10])
            print('  --------')

Window size = 5
  support: (15, 5, 11)
  support_labels: (15,)
  y_index:  [  83  451  610  634  970  442  460  987 1494 1843]
  --------
  query: (15, 5, 11)
  query_labels: (15,)
  y_index:  [  84  454  611  638  971  444  464  988 1499 1848]
  --------
Window size = 10
  support: (15, 10, 11)
  support_labels: (15,)
  y_index:  [ 292  358  556  558 2226  432 1303 1306 1436 1565]
  --------
  query: (15, 10, 11)
  query_labels: (15,)
  y_index:  [ 294  359  557  568 2230  433 1305 1309 1440 1567]
  --------
Window size = 15
  support: (15, 15, 11)
  support_labels: (15,)
  y_index:  [ 450  470  712 1724 1760  343  410 1792 1941 2031]
  --------
  query: (15, 15, 11)
  query_labels: (15,)
  y_index:  [ 451  471  714 1736 1763  346  412 1793 1943 2032]
  --------
Window size = 20
  support: (15, 20, 11)
  support_labels: (15,)
  y_index:  [ 119 1845 1947 2048 2075  789  917 1388 1419 1487]
  --------
  query: (15, 20, 11)
  query_labels: (15,)
  y_index:  [ 121 1846 1952 2052 2076  796

---

# Model

In [332]:
import torch
import torch.nn as nn
import pytorch_lightning as pl

class LSTM(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, num_layers: int):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=False)
        self.lnorm = nn.LayerNorm(hidden_size)
    
    def forward(self, x: torch.tensor):
        # x: (B, T, I)
        o, (h, _) = self.lstm(x) # o: (B, T, H) / h: (1, B, H)
        normed_context = self.lnorm(h)
        return normed_context

class LSTMAttention(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, num_layers: int):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=False)
        self.lnorm = nn.LayerNorm(hidden_size)

    def forward(self, x: torch.tensor, rt_attn: bool=False):
        # x: (B, T, I)
        o, (h, _) = self.lstm(x) # o: (B, T, H) / h: (1, B, H)
        score = torch.bmm(o, h.permute(1, 2, 0)) # (B, T, H) x (B, H, 1)
        attn = torch.softmax(score, 1).squeeze(-1)  # (B, T)
        context = torch.bmm(attn.unsqueeze(1), o).squeeze(1)  # (B, 1, T) x (B, T, H)
        normed_context = self.lnorm(context)  # (B, H)
        if rt_attn:
            return normed_context, attn
        else:
            return normed_context, None

class MappingNet(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.rn = nn.Sequential(
            nn.Linear(hidden_size, 2*hidden_size, bias=False),
            nn.ReLU(),
            nn.Linear(2*hidden_size, 2*hidden_size, bias=False),
        )

    def forward(self, x: torch.tensor):
        # x: (B, H)
        outputs = self.rn(x)
        return outputs

class Model(nn.Module):
    def __init__(
            self, 
            feature_size: int, 
            hidden_size: int, 
            output_size: int, 
            num_layers: int, 
            has_same_dist_on_time: bool,
            drop_rate: float, 
            n_sample: int
        ):
        super().__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.parameter_size = hidden_size*output_size
        self.n_sample = n_sample
        self.has_same_dist_on_time = has_same_dist_on_time
        
        self.dropout = nn.Dropout(drop_rate)
        self.lstm = LSTMAttention(feature_size, hidden_size, num_layers)  # encode
        self.mapping_net = MappingNet(hidden_size)  # generate z(latent)
        self.decoder = nn.Linear(hidden_size, 2*self.parameter_size, bias=False)
        self.prob_layer = nn.LogSoftmax(dim=1) if output_size >=2 else nn.LogSigmoid()

    def encode(self, inputs, rt_attn: bool=False):
        # inputs: (B, T, I)
        inputs = self.dropout(inputs)
        encoded, attn = self.lstm(inputs, rt_attn)  # B, H
        if self.has_same_dist_on_time:
            # average the effect (n_stock, n_sample, H) -> (n_stock, H)
            hs = encoded.view(-1, self.n_sample, self.hidden_size).mean(1)
        else:
            hs = encoded
        hs = self.mapping_net(hs)

        z, dist = self.sample(hs, size=self.hidden_size)
        return encoded, z, dist, attn

    def sample(self, distribution_params, size):
        mean, log_std = distribution_params[:, :size], distribution_params[:, size:]
        std = torch.exp(log_std)
        dist = torch.distributions.Normal(torch.zeros_like(mean), torch.ones_like(std))
        z = dist.rsample()
        return mean + std*z, dist

    def decode(self, z):
        param_hs = self.decoder(z)
        parameters, _ = self.sample(param_hs, size=self.parameter_size)
        return parameters

    def predict(self, encoded, parameters):
        theta = parameters.view(-1, self.hidden_size, self.output_size)
        scores = encoded.unsqueeze(1).bmm(theta).squeeze()
        probs = self.prob_layer(scores)
        return probs
        
    # def forward(self, x, rt_attn: bool=False):
    #     encoded, z, dist, attn = self.encode(x, rt_attn=rt_attn)
    #     parameters = self.decode(z)
    #     theta = parameters.view(-1, self.hidden_size, self.output_size)
    #     scores = encoded.unsqueeze(1).bmm(theta).squeeze()
    #     probs = self.prob_layer(scores)
    #     return probs, z, dist, attn

    def cal_kl_div(self, dist, z):
        normal = torch.distributions.Normal(torch.zeros_like(z), torch.ones_like(z))
        return torch.mean(dist.log_prob(z) - normal.log_prob(z))


* LEO-Deepmind: https://github.com/deepmind/leo/blob/de9a0c2a77dd7a42c1986b1eef18d184a86e294a/model.py#L256
* LEO-pytorch: https://github.com/timchen0618/pytorch-leo/blob/master/model.py

In [250]:
n_sample = 5
n_stock = 3
window_size = 5
feature_size = 11
output_size = 1
hidden_size, num_layers, drop_rate = 20, 1, 0.2
parameter_size = output_size*hidden_size

has_same_dist_on_time = False  # case for each stock has same distribution on time

lstm = LSTMAttention(feature_size, hidden_size, num_layers)  # encode
mapping_net = MappingNet(hidden_size)  # generate z
decoder = nn.Linear(hidden_size, 2*parameter_size)

sample data

In [333]:
x = torch.randn(n_stock*n_sample, window_size, feature_size)
y = torch.ones(n_stock*n_sample)

print('--- Encoder ---')
# lstm
encoded, attn = lstm(x, rt_attn=False)
# mapping
if has_same_dist_on_time:
    # average the effect (n_stock, n_sample, H) -> (n_stock, H)
    hs = encoded.view(-1, n_sample, hidden_size).mean(1)
else:
    hs = encoded
hs = mapping_net(hs)
# encoded, hs, attn = encoder(x, rt_attn=False)
print(f'encoder: encoded - {encoded.size()} | mapping_net - {hs.size()}')

# sample
e_mean, e_std = hs[:, :hidden_size], hs[:, hidden_size:]
e_std = torch.exp(e_std)
print(f'mean: {e_mean.size()} | std: {e_std.size()}')
e_dist = torch.distributions.Normal(torch.zeros_like(e_mean), torch.ones_like(e_std))
z = e_dist.rsample()
print(f'z: {z.size()}')
latents = e_mean + e_std*z
print(f'latent: {latents.size()}')

# decoder
print('--- Decoder ---')
param_hs = decoder(latents)
# sample
print(f'param_hs: {param_hs.size()}')
d_mean, d_std = param_hs[:, :parameter_size], param_hs[:, parameter_size:]
d_std = torch.exp(d_std)
print(f'mean: {d_mean.size()} | std: {d_std.size()}')
d_dist = torch.distributions.Normal(torch.zeros_like(d_mean), torch.ones_like(d_std))
z = d_dist.rsample()
print(f'z: {z.size()}')
parameters = d_mean + d_std*z
if has_same_dist_on_time:
    parameters = torch.repeat_interleave(parameters.unsqueeze(1), n_sample, dim=1).view(-1, feature_size)
print(f'parameters: {parameters.size()}')

# score
theta = parameters.view(-1, hidden_size, output_size)
scores = encoded.unsqueeze(1).bmm(theta).squeeze(1)
print(f'scores: {scores.size()}')

--- Encoder ---
encoder: encoded - torch.Size([15, 20]) | mapping_net - torch.Size([15, 40])
mean: torch.Size([15, 20]) | std: torch.Size([15, 20])
z: torch.Size([15, 20])
latent: torch.Size([15, 20])
--- Decoder ---
param_hs: torch.Size([15, 40])
mean: torch.Size([15, 20]) | std: torch.Size([15, 20])
z: torch.Size([15, 20])
parameters: torch.Size([15, 20])
scores: torch.Size([15, 1])


In [345]:
feature_size = 11
hidden_size = 20 
output_size = 1  # 1 == 2 setting
num_layers = 1
has_same_dist_on_time=False
drop_rate = 0.2
n_sample = 5

x = torch.randn(n_stock*n_sample, window_size, feature_size)
y = torch.ones(n_stock*n_sample, dtype=torch.long)

loss_fn = nn.NLLLoss()
model = Model(feature_size, hidden_size, output_size, num_layers, has_same_dist_on_time, drop_rate, n_sample)

# forward?
encoded, z, dist, attn = model.encode(x, rt_attn=rt_attn)
parameters = model.decode(z)
probs = model.predict(encoded, parameters)
print(scores.shape)
kl_loss = model.cal_kl_div(dist, z)
loss = loss_fn(probs, y)
print(f'Loss: {loss:.4f}, KLD: {kl_loss:.4f}')

---

# Trainer

In [411]:
import yaml
import inspect
with open('settings.yml') as file:
    kwargs = yaml.load(file, Loader=yaml.FullLoader)

comm_kwargs = {
    k: kwargs.get(k) for k in inspect.signature(MetaStockDataset.__init__).parameters.keys() 
    if kwargs.get(k) is not None
}

In [423]:
meta_train = MetaStockDataset(meta_type='train', meta_train_stocks=None, **comm_kwargs)
meta_test1 = MetaStockDataset(meta_type='test1', meta_train_stocks=meta_train.symbols(), **comm_kwargs)
meta_test2 = MetaStockDataset(meta_type='test2', meta_train_stocks=meta_train.symbols(), **comm_kwargs)
meta_test3 = MetaStockDataset(meta_type='test3', meta_train_stocks=meta_train.symbols(), **comm_kwargs)

len(meta_train.symbols()), len(meta_test1.symbols()), len(meta_test2.symbols()), len(meta_test3.symbols())

(40, 40, 10, 10)

In [424]:
tasks = meta_train.generate_tasks()
for window_size, t in tasks.items():
    t_tensor = meta_train.map_to_tensor(t)
    break

In [427]:
n_inner_step = 5

In [None]:
latents, kl_div = self.model.encode(inputs)
latents_init = latents

for i in range(self.config['inner_update_step']):
    latents.retain_grad()
    classifier_weights = self.model.decode(latents)
    train_loss, _ = self.model.cal_target_loss(inputs, classifier_weights, target)
    train_loss.backward(retain_graph=True)

    latents = latents - self.model.inner_l_rate * latents.grad.data

encoder_penalty = torch.mean((latents_init - latents) ** 2)