# PyTorch LSTM with minimum features

系列を扱うニューラルネットワークをPyTorchで実装しました。以下の4種の特徴量を入力としています。

- CATEGORICAL_COLS: 系列として扱うカテゴリ変数
- NUMERICAL_COLS: 系列として扱う数値変数
- META_C_COLS: 付随するカテゴリ変数
- META_N_COLS: 付随する数値変数

```python
CATEGORICAL_COLS = [
    'value_1',
    'display_action_id',
    'name_1',
    'kind_1'
]
NUMERICAL_COLS = [
    'unit_price',
    'n_items',
    'spend_time',
    'hour'
]
META_C_COLS = [
    'gender',
    'age',
    'dow',
]
META_N_COLS = [
    'n_seq'    # 予測時点でのsessionの長さ
]
```

1 fold 分を 1 epoch 学習させたところ、検証用データセットの auc が 0.54 でした。

```bash
Training fold0...
1/1 * Epoch (train): 100% 1586/1586 [50:45<00:00,  1.92s/it]
1/1 * Epoch (valid): 100% 397/397 [06:18<00:00,  1.05it/s]
[2021-01-31 23:12:17,948] 
1/1 * Epoch 1 (train): auc=0.5298 | loss=0.2419
1/1 * Epoch 1 (valid): auc=0.5418 | loss=0.2296
Top best models:
logdir_nn000/fold0/checkpoints/train.1.pth	0.2296
```

In [None]:
!pip install iterative-stratification

In [None]:
from keras.preprocessing import sequence
import numpy as np
import pandas as pd
import pathlib
import torch


# https://www.guruguru.science/competitions/14/discussions/9768ddc3-b6f4-440d-830d-cd5330fe1611/
class RetailDataset:
    def __init__(self, file_path: pathlib.Path, thres_sec: int) -> None:
        self.file_path = file_path
        self.thres_sec = thres_sec
        self.cartlog: pd.DataFrame = pd.read_csv(file_path / "carlog.csv",
                                                 dtype={'value_1': str},
                                                 parse_dates=['date'])
        self.product_master: pd.DataFrame = pd.read_csv(
            file_path / "product_master.csv"
        )
        self.meta: pd.DataFrame = pd.read_csv(file_path / "meta.csv")
        self.meta['time_elapsed_sec'] = self.meta['time_elapsed'] * 60
        self.test: pd.DataFrame = pd.read_csv(file_path / "test.csv")
        self.user_master: pd.DataFrame = pd.read_csv(file_path / "user_master.csv")
        self.target_category_ids = [
            38,  # アイスクリーム__ノベルティー
            110,  # スナック・キャンディー__ガム
            113,  # スナック・キャンディー__シリアル
            114,  # スナック・キャンディー__スナック
            134,  # チョコ・ビスクラ__チョコレート
            171,  # ビール系__RTD
            172,  # ビール系__ノンアルコール
            173,  # ビール系__ビール系
            376,  # 和菓子__米菓
            435,  # 大型PET__無糖茶（大型PET）
            467,  # 小型PET__コーヒー（小型PET）
            537,  # 水・炭酸水__大型PET（炭酸水）
            539,  # 水・炭酸水__小型PET（炭酸水）
            629,  # 缶飲料__コーヒー（缶）
            768,  # 麺類__カップ麺
        ]
        """学習用データセットでは `time_elapsed_sec` が欠損しているので、正解ラベルのために補完
        10.0    16833
        0.0     14277
        5.0     14072
        3.0     11304
        """
        self.meta.loc[
            self.meta["time_elapsed_sec"].isnull(), "time_elapsed_sec"
        ] = thres_sec

    def get_test_sessions(self) -> set:
        """以下の条件を満たすセッションを取得する
        - 予測対象である
        """
        return set(self.test["session_id"].unique())

    def get_test_input_log(self) -> pd.DataFrame:
        """以下の条件を満たすログを取得する
        - 予測対象である

        ログが存在しないセッションもあるので注意.
        """
        test_sessions = self.get_test_sessions()
        return self.cartlog[self.cartlog["session_id"].isin(test_sessions)]

    def get_log_first_half(self) -> pd.DataFrame:
        """以下の条件を満たすログを取得する
        - 学習期間(2020-08-01の前日まで)のセッションである
        """
        first_half_sessions = set(
            self.meta.query("date < '2020-08-01'")["session_id"].unique()
        )
        return self.cartlog[self.cartlog["session_id"].isin(first_half_sessions)]

    def get_train_output_log(self) -> pd.DataFrame:
        """以下の条件を満たすログを取得する
        - 学習期間(2020-08-01の前日まで)のセッションである
        - 指定した時間(thres_sec)以降にログが存在している
        """
        return pd.merge(
            self.get_log_first_half(),
            self.meta[["session_id", "time_elapsed_sec"]],
            on=["session_id"],
            how="inner",
        ).query("spend_time > time_elapsed_sec")

    def get_train_sessions(self) -> set:
        """以下の条件を満たすセッションを取得する
        - 学習期間(2020-08-01の前日まで)のセッションである
        - 指定した時間(thres_sec)以降にログが存在している
        """
        return set(self.get_train_output_log()["session_id"].unique())

    def get_train_input_log(self) -> pd.DataFrame:
        """以下の条件を満たすログを取得する
        - 学習期間(2020-08-01の前日まで)のセッションである
        - 指定した時間(thres_sec)以降にログが存在している
        - 指定した時間(thres_sec)より前のログである
        """
        train_sessions = self.get_train_sessions()
        return pd.merge(
            self.get_log_first_half()[
                self.get_log_first_half()["session_id"].isin(train_sessions)
            ],
            self.meta[["session_id", "time_elapsed_sec"]],
            on=["session_id"],
            how="inner",
        ).query("spend_time <= time_elapsed_sec").drop('time_elapsed_sec', axis=1)

    def get_payment_sessions(self) -> set:
        """以下の条件を満たすセッションを取得する
        - 決済を行った
        """
        return set(self.cartlog.query("is_payment == 1")["session_id"].unique())

    def agg_payment(self, cartlog) -> pd.DataFrame:
        """セッションごと・商品ごとの購買個数を集計する"""
        # 購買情報は商品のものだけ.
        target_index = (cartlog["kind_1"] == "商品")

        # JANコード (vale_1)ごとに商品の購入個数(n_items)を足し算
        agg = (
            cartlog.loc[target_index]
            .groupby(["session_id", "value_1"])["n_items"]
            .sum()
            .reset_index()
        )
        agg = agg.rename(columns={"value_1": "JAN"})
        agg = agg.astype({"JAN": int})
        return agg

    def get_train_target(self) -> pd.DataFrame:
        """学習で使用するセッションの目的変数を取得する"""
        # 空のターゲット用データフレームを用意する
        train_sessions = self.get_train_sessions()
        train_target = pd.DataFrame(
            np.zeros((len(train_sessions), len(self.target_category_ids))),
            index=train_sessions,
            columns=self.target_category_ids,
        ).astype(int)
        train_target.index.name = "session_id"

        # 集計する
        train_output_log = self.get_train_output_log()
        train_items_per_session_jan = self.agg_payment(train_output_log)
        train_items_per_session_target_jan = pd.merge(
            train_items_per_session_jan,
            self.product_master[["JAN", "category_id"]],
            on="JAN",
            how="inner",
        ).query("category_id in @self.target_category_ids")
        train_target_pos = (
            train_items_per_session_target_jan.groupby(["session_id", "category_id"])[
                "n_items"
            ]
            .sum()
            .unstack()
            .fillna(0)
            .astype(int)
        )
        train_target_pos[train_target_pos > 0] = 1
        train_target_pos[train_target_pos <= 0] = 0

        train_target.loc[train_target_pos.index] = train_target_pos.values
        return train_target[self.target_category_ids]


class RetailNNDataset(torch.utils.data.Dataset):
    def __init__(self,
                 input_log_df: pd.DataFrame,
                 target_df=None,
                 categorical_features=[],
                 numerical_features=[],
                 meta_c_features=[],
                 meta_n_features=[],
                 is_train: bool = True) -> None:
        super().__init__
        self.is_train = is_train
        self.input_log_df = input_log_df
        self.target_df = target_df
        self.categorical_features = categorical_features
        self.numerical_features = numerical_features
        self.meta_c_features = meta_c_features
        self.meta_n_features = meta_n_features

    def __len__(self) -> int:
        return len(self.input_log_df)

    def __getitem__(self, index: int):
        items = []
        for c in self.categorical_features + self.numerical_features:
            items.append(self.input_log_df[c].values[index])
        meta_c = self.input_log_df[self.meta_c_features].iloc[index].values
        meta_n = self.input_log_df[self.meta_n_features].iloc[index].values
        items.append(meta_c)
        items.append(meta_n)
        if self.is_train:
            targets = self.target_df[self.target_df.index == self.input_log_df['session_id'].values[index]].values[0]
            items.append(targets)
        return tuple(items)


class MyCollator(object):
    def __init__(self, is_train=True):
        self.is_train = is_train

    def __call__(self, batch):
        value_1_tensor = [item[0] for item in batch]
        display_action_id_tensor = [item[1] for item in batch]
        name_1_tensor = [item[2] for item in batch]
        kind_1_tensor = [item[3] for item in batch]
        unit_price_tensor = [item[4] for item in batch]
        n_items_tensor = [item[5] for item in batch]
        spend_time_tensor = [item[6] for item in batch]
        hour_tensor = [item[7] for item in batch]
        meta_c_tensor = [item[8] for item in batch]
        meta_n_tensor = [item[9] for item in batch]
        if self.is_train:
            targets = [item[-1] for item in batch]

        def _pad_sequences(data, maxlen: int, dtype=torch.long) -> torch.tensor:
            data = sequence.pad_sequences(data, maxlen=maxlen)
            return torch.tensor(data, dtype=dtype)

        lens = [len(s) for s in value_1_tensor]
        value_1_tensor = _pad_sequences(value_1_tensor, max(lens))
        display_action_id_tensor = _pad_sequences(display_action_id_tensor, max(lens))
        name_1_tensor = _pad_sequences(name_1_tensor, max(lens))
        kind_1_tensor = _pad_sequences(kind_1_tensor, max(lens))
        unit_price_tensor = _pad_sequences(unit_price_tensor, max(lens), dtype=torch.float)
        n_items_tensor = _pad_sequences(n_items_tensor, max(lens), dtype=torch.float)
        spend_time_tensor = _pad_sequences(spend_time_tensor, max(lens), dtype=torch.float)
        hour_tensor = _pad_sequences(hour_tensor, max(lens), dtype=torch.float)
        meta_c_tensor = torch.tensor(meta_c_tensor, dtype=torch.long)
        meta_n_tensor = torch.tensor(meta_n_tensor, dtype=torch.float)
        if self.is_train:
            targets = torch.tensor(targets, dtype=torch.float)
            return (
                value_1_tensor,
                display_action_id_tensor,
                name_1_tensor,
                kind_1_tensor,
                unit_price_tensor,
                n_items_tensor,
                spend_time_tensor,
                hour_tensor,
                meta_c_tensor,
                meta_n_tensor,
                targets,
            )

        return (
            value_1_tensor,
            display_action_id_tensor,
            name_1_tensor,
            kind_1_tensor,
            unit_price_tensor,
            n_items_tensor,
            spend_time_tensor,
            hour_tensor,
            meta_c_tensor,
            meta_n_tensor,
        )


In [None]:
import torch
from torch import nn


class RetailNN(nn.Module):
    def __init__(
        self,
        categorical_features,
        numerical_features,
        meta_c_features,
        meta_n_features,
        n_targets,
        n_value_1,
        n_display_action_id,
        n_name_1,
        n_kind_1,
        emb_dim=128,
        rnn_dim=128,
        hidden_size=128,
        num_layers=2,
        dropout=0.3,
        rnn_dropout=0.3,
    ):
        super().__init__()

        self.n_value_1 = n_value_1
        self.n_targets = n_targets
        self.emb_dim = emb_dim
        self.rnn_dim = rnn_dim
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout = dropout
        self.rnn_dropout = rnn_dropout

        self.categorical_features = categorical_features
        self.numerical_features = numerical_features
        self.meta_c_features = meta_c_features
        self.meta_n_features = meta_n_features

        self.drop = nn.Dropout(dropout)
        self.value_1_embedding = nn.Embedding(n_value_1, emb_dim, padding_idx=0)
        self.display_action_id_embedding = nn.Embedding(n_display_action_id, emb_dim, padding_idx=0)
        self.name_1_embedding = nn.Embedding(n_name_1, emb_dim, padding_idx=0)
        self.kind_1_embedding = nn.Embedding(n_kind_1, emb_dim, padding_idx=0)

        self.gender_embedding = nn.Embedding(4, 5)
        self.age_embedding = nn.Embedding(12, 10)
        self.dow_embedding = nn.Embedding(7, 10)

        self.cate_proj = nn.Sequential(
            nn.Linear(emb_dim * len(self.categorical_features), hidden_size // 2),
            nn.LayerNorm(hidden_size // 2),
        )
        self.cont_emb = nn.Sequential(
            nn.Linear(len(self.numerical_features), hidden_size // 2),
            nn.LayerNorm(hidden_size // 2),
        )

        self.lstm = nn.LSTM(
            input_size=hidden_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            dropout=rnn_dropout,
            bidirectional=False,
            batch_first=True,
        )
        self.ffn = nn.Sequential(
            nn.Linear(hidden_size * 2 + 26, hidden_size),
#             nn.LayerNorm(hidden_size),
#             nn.Dropout(dropout),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_size, self.n_targets),
        )

    def forward(
        self,
        value_1_tensor,
        display_action_id_tensor,
        name_1_tensor,
        kind_1_tensor,
        unit_price_tensor,
        spend_time_tensor,
        hour_tensor,
        n_items_tensor,
        meta_c_tensor,
        meta_n_tensor,
    ):
        mc = torch.cat([
            self.gender_embedding(meta_c_tensor[:, 0]),
            self.age_embedding(meta_c_tensor[:, 1]),
            self.dow_embedding(meta_c_tensor[:, 2]),
        ], axis=1)

        unit_price_tensor_feature = unit_price_tensor.unsqueeze(2)
        n_items_tensor_feature = n_items_tensor.unsqueeze(2)
        spend_time_tensor_feature = spend_time_tensor.unsqueeze(2)
        hour_tensor_feature = hour_tensor.unsqueeze(2)
        cate_emb = torch.cat(
            [
                self.value_1_embedding(value_1_tensor),
                self.display_action_id_embedding(display_action_id_tensor),
                self.name_1_embedding(name_1_tensor),
                self.kind_1_embedding(kind_1_tensor),
            ],
            dim=2,
        )
        cate_emb = self.cate_proj(cate_emb)

        cont_emb = torch.cat(
            [
                unit_price_tensor_feature,
                n_items_tensor_feature,
                spend_time_tensor_feature,
                hour_tensor_feature,
            ],
            dim=2,
        )
        cont_emb = self.cont_emb(cont_emb)
        out = torch.cat([cate_emb, cont_emb], dim=2)
        out, _ = self.lstm(out)
        avg_pool = torch.mean(out, 1)
        max_pool, _ = torch.max(out, 1)
        conc = torch.cat([avg_pool, max_pool, mc, meta_n_tensor[:, ]], axis=1)
        conc = self.ffn(conc)
        return conc


In [None]:
import torch
from sklearn.metrics import roc_auc_score
from catalyst.dl import Runner
from catalyst.dl.utils import any2device


class CustomRunner(Runner):
    def _handle_batch(self, batch):
        (
            value_1_tensor,
            display_action_id_tensor,
            name_1_tensor,
            kind_1_tensor,
            unit_price_tensor,
            n_items_tensor,
            spend_time_tensor,
            hour_tensor,
            meta_c_tensor,
            meta_n_tensor,
            y,
        ) = batch
        out = self.model(
            value_1_tensor,
            display_action_id_tensor,
            name_1_tensor,
            kind_1_tensor,
            unit_price_tensor,
            n_items_tensor,
            spend_time_tensor,
            hour_tensor,
            meta_c_tensor,
            meta_n_tensor,
        )
        loss = self.criterion(out, y)
        try:
            score = roc_auc_score(y.to('cpu').detach().numpy().copy(), out.to('cpu').detach().numpy().copy(), average='macro')
            self.batch_metrics.update(
                {"loss": loss, "auc": score}
            )
        except ValueError:
            pass

        if self.is_train_loader:
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()

    @torch.no_grad()
    def predict_batch(self, batch):
        batch = any2device(batch, self.device)
        if len(batch) == 10:
            (
                value_1_tensor,
                display_action_id_tensor,
                name_1_tensor,
                kind_1_tensor,
                unit_price_tensor,
                n_items_tensor,
                spend_time_tensor,
                hour_tensor,
                meta_c_tensor,
                meta_n_tensor,
            ) = batch
        elif len(batch) == 11:
            (
                value_1_tensor,
                display_action_id_tensor,
                name_1_tensor,
                kind_1_tensor,
                unit_price_tensor,
                n_items_tensor,
                spend_time_tensor,
                hour_tensor,
                meta_c_tensor,
                meta_n_tensor,
                y,
            ) = batch
        else:
            raise RuntimeError
        out = self.model(
            value_1_tensor,
            display_action_id_tensor,
            name_1_tensor,
            kind_1_tensor,
            unit_price_tensor,
            n_items_tensor,
            spend_time_tensor,
            hour_tensor,
            meta_c_tensor,
            meta_n_tensor,
        )
        return out


In [None]:
import os
import random

import numpy as np
import torch


def seed_everything(seed=0):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

In [None]:
import os
import sys

from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader
import torch

# sys.path.append('../')
# from src.datasets import RetailNNDataset, MyCollator
# from src.models import RetailNN
# from src.utils import seed_everything
# from src.runner import CustomRunner

CATEGORICAL_COLS = [
    'value_1',
    'display_action_id',
    'name_1',
    'kind_1'
]
NUMERICAL_COLS = [
    'unit_price',
    'n_items',
    'spend_time',
    'hour'
]
META_C_COLS = [
    'gender',
    'age',
    'dow',
]
META_N_COLS = [
    'n_seq'
]
run_name = 'nn000'
seed_everything(0)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
batch_size = 256


if __name__ == '__main__':
    print('File loading...')
    X_train = pd.read_pickle('../input/atmacup9-features/X_train.pickle')
    y_train = pd.read_pickle('../input/atmacup9-features/y_train.pickle')
    X_test = pd.read_pickle('../input/atmacup9-features/X_test.pickle')

    cv = MultilabelStratifiedKFold(n_splits=5, shuffle=False)
    oof_preds = np.zeros((len(X_train), 15), dtype=np.float32)
    test_preds = np.zeros((len(X_test), 15), dtype=np.float32)
    cv_scores = []

    test_dataset = RetailNNDataset(input_log_df=X_test,
                                   target_df=None,
                                   categorical_features=CATEGORICAL_COLS,
                                   numerical_features=NUMERICAL_COLS,
                                   meta_c_features=META_C_COLS,
                                   meta_n_features=META_N_COLS,
                                   is_train=False)
    test_loader = DataLoader(test_dataset,
                             shuffle=False,
                             collate_fn=MyCollator(is_train=False),
                             batch_size=batch_size,
                             num_workers=os.cpu_count(),
                             pin_memory=True)

    for fold_id, (tr_idx, va_idx) in enumerate(cv.split(X_train,
                                                        y_train.loc[y_train.index.isin(X_train["session_id"].to_list())])):
        if fold_id in (0,):
            print(f'Training fold{fold_id}...')
            X_tr = X_train.loc[tr_idx, :]
            X_val = X_train.loc[va_idx, :]

            train_dataset = RetailNNDataset(input_log_df=X_tr,
                                            target_df=y_train,
                                            categorical_features=CATEGORICAL_COLS,
                                            numerical_features=NUMERICAL_COLS,
                                            meta_c_features=META_C_COLS,
                                            meta_n_features=META_N_COLS)
            valid_dataset = RetailNNDataset(input_log_df=X_val,
                                            target_df=y_train,
                                            categorical_features=CATEGORICAL_COLS,
                                            numerical_features=NUMERICAL_COLS,
                                            meta_c_features=META_C_COLS,
                                            meta_n_features=META_N_COLS)

            train_loader = DataLoader(train_dataset,
                                      shuffle=False,
                                      collate_fn=MyCollator(is_train=True),
                                      batch_size=batch_size,
                                      num_workers=os.cpu_count(),
                                      pin_memory=True)
            valid_loader = DataLoader(valid_dataset,
                                      shuffle=False,
                                      collate_fn=MyCollator(is_train=True),
                                      batch_size=batch_size,
                                      num_workers=os.cpu_count(),
                                      pin_memory=True)

            loaders = {'train': train_loader, 'valid': valid_loader}

            runner = CustomRunner(device=device)
            model = RetailNN(
                categorical_features=CATEGORICAL_COLS,
                numerical_features=NUMERICAL_COLS,
                meta_c_features=META_C_COLS,
                meta_n_features=META_N_COLS,
                n_targets=15,
                n_value_1=67590 + 1 + 1,     # len(classes_) + padding_id + null_id
                n_display_action_id=41 + 1 + 1,     # len(classes_) + padding_id + null_id
                n_name_1=42070 + 1 + 1,     # len(classes_) + padding_id + null_id
                n_kind_1=14 + 1 + 1,     # len(classes_) + padding_id + null_id
            )
            criterion = torch.nn.BCEWithLogitsLoss()
            optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30, eta_min=1e-6)
            logdir = f'logdir_{run_name}/fold{fold_id}'
            runner.train(
                model=model,
                criterion=criterion,
                optimizer=optimizer,
                scheduler=scheduler,
                loaders=loaders,
                logdir=logdir,
                num_epochs=1,
                verbose=True,
            )
            oof_preds[va_idx, :] = np.concatenate(list(map(lambda x: x.cpu().numpy(),
                                                  runner.predict_loader(loader=valid_loader,
                                                                        resume=f'{logdir}/checkpoints/best.pth',
                                                                        model=model,),)))
            np.save(f"{logdir}/y_val_pred_fold{fold_id}", oof_preds[va_idx, :])

            test_preds_ = np.concatenate(list(map(lambda x: x.cpu().numpy(),
                                         runner.predict_loader(loader=test_loader,
                                                               resume=f'{logdir}/checkpoints/best.pth',
                                                               model=model,),)))
            test_preds += test_preds_ / cv.n_splits
            np.save(f"{logdir}/y_test_pred_fold{fold_id}", test_preds_)

    np.save(f"{logdir}/y_oof_pred", oof_preds)
    np.save(f"{logdir}/y_test_pred", test_preds)


In [None]:
# for batch in train_loader:
#     _b = batch
#     model = RetailNN(
#         categorical_features=CATEGORICAL_COLS,
#         numerical_features=NUMERICAL_COLS,
#         meta_c_features=META_C_COLS,
#         meta_n_features=META_N_COLS,
#         n_targets=15,
#         n_value_1=67590 + 1 + 1,     # len(classes_) + padding_id + null_id
#         n_display_action_id=41 + 1 + 1,     # len(classes_) + padding_id + null_id
#         n_name_1=42070 + 1 + 1,     # len(classes_) + padding_id + null_id
#         n_kind_1=14 + 1 + 1,     # len(classes_) + padding_id + null_id
#     )
#     out = model(_b[0], _b[1], _b[2], _b[3], _b[4], _b[5], _b[6], _b[7], _b[8], _b[9])
#     print(out.shape)

In [None]:
# print(model)

# RetailNN(
#   (drop): Dropout(p=0.3, inplace=False)
#   (value_1_embedding): Embedding(67592, 128, padding_idx=0)
#   (display_action_id_embedding): Embedding(43, 128, padding_idx=0)
#   (name_1_embedding): Embedding(42072, 128, padding_idx=0)
#   (kind_1_embedding): Embedding(16, 128, padding_idx=0)
#   (gender_embedding): Embedding(4, 5)
#   (age_embedding): Embedding(12, 10)
#   (dow_embedding): Embedding(7, 10)
#   (cate_proj): Sequential(
#     (0): Linear(in_features=512, out_features=64, bias=True)
#     (1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
#   )
#   (cont_emb): Sequential(
#     (0): Linear(in_features=4, out_features=64, bias=True)
#     (1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
#   )
#   (lstm): LSTM(128, 128, num_layers=2, batch_first=True, dropout=0.3)
#   (ffn): Sequential(
#     (0): Linear(in_features=282, out_features=128, bias=True)
#     (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
#     (2): Dropout(p=0.3, inplace=False)
#     (3): ReLU(inplace=True)
#     (4): Linear(in_features=128, out_features=15, bias=True)
#   )
# )