In [1]:
import importlib
import warnings

warnings.filterwarnings("ignore")

from pathlib import Path
from typing import List, Optional, Union

import numpy as np
import pandas as pd

from tsururu.dataset import Pipeline, TSDataset
from tsururu.model_training.trainer import DLTrainer
from tsururu.model_training.validator import KFoldCrossValidator
from tsururu.models.torch_based.dlinear import DLinear_NN
from tsururu.strategies import (
    RecursiveStrategy,
    MIMOStrategy,
    DirectStrategy,
    FlatWideMIMOStrategy,
)
from tsururu.transformers import (
    LagTransformer,
    SequentialTransformer,
    TargetGenerator,
    UnionTransformer,
)

In [2]:
def get_results(
    cv: int,
    regime: str,
    y_true: Optional[List[np.ndarray]] = None,
    y_pred: Optional[List[np.ndarray]] = None,
    ids: Optional[List[Union[float, str]]] = None,
) -> pd.DataFrame:
    def _get_fold_value(
        value: Optional[Union[float, np.ndarray]], idx: int
    ) -> List[Optional[Union[float, np.ndarray]]]:
        if value is None:
            return [None]
        if isinstance(value[idx], float):
            return value[idx]
        if isinstance(value[idx], np.ndarray):
            return value[idx].reshape(-1)
        raise TypeError(f"Unexpected value type. Value: {value}")

    df_res_dict = {}

    for idx_fold in range(cv):
        # Fill df_res_dict
        for name, value in [("y_true", y_true), ("y_pred", y_pred)]:
            df_res_dict[f"{name}_{idx_fold+1}"] = _get_fold_value(value, idx_fold)
        if regime != "local":
            df_res_dict[f"id_{idx_fold+1}"] = _get_fold_value(ids, idx_fold)

    # Save datasets to specified directory
    df_res = pd.DataFrame(df_res_dict)
    return df_res

## Initialize TSDataset, Pipeline, Model, Validator, Strategy

The initialization of the main components is exactly the same as when using ML models. The only difference is that `DLTrainer` allows you to pass many more parameters compared to `MLTrainer`.

### TSDataset

In [3]:
df_path = Path("datasets/global/simulated_data_to_check.csv")

dataset_params = {
    "target": {
        "columns": ["value"],
        "type": "continious",
    },
    "date": {
        "columns": ["date"],
        "type": "datetime",
    },
    "id": {
        "columns": ["id"],
        "type": "categorical",
    }
}

In [4]:
dataset = TSDataset(
    data=pd.read_csv(df_path),
    columns_params=dataset_params,
    print_freq_period_info=True,
)

freq: Day; period: 1


### Pipeline

In [5]:
lag = LagTransformer(lags=7)
target_generator = TargetGenerator()

union_1 = UnionTransformer(transformers_list=[lag, target_generator])
seq_1 = SequentialTransformer(transformers_list=[union_1], input_features=["value"])
union = UnionTransformer(transformers_list=[seq_1])

pipeline = Pipeline(union, multivariate=True)

### Trainer

In [6]:
# Configure the model parameters
model = DLinear_NN
model_params = {"moving_avg": 7, "individual": False, "enc_in": None}

# Configure the validation parameters
validation = KFoldCrossValidator
validation_params = {
    "n_splits": 2,
}

trainer_params = {
    "device": "cpu",
    "num_workers": 0,
    "best_by_metric": True,
    "save_to_dir": False,
}

trainer = DLTrainer(
    model, 
    model_params, 
    validation, 
    validation_params, 
    **trainer_params
)

### Strategy

In [7]:
horizon = 7
model_horizon = 7
history = 7

In [8]:
strategy = RecursiveStrategy(
    pipeline=pipeline,
    trainer=trainer,
    horizon=horizon,
    history=history,
)

In [9]:
strategy.fit(dataset)

length of train dataset: 496
length of val dataset: 497
Epoch 1/10, cost time: 0.71s
train loss: 397.4165
Validation, Loss: 162.1953, Metric: -162.1953
val loss: 162.1953
Epoch 2/10, cost time: 0.55s
train loss: 64.1882
Validation, Loss: 47.4241, Metric: -47.4241
val loss: 47.4241
Epoch 3/10, cost time: 0.54s
train loss: 17.5566
Validation, Loss: 8.8561, Metric: -8.8561
val loss: 8.8561
Epoch 4/10, cost time: 0.54s
train loss: 6.4756
Validation, Loss: 4.4630, Metric: -4.4630
val loss: 4.4630
Epoch 5/10, cost time: 0.54s
train loss: 4.3727
Validation, Loss: 3.9053, Metric: -3.9053
val loss: 3.9053
Epoch 6/10, cost time: 0.56s
train loss: 3.8692
Validation, Loss: 3.8264, Metric: -3.8264
val loss: 3.8264
Removing worst model snapshot: from epoch 0
Epoch 7/10, cost time: 0.55s
train loss: 3.7675
Validation, Loss: 3.7656, Metric: -3.7656
val loss: 3.7656
Removing worst model snapshot: from epoch 1
Epoch 8/10, cost time: 0.61s
train loss: 3.7314
Validation, Loss: 3.7414, Metric: -3.7414
val 

(22.96847891807556,
 <tsururu.strategies.recursive.RecursiveStrategy at 0x320b16c20>)

In [10]:
forecast_time, current_pred = strategy.predict(dataset)

length of test dataset: 1
Validation, Loss: nan, Metric: nan
Validation, Loss: nan, Metric: nan
length of test dataset: 1
Validation, Loss: nan, Metric: nan
Validation, Loss: nan, Metric: nan
length of test dataset: 1
Validation, Loss: nan, Metric: nan
Validation, Loss: nan, Metric: nan
length of test dataset: 1
Validation, Loss: nan, Metric: nan
Validation, Loss: nan, Metric: nan
length of test dataset: 1
Validation, Loss: nan, Metric: nan
Validation, Loss: nan, Metric: nan
length of test dataset: 1
Validation, Loss: nan, Metric: nan
Validation, Loss: nan, Metric: nan
length of test dataset: 1
Validation, Loss: nan, Metric: nan
Validation, Loss: nan, Metric: nan


In [11]:
current_pred

Unnamed: 0,id,date,value
0,0,2022-09-27,1997.029297
1,0,2022-09-28,1997.616943
2,0,2022-09-29,1998.106445
3,0,2022-09-30,1998.513306
4,0,2022-10-01,1998.86499
...,...,...,...
65,9,2022-09-29,11004.607422
66,9,2022-09-30,11006.004883
67,9,2022-10-01,11007.439453
68,9,2022-10-02,11008.916016


Saving and loading checkpoints is an essential practice in training DL models. 

Let's explore how to save checkpoints to disk, what structure the saved files have, and how to restore the model from a checkpoint for either fine-tuning or inference.

## Save and load checkpoints

Let’s consider working with checkpoints using the Direct strategy as an example.

In [9]:
trainer_params = {
    "device": "cpu",
    "num_workers": 0,
    "best_by_metric": True,
    # Let's enable save_to_dir (by the way, default value is True)
    "save_to_dir": True,
    "checkpoint_path": "checkpoints/",
    # Save checkpoints for 3 best model
    "save_k_best": 3,
    # Average checkpoints for the final model
    "average_snapshots": True,
}

trainer = DLTrainer(
    model, 
    model_params, 
    validation, 
    validation_params, 
    **trainer_params
)

strategy = DirectStrategy(
    pipeline=pipeline,
    trainer=trainer,
    horizon=horizon,
    history=history,
)

### Save checkpoint

In [10]:
strategy.fit(dataset)

length of train dataset: 496
length of val dataset: 497
Epoch 1/10, cost time: 0.67s
train loss: 410.0893
Validation, Loss: 175.5516, Metric: -175.5516
val loss: 175.5516
Last epoch model saved to checkpoints/trainer_0/fold_0/model_0.pth
Last epoch optimizer saved to checkpoints/trainer_0/fold_0/opt_0.pth
Best model snapshot saved to checkpoints/trainer_0/fold_0/model_0.pth
Checkpoint manager saved to checkpoints/trainer_0/fold_0/es_checkpoint_manager.pth
Epoch 2/10, cost time: 0.58s
train loss: 66.1765
Validation, Loss: 52.3162, Metric: -52.3162
val loss: 52.3162
Last epoch model saved to checkpoints/trainer_0/fold_0/model_1.pth
Last epoch optimizer saved to checkpoints/trainer_0/fold_0/opt_1.pth
Best model snapshot saved to checkpoints/trainer_0/fold_0/model_1.pth
Checkpoint manager saved to checkpoints/trainer_0/fold_0/es_checkpoint_manager.pth
Epoch 3/10, cost time: 0.56s
train loss: 17.1039
Validation, Loss: 9.6560, Metric: -9.6560
val loss: 9.6560
Last epoch model saved to checkp

(166.6328408718109, <tsururu.strategies.direct.DirectStrategy at 0x311467640>)

### Load checkpoint for finetune

Once we have the saved checkpoints, we can continue training by passing the pretrained path and another checkpoint path to the trainer’s parameters. All other parameters remain the same.

In [9]:
trainer_params = {
    "device": "cpu",
    "num_workers": 0,
    "best_by_metric": True,
    # Let's enable save_to_dir (by the way, default value is True)
    "save_to_dir": True,
    "pretrained_path": "checkpoints/",
    "checkpoint_path": "checkpoints_finetuned/",
    # Save checkpoints for 3 best model
    "save_k_best": 3,
    # Average checkpoints for the final model
    "average_snapshots": True,
}

trainer = DLTrainer(
    model, 
    model_params, 
    validation, 
    validation_params, 
    **trainer_params
)

strategy = DirectStrategy(
    pipeline=pipeline,
    trainer=trainer,
    horizon=horizon,
    history=history,
)

In [10]:
strategy.fit(dataset)

length of train dataset: 496
length of val dataset: 497
Epoch 1/10, cost time: 0.69s
train loss: 2.2838
Validation, Loss: 2.2947, Metric: -2.2947
val loss: 2.2947
Last epoch model saved to checkpoints_finetuned/trainer_0/fold_0/model_0.pth
Last epoch optimizer saved to checkpoints_finetuned/trainer_0/fold_0/opt_0.pth
Best model snapshot saved to checkpoints_finetuned/trainer_0/fold_0/model_0.pth
Checkpoint manager saved to checkpoints_finetuned/trainer_0/fold_0/es_checkpoint_manager.pth
Epoch 2/10, cost time: 0.56s
train loss: 2.2745
Validation, Loss: 2.2801, Metric: -2.2801
val loss: 2.2801
Last epoch model saved to checkpoints_finetuned/trainer_0/fold_0/model_1.pth
Last epoch optimizer saved to checkpoints_finetuned/trainer_0/fold_0/opt_1.pth
Best model snapshot saved to checkpoints_finetuned/trainer_0/fold_0/model_1.pth
Checkpoint manager saved to checkpoints_finetuned/trainer_0/fold_0/es_checkpoint_manager.pth
Epoch 3/10, cost time: 0.61s
train loss: 2.2617
Validation, Loss: 2.2676

(165.66884803771973, <tsururu.strategies.direct.DirectStrategy at 0x17fa0f0d0>)

### Load checkpoint for inference

In [11]:
trainer_params = {
    "device": "cpu",
    "num_workers": 0,
    "n_epochs": 0,
    "pretrained_path": "checkpoints_finetuned/",
    # Average checkpoints for the final model
    "average_snapshots": True,
}

trainer = DLTrainer(
    model, 
    model_params, 
    validation, 
    validation_params, 
    **trainer_params
)

strategy = DirectStrategy(
    pipeline=pipeline,
    trainer=trainer,
    horizon=horizon,
    history=history,
)

In [12]:
strategy.fit(dataset)

length of train dataset: 496
length of val dataset: 497
Training finished.
Fold 0. Score: nan
length of train dataset: 497
length of val dataset: 496
Training finished.
Fold 1. Score: nan
Mean score: nan
Std: nan
length of train dataset: 496
length of val dataset: 496
Training finished.
Fold 0. Score: nan
length of train dataset: 496
length of val dataset: 496
Training finished.
Fold 1. Score: nan
Mean score: nan
Std: nan
length of train dataset: 495
length of val dataset: 496
Training finished.
Fold 0. Score: nan
length of train dataset: 496
length of val dataset: 495
Training finished.
Fold 1. Score: nan
Mean score: nan
Std: nan
length of train dataset: 495
length of val dataset: 495
Training finished.
Fold 0. Score: nan
length of train dataset: 495
length of val dataset: 495
Training finished.
Fold 1. Score: nan
Mean score: nan
Std: nan
length of train dataset: 494
length of val dataset: 495
Training finished.
Fold 0. Score: nan
length of train dataset: 495
length of val dataset: 49

(0.1153402328491211, <tsururu.strategies.direct.DirectStrategy at 0x3238a6f20>)

In [13]:
forecast_time, current_pred = strategy.predict(dataset)

length of test dataset: 1
Validation, Loss: nan, Metric: nan
Validation, Loss: nan, Metric: nan
length of test dataset: 1
Validation, Loss: nan, Metric: nan
Validation, Loss: nan, Metric: nan
length of test dataset: 1
Validation, Loss: nan, Metric: nan
Validation, Loss: nan, Metric: nan
length of test dataset: 1
Validation, Loss: nan, Metric: nan
Validation, Loss: nan, Metric: nan
length of test dataset: 1
Validation, Loss: nan, Metric: nan
Validation, Loss: nan, Metric: nan
length of test dataset: 1
Validation, Loss: nan, Metric: nan
Validation, Loss: nan, Metric: nan
length of test dataset: 1
Validation, Loss: nan, Metric: nan
Validation, Loss: nan, Metric: nan


In [14]:
current_pred

Unnamed: 0,id,date,value
0,0,2022-09-27,1997.256836
1,0,2022-09-28,1997.116211
2,0,2022-09-29,1997.849365
3,0,2022-09-30,1998.221558
4,0,2022-10-01,1998.371094
...,...,...,...
65,9,2022-09-29,11004.716797
66,9,2022-09-30,11006.404297
67,9,2022-10-01,11007.738281
68,9,2022-10-02,11009.052734
