In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from gluonts.dataset.multivariate_grouper import MultivariateGrouper
from gluonts.dataset.repository.datasets import dataset_recipes, get_dataset
from gluonts.evaluation.backtest import make_evaluation_predictions
from gluonts.evaluation import MultivariateEvaluator
from pathlib import Path
from gluonts.dataset.common import load_datasets
from Battery_estimator import Battery_Estimator
from trainer import Trainer
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [None]:
from pathlib import Path
from gluonts.dataset.common import load_datasets
data_dir = r'/data'    
data_dir = Path(data_dir)
dataset = load_datasets(
    metadata=data_dir / "nasa",
    train=data_dir / "nasa" / "train",
    test=data_dir / "nasa" / "test",
)
features = [i for i in list(dir(dataset)) if not i.startswith("_")]

In [None]:
train_grouper = MultivariateGrouper(max_target_dim=16)
test_grouper = MultivariateGrouper(num_test_dates=int(len(dataset.test)/len(dataset.train)), 
                                   max_target_dim=16)

dataset_train = train_grouper(dataset.train)
dataset_test = test_grouper(dataset.test)

for group1 in dataset_train:
    print(group1)

for group2 in dataset_test:
    print(group2)

In [None]:
estimator = Battery_Estimator(
    target_dim=3,
    conditioning_length=16,
    prediction_length=32,
    context_length=16,
    cell_type='LSTM',   
    input_size=11,
    freq="D",          
    loss_type='l2',
    scaling=True,
    diff_steps=40,
    beta_end=0.04,  
    beta_schedule="linear",
    trainer=Trainer(device=device, epochs=30, learning_rate=7e-6, num_batches_per_epoch=100, batch_size=128,),
    series=group2['target'],
)
predictor = estimator.train(dataset_train, num_workers=10)

In [None]:
forecast_it, ts_it = make_evaluation_predictions(dataset=dataset_test,
                                                 predictor=predictor,
                                                 num_samples=100)
forecasts = list(forecast_it)
targets = list(ts_it)

evaluator = MultivariateEvaluator(quantiles=(np.arange(20)/20.0)[1:], 
                                  target_agg_funcs={'sum': np.sum})

agg_metric, item_metrics = evaluator(targets, forecasts, num_series=len(dataset_test))

print("0_MSE:", agg_metric["0_MSE"])
print("0_MAPE:", agg_metric["0_MAPE"])
print("0_mean_wQuantileLoss:", agg_metric["0_mean_wQuantileLoss"])
