In [59]:
"""
実行時に設定するパラメータ
"""

# Informerの学習パラメータ
# Informerを学習するときに使用したパラメータの文字列をここで代入する
# ARG_STR = "--model informer --data ETTh1 --attn prob --freq h --features S  --e_layers 1  --d_layers 1 --dropout 0.3 --learning_rate 0.0001 --embed timeF --use_y_pred_cache --proposed_lmda 0.5 --proposed_moe_lr=0.1 --proposed_moe_weight_decay=0.01 --proposed_moe_epochs=10"
ARG_STR = "--model informer --data NaturalGas --root_path './Informer2020/data/NaturalGas/' --data_path combined_data.csv --features S --attn prob --freq h --e_layers 1  --d_layers 1 --dropout 0.3 --learning_rate 0.0001 --embed timeF --use_y_pred_cache --proposed_lmda 0.2 --proposed_moe_lr=0.01 --proposed_moe_weight_decay=0.01 --proposed_moe_epochs=10"

Y_PRED_PATH_PLACEHOLDER = "checkpoints/{method}_{data}_y_pred.pkl"

In [60]:
import sys

sys.path.append("Informer2020")

from typing import Optional
import pickle
import argparse
import random
import tqdm
import numpy as np
import torch
import pandas as pd
import os

import dataset
from model.informer_model import InformerModel
from model.model import Model
from model.moment_model import MomentModel
from propose import ProposedModelWithMoe, ProposedModel
from evaluation import evaluate_mse, evaluate_nll

from main_informer import parse_args

In [61]:
def set_seed(seed: int) -> None:
    # random
    random.seed(seed)

    # numpy
    np.random.seed(seed)

    # pytorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.mps.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


def load_moment_model(
    args: argparse.Namespace, y_pred_path: Optional[str] = None
) -> Model:
    return MomentModel(
        param="AutonLab/MOMENT-1-large", pred_len=args.pred_len, y_pred_path=y_pred_path
    )


def load_moment_model_finetuned(
    args: argparse.Namespace,
    train_dataset: torch.utils.data.Dataset,
    valid_dataset: torch.utils.data.Dataset,
) -> Model:
    model: MomentModel = load_moment_model(args=args)
    model.fine_tuning(
        train_dataset=train_dataset, valid_dataset=valid_dataset, args=args
    )
    return model


def load_informer_model(
    args: argparse.Namespace, use_saved_model: bool = False, y_pred_path: Optional[str] = None
) -> Model:
    return InformerModel(args=args, use_saved_model=use_saved_model, y_pred_path=y_pred_path)


def load_proposed_model(moment_model: Model, informer_model: Model) -> Model:
    model = ProposedModel(moment_model=moment_model, informer_model=informer_model)
    return model


def load_proposed_model_with_moe(
    moment_model: Model,
    informer_model: Model,
    input_size: int,
    train_dataset: torch.utils.data.Dataset,
    valid_dataset: torch.utils.data.Dataset,
    args: argparse.Namespace,
    lr: float,
    weight_decay: float,
) -> Model:
    model = ProposedModelWithMoe(
        moment_model=moment_model,
        informer_model=informer_model,
        input_size=input_size,
    )
    model.train(
        train_dataset=train_dataset,
        valid_dataset=valid_dataset,
        args=args,
        lr=lr,
        weight_decay=weight_decay,
    )
    return model

In [62]:
set_seed(0)

args = parse_args(ARG_STR)
print("args:", args)

train_dataset, valid_dataset, test_dataset = dataset.load_dataset(args=args)
input_size = args.seq_len

moment_model = load_moment_model(
    args=args,
    y_pred_path=(
        Y_PRED_PATH_PLACEHOLDER.format(method="moment", data=args.data)
        if args.use_y_pred_cache
        else None
    ),
)
informer_model = load_informer_model(
    args=args,
    use_saved_model=False,
    y_pred_path=(
        Y_PRED_PATH_PLACEHOLDER.format(method="informer", data=args.data)
        if args.use_y_pred_cache
        else None
    ),
)
proposed_model = load_proposed_model(moment_model, informer_model)

args: Namespace(model='informer', data='NaturalGas', root_path='./Informer2020/data/NaturalGas/', data_path='combined_data.csv', features='S', target='actual_wdl_gj', freq='h', checkpoints='checkpoints/NaturalGas_sample30_window30', seq_len=96, label_len=48, pred_len=24, enc_in=1, dec_in=1, c_out=1, d_model=512, n_heads=8, e_layers=1, d_layers=1, s_layers=[3, 2, 1], d_ff=2048, factor=5, padding=0, distil=True, dropout=0.3, attn='prob', embed='timeF', activation='gelu', output_attention=False, do_predict=False, mix=True, cols=None, num_workers=0, itr=2, train_epochs=6, batch_size=32, patience=3, learning_rate=0.0001, des='test', loss='mse', lradj='type1', use_amp=False, inverse=False, use_gpu=False, gpu=0, use_multi_gpu=False, devices='0,1,2,3', use_saved_informer=False, use_y_pred_cache=True, detail_freq='h')
[test] self.target actual_wdl_gj
[test] cols ['date', 'schedule_interval', 'transmission_id', 'sched_inj_gj', 'sched_wdl_gj', 'price_value', 'administered_price', 'actual_wdl_gj',

In [63]:
proposed_model_with_moe = load_proposed_model_with_moe(
    moment_model,
    informer_model,
    input_size,
    train_dataset,
    valid_dataset,
    args,
    lr=1e-1 if args.data == "ETTh1" else 1e-2,
    weight_decay=1e-2,
)

100%|██████████| 196/196 [00:00<00:00, 1092.23it/s]
100%|██████████| 27/27 [00:00<00:00, 1944.51it/s]


Epoch [1/10], Train Loss: 0.2113, Valid Loss: 0.3311


100%|██████████| 196/196 [00:00<00:00, 1161.61it/s]
100%|██████████| 27/27 [00:00<00:00, 1967.89it/s]


Epoch [2/10], Train Loss: 0.2069, Valid Loss: 0.3276


100%|██████████| 196/196 [00:00<00:00, 1126.70it/s]
100%|██████████| 27/27 [00:00<00:00, 1936.46it/s]


Epoch [3/10], Train Loss: 0.2056, Valid Loss: 0.3284


100%|██████████| 196/196 [00:00<00:00, 1110.34it/s]
100%|██████████| 27/27 [00:00<00:00, 1898.99it/s]


Epoch [4/10], Train Loss: 0.2061, Valid Loss: 0.3268


100%|██████████| 196/196 [00:00<00:00, 1160.88it/s]
100%|██████████| 27/27 [00:00<00:00, 2016.17it/s]


Epoch [5/10], Train Loss: 0.2053, Valid Loss: 0.3265


100%|██████████| 196/196 [00:00<00:00, 1185.90it/s]
100%|██████████| 27/27 [00:00<00:00, 1907.30it/s]


Epoch [6/10], Train Loss: 0.2051, Valid Loss: 0.3264


100%|██████████| 196/196 [00:00<00:00, 1071.60it/s]
100%|██████████| 27/27 [00:00<00:00, 1977.31it/s]


Epoch [7/10], Train Loss: 0.2050, Valid Loss: 0.3264


100%|██████████| 196/196 [00:00<00:00, 1198.44it/s]
100%|██████████| 27/27 [00:00<00:00, 1929.96it/s]


Epoch [8/10], Train Loss: 0.2050, Valid Loss: 0.3263


100%|██████████| 196/196 [00:00<00:00, 1148.55it/s]
100%|██████████| 27/27 [00:00<00:00, 1966.66it/s]


Epoch [9/10], Train Loss: 0.2049, Valid Loss: 0.3263


100%|██████████| 196/196 [00:00<00:00, 1104.12it/s]
100%|██████████| 27/27 [00:00<00:00, 1926.28it/s]

Epoch [10/10], Train Loss: 0.2049, Valid Loss: 0.3263





args: Namespace(model='informer', data='NaturalGas', root_path='./Informer2020/data/NaturalGas/', data_path='combined_data.csv', features='S', target='actual_wdl_gj', freq='h', checkpoints='checkpoints/NaturalGas_sample30_window30', seq_len=96, label_len=48, pred_len=24, enc_in=1, dec_in=1, c_out=1, d_model=512, n_heads=8, e_layers=1, d_layers=1, s_layers=[3, 2, 1], d_ff=2048, factor=5, padding=0, distil=True, dropout=0.3, attn='prob', embed='timeF', activation='gelu', output_attention=False, do_predict=False, mix=True, cols=None, num_workers=0, itr=2, train_epochs=6, batch_size=32, patience=3, learning_rate=0.0001, des='test', loss='mse', lradj='type1', use_amp=False, inverse=False, use_gpu=False, gpu=0, use_multi_gpu=False, devices='0,1,2,3', use_saved_informer=False, use_y_pred_cache=True, detail_freq='h')
testing: informer


100%|██████████| 56/56 [00:00<00:00, 2705.04it/s]


{'mse': 0.2425619520486699, 'nll': 3838.6947371467613}
testing: moment


100%|██████████| 56/56 [00:00<00:00, 2723.20it/s]


{'mse': 0.32092184517548183, 'nll': 504.0507907688885}
testing: proposed


100%|██████████| 56/56 [00:00<00:00, 2497.78it/s]


{'mse': 0.2424583544034962, 'nll': 291.27698335871935}
testing: proposed+moe


100%|██████████| 56/56 [00:00<00:00, 2042.30it/s]


{'mse': 0.23301477171899582, 'nll': 380.05546410874524}


{'informer': {'mse': 0.2425619520486699, 'nll': 3838.6947371467613},
 'moment': {'mse': 0.32092184517548183, 'nll': 504.0507907688885},
 'proposed': {'mse': 0.2424583544034962, 'nll': 291.27698335871935},
 'proposed+moe': {'mse': 0.23301477171899582, 'nll': 380.05546410874524}}

In [65]:
df = pd.DataFrame(results)
df.to_csv(f"data/results_{args.data}.csv")
df

Unnamed: 0,informer,moment,proposed,proposed+moe
mse,0.242562,0.320922,0.242458,0.233015
nll,3838.694737,504.050791,291.276983,380.055464


In [66]:
with open(f"checkpoints/informer_{args.data}_y_pred.pkl", "wb") as f:
    pickle.dump(informer_model.y_pred, f)

with open(f"checkpoints/moment_{args.data}_y_pred.pkl", "wb") as f:
    pickle.dump(moment_model.y_pred, f)