In [1]:
import argparse
import random

import numpy as np
import torch

import config
import dataset
from informer_model import InformerModel
from model import Model
from moment_model import MomentModel
from propose import ProposedModel
from evaluation import evaluate_mse, evaluate_nll

In [2]:
def set_seed(seed: int) -> None:
    # random
    random.seed(seed)

    # numpy
    np.random.seed(seed)

    # pytorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.mps.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


def load_moment_model(args: argparse.Namespace) -> Model:
    return MomentModel(param="AutonLab/MOMENT-1-large", pred_len=args.pred_len)


def load_informer_model(args: argparse.Namespace) -> Model:
    return InformerModel(args, checkpoint_path="checkpoints/informer.pth")


def load_proposed_model(moment_model: Model, informer_model: Model,input_size: int,train_dataset: torch.utils.data.Dataset, args: argparse.Namespace) -> Model:
    model = ProposedModel(moment_model=moment_model, informer_model=informer_model,input_size=input_size)
    model.train(train_dataset=train_dataset,args=args)
    torch.save(model, "checkpoints/proposed_model.pkt")
    return model

In [3]:
set_seed(0)

args = config.ARGS
train_dataset, test_dataset = dataset.load_dataset(args=args)
input_size = args.seq_len
moment_model = load_moment_model(args=args)
informer_model = load_informer_model(args=args)

In [4]:
proposed_model = load_proposed_model(moment_model, informer_model,input_size,test_dataset,args)

100%|██████████| 89/89 [03:31<00:00,  2.38s/it]


Epoch [1/5], Loss: 0.0439


100%|██████████| 89/89 [03:34<00:00,  2.41s/it]


Epoch [2/5], Loss: 0.0791


100%|██████████| 89/89 [03:33<00:00,  2.40s/it]


Epoch [3/5], Loss: 0.0366


100%|██████████| 89/89 [03:33<00:00,  2.40s/it]


Epoch [4/5], Loss: 0.0415


100%|██████████| 89/89 [03:33<00:00,  2.40s/it]


Epoch [5/5], Loss: 0.0560


In [5]:
# 予測結果の評価
results = {}

for method, model in {
    "informer": informer_model,
    "moment": moment_model,
    "proposed": proposed_model,
}.items():
    print(f"testing: {method}")
    test_dataloader = dataset.to_dataloader(test_dataset, args, "test")
    
    mse = 0
    nll = 0

    k = 10 # 評価するバッチ数の上限、ちゃんと評価するときは十分大きい値に設定する

    for i, batch in enumerate(test_dataloader):
        _, batch_y, _, _ = batch
        Y_pred = model.predict_distr(batch).detach().numpy()
        y_true = batch_y[:, -1].squeeze().detach().numpy()

        mse += evaluate_mse(Y_pred, y_true)
        nll += evaluate_nll(Y_pred, y_true)

        print(f"{i}/{len(test_dataloader)}", {"MSE": mse/(i+1), "NLL": nll/(i+1)})

        if i >= k:
            break

    mse /= min(k, len(test_dataloader))
    nll /= min(k, len(test_dataloader))

    results[method] = {"MSE": mse, "NLL": nll}
    print(results[method])

results

testing: informer
0/89 {'MSE': 0.017435673018273753, 'NLL': 26.554210545591097}
1/89 {'MSE': 0.08738881193743193, 'NLL': 443.7915241770397}
2/89 {'MSE': 0.10775600637984722, 'NLL': 508.95950680923016}
3/89 {'MSE': 0.1803269386162652, 'NLL': 756.2014225955968}
4/89 {'MSE': 0.2184513515748902, 'NLL': 786.8236348559341}
5/89 {'MSE': 0.3618925743051604, 'NLL': 930.5890707085526}
6/89 {'MSE': 0.3704380375519001, 'NLL': 935.6550566002317}
7/89 {'MSE': 0.6408608834926025, 'NLL': 1462.7270418057924}
8/89 {'MSE': 0.6597925686844112, 'NLL': 1470.7574268235842}
9/89 {'MSE': 0.8096525242477083, 'NLL': 1766.2344990566226}
10/89 {'MSE': 0.8759173990604965, 'NLL': 1797.1450130382336}
{'MSE': 0.08759173990604965, 'NLL': 179.71450130382337}
testing: moment
0/89 {'MSE': 0.03261865224844104, 'NLL': 97.3054442213755}
1/89 {'MSE': 0.053694206379893306, 'NLL': 217.6669622268032}
2/89 {'MSE': 0.07836171865290008, 'NLL': 282.5721716812789}
3/89 {'MSE': 0.13927535059647367, 'NLL': 495.16878638207214}
4/89 {'MS

{'informer': {'MSE': 0.08759173990604965, 'NLL': 179.71450130382337},
 'moment': {'MSE': 0.06591766627712266, 'NLL': 191.3535589814859},
 'proposed': {'MSE': 0.049105082946133936, 'NLL': 81.01831573282499}}

In [6]:
import pandas as pd

pd.DataFrame(results).to_csv("results.csv", index=False)