In [None]:
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

from load_data import (load_tourism,
                       create_train_dataloader,
                       create_backtest_dataloader,
                       create_test_dataloader)
from sur_ts import SurrogateTimeSeriesTransformerConfig, SurrogateTimeSeriesTransformer


In [None]:
freq = '1M'
train_dataset, test_dataset, lags_sequence, time_features = load_tourism()
prediction_length = 24


config = SurrogateTimeSeriesTransformerConfig(
    prediction_length=prediction_length,
    # context length:
    context_length=prediction_length * 2,
    # lags coming from helper given the freq:
    lags_sequence=lags_sequence,
    # we'll add 2 time features ("month of year" and "age", see further):
    num_time_features=len(time_features) + 1,
    # we have a single static categorical feature, namely time series ID:
    num_static_categorical_features=1,
    # it has 366 possible values:
    cardinality=[len(train_dataset)],
    # the model will learn an embedding of size 2 for each of the 366 possible values:
    embedding_dimension=[2],
    
    # transformer params:
    encoder_layers=4,
    decoder_layers=4,
    d_model=32,
)

In [None]:
train_dataloader = create_train_dataloader(
    config=config,
    freq=freq,
    data=train_dataset,
    batch_size=256,
    num_batches_per_epoch=100,
)

test_dataloader = create_backtest_dataloader(
    config=config,
    freq=freq,
    data=test_dataset,
    batch_size=64,
)

In [None]:
sur_model = SurrogateTimeSeriesTransformer(config)
batch = next(iter(train_dataloader))
sur_outputs = sur_model(
    past_values=batch["past_values"],
    past_time_features=batch["past_time_features"],
    past_observed_mask=batch["past_observed_mask"],
    static_categorical_features=batch["static_categorical_features"]
    if config.num_static_categorical_features > 0
    else None,
    static_real_features=batch["static_real_features"]
    if config.num_static_real_features > 0
    else None,
    future_values=batch["future_values"],
    future_time_features=batch["future_time_features"],
    future_observed_mask=batch["future_observed_mask"],
    output_hidden_states=True,
)
print("attention_output shape: " ,sur_outputs.attention_output.shape)
print("attention_weights shape: " ,sur_outputs.attention_weights.shape)
print("last_encoder_state shape: " ,sur_outputs.encoder_hidden_states[-1].shape)
print("last_decoder_state shape: " ,sur_outputs.decoder_hidden_states[-1].shape)
print("Surrogate loss: ", sur_outputs.sur_loss)
print("Prediction loss: ", sur_outputs.pred_loss)
print("Total loss: ", sur_outputs.loss)

In [None]:
# # test how many batch per epoch
# for idx, batch in enumerate(train_dataloader):
#     print(idx)

In [None]:
from accelerate import Accelerator
from torch.optim import AdamW
import torch
from transformers import AutoModel

accelerator = Accelerator()
device = accelerator.device
# device = 'cpu'

sur_model.to(device)
optimizer = AdamW(sur_model.parameters(), lr=6e-4, betas=(0.9, 0.95), weight_decay=1e-1)

sur_model, optimizer, train_dataloader = accelerator.prepare(
    sur_model,
    optimizer,
    train_dataloader,
)

sur_model.train()

print(f"Started training: ")
for epoch in range(40):
    for idx, batch in enumerate(train_dataloader):
        optimizer.zero_grad()
        outputs = sur_model(
            static_categorical_features=batch["static_categorical_features"].to(device)
            if config.num_static_categorical_features > 0
            else None,
            static_real_features=batch["static_real_features"].to(device)
            if config.num_static_real_features > 0
            else None,
            past_time_features=batch["past_time_features"].to(device),
            past_values=batch["past_values"].to(device),
            future_time_features=batch["future_time_features"].to(device),
            future_values=batch["future_values"].to(device),
            past_observed_mask=batch["past_observed_mask"].to(device),
            future_observed_mask=batch["future_observed_mask"].to(device),
            output_hidden_states=True,
        )
        loss = outputs.loss

        # Backpropagation
        accelerator.backward(loss)
        optimizer.step()

        if idx % 100 == 0:
            pred_loss = outputs.pred_loss
            sur_loss = outputs.sur_loss
            print(f"Epoch [{epoch+1}] Step [{idx+1}]: loss: {loss.item():.4f},\tpred loss: "
                  f"{pred_loss.item():.4f},\tsur loss: {sur_loss.item():.4f}")  


In [None]:
# Specify the file path to save the model
model_path = "model2"

# Save the model
sur_model.save_pretrained(model_path)

In [None]:
import numpy as np
sur_model.eval()

forecasts = []

for batch in test_dataloader:
    outputs = sur_model.generate(
        static_categorical_features=batch["static_categorical_features"].to(device)
        if config.num_static_categorical_features > 0
        else None,
        static_real_features=batch["static_real_features"].to(device)
        if config.num_static_real_features > 0
        else None,
        past_time_features=batch["past_time_features"].to(device),
        past_values=batch["past_values"].to(device),
        future_time_features=batch["future_time_features"].to(device),
        past_observed_mask=batch["past_observed_mask"].to(device),
        # future_observed_mask=None,
    )
    forecasts.append(outputs.sequences.cpu().numpy())
forecasts = np.vstack(forecasts)
print(forecasts.shape)

In [None]:
from evaluate import load
from gluonts.time_feature import get_seasonality

mase_metric = load("evaluate-metric/mase")
smape_metric = load("evaluate-metric/smape")

forecast_median = np.median(forecasts, 1)

mase_metrics = []
smape_metrics = []
for item_id, ts in enumerate(test_dataset):
    training_data = ts["target"][:-prediction_length]
    ground_truth = ts["target"][-prediction_length:]
    mase = mase_metric.compute(
        predictions=forecast_median[item_id], 
        references=np.array(ground_truth), 
        training=np.array(training_data), 
        periodicity=get_seasonality(freq))
    mase_metrics.append(mase["mase"])
    
    smape = smape_metric.compute(
        predictions=forecast_median[item_id], 
        references=np.array(ground_truth), 
    )
    smape_metrics.append(smape["smape"])

print(f"MASE: {np.mean(mase_metrics)}")
print(f"sMAPE: {np.mean(smape_metrics)}")