
# Getting Started with TSPulse Classification

This notebook demonstrates the usage of a pre-trained TSPulse model for time-series classification task. Refer to [TSPulse](https://arxiv.org/abs/2505.13033) paper for architecture and other details.

Backbone of the pre-trained model is freezed and the classifier head along with the input patch embedding layer is finetuned on the classification dataset.

The pre-trained TSPulse model can be accessed from the [Hugging Face TSPulse Model Repository](https://huggingface.co/ibm-granite/granite-timeseries-tspulse-r1).


## Imports

In [16]:
import math
import os
import tempfile

import numpy as np
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader, random_split
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed
from transformers.data.data_collator import default_data_collator
from transformers.trainer_utils import RemoveColumnsCollator

In [17]:
import warnings


warnings.filterwarnings("ignore")

In [18]:
from tsfm_public.models.tspulse import TSPulseForClassification
from tsfm_public.toolkit.dataset import ClassificationDFDataset
from tsfm_public.toolkit.lr_finder import optimal_lr_finder
from tsfm_public.toolkit.time_series_classification_preprocessor import TimeSeriesClassificationPreprocessor
from tsfm_public.toolkit.util import convert_tsfile_to_dataframe

## Data Preprocessing

In [19]:
seed = 42
set_seed(seed)

In [20]:
dataset_name = "BasicMotions"

In [21]:
# --- Use Relative Paths for Portability ---

# The dataset directory is relative to the current notebook's location.
dataset_dir = "Multivariate_ts"

# Construct the relative paths to the train and test files
path = os.path.join(dataset_dir, dataset_name, f"{dataset_name}_TRAIN.ts")

df_base = convert_tsfile_to_dataframe(
    path,
    return_separate_X_and_y=False,
)

label_column = "class_vals"
input_columns = [f"dim_{i}" for i in range(df_base.shape[1] - 1)]

tsp = TimeSeriesClassificationPreprocessor(
    input_columns=input_columns,
    label_column=label_column,
    scaling=True,
)

tsp.train(df_base)
df_train_prep = tsp.preprocess(df_base)

base_dataset = ClassificationDFDataset(
    df_train_prep,
    id_columns=[],
    timestamp_column=None,
    input_columns=input_columns,
    label_column=label_column,
    context_length=512,
    static_categorical_columns=[],
    stride=1,
    enable_padding=False,
    full_series=True,
)

path = os.path.join(dataset_dir, dataset_name, f"{dataset_name}_TEST.ts")

df_test = convert_tsfile_to_dataframe(
    path,
    return_separate_X_and_y=False,
)

label_column = "class_vals"
input_columns = [f"dim_{i}" for i in range(df_test.shape[1] - 1)]

df_test_prep = tsp.preprocess(df_test)

test_dataset = ClassificationDFDataset(
    df_test_prep,
    id_columns=[],
    timestamp_column=None,
    input_columns=input_columns,
    label_column=label_column,
    context_length=512,
    static_categorical_columns=[],
    stride=1,
    enable_padding=False,
    full_series=True,
)


# creating a validation set

dataset_size = len(base_dataset)
print(dataset_size)
split_valid_ratio = 0.1
val_size = int(split_valid_ratio * dataset_size)  # 10% valid split
train_size = dataset_size - val_size
train_dataset, valid_dataset = random_split(base_dataset, [train_size, val_size])

40


## Configs for the TSPulse model
### Hyperparameters to Optimize and suggested values :
#### head_reduce_d_model = 1, 2
#### decoder_mode = mix_channel, common_channel
#### head_gated_attention_activation = softmax, sigmoid
#### mask_ratio = 0, 0.3
#### channel_virtual_expand_scale = 1, 2

In [22]:
config_dict = {
    "head_gated_attention_activation": "softmax",
    "channel_virtual_expand_scale": 2,
    "mask_ratio": 0.3,
    "head_reduce_d_model": 1,
    "disable_mask_in_classification_eval": True,
    "fft_time_consistent_masking": True,
    "decoder_mode": "mix_channel",
    "head_aggregation_dim": "patch",
    "head_aggregation": None,
    "loss": "cross_entropy",
    "ignore_mismatched_sizes": True,
}

config_dict["num_input_channels"] = tsp.num_input_channels
config_dict["num_targets"] = df_base["class_vals"].nunique()

## Getting the Pretrained Model with above configs

In [23]:
model = TSPulseForClassification.from_pretrained(
    "ibm-granite/granite-timeseries-tspulse-r1", revision="tspulse-block-dualhead-512-p16-r1", **config_dict
)

INFO:p-1652385:t-140272554137408:modeling_tspulse.py:_init_weights:Initializing Linear layers with method: pytorch
INFO:p-1652385:t-140272554137408:modeling_tspulse.py:_init_weights:Initializing Linear layers with method: pytorch
INFO:p-1652385:t-140272554137408:modeling_tspulse.py:_init_weights:Initializing Linear layers with method: pytorch
INFO:p-1652385:t-140272554137408:modeling_tspulse.py:_init_weights:Initializing Linear layers with method: pytorch
INFO:p-1652385:t-140272554137408:modeling_tspulse.py:_init_weights:Initializing Linear layers with method: pytorch
INFO:p-1652385:t-140272554137408:modeling_tspulse.py:_init_weights:Initializing Linear layers with method: pytorch
INFO:p-1652385:t-140272554137408:modeling_tspulse.py:_init_weights:Initializing Linear layers with method: pytorch
INFO:p-1652385:t-140272554137408:modeling_tspulse.py:_init_weights:Initializing Linear layers with method: pytorch
INFO:p-1652385:t-140272554137408:modeling_tspulse.py:_init_weights:Initializing 

In [24]:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"
# model = model.to(device).float()
print(device)

cuda


In [25]:
# Freezing Backbone except patch embedding layer....

for param in model.backbone.parameters():
    param.requires_grad = False

for param in model.backbone.time_encoding.parameters():
    param.requires_grad = True
for param in model.backbone.fft_encoding.parameters():
    param.requires_grad = True

## Finetuning the classifier head and patch embedding layer

In [26]:
OUT_DIR = "tspulse_finetuned_models/"

In [27]:
temp_dir = tempfile.mkdtemp()

suggested_lr = None

train_dict = {"per_device_train_batch_size": 32, "num_train_epochs": 200, "eval_accumulation_steps": None}

EPOCHS = train_dict["num_train_epochs"]
BATCH_SIZE = train_dict["per_device_train_batch_size"]
eval_accumulation_steps = train_dict["eval_accumulation_steps"]
NUM_WORKERS = 1
NUM_GPUS = 1

set_seed(42)
if suggested_lr is None:
    lr, model = optimal_lr_finder(
        model,
        train_dataset,
        batch_size=BATCH_SIZE,
    )
    suggested_lr = lr
print("Suggested LR : ", suggested_lr)
finetune_args = TrainingArguments(
    output_dir=temp_dir,
    overwrite_output_dir=True,
    learning_rate=suggested_lr,
    num_train_epochs=EPOCHS,
    do_eval=True,
    eval_strategy="epoch",
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    eval_accumulation_steps=eval_accumulation_steps,
    dataloader_num_workers=NUM_WORKERS,
    report_to="tensorboard",
    save_strategy="epoch",
    logging_strategy="epoch",
    save_total_limit=1,
    logging_dir=os.path.join(OUT_DIR, "output"),  # Make sure to specify a logging directory
    load_best_model_at_end=True,  # Load the best model when training ends
    metric_for_best_model="eval_loss",  # Metric to monitor for early stopping
    greater_is_better=False,  # For loss
)

# Create the early stopping callback
early_stopping_callback = EarlyStoppingCallback(
    early_stopping_patience=100,  # Number of epochs with no improvement after which to stop
    early_stopping_threshold=0.0001,  # Minimum improvement required to consider as improvement
)

# Optimizer and scheduler
optimizer = AdamW(model.parameters(), lr=suggested_lr)
scheduler = OneCycleLR(
    optimizer,
    suggested_lr,
    epochs=EPOCHS,
    steps_per_epoch=math.ceil(len(train_dataset) / (BATCH_SIZE * NUM_GPUS)),
)

finetune_trainer = Trainer(
    model=model,
    args=finetune_args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    callbacks=[early_stopping_callback],
    optimizers=(optimizer, scheduler),
)

# Fine tune
finetune_trainer.train()

INFO:p-1652385:t-140272554137408:lr_finder.py:optimal_lr_finder:LR Finder: Running learning rate (LR) finder algorithm. If the suggested LR is very low, we suggest setting the LR manually.
INFO:p-1652385:t-140272554137408:lr_finder.py:optimal_lr_finder:LR Finder: Using cuda:0.


INFO:p-1652385:t-140272554137408:lr_finder.py:optimal_lr_finder:LR Finder: Suggested learning rate = 0.004037017258596558


Suggested LR :  0.004037017258596558


Epoch,Training Loss,Validation Loss
1,1.5949,1.455229
2,1.4529,1.44989
3,1.4799,1.444723
4,1.457,1.44
5,1.4587,1.436479
6,1.5208,1.433226
7,1.4791,1.430662
8,1.4951,1.427992
9,1.4466,1.425089
10,1.4764,1.421724


TrainOutput(global_step=200, training_loss=0.311656161108549, metrics={'train_runtime': 168.3002, 'train_samples_per_second': 42.781, 'train_steps_per_second': 1.188, 'total_flos': 49402375372800.0, 'train_loss': 0.311656161108549, 'epoch': 200.0})

## Classification Scores

In [28]:
predictions_dict = finetune_trainer.predict(test_dataset)
preds_np = predictions_dict.predictions[0]

remove_columns_collator = RemoveColumnsCollator(
    data_collator=default_data_collator,
    signature_columns=["target_values"],
    logger=None,
    description=None,
    model_name="temp",
)

test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=remove_columns_collator)
target_list = []
for batch in test_dataloader:
    batch_labels = batch["target_values"].numpy()
    target_list.append(batch_labels)
targets_np = np.concatenate(target_list, axis=0)
test_accuracy = np.mean(targets_np == np.argmax(preds_np, axis=1))
print("test_accuracy : ", test_accuracy)

test_accuracy :  1.0


## Using Classification Pipeline for inference with the finetuned model

In [29]:
from tsfm_public.toolkit.time_series_classification_pipeline import TimeSeriesClassificationPipeline


pipe = TimeSeriesClassificationPipeline(finetune_trainer.model, feature_extractor=tsp, device=device)

Device set to use cuda


In [30]:
pipe(df_test)

Unnamed: 0,dim_0,dim_1,dim_2,dim_3,dim_4,dim_5,class_vals,class_vals_prediction
0,0 -0.740653 1 -0.740653 2 10.20844...,0 0.756509 1 0.756509 2 -9.216970 3...,0 -0.275809 1 -0.275809 2 -12.37890...,0 -0.423476 1 -0.423476 2 -14.69915...,0 0.013317 1 0.013317 2 4.578337 3...,0 0.013317 1 0.013317 2 -5.055081 3...,standing,standing
1,0 -0.247409 1 -0.247409 2 -0.771290 3...,0 -0.060459 1 -0.060459 2 -0.047618 3...,0 -0.608565 1 -0.608565 2 -0.294411 3...,0 -0.023970 1 -0.023970 2 -0.269001 3...,0 0.101208 1 0.101208 2 0.111862 3...,0 0.071911 1 0.071911 2 0.135832 3...,standing,standing
2,0 -0.663284 1 -0.663284 2 5.393924 3...,0 0.273010 1 0.273010 2 -3.079673 3...,0 -0.160963 1 -0.160963 2 -3.175911 3...,0 -0.245030 1 -0.245030 2 -6.408074 3...,0 -0.077238 1 -0.077238 2 0.471417 3...,0 -0.018644 1 -0.018644 2 -3.592890 3...,standing,standing
3,0 -1.088052 1 -1.088052 2 -0.683620 3...,0 0.183832 1 0.183832 2 -2.909047 3...,0 -0.260871 1 -0.260871 2 1.507042 3...,0 -0.284981 1 -0.284981 2 0.415486 3...,0 0.487397 1 0.487397 2 0.013317 3...,0 1.081329 1 1.081329 2 0.820319 3...,standing,standing
4,0 0.354481 1 0.354481 2 0.449142 3...,0 -0.567671 1 -0.567671 2 -1.899854 3...,0 -0.084270 1 -0.084270 2 0.913056 3...,0 -0.223723 1 -0.223723 2 0.692477 3...,0 -0.247694 1 -0.247694 2 0.149149 3...,0 0.050604 1 0.050604 2 0.849616 3...,standing,standing
5,0 -1.182602 1 -0.765368 2 -0.519464 3...,0 -0.612973 1 -2.759566 2 -3.213704 3...,0 0.167450 1 0.414760 2 0.907956 3...,0 -0.276991 1 -0.508704 2 -0.077238 3...,0 -0.082565 1 -0.114525 2 -0.261010 3...,0 -0.213070 1 -0.426140 2 0.215733 3...,standing,standing
6,0 1.275129 1 1.275129 2 -0.273185 3...,0 -1.024406 1 -1.024406 2 0.095152 3...,0 -0.545722 1 -0.545722 2 0.023203 3...,0 -0.463427 1 -0.463427 2 0.042614 3...,0 -0.367545 1 -0.367545 2 -0.109198 3...,0 -0.159802 1 -0.159802 2 0.183773 3...,standing,standing
7,0 -0.352746 1 -0.352746 2 -1.354561 3...,0 0.316845 1 0.316845 2 0.490525 3...,0 -0.473779 1 -0.473779 2 1.454261 3...,0 -0.327595 1 -0.327595 2 -0.269001 3...,0 0.106535 1 0.106535 2 0.021307 3...,0 0.197090 1 0.197090 2 0.460763 3...,standing,standing
8,0 0.498121 1 0.498121 2 0.196889 3...,0 0.031305 1 0.031305 2 -3.122323 3...,0 -0.358509 1 -0.358509 2 0.258171 3...,0 0.047941 1 0.047941 2 0.143822 3...,0 -0.119852 1 -0.119852 2 0.015980 3...,0 0.005327 1 0.005327 2 0.010653 3...,standing,standing
9,0 0.126160 1 0.126160 2 1.771871 3...,0 0.102733 1 0.102733 2 -3.798484 3...,0 0.308964 1 0.308964 2 0.141369 3...,0 0.002663 1 0.002663 2 -1.427568 3...,0 0.000000 1 0.000000 2 -0.167792 3...,0 -0.007990 1 -0.007990 2 -1.643301 3...,standing,standing
