In [None]:
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

from transformers import TimeSeriesTransformerConfig

from src import ts_transformer as tsf
from src.inference.wrapper import TFWrapper
from src.inference.monitor import EnsembleForecaster

In [None]:
#assumes a dataset <data>

train_df, test_df = train_test_split(data, test_size=0.2, random_state=42)
train_df = train_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)
print(train_df.shape)

In [None]:
freq = '1H'

transformer_config = {
    'prediction_length': 24,
    'context_length': 48,
    'num_static_categorical_features': 2,
    'cardinality': [3,4],
    'embedding_dimension': [2,2],
    'encoder_layers': 4,
    'decoder_layers': 4,
    'd_model': 32,
}

transformer, train_dataloader = tsf.setup_training(
    train_df=train_df,
    freq=freq,
    batch_size = 32,
    num_batches_per_epoch = 16,
    max_lags = len(train_df['target'][0]) - transformer_config['context_length'],
    transformer_config=transformer_config
)

transformer, list_loss = tsf.train(transformer, train_dataloader, 200)
plt.plot(list_loss)

In [None]:
loss_monitor = EnsembleForecaster()
model = TFWrapper(transformer, freq, loss_window=24)

model.initialize_buffer(
    context=test_df['target'][0][:model.full_context_length],
    start=test_df['start'][0],
    static_cat_features=test_df['feat_static_cat'][0]
)

In [None]:
#7k 17min
ensemble_losses = []
for i in range(0, 7000):
    model.predict()
    model.ingest(test_df['target'][0][model.full_context_length+i])
    
    last_preds = model.get_last_points_predictions()
    last_true = model.get_last_true_points()
    indices, preds, uncert = zip(*last_preds)
    _, loss = monitor.ensemble_loss(preds, uncert, last_true)
    ensemble_losses.append(loss)

plt.plot(ensemble_losses)
    