# PatchTSMixer in HuggingFace - Getting Started


`PatchTSMixer` is a lightweight time-series modeling approach based on the MLP-Mixer architecture. It is proposed in [TSMixer: Lightweight MLP-Mixer Model for Multivariate Time Series Forecasting](https://arxiv.org/pdf/2306.09364.pdf) by IBM Research authors `Vijay Ekambaram`, `Arindam Jati`, `Nam Nguyen`, `Phanwadee Sinthong` and `Jayant Kalagnanam`.

For effective mindshare and to promote opensourcing - IBM Research join hands with the HuggingFace team to opensource this model in HF.

In this [HuggingFace implementation](https://huggingface.co/docs/transformers/main/en/model_doc/patchtsmixer), we provide PatchTSMixer’s capabilities to effortlessly facilitate lightweight mixing across patches, channels, and hidden features for effective multivariate time-series modeling. It also supports various attention mechanisms starting from simple gated attention to more complex self-attention blocks that can be customized accordingly. The model can be pretrained and subsequently used for various downstream tasks such as forecasting, classification, and regression.

`PatchTSMixer` outperforms state-of-the-art MLP and Transformer models in forecasting by a considerable margin of 8-60%. It also outperforms the latest strong benchmarks of Patch-Transformer models (by 1-2%) with a significant reduction in memory and runtime (2-3X). For more details, refer to the [paper](https://arxiv.org/pdf/2306.09364.pdf)

In this blog, we will demonstrate examples of getting started with PatchTSMixer. We will first demonstrate the forecasting capability of `PatchTSMixer` on the Electricity data. We will then demonstrate the transfer learning capability of PatchTSMixer by using the model trained on the Electricity to do zero-shot forecasting on the ETTH2 dataset.


`Blog authors`: Arindam Jati, Vijay Ekambaram, Nam Ngugen, Wesley Gifford and Kashif Rasul


## Installation
This demo needs Huggingface [`transformers`](https://github.com/huggingface/transformers) for main modeling tasks, and IBM `tsfm` for auxiliary data pre-processing.
We can install both by cloning the `tsfm` repository and following the below steps.

1. Clone IBM Time Series Foundation Model Repository [`tsfm`](https://github.com/ibm/tsfm).
    ```
    git clone git@github.com:IBM/tsfm.git
    cd tsfm
    ```
2. Install `tsfm`. This will also install Huggingface `transformers`.
    ```
    pip install .
    ```
3. Test it with the following commands in a `python` terminal.
    ```
    from transformers import PatchTSMixerConfig
    from tsfm_public.toolkit.dataset import ForecastDFDataset
    ```

## Part 1: Forecasting on Electricity dataset

In [3]:
# Clone the IBM tsfm repository
!git clone https://github.com/IBM/tsfm.git
%cd tsfm

# Install the package and its dependencies
!pip install .

Cloning into 'tsfm'...
remote: Enumerating objects: 8459, done.[K
remote: Counting objects: 100% (2487/2487), done.[K
remote: Compressing objects: 100% (606/606), done.[K
remote: Total 8459 (delta 2032), reused 1903 (delta 1881), pack-reused 5972 (from 2)[K
Receiving objects: 100% (8459/8459), 48.90 MiB | 15.04 MiB/s, done.
Resolving deltas: 100% (5633/5633), done.
/content/tsfm
Processing /content/tsfm
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting datasets (from granite-tsfm==2rc3.dev14+gf6cfbf1)
  Downloading datasets-3.4.1-py3-none-any.whl.metadata (19 kB)
Collecting numpy<2 (from granite-tsfm==2rc3.dev14+gf6cfbf1)
  Downloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
Collecting 

In [4]:
!pip uninstall numpy -y
!pip install numpy==1.24.3

Found existing installation: numpy 1.26.4
Uninstalling numpy-1.26.4:
  Successfully uninstalled numpy-1.26.4
Collecting numpy==1.24.3
  Downloading numpy-1.24.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.6 kB)
Downloading numpy-1.24.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.3/17.3 MB[0m [31m35.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: numpy
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
albumentations 2.0.5 requires numpy>=1.24.4, but you have numpy 1.24.3 which is incompatible.
albucore 0.0.23 requires numpy>=1.24.4, but you have numpy 1.24.3 which is incompatible.
treescope 0.1.9 requires numpy>=1.25.2, but you have numpy 1.24.3 which is incompatible.
pymc 5.21.1 requires numpy>=1.25.0, but you have numpy

In [1]:
!pip install --upgrade numpy==1.24.3 jax tensorflow

Collecting tensorflow
  Downloading tensorflow-2.19.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.1 kB)
INFO: pip is looking at multiple versions of jax to determine which version is compatible with other requirements. This could take a while.
Collecting jax
  Downloading jax-0.5.2-py3-none-any.whl.metadata (22 kB)
  Downloading jax-0.5.1-py3-none-any.whl.metadata (22 kB)
  Downloading jax-0.5.0-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib<=0.5.0,>=0.5.0 (from jax)
  Downloading jaxlib-0.5.0-cp311-cp311-manylinux2014_x86_64.whl.metadata (978 bytes)
Collecting jax
  Downloading jax-0.4.38-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib<=0.4.38,>=0.4.38 (from jax)
  Downloading jaxlib-0.4.38-cp311-cp311-manylinux2014_x86_64.whl.metadata (1.0 kB)
Collecting tensorboard~=2.19.0 (from tensorflow)
  Downloading tensorboard-2.19.0-py3-none-any.whl.metadata (1.8 kB)
INFO: pip is looking at multiple versions of tensorflow to determine which version is compa

In [2]:
from transformers import PatchTSMixerConfig
from tsfm_public.toolkit.dataset import ForecastDFDataset


In [3]:
# Standard
import os
import random

import numpy as np
import pandas as pd
import torch

# Third Party
from transformers import (
    EarlyStoppingCallback,
    PatchTSMixerConfig,
    PatchTSMixerForPrediction,
    Trainer,
    TrainingArguments,
)

# First Party
from tsfm_public.toolkit.dataset import ForecastDFDataset
from tsfm_public.toolkit.time_series_preprocessor import TimeSeriesPreprocessor
from tsfm_public.toolkit.util import select_by_index

 ### Set seed

In [4]:
SEED = 42
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)

### Load and prepare datasets

In the next cell, please adjust the following parameters to suit your application:
- `PRETRAIN_AGAIN`: Set this to `True` if you want to perform pretraining again. Note that this might take some time depending on the GPU availability. Otherwise, the already pretrained model will be used.
- `dataset_path`: path to local .csv file, or web address to a csv file for the data of interest. Data is loaded with pandas, so anything supported by
`pd.read_csv` is supported: (https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html).
- `timestamp_column`: column name containing timestamp information, use None if there is no such column
- `id_columns`: List of column names specifying the IDs of different time series. If no ID column exists, use []
- `forecast_columns`: List of columns to be modeled
- `context_length`: The amount of historical data used as input to the model. Windows of the input time series data with length equal to
`context_length` will be extracted from the input dataframe. In the case of a multi-time series dataset, the context windows will be created
so that they are contained within a single time series (i.e., a single ID).
- `forecast_horizon`: Number of timestamps to forecast in future.
- `train_start_index`, `train_end_index`: the start and end indices in the loaded data which delineate the training data.
- `valid_start_index`, `valid_end_index`: the start and end indices in the loaded data which delineate the validation data.
- `test_start_index`, `test_end_index`: the start and end indices in the loaded data which delineate the test data.
- `patch_length`: The patch length for the `PatchTSMixer` model. It is recommended to choose a value that evenly divides `context_length`.
- `num_workers`: Number of dataloder workers in pytorch dataloader.
- `batch_size`: Batch size.
The data is first loaded into a Pandas dataframe and split into training, validation, and test parts. Then the pandas dataframes are converted
to the appropriate torch dataset needed for training.

In [9]:
PRETRAIN_AGAIN = True
# Download ECL data from https://github.com/zhouhaoyi/Informer2020
dataset_path = "ECL.csv"
timestamp_column = "date"
id_columns = []

context_length = 512
forecast_horizon = 96
patch_length = 8
num_workers = 16  # Reduce this if you have low number of CPU cores
batch_size = 64  # Adjust according to GPU memory

In [None]:
# import numpy as np
# # Parameters
# beta = 0.2
# gamma = 0.1
# n = 10
# tau = 17
# t_max = 11000
# dt = 1  # time step
# steps = int(t_max / dt)

# # Delay in steps
# tau_steps = int(tau / dt)

# # Initialize P with initial values (e.g., 1.2 for all)
# P = np.zeros(steps)
# P[:tau_steps+1] = 1.2  # history for the delay

# # Time vector
# t = np.arange(0, t_max, dt)

# # Numerical integration using Euler's method
# for i in range(tau_steps, steps - 1):
#     delayed = P[i - tau_steps]
#     dP = beta * delayed / (1 + delayed**n) - gamma * P[i]
#     P[i + 1] = P[i] + dP * dt

# MGdata = pd.DataFrame({
#     'time': t,
#     'P': P
# })

In [10]:
if PRETRAIN_AGAIN:
    data = pd.read_csv(
        dataset_path,
        parse_dates=[timestamp_column],
    )
    forecast_columns = list(data.columns[1:])

    # get split
    num_train = int(len(data) * 0.7)
    num_test = int(len(data) * 0.2)
    num_valid = len(data) - num_train - num_test
    border1s = [
        0,
        num_train - context_length,
        len(data) - num_test - context_length,
    ]
    border2s = [num_train, num_train + num_valid, len(data)]

    train_start_index = border1s[0]  # None indicates beginning of dataset
    train_end_index = border2s[0]

    # we shift the start of the evaluation period back by context length so that
    # the first evaluation timestamp is immediately following the training data
    valid_start_index = border1s[1]
    valid_end_index = border2s[1]

    test_start_index = border1s[2]
    test_end_index = border2s[2]

    train_data = select_by_index(
        data,
        id_columns=id_columns,
        start_index=train_start_index,
        end_index=train_end_index,
    )
    valid_data = select_by_index(
        data,
        id_columns=id_columns,
        start_index=valid_start_index,
        end_index=valid_end_index,
    )
    test_data = select_by_index(
        data,
        id_columns=id_columns,
        start_index=test_start_index,
        end_index=test_end_index,
    )

    tsp = TimeSeriesPreprocessor(
        timestamp_column=timestamp_column,
        id_columns=id_columns,
        input_columns=forecast_columns,
        output_columns=forecast_columns,
        scaling=True,
    )
    tsp.train(train_data)

In [11]:
if PRETRAIN_AGAIN:
    train_dataset = ForecastDFDataset(
        tsp.preprocess(train_data),
        id_columns=id_columns,
        timestamp_column="date",
        input_columns=forecast_columns,
        output_columns=forecast_columns,
        context_length=context_length,
        prediction_length=forecast_horizon,
    )
    valid_dataset = ForecastDFDataset(
        tsp.preprocess(valid_data),
        id_columns=id_columns,
        timestamp_column="date",
        input_columns=forecast_columns,
        output_columns=forecast_columns,
        context_length=context_length,
        prediction_length=forecast_horizon,
    )
    test_dataset = ForecastDFDataset(
        tsp.preprocess(test_data),
        id_columns=id_columns,
        timestamp_column="date",
        input_columns=forecast_columns,
        output_columns=forecast_columns,
        context_length=context_length,
        prediction_length=forecast_horizon,
    )

TypeError: ForecastDFDataset.__init__() got an unexpected keyword argument 'input_columns'

 ## Configure the PatchTSMixer model

 The settings below control the different components in the PatchTSMixer model.
  - `num_input_channels`: the number of input channels (or dimensions) in the time series data. This is
    automatically set to the number for forecast columns.
  - `context_length`: As described above, the amount of historical data used as input to the model.
  - `prediction_length`: This is same as the forecast horizon as decribed above.
  - `patch_length`: The length of the patches extracted from the context window (of length `context_length``).
  - `patch_stride`: The stride used when extracting patches from the context window.
  - `d_model`: Hidden feature dimension of the model.
  - `num_layers`: The number of model layers.
  - `dropout`: Dropout probability for all fully connected layers in the encoder.
  - `head_dropout`: Dropout probability used in the head of the model.
  - `mode`: PatchTSMixer operating mode. "common_channel"/"mix_channel". Common-channel works in channel-independent mode. For pretraining, use "common_channel".
  - `scaling`: Per-widow standard scaling. Recommended value: "std".

For full details on the parameters - refer [here](https://huggingface.co/docs/transformers/main/en/model_doc/patchtsmixer)

We recommend that you only adjust the values in the next cell.

In [None]:
if PRETRAIN_AGAIN:
    config = PatchTSMixerConfig(
        context_length=context_length,
        prediction_length=forecast_horizon,
        patch_length=patch_length,
        num_input_channels=len(forecast_columns),
        patch_stride=patch_length,
        d_model=16,
        num_layers=8,
        expansion_factor=2,
        dropout=0.2,
        head_dropout=0.2,
        mode="common_channel",
        scaling="std",
    )
    model = PatchTSMixerForPrediction(config)

 ## Train model

 Trains the PatchTSMixer model based on the direct forecasting strategy.

In [None]:
if PRETRAIN_AGAIN:
    training_args = TrainingArguments(
        output_dir="./checkpoint/patchtsmixer/electricity/pretrain/output/",
        overwrite_output_dir=True,
        learning_rate=0.001,
        num_train_epochs=100,  # For a quick test of this notebook, set it to 1
        do_eval=True,
        evaluation_strategy="epoch",
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        dataloader_num_workers=num_workers,
        report_to="tensorboard",
        save_strategy="epoch",
        logging_strategy="epoch",
        save_total_limit=3,
        logging_dir="./checkpoint/patchtsmixer/electricity/pretrain/logs/",  # Make sure to specify a logging directory
        load_best_model_at_end=True,  # Load the best model when training ends
        metric_for_best_model="eval_loss",  # Metric to monitor for early stopping
        greater_is_better=False,  # For loss
        label_names=["future_values"],
        # max_steps=20,
    )

    # Create the early stopping callback
    early_stopping_callback = EarlyStoppingCallback(
        early_stopping_patience=10,  # Number of epochs with no improvement after which to stop
        early_stopping_threshold=0.0001,  # Minimum improvement required to consider as improvement
    )

    # define trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=valid_dataset,
        callbacks=[early_stopping_callback],
    )

    # pretrain
    trainer.train()



Epoch,Training Loss,Validation Loss
1,0.2471,0.141067
2,0.1686,0.127757
3,0.1565,0.122327
4,0.1503,0.118918
5,0.146,0.116496
6,0.1431,0.114968
7,0.1408,0.113678
8,0.1392,0.113057
9,0.1379,0.112405
10,0.1369,0.112225




 ## Evaluate model on the test set.


In [None]:
if PRETRAIN_AGAIN:
    results = trainer.evaluate(test_dataset)
    print("Test result:")
    print(results)



Test result:
{'eval_loss': 0.12884521484375, 'eval_runtime': 5.7532, 'eval_samples_per_second': 897.763, 'eval_steps_per_second': 3.65, 'epoch': 35.0}


We get MSE score of 0.128 which is the SOTA result on the Electricity data.

 ## Save model

In [None]:
if PRETRAIN_AGAIN:
    save_dir = "patchtsmixer/electricity/model/pretrain/"
    os.makedirs(save_dir, exist_ok=True)
    trainer.save_model(save_dir)

# Part 2: Transfer Learning from Electicity to ETTH2

In this section, we will demonstrate the transfer learning capability of the `PatchTSMixer` model.
We use the model pretrained on Electricity dataset to do zeroshot testing on ETTH2 dataset.


In Transfer Learning,  we will pretrain the model for a forecasting task on a `source` dataset. Then, we will use the
 pretrained model for zero-shot forecasting on a `target` dataset. The zero-shot forecasting
 performance will denote the `test` performance of the model in the `target` domain, without any
 training on the target domain. Subsequently, we will do linear probing and (then) finetuning of
 the pretrained model on the `train` part of the target data, and will validate the forecasting
 performance on the `test` part of the target data. In this example, the source dataset is the Electricity dataset and the target dataset is ETTH2

## Transfer Learing on `ETTh2` data. All evaluations are on the `test` part of the `ETTh2` data.
Step 1: Directly evaluate the electricity-pretrained model. This is the zero-shot performance.  
Step 2: Evalute after doing linear probing.  
Step 3: Evaluate after doing full finetuning.  

### Load ETTh2 data

In [None]:
dataset = "ETTh2"

In [None]:
print(f"Loading target dataset: {dataset}")
dataset_path = f"https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/{dataset}.csv"
timestamp_column = "date"
id_columns = []
forecast_columns = ["HUFL", "HULL", "MUFL", "MULL", "LUFL", "LULL", "OT"]
train_start_index = None  # None indicates beginning of dataset
train_end_index = 12 * 30 * 24

# we shift the start of the evaluation period back by context length so that
# the first evaluation timestamp is immediately following the training data
valid_start_index = 12 * 30 * 24 - context_length
valid_end_index = 12 * 30 * 24 + 4 * 30 * 24

test_start_index = 12 * 30 * 24 + 4 * 30 * 24 - context_length
test_end_index = 12 * 30 * 24 + 8 * 30 * 24

Loading target dataset: ETTh2


In [None]:
data = pd.read_csv(
    dataset_path,
    parse_dates=[timestamp_column],
)

train_data = select_by_index(
    data,
    id_columns=id_columns,
    start_index=train_start_index,
    end_index=train_end_index,
)
valid_data = select_by_index(
    data,
    id_columns=id_columns,
    start_index=valid_start_index,
    end_index=valid_end_index,
)
test_data = select_by_index(
    data,
    id_columns=id_columns,
    start_index=test_start_index,
    end_index=test_end_index,
)

tsp = TimeSeriesPreprocessor(
    timestamp_column=timestamp_column,
    id_columns=id_columns,
    input_columns=forecast_columns,
    output_columns=forecast_columns,
    scaling=True,
)
tsp.train(train_data)

TimeSeriesPreprocessor {
  "context_length": 64,
  "feature_extractor_type": "TimeSeriesPreprocessor",
  "id_columns": [],
  "input_columns": [
    "HUFL",
    "HULL",
    "MUFL",
    "MULL",
    "LUFL",
    "LULL",
    "OT"
  ],
  "output_columns": [
    "HUFL",
    "HULL",
    "MUFL",
    "MULL",
    "LUFL",
    "LULL",
    "OT"
  ],
  "prediction_length": null,
  "processor_class": "TimeSeriesPreprocessor",
  "scaler_dict": {
    "0": {
      "copy": true,
      "feature_names_in_": [
        "HUFL",
        "HULL",
        "MUFL",
        "MULL",
        "LUFL",
        "LULL",
        "OT"
      ],
      "mean_": [
        41.53683496078959,
        12.273452896210882,
        46.60977329964991,
        10.526153112865156,
        1.1869920139097505,
        -2.373217913729173,
        26.872023494265697
      ],
      "n_features_in_": 7,
      "n_samples_seen_": 8640,
      "scale_": [
        10.448841072588488,
        4.587112566531959,
        16.858190332598408,
        3.0

In [None]:
train_dataset = ForecastDFDataset(
    tsp.preprocess(train_data),
    id_columns=id_columns,
    input_columns=forecast_columns,
    output_columns=forecast_columns,
    context_length=context_length,
    prediction_length=forecast_horizon,
)
valid_dataset = ForecastDFDataset(
    tsp.preprocess(valid_data),
    id_columns=id_columns,
    input_columns=forecast_columns,
    output_columns=forecast_columns,
    context_length=context_length,
    prediction_length=forecast_horizon,
)
test_dataset = ForecastDFDataset(
    tsp.preprocess(test_data),
    id_columns=id_columns,
    input_columns=forecast_columns,
    output_columns=forecast_columns,
    context_length=context_length,
    prediction_length=forecast_horizon,
)

## Zero-shot forecasting on `ETTh2`

In [None]:
print("Loading pretrained model")
finetune_forecast_model = PatchTSMixerForPrediction.from_pretrained("patchtsmixer/electricity/model/pretrain/")
print("Done")

Loading pretrained model
Done


In [None]:
finetune_forecast_args = TrainingArguments(
    output_dir="./checkpoint/patchtsmixer/transfer/finetune/output/",
    overwrite_output_dir=True,
    learning_rate=0.0001,
    num_train_epochs=100,
    do_eval=True,
    evaluation_strategy="epoch",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    dataloader_num_workers=num_workers,
    report_to="tensorboard",
    save_strategy="epoch",
    logging_strategy="epoch",
    save_total_limit=3,
    logging_dir="./checkpoint/patchtsmixer/transfer/finetune/logs/",  # Make sure to specify a logging directory
    load_best_model_at_end=True,  # Load the best model when training ends
    metric_for_best_model="eval_loss",  # Metric to monitor for early stopping
    greater_is_better=False,  # For loss
)

# Create a new early stopping callback with faster convergence properties
early_stopping_callback = EarlyStoppingCallback(
    early_stopping_patience=5,  # Number of epochs with no improvement after which to stop
    early_stopping_threshold=0.001,  # Minimum improvement required to consider as improvement
)

finetune_forecast_trainer = Trainer(
    model=finetune_forecast_model,
    args=finetune_forecast_args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    callbacks=[early_stopping_callback],
)

print("\n\nDoing zero-shot forecasting on target data")
result = finetune_forecast_trainer.evaluate(test_dataset)
print("Target data zero-shot forecasting result:")
print(result)



Doing zero-shot forecasting on target data




Target data zero-shot forecasting result:
{'eval_loss': 0.3038313388824463, 'eval_runtime': 1.8364, 'eval_samples_per_second': 1516.562, 'eval_steps_per_second': 5.99}


By a direct zeroshot, we get MSE of 0.3 which is near to the SOTA result. Lets see, how we can do a simple linear probing to match the SOTA results.

## Target data `ETTh2` linear probing
We can do a quick linear probing on the `train` part of the target data to see any possible `test` performance improvement.

In [None]:
# Freeze the backbone of the model
for param in finetune_forecast_trainer.model.model.parameters():
    param.requires_grad = False

print("\n\nLinear probing on the target data")
finetune_forecast_trainer.train()
print("Evaluating")
result = finetune_forecast_trainer.evaluate(test_dataset)
print("Target data head/linear probing result:")
print(result)



Linear probing on the target data


Epoch,Training Loss,Validation Loss
1,0.447,0.216436
2,0.4386,0.215667
3,0.4294,0.215104
4,0.4225,0.21382
5,0.4185,0.213585
6,0.415,0.213016
7,0.412,0.213067
8,0.4124,0.211993
9,0.4059,0.21246
10,0.4053,0.211772




Evaluating




Target data head/linear probing result:
{'eval_loss': 0.27119266986846924, 'eval_runtime': 1.7621, 'eval_samples_per_second': 1580.478, 'eval_steps_per_second': 6.242, 'epoch': 13.0}


By doing a simple linear probing, MSE decreased from 0.3 to 0.271 achiving the SOTA results.

In [None]:
save_dir = f"patchtsmixer/electricity/model/transfer/{dataset}/model/linear_probe/"
os.makedirs(save_dir, exist_ok=True)
finetune_forecast_trainer.save_model(save_dir)

save_dir = f"patchtsmixer/electricity/model/transfer/{dataset}/preprocessor/"
os.makedirs(save_dir, exist_ok=True)
tsp.save_pretrained(save_dir)

['patchtsmixer/electricity/model/transfer/ETTh2/preprocessor/preprocessor_config.json']

Lets now see, if we get any more improvements by doing a full finetune.

## Target data `ETTh2` full finetune

We can do a full model finetune (instead of probing the last linear layer as shown above) on the `train` part of the target data to see a possible `test` performance improvement.

In [None]:
# Reload the model
finetune_forecast_model = PatchTSMixerForPrediction.from_pretrained("patchtsmixer/electricity/model/pretrain/")
finetune_forecast_trainer = Trainer(
    model=finetune_forecast_model,
    args=finetune_forecast_args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    callbacks=[early_stopping_callback],
)
print("\n\nFinetuning on the target data")
finetune_forecast_trainer.train()
print("Evaluating")
result = finetune_forecast_trainer.evaluate(test_dataset)
print("Target data full finetune result:")
print(result)



Finetuning on the target data




Epoch,Training Loss,Validation Loss
1,0.4329,0.2152
2,0.4167,0.210919
3,0.4014,0.209932
4,0.3929,0.208808
5,0.3881,0.209692
6,0.3759,0.209546
7,0.37,0.210207
8,0.367,0.211601
9,0.3594,0.211405




Evaluating




Target data full finetune result:
{'eval_loss': 0.2734043300151825, 'eval_runtime': 1.5853, 'eval_samples_per_second': 1756.725, 'eval_steps_per_second': 6.939, 'epoch': 9.0}


There is not much improvement with ETTH2 dataset with full finetuning. Lets save the model anyway.

In [None]:
save_dir = f"patchtsmixer/electricity/model/transfer/{dataset}/model/fine_tuning/"
os.makedirs(save_dir, exist_ok=True)
finetune_forecast_trainer.save_model(save_dir)


Summary: In this blog, we presented a step-by-step guide on leveraging PatchTSMixer for tasks related to forecasting and transfer learning. We intend to facilitate the seamless integration of the PatchTSMixer HF model for your forecasting use cases. We trust that this content serves as a useful resource to expedite your adoption of PatchTSMixer. Thank you for tuning in to our blog, and we hope you find this information beneficial for your projects.
