# Time-Series Search

This notebook demonstrates the basic usage of a pre-trained `TSPulse` model for time-series search.
For more comprehensive experiments, please refer to the [TSPulse paper](https://arxiv.org/abs/2505.13033).

The motivation behind time-series search is to understand and interpret a time series of interest by comparing it with others that exhibit similar patterns.
For example, suppose a sudden spike or dip is detected in a machine's sensor reading.
Retrieving similar time series from past observations can help to interpret the periodicity of such spikes, identify co-occurring anomalies in other sensors, and more.

When searching for similar time series, the following factors are important to consider:
- Similar patterns may appear at different positions, with different amplitudes, or slightly shifted in time.
- Since time-series data can be very long, it is important to consider overlaps between small segments in the query and those in past data.
- Noise and distortions in observations can make matching more difficult.

TSPulse addresses these challenges and enables rapid search thanks to its tiny model size and compact embeddings.
The embeddings can capture similar patterns in a way that is robust to shifts, scaling, and noise—enabling powerful zero-shot similarity search without the need for labeled examples.

As a simple example of time-series search, we use a classification dataset, treating class labels as the ground truth.
The rationale is that time series belonging to the same class may share similar patterns.
Each time series has a length of 512 data points (context length). The index database contains 240-dimensional embeddings of 512-length time-series data.
Once the query time-series data is provided, we search for similar embeddings using the Euclidean distances.
In the final cell, we visualize the query and the retrieved time series to demonstrate the accuracy of TSPulse's search performance.

## Imports

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import set_seed

from tsfm_public.models.tspulse import TSPulseForReconstruction
from tsfm_public.toolkit.dataset import ClassificationDFDataset
from tsfm_public.toolkit.time_series_classification_preprocessor import (
    TimeSeriesClassificationPreprocessor,
)
from tsfm_public.toolkit.util import convert_tsfile_to_dataframe

## Important arguments

In [None]:
# Set seed for reproducibility
SEED = 42
set_seed(SEED)

CONTEXT_LENGTH = 512

model_path = ""  # modify it to use TSPulse model for search
batch_size = 128

# search
k = 3

## Data processing

In [None]:
def get_data(split):
    df_base = convert_tsfile_to_dataframe(
        f"Univariate_ts/ShapesAll/ShapesAll_{split}.ts",
        return_separate_X_and_y=False,
    )
    label_column = "class_vals"
    input_columns = [f"dim_{i}" for i in range(df_base.shape[1] - 1)]

    tsp = TimeSeriesClassificationPreprocessor(
        input_columns=input_columns,
        label_column=label_column,
        scaling=True,
    )

    tsp.train(df_base)
    df_prep = tsp.preprocess(df_base)
    base_dataset = ClassificationDFDataset(
        df_prep,
        id_columns=[],
        timestamp_column=None,
        input_columns=input_columns,
        label_column=label_column,
        context_length=CONTEXT_LENGTH,
        static_categorical_columns=[],
        stride=1,
        enable_padding=False,
        full_series=True,
    )
    return base_dataset


train_dataset = get_data("TRAIN")
test_dataset = get_data("TEST")
len(train_dataset), len(test_dataset)

## Pre-trained Model Loading

In [None]:
model = TSPulseForReconstruction.from_pretrained(model_path, num_input_channels=1, mask_ratio=0)
model.eval();

## Index Database Creation

In [None]:
@torch.no_grad()
def get_embedding(batch, model):
    # similarity search uses register embeddings
    d_model = model.config.d_model
    reg_tokens = model.config.patch_register_tokens
    reg_emb_size = reg_tokens * d_model

    _x = batch["past_values"]  # [B, T, C]
    # instance-wise scaling
    mu = torch.mean(_x, dim=1, keepdims=True)  # [B, 1, C]
    std = torch.std(_x, dim=1, keepdims=True, correction=0)  # [B, 1, C]
    x = (_x - mu) / std
    embedding = model(x)
    embedding = embedding["decoder_hidden_state"]  # [B, T, C] -> [B, C, D]
    embedding = embedding[:, :, -reg_emb_size:]
    return embedding


# compute embeddings for train dataset
dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
embeddings = []
for batch_idx, batch in enumerate(tqdm(dataloader, total=len(dataloader))):
    embedding = get_embedding(batch, model)
    embeddings.append(embedding.numpy())
train_embeddings = np.concatenate(embeddings)
print(train_embeddings.shape)

# create index set of embeddings
embs = train_embeddings.squeeze(axis=1)  # [B, 1, D] because of the univariate time-series
nn = NearestNeighbors(metric="l2")
nn.fit(embs)

## Performance Evaluation

In [None]:
def calc_precision_k(cmp, k):
    cmp_k = np.sum(cmp[:, :k], axis=1) / k
    mean_cmp_k = np.mean(cmp_k)
    return mean_cmp_k


dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
cmp, I_all = [], []
for batch in tqdm(dataloader):
    test_embedding = get_embedding(batch, model)
    D, I = nn.kneighbors(test_embedding.squeeze(axis=1), n_neighbors=k)
    I_all.append(I)

    label_test = batch["target_values"].numpy()
    retrieved_label = np.array([[train_dataset[i]["target_values"] for i in _I] for _I in I])
    cmp.append(label_test[:, None] == retrieved_label)

I_all = np.concatenate(I_all, axis=0)  # for visualization
cmp = np.concatenate(cmp, axis=0)
prec_k = calc_precision_k(cmp, k)
print(f"PREC@{k}={prec_k:.3f}")

## Visualization

In [None]:
plotted_list = []
for sample_idx, sample in enumerate(test_dataset):
    if sample["target_values"] in plotted_list:
        continue
    elif len(plotted_list) >= 5:  # plot max five classes
        break
    else:
        plotted_list.append(sample["target_values"])

    first_match = I_all[sample_idx][0]
    second_match = I_all[sample_idx][1]
    third_match = I_all[sample_idx][2]

    label = test_dataset[sample_idx]["target_values"]
    first_label = train_dataset[first_match]["target_values"]
    second_label = train_dataset[second_match]["target_values"]
    third_label = train_dataset[third_match]["target_values"]

    fig, ax = plt.subplots(1, 4, figsize=(6.4 * 3, 3.4))
    ax[0].plot(test_dataset[sample_idx]["past_values"])
    ax[0].set_title(f"Query (Test ID: {sample_idx}, class: {label})")

    ax[1].plot(train_dataset[first_match]["past_values"])
    ax[1].set_title(f"First match (Train ID: {first_match}, class: {first_label})")
    ax[1].set_facecolor("lightgreen") if label == first_label else None

    ax[2].plot(train_dataset[second_match]["past_values"])
    ax[2].set_title(f"Second match (Train ID: {second_match}, class: {second_label})")
    ax[2].set_facecolor("lightgreen") if label == second_label else None

    ax[3].plot(train_dataset[third_match]["past_values"])
    ax[3].set_title(f"Third match (Train ID: {third_match}, class: {third_label})")
    ax[3].set_facecolor("lightgreen") if label == third_label else None