# TTM zero-shot and few-shot benchmarking on multiple datasets

**Using TTM-1024-96 model.**

Pre-trained TTM models will be fetched from the [Granite-TTM-R2 Model Card](https://huggingface.co/ibm-granite/granite-timeseries-ttm-r2).

For details, visit the [Hugging Face TTM Model Repository](https://huggingface.co/ibm-granite/granite-timeseries-ttm-r2).

1. IBM Granite TTM-R1 pre-trained models can be found here: [Granite-TTM-R1 Model Card](https://huggingface.co/ibm-granite/granite-timeseries-ttm-r1)
2. IBM Granite TTM-R2 pre-trained models can be found here: [Granite-TTM-R2 Model Card](https://huggingface.co/ibm-granite/granite-timeseries-ttm-r2)
3. Research-use (non-commercial use only) TTM-R2 pre-trained models can be found here: [Research-Use-TTM-R2](https://huggingface.co/ibm/ttm-research-r2)

## Imports

In [1]:
import math
import warnings

import matplotlib.pyplot as plt
import pandas as pd
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed
from transformers.integrations import INTEGRATION_TO_CALLBACK

from tsfm_public import TrackingCallback, count_parameters, load_dataset
from tsfm_public.toolkit.get_model import get_model
from tsfm_public.toolkit.lr_finder import optimal_lr_finder
from tsfm_public.toolkit.visualization import plot_predictions


warnings.filterwarnings("ignore")

2024-11-05 09:36:47.873779: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-11-05 09:36:58.939912: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  warn(f"Failed to load image Python extension: {e}")


## Important arguments

In [2]:
# Set seed
SEED = 42
set_seed(SEED)

# Specify model parameters
context_length = 1024
forecast_length = 96
freeze_backbone = True

# Other args
EPOCHS = 50
NUM_WORKERS = 16

# Make sure all the datasets in the following `list_datasets` are
# saved in the `DATA_ROOT_PATH` folder. Or, change it accordingly.
# Refer to the load_datasets() function
# in notebooks/hfdemo/tinytimemixer/utils/ttm_utils.py
# to see how it is used.
DATA_ROOT_PATH = "/dccstor/tsfm23/datasets/"

# This is where results will be saved
OUT_DIR = f"ttm-r2_results_benchmark_{context_length}_{forecast_length}/"

## List of benchmark datasets (TTM was not pre-trained on any of these)

In [3]:
list_datasets = [
    "etth1",
    "etth2",
    "ettm1",
    "ettm2",
    "weather",
    "electricity",
    "traffic",
]

## Set model path

In [4]:
hf_model_path = "ibm-granite/granite-timeseries-ttm-r2"

## Main benchmarking loop

In [5]:
all_results = {
    "dataset": [],
    "zs_mse": [],
    "fs5_mse": [],
    "zs_eval_time": [],
    "fs5_mean_epoch_time": [],
    "fs5_total_train_time": [],
    "fs5_best_val_metric": [],
}
# Loop over data
for DATASET in list_datasets:
    print()
    print("=" * 100)
    print(
        f"Running zero-shot/few-shot for TTM-{context_length} on dataset = {DATASET}, forecast_len = {forecast_length}"
    )
    print(f"Model will be loaded from {hf_model_path}")
    SUBDIR = f"{OUT_DIR}/{DATASET}"

    # Set batch size
    if DATASET == "traffic":
        BATCH_SIZE = 8
    elif DATASET == "electricity":
        BATCH_SIZE = 32
    else:
        BATCH_SIZE = 64

    # Data prep: Get dataset
    _, _, dset_test = load_dataset(DATASET, context_length, forecast_length, dataset_root_path=DATA_ROOT_PATH)

    #############################################################
    ##### Use the pretrained model in zero-shot forecasting #####
    #############################################################
    # Load model
    zeroshot_model = get_model(hf_model_path, context_length=context_length, prediction_length=forecast_length)

    # zeroshot_trainer
    zeroshot_trainer = Trainer(
        model=zeroshot_model,
        args=TrainingArguments(
            output_dir=f"{SUBDIR}/zeroshot",
            per_device_eval_batch_size=BATCH_SIZE,
            seed=SEED,
        ),
        eval_dataset=dset_test,
    )

    # evaluate = zero-shot performance
    print("+" * 20, "Test MSE zero-shot", "+" * 20)
    zeroshot_output = zeroshot_trainer.evaluate(dset_test)
    print(zeroshot_output)
    print("+" * 60)
    all_results["zs_eval_time"].append(zeroshot_output["eval_runtime"])

    # Plot
    plot_predictions(
        model=zeroshot_trainer.model,
        dset=dset_test,
        plot_dir=SUBDIR,
        num_plots=10,
        plot_prefix="test_zeroshot",
        channel=0,
    )
    plt.close()

    # write results
    all_results["dataset"].append(DATASET)
    all_results["zs_mse"].append(zeroshot_output["eval_loss"])

    ################################################################
    ## Use the pretrained model in few-shot 5% and 10% forecasting #
    ################################################################
    for fewshot_percent in [5]:
        # Set learning rate
        learning_rate = None  # `None` value indicates that the optimal_lr_finder() will be used

        print("-" * 20, f"Running few-shot {fewshot_percent}%", "-" * 20)
        # Data prep: Get dataset
        dset_train, dset_val, dset_test = load_dataset(
            DATASET,
            context_length,
            forecast_length,
            fewshot_fraction=fewshot_percent / 100,
            dataset_root_path=DATA_ROOT_PATH,
        )

        # change head dropout to 0.7 for ett datasets
        # change head dropout to 0.7 for ett datasets
        if "ett" in DATASET:
            finetune_forecast_model = get_model(
                hf_model_path, context_length=context_length, prediction_length=forecast_length, head_dropout=0.7
            )
        else:
            finetune_forecast_model = get_model(
                hf_model_path, context_length=context_length, prediction_length=forecast_length
            )

        if freeze_backbone:
            print(
                "Number of params before freezing backbone",
                count_parameters(finetune_forecast_model),
            )

            # Freeze the backbone of the model
            for param in finetune_forecast_model.backbone.parameters():
                param.requires_grad = False

            # Count params
            print(
                "Number of params after freezing the backbone",
                count_parameters(finetune_forecast_model),
            )

        if learning_rate is None:
            learning_rate, finetune_forecast_model = optimal_lr_finder(
                finetune_forecast_model,
                dset_train,
                batch_size=BATCH_SIZE,
            )
            print("OPTIMAL SUGGESTED LEARNING RATE =", learning_rate)

        print(f"Using learning rate = {learning_rate}")
        finetune_forecast_args = TrainingArguments(
            output_dir=f"{SUBDIR}/fewshot_{fewshot_percent}",
            overwrite_output_dir=True,
            learning_rate=learning_rate,
            num_train_epochs=EPOCHS,
            do_eval=True,
            evaluation_strategy="epoch",
            per_device_train_batch_size=BATCH_SIZE,
            per_device_eval_batch_size=BATCH_SIZE,
            dataloader_num_workers=NUM_WORKERS,
            report_to=None,
            save_strategy="epoch",
            logging_strategy="epoch",
            save_total_limit=1,
            logging_dir=f"{SUBDIR}/fewshot_{fewshot_percent}",  # Make sure to specify a logging directory
            load_best_model_at_end=True,  # Load the best model when training ends
            metric_for_best_model="eval_loss",  # Metric to monitor for early stopping
            greater_is_better=False,  # For loss
            seed=SEED,
        )

        # Create the early stopping callback
        early_stopping_callback = EarlyStoppingCallback(
            early_stopping_patience=10,  # Number of epochs with no improvement after which to stop
            early_stopping_threshold=0.0,  # Minimum improvement required to consider as improvement
        )
        tracking_callback = TrackingCallback()

        # Optimizer and scheduler
        optimizer = AdamW(finetune_forecast_model.parameters(), lr=learning_rate)
        scheduler = OneCycleLR(
            optimizer,
            learning_rate,
            epochs=EPOCHS,
            steps_per_epoch=math.ceil(len(dset_train) / (BATCH_SIZE)),
        )

        finetune_forecast_trainer = Trainer(
            model=finetune_forecast_model,
            args=finetune_forecast_args,
            train_dataset=dset_train,
            eval_dataset=dset_val,
            callbacks=[early_stopping_callback, tracking_callback],
            optimizers=(optimizer, scheduler),
        )
        finetune_forecast_trainer.remove_callback(INTEGRATION_TO_CALLBACK["codecarbon"])

        # Fine tune
        finetune_forecast_trainer.train()

        # Evaluation
        print(
            "+" * 20,
            f"Test MSE after few-shot {fewshot_percent}% fine-tuning",
            "+" * 20,
        )
        fewshot_output = finetune_forecast_trainer.evaluate(dset_test)
        print(fewshot_output)
        print("+" * 60)

        # Plot
        plot_predictions(
            model=finetune_forecast_trainer.model,
            dset=dset_test,
            plot_dir=SUBDIR,
            num_plots=10,
            plot_prefix=f"test_fewshot_{fewshot_percent}",
            channel=0,
        )
        plt.close()

        # write results
        all_results[f"fs{fewshot_percent}_mse"].append(fewshot_output["eval_loss"])
        all_results[f"fs{fewshot_percent}_mean_epoch_time"].append(tracking_callback.mean_epoch_time)
        all_results[f"fs{fewshot_percent}_total_train_time"].append(tracking_callback.total_train_time)
        all_results[f"fs{fewshot_percent}_best_val_metric"].append(tracking_callback.best_eval_metric)

    df_out = pd.DataFrame(all_results).round(3)
    print(df_out[["dataset", "zs_mse", "fs5_mse"]])
    df_out.to_csv(f"{OUT_DIR}/results_zero_few.csv")
    df_out.to_csv(f"{OUT_DIR}/results_zero_few.csv")




Running zero-shot/few-shot for TTM-1024 on dataset = etth1, forecast_len = 96
Model will be loaded from ibm-granite/granite-timeseries-ttm-r2




++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++


{'eval_loss': 0.35859495401382446, 'eval_model_preparation_time': 0.0028, 'eval_runtime': 8.9505, 'eval_samples_per_second': 311.157, 'eval_steps_per_second': 4.916}
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++




-------------------- Running few-shot 5% --------------------
Number of params before freezing backbone 2964960
Number of params after freezing the backbone 955424




OPTIMAL SUGGESTED LEARNING RATE = 0.000298364724028334
Using learning rate = 0.000298364724028334


Epoch,Training Loss,Validation Loss
1,0.9166,0.665669
2,0.8887,0.665982
3,0.8243,0.666453
4,0.8863,0.66717
5,0.7737,0.668418
6,0.6951,0.66992
7,0.525,0.671401
8,0.4757,0.673846
9,0.4047,0.675814
10,0.3744,0.677924


[TrackingCallback] Mean Epoch Time = 0.9498170722614635 seconds, Total Train Time = 21.376437425613403
++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++


{'eval_loss': 0.35856103897094727, 'eval_runtime': 0.8932, 'eval_samples_per_second': 3117.893, 'eval_steps_per_second': 49.259, 'epoch': 11.0}
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++




  dataset  zs_mse  fs5_mse
0   etth1   0.359    0.359

Running zero-shot/few-shot for TTM-1024 on dataset = etth2, forecast_len = 96
Model will be loaded from ibm-granite/granite-timeseries-ttm-r2




++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++


{'eval_loss': 0.269417405128479, 'eval_model_preparation_time': 0.0024, 'eval_runtime': 0.8529, 'eval_samples_per_second': 3265.37, 'eval_steps_per_second': 51.589}
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++




-------------------- Running few-shot 5% --------------------
Number of params before freezing backbone 2964960
Number of params after freezing the backbone 955424




OPTIMAL SUGGESTED LEARNING RATE = 0.000298364724028334
Using learning rate = 0.000298364724028334


Epoch,Training Loss,Validation Loss
1,0.9452,0.239151
2,0.861,0.239945
3,0.8059,0.241062
4,0.7247,0.242527
5,0.6549,0.244388
6,0.5768,0.246938
7,0.4957,0.250335
8,0.4577,0.256598
9,0.3929,0.267042
10,0.3571,0.283817


[TrackingCallback] Mean Epoch Time = 0.6733335581692782 seconds, Total Train Time = 18.43419575691223
++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++


{'eval_loss': 0.26942315697669983, 'eval_runtime': 0.9311, 'eval_samples_per_second': 2991.213, 'eval_steps_per_second': 47.258, 'epoch': 11.0}
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
  dataset  zs_mse  fs5_mse
0   etth1   0.359    0.359
1   etth2   0.269    0.269

Running zero-shot/few-shot for TTM-1024 on dataset = ettm1, forecast_len = 96
Model will be loaded from ibm-granite/granite-timeseries-ttm-r2




++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++


{'eval_loss': 0.3369019627571106, 'eval_model_preparation_time': 0.0024, 'eval_runtime': 3.5784, 'eval_samples_per_second': 3192.741, 'eval_steps_per_second': 50.022}
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++




-------------------- Running few-shot 5% --------------------
Number of params before freezing backbone 2964960
Number of params after freezing the backbone 955424




OPTIMAL SUGGESTED LEARNING RATE = 0.0005214008287999684
Using learning rate = 0.0005214008287999684


Epoch,Training Loss,Validation Loss
1,0.8141,0.39455
2,0.6072,0.395544
3,0.4779,0.397824
4,0.3807,0.3973
5,0.3116,0.408491
6,0.2684,0.428093
7,0.2426,0.437327
8,0.2233,0.456643
9,0.2076,0.463043
10,0.1972,0.468228


[TrackingCallback] Mean Epoch Time = 0.9686962907964533 seconds, Total Train Time = 32.147533893585205
++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++


{'eval_loss': 0.33640581369400024, 'eval_runtime': 1.9157, 'eval_samples_per_second': 5963.974, 'eval_steps_per_second': 93.44, 'epoch': 11.0}
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
  dataset  zs_mse  fs5_mse
0   etth1   0.359    0.359
1   etth2   0.269    0.269
2   ettm1   0.337    0.336

Running zero-shot/few-shot for TTM-1024 on dataset = ettm2, forecast_len = 96
Model will be loaded from ibm-granite/granite-timeseries-ttm-r2




++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++


{'eval_loss': 0.1764754354953766, 'eval_model_preparation_time': 0.0025, 'eval_runtime': 3.4544, 'eval_samples_per_second': 3307.416, 'eval_steps_per_second': 51.819}
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++




-------------------- Running few-shot 5% --------------------
Number of params before freezing backbone 2964960
Number of params after freezing the backbone 955424




OPTIMAL SUGGESTED LEARNING RATE = 0.000298364724028334
Using learning rate = 0.000298364724028334


Epoch,Training Loss,Validation Loss
1,0.4957,0.122071
2,0.3996,0.122304
3,0.3283,0.122963
4,0.2421,0.124153
5,0.1883,0.127375
6,0.1501,0.135246
7,0.1331,0.143912
8,0.1225,0.151637
9,0.1174,0.158312
10,0.1112,0.164967


[TrackingCallback] Mean Epoch Time = 1.0002062320709229 seconds, Total Train Time = 32.82567358016968
++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++


{'eval_loss': 0.17645052075386047, 'eval_runtime': 1.9039, 'eval_samples_per_second': 6000.805, 'eval_steps_per_second': 94.017, 'epoch': 11.0}
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
  dataset  zs_mse  fs5_mse
0   etth1   0.359    0.359
1   etth2   0.269    0.269
2   ettm1   0.337    0.336
3   ettm2   0.176    0.176

Running zero-shot/few-shot for TTM-1024 on dataset = weather, forecast_len = 96
Model will be loaded from ibm-granite/granite-timeseries-ttm-r2




++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++


{'eval_loss': 0.15011762082576752, 'eval_model_preparation_time': 0.0024, 'eval_runtime': 6.3602, 'eval_samples_per_second': 1642.084, 'eval_steps_per_second': 25.785}
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++




-------------------- Running few-shot 5% --------------------
Number of params before freezing backbone 2964960
Number of params after freezing the backbone 955424




OPTIMAL SUGGESTED LEARNING RATE = 0.00035938136638046257
Using learning rate = 0.00035938136638046257


Epoch,Training Loss,Validation Loss
1,0.1535,0.393854
2,0.1477,0.399079
3,0.1405,0.40777
4,0.1303,0.410832
5,0.1155,0.407429
6,0.1026,0.41183
7,0.0927,0.409271
8,0.0857,0.415379
9,0.0807,0.41457
10,0.0769,0.414594


[TrackingCallback] Mean Epoch Time = 1.4071654189716687 seconds, Total Train Time = 38.4444363117218
++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++


{'eval_loss': 0.1500033736228943, 'eval_runtime': 3.4099, 'eval_samples_per_second': 3062.849, 'eval_steps_per_second': 48.095, 'epoch': 11.0}
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
   dataset  zs_mse  fs5_mse
0    etth1   0.359    0.359
1    etth2   0.269    0.269
2    ettm1   0.337    0.336
3    ettm2   0.176    0.176
4  weather   0.150    0.150

Running zero-shot/few-shot for TTM-1024 on dataset = electricity, forecast_len = 96
Model will be loaded from ibm-granite/granite-timeseries-ttm-r2




++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++


{'eval_loss': 0.15828542411327362, 'eval_model_preparation_time': 0.0024, 'eval_runtime': 26.7977, 'eval_samples_per_second': 192.74, 'eval_steps_per_second': 6.045}
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
-------------------- Running few-shot 5% --------------------




Number of params before freezing backbone 2964960
Number of params after freezing the backbone 955424




OPTIMAL SUGGESTED LEARNING RATE = 8.111308307896872e-05
Using learning rate = 8.111308307896872e-05


Epoch,Training Loss,Validation Loss
1,0.1541,0.13355
2,0.1504,0.133363
3,0.1481,0.13155
4,0.1471,0.129834
5,0.1446,0.128791
6,0.1435,0.127429
7,0.1405,0.126259
8,0.1396,0.125177
9,0.1374,0.124556
10,0.1348,0.123992


[TrackingCallback] Mean Epoch Time = 4.880673928260803 seconds, Total Train Time = 747.7492995262146
++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++


{'eval_loss': 0.14718736708164215, 'eval_runtime': 19.2227, 'eval_samples_per_second': 268.692, 'eval_steps_per_second': 8.428, 'epoch': 50.0}
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       dataset  zs_mse  fs5_mse
0        etth1   0.359    0.359
1        etth2   0.269    0.269
2        ettm1   0.337    0.336
3        ettm2   0.176    0.176
4      weather   0.150    0.150
5  electricity   0.158    0.147

Running zero-shot/few-shot for TTM-1024 on dataset = traffic, forecast_len = 96
Model will be loaded from ibm-granite/granite-timeseries-ttm-r2




++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++


{'eval_loss': 0.4737617075443268, 'eval_model_preparation_time': 0.0024, 'eval_runtime': 46.8323, 'eval_samples_per_second': 72.877, 'eval_steps_per_second': 9.118}
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
-------------------- Running few-shot 5% --------------------




Number of params before freezing backbone 2964960
Number of params after freezing the backbone 955424




OPTIMAL SUGGESTED LEARNING RATE = 0.00020565123083486514
Using learning rate = 0.00020565123083486514


Epoch,Training Loss,Validation Loss
1,0.3063,0.384197
2,0.2906,0.380115
3,0.2831,0.377606
4,0.2754,0.375396
5,0.2678,0.371779
6,0.2621,0.370619
7,0.2576,0.364189
8,0.2535,0.361611
9,0.2477,0.357288
10,0.2449,0.354975


[TrackingCallback] Mean Epoch Time = 7.419389545917511 seconds, Total Train Time = 675.6236307621002
++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++


{'eval_loss': 0.4179241955280304, 'eval_runtime': 32.9888, 'eval_samples_per_second': 103.459, 'eval_steps_per_second': 12.944, 'epoch': 28.0}
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       dataset  zs_mse  fs5_mse
0        etth1   0.359    0.359
1        etth2   0.269    0.269
2        ettm1   0.337    0.336
3        ettm2   0.176    0.176
4      weather   0.150    0.150
5  electricity   0.158    0.147
6      traffic   0.474    0.418


## Benchmarking results*

*Some slight differences in the results as compared to the TTM paper results is possible due to different training environments.

In [6]:
df_out

Unnamed: 0,dataset,zs_mse,fs5_mse,zs_eval_time,fs5_mean_epoch_time,fs5_total_train_time,fs5_best_val_metric
0,etth1,0.359,0.359,8.95,0.95,21.376,0.666
1,etth2,0.269,0.269,0.853,0.673,18.434,0.239
2,ettm1,0.337,0.336,3.578,0.969,32.148,0.395
3,ettm2,0.176,0.176,3.454,1.0,32.826,0.122
4,weather,0.15,0.15,6.36,1.407,38.444,0.394
5,electricity,0.158,0.147,26.798,4.881,747.749,0.117
6,traffic,0.474,0.418,46.832,7.419,675.624,0.345
