In [None]:
! pip install "granite-tsfm[notebooks] @ git+https://github.com/ibm-granite/granite-tsfm.git@v0.2.22"

In [2]:
import math
import os
import tempfile

import pandas as pd
import numpy as np
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed
from transformers.integrations import INTEGRATION_TO_CALLBACK

from tsfm_public import TimeSeriesPreprocessor, TrackingCallback, count_parameters, get_datasets
from tsfm_public.toolkit.get_model import get_model
from tsfm_public.toolkit.lr_finder import optimal_lr_finder
from tsfm_public.toolkit.visualization import plot_predictions
from sklearn.preprocessing import LabelEncoder, StandardScaler

import warnings
# Suppress all warnings
warnings.filterwarnings("ignore")

In [3]:
os.listdir()

['.config', 'testing_dataset.csv', 'training_dataset.csv', 'sample_data']

In [4]:
# Set seed for reproducibility
SEED = 42
set_seed(SEED)

# TTM Model path. The default model path is Granite-R2. Below, you can choose other TTM releases.
TTM_MODEL_PATH = "ibm-granite/granite-timeseries-ttm-r2"
# TTM_MODEL_PATH = "ibm-granite/granite-timeseries-ttm-r1"
# TTM_MODEL_PATH = "ibm-research/ttm-research-r2"

# Context length, Or Length of the history.
# Currently supported values are: 512/1024/1536 for Granite-TTM-R2 and Research-Use-TTM-R2, and 512/1024 for Granite-TTM-R1
CONTEXT_LENGTH = 512

# Granite-TTM-R2 supports forecast length upto 720 and Granite-TTM-R1 supports forecast length upto 96
PREDICTION_LENGTH = 96

TARGET_DATASET = "etth1"
dataset_path = "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTh1.csv"


# Results dir
OUT_DIR = "ttm_finetuned_models/"

label_encoder = LabelEncoder()
scaler = StandardScaler()

In [29]:
# Dataset
TARGET_DATASET = "datavidia"
dataset_path = "./training_dataset.csv"
test_dataset_path = "./testing_dataset.csv"
timestamp_column = "Date"
id_columns = ['commodity', 'province']
target_columns = ['price']
split_config = {
    "train": 0.8,
    "test": 0.1
}

test_split_config = {
    "train": 0.8,
    "test": 0.1
}

# Understanding the split config -- slides

feature_to_scale = ['GlobalOpen', 'GlobalHigh', 'GlobalVol.', 'GlobalPrice', 'CE_Close', 'CE_High', 'CE_Low', 'CE_Open']

data = pd.read_csv(
    dataset_path,
    parse_dates=[timestamp_column],
)

test_data = pd.read_csv(
    test_dataset_path,
    parse_dates=[timestamp_column],
)

column_specifiers = {
    "timestamp_column": timestamp_column,
    "id_columns": id_columns,
    "target_columns": target_columns,
    "control_columns": [],
}

In [14]:
def process_dataset(df: pd.DataFrame):
    df['Date'] = pd.to_datetime(df['Date'])
    df['timestamp'] = df['Date'].astype(int)
    df['timestamp'] = df['timestamp'].div(10**9)

    df['province'] = label_encoder.fit_transform(df['province'])
    df['commodity'] = label_encoder.fit_transform(df['commodity'])

    df = df.drop(columns=['Unnamed: 0'])

    for col in feature_to_scale:
        df[col] = scaler.fit_transform(df[[col]])
    return df

In [15]:
data = process_dataset(data)
test_data = process_dataset(test_data)

In [16]:
data.head()

Unnamed: 0,Date,commodity,province,price,GlobalOpen,GlobalHigh,GlobalLow,GlobalVol.,GlobalChange %,GlobalPrice,CE_Close,CE_High,CE_Low,CE_Open,timestamp
0,2022-01-01,0,0,28970.0,0.384917,0.977176,3.5,-1.266524,-0.48,0.316524,-1.932679,-1.862012,1.437183,-1.87819,1640995000.0
1,2022-01-01,1,0,27440.0,0.384917,0.977176,3.5,-1.266524,-0.48,0.316524,-1.932679,-1.862012,1.437183,-1.87819,1640995000.0
2,2022-01-01,2,0,11030.0,0.384917,0.977176,3.5,-1.266524,-0.48,0.316524,-1.932679,-1.862012,1.437183,-1.87819,1640995000.0
3,2022-01-01,3,0,12080.0,0.384917,0.977176,3.5,-1.266524,-0.48,0.316524,-1.932679,-1.862012,1.437183,-1.87819,1640995000.0
4,2022-01-01,4,0,22360.0,0.384917,0.977176,3.5,-1.266524,-0.48,0.316524,-1.932679,-1.862012,1.437183,-1.87819,1640995000.0


In [31]:
test_data['price'] = 0
len(test_data)

40664

In [26]:
def get_model_result(
        model,
        context_length: int = 512,
        forecast_length: int = 96,
        ) -> None:

    tsp = TimeSeriesPreprocessor(
        **column_specifiers,
        context_length=context_length,
        prediction_length=forecast_length,
        scaling=True,
        encode_categorical=False,
        scaler_type="standard",
    )

    dset, dset_val, dset_test = get_datasets(
        tsp, test_data, test_split_config, fewshot_fraction=100 / 100, fewshot_location="first"
    )

    print(f'dset: {len(dset)}')
    print(f'val: {len(dset_val)}')
    print(f'test: {len(dset_test)}')

    return

In [32]:
get_model_result("")

dset: 442
val: 442
test: 442


In [64]:
def fewshot_finetune_eval(
    dataset_name,
    batch_size,
    learning_rate=None,
    context_length=512,
    forecast_length=96,
    fewshot_percent=5,
    freeze_backbone=True,
    num_epochs=50,
    save_dir=OUT_DIR,
    loss="mse",
    quantile=0.5,
):
    out_dir = os.path.join(save_dir, dataset_name)

    print("-" * 20, f"Running few-shot {fewshot_percent}%", "-" * 20)

    # Data prep: Get dataset

    tsp = TimeSeriesPreprocessor(
        **column_specifiers,
        context_length=context_length,
        prediction_length=forecast_length,
        scaling=True,
        encode_categorical=False,
        scaler_type="standard",
    )

    dset_train, dset_val, dset_test = get_datasets(
        tsp, data, split_config, fewshot_fraction=fewshot_percent / 100, fewshot_location="first"
    )

    # change head dropout to 0.7 for ett datasets
    if "ett" in dataset_name:
        finetune_forecast_model = get_model(
            TTM_MODEL_PATH,
            context_length=context_length,
            prediction_length=forecast_length,
            freq_prefix_tuning=False,
            freq=None,
            prefer_l1_loss=False,
            prefer_longer_context=True,
            # Can also provide TTM Config args
            head_dropout=0.7,
            loss=loss,
            quantile=quantile,
        )
    else:
        finetune_forecast_model = get_model(
            TTM_MODEL_PATH,
            context_length=context_length,
            prediction_length=forecast_length,
            freq_prefix_tuning=False,
            freq=None,
            prefer_l1_loss=False,
            prefer_longer_context=True,
            # Can also provide TTM Config args
            head_dropout=1,
            loss=loss,
            quantile=quantile,
        )

    if freeze_backbone:
        print(
            "Number of params before freezing backbone",
            count_parameters(finetune_forecast_model),
        )

        # Freeze the backbone of the model
        for param in finetune_forecast_model.backbone.parameters():
            param.requires_grad = False

        # Count params
        print(
            "Number of params after freezing the backbone",
            count_parameters(finetune_forecast_model),
        )

    # Find optimal learning rate
    # Use with caution: Set it manually if the suggested learning rate is not suitable
    if learning_rate is None:
        learning_rate, finetune_forecast_model = optimal_lr_finder(
            finetune_forecast_model,
            dset_train,
            batch_size=batch_size,
        )
        print("OPTIMAL SUGGESTED LEARNING RATE =", learning_rate)

    print(f"Using learning rate = {learning_rate}")
    finetune_forecast_args = TrainingArguments(
        output_dir=os.path.join(out_dir, "output"),
        overwrite_output_dir=True,
        learning_rate=learning_rate,
        num_train_epochs=num_epochs,
        do_eval=True,
        evaluation_strategy="epoch",
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        dataloader_num_workers=8,
        report_to="none",
        save_strategy="epoch",
        logging_strategy="epoch",
        save_total_limit=1,
        logging_dir=os.path.join(out_dir, "logs"),  # Make sure to specify a logging directory
        load_best_model_at_end=True,  # Load the best model when training ends
        metric_for_best_model="eval_loss",  # Metric to monitor for early stopping
        greater_is_better=False,  # For loss
        seed=SEED,
    )

    # Create the early stopping callback
    early_stopping_callback = EarlyStoppingCallback(
        early_stopping_patience=10,  # Number of epochs with no improvement after which to stop
        early_stopping_threshold=1e-5,  # Minimum improvement required to consider as improvement
    )
    tracking_callback = TrackingCallback()

    # Optimizer and scheduler
    optimizer = AdamW(finetune_forecast_model.parameters(), lr=learning_rate)
    scheduler = OneCycleLR(
        optimizer,
        learning_rate,
        epochs=num_epochs,
        steps_per_epoch=math.ceil(len(dset_train) / (batch_size)),
    )

    finetune_forecast_trainer = Trainer(
        model=finetune_forecast_model,
        args=finetune_forecast_args,
        train_dataset=dset_train,
        eval_dataset=dset_val,
        callbacks=[early_stopping_callback, tracking_callback],
        optimizers=(optimizer, scheduler),
    )
    finetune_forecast_trainer.remove_callback(INTEGRATION_TO_CALLBACK["codecarbon"])

    # Fine tune
    finetune_forecast_trainer.train()

    # Evaluation
    print("+" * 20, f"Test MSE after few-shot {fewshot_percent}% fine-tuning", "+" * 20)

    finetune_forecast_trainer.model.loss = "mse"  # fixing metric to mse for evaluation

    fewshot_output = finetune_forecast_trainer.evaluate(dset_test)
    print(fewshot_output)
    print("+" * 60)

    # get predictions

    predictions_dict = finetune_forecast_trainer.predict(dset_test)

    predictions_np = predictions_dict.predictions[0]

    print(predictions_np.shape)

    # get backbone embeddings (if needed for further analysis)

    backbone_embedding = predictions_dict.predictions[1]

    print(backbone_embedding.shape)

    # plot
    # plot_predictions(
    #     model=finetune_forecast_trainer.model,
    #     dset=dset_test,
    #     plot_dir=os.path.join(OUT_DIR, dataset_name),
    #     plot_prefix="test_fewshot",
    #     channel=0,
    # )
    # view
    view_prediction(dset_test, predictions_dict)

In [65]:
# fewshot_finetune_eval(
#     dataset_name=TARGET_DATASET,
#     context_length=CONTEXT_LENGTH,
#     forecast_length=PREDICTION_LENGTH,
#     batch_size=32,
#     fewshot_percent=30,
#     learning_rate=0.001,
#     num_epochs=20
# )

fewshot_finetune_eval(
    dataset_name=TARGET_DATASET,
    context_length=CONTEXT_LENGTH,
    forecast_length=PREDICTION_LENGTH,
    batch_size=32,
    fewshot_percent=5,
    learning_rate=0.001,
    num_epochs=1
)

-------------------- Running few-shot 5% --------------------


INFO:/usr/local/lib/python3.11/dist-packages/tsfm_public/toolkit/get_model.py:Loading model from: ibm-granite/granite-timeseries-ttm-r2
INFO:/usr/local/lib/python3.11/dist-packages/tsfm_public/toolkit/get_model.py:Model loaded successfully from ibm-granite/granite-timeseries-ttm-r2, revision = main.
INFO:/usr/local/lib/python3.11/dist-packages/tsfm_public/toolkit/get_model.py:[TTM] context_length = 512, prediction_length = 96


Number of params before freezing backbone 805280
Number of params after freezing the backbone 289696
Using learning rate = 0.001


Epoch,Training Loss,Validation Loss
1,1.0041,0.71933


[TrackingCallback] Mean Epoch Time = 1.5764830112457275 seconds, Total Train Time = 3.6706032752990723
++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++


{'eval_loss': 0.6005375981330872, 'eval_runtime': 1.2282, 'eval_samples_per_second': 470.611, 'eval_steps_per_second': 15.47, 'epoch': 1.0}
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
(578, 96, 1)
(578, 1, 8, 192)
{'past_values': tensor([[ 0.0000e+00],
        [ 4.9788e-01],
        [ 4.8613e-01],
        [ 5.3316e-01],
        [ 6.3115e-01],
        [ 6.0763e-01],
        [ 5.6060e-01],
        [ 6.5467e-01],
        [ 7.6246e-01],
        [ 8.5261e-01],
        [ 8.8201e-01],
        [ 9.7020e-01],
        [ 9.0945e-01],
        [ 8.8985e-01],
        [ 9.4864e-01],
        [ 8.8397e-01],
        [ 8.1538e-01],
        [ 9.1337e-01],
        [ 8.1929e-01],
        [ 6.3507e-01],
        [ 7.3894e-01],
        [ 5.1944e-01],
        [ 5.8804e-01],
        [ 6.0567e-01],
        [ 5.0180e-01],
        [ 3.8421e-01],
        [ 3.5286e-01],
        [ 3.9793e-01],
        [ 3.6462e-01],
        [ 4.0577e-01],
        [ 2.9210e-01],
        [ 3.1954e-01],
        [ 3.3522e