# Pretraining forecasters on panel data

Pretraining allows a forecaster to learn shared temporal patterns from a
collection of related time series (panel data) before being fine-tuned on a
specific target series.

This is conceptually different from **foundation models** (Chronos, MOIRAI,
TimesFM, etc.) that ship pre-trained on massive external corpora. The
``pretrain`` API lets you train on **your own domain data**, giving you full
control over what the model learns.

## When to use pretraining

| Approach | Use when |
|---|---|
| ``fit`` only | You have a single series with enough history |
| ``pretrain`` + ``fit`` | You have related series and want transfer learning on your own data |
| Foundation model | You want zero-shot or few-shot forecasting from a large pre-trained model |
| Global forecasting | You want a single model that predicts all series jointly (no per-series fine-tuning) |

## Estimator state lifecycle

Forecasters in sktime have three states, accessible via the ``state`` property:

```
new ──pretrain()──> pretrained ──fit()──> fitted ──predict()──> ...
 │                                          ^
 └──────────fit()───────────────────────────┘
```

Calling ``fit`` on a ``"new"`` forecaster resets and trains from scratch (the
standard workflow). Calling ``fit`` on a ``"pretrained"`` forecaster preserves
the pretrained weights and fine-tunes.

## Setup: loading panel data

We use the hierarchical sales toy dataset, which contains monthly sales for
4 product groups across 2 product lines (240 rows total).

In [None]:
from sktime.datasets import load_hierarchical_sales_toydata
from sktime.forecasting.base import ForecastingHorizon

y_panel = load_hierarchical_sales_toydata()

print(f"Shape: {y_panel.shape}")
print(f"Index levels: {y_panel.index.names}")
print(f"Unique series: {len(y_panel.index.droplevel(-1).unique())}")
y_panel.head(10)

## Basic example with DummyGlobalForecaster

``DummyGlobalForecaster`` is a lightweight baseline that computes summary
statistics during pretraining (no deep learning dependencies required).
It is useful for testing the API and as a comparison baseline.

In [None]:
from sktime.forecasting.dummy_global import DummyGlobalForecaster

forecaster = DummyGlobalForecaster(strategy="mean")

# Step 1: pretrain on the full panel
forecaster.pretrain(y_panel)
print(f"State: {forecaster.state}")
print(f"Global mean learned: {forecaster.global_mean_:.2f}")
print(f"Instances seen: {forecaster.n_pretrain_instances_}")

In [None]:
# Step 2: fit on a single target series
y_target = y_panel.xs(("Food preparation", "Hobs"))["Sales"]
forecaster.fit(y_target, fh=[1, 2, 3])
print(f"State: {forecaster.state}")

# Step 3: predict
y_pred = forecaster.predict()
print(f"\nPredictions (global mean repeated):\n{y_pred}")

## Inspecting pretrained parameters

Attributes set during ``pretrain`` are tracked separately from those set during
``fit``. Use ``get_pretrained_params()`` to inspect them, and
``get_fitted_params()`` to inspect fit-time attributes.

In [None]:
print("Pretrained parameters (from pretrain):")
for key, val in sorted(forecaster.get_pretrained_params().items()):
    print(f"  {key}: {val}")

print("\nFitted parameters (from fit):")
for key, val in sorted(forecaster.get_fitted_params().items()):
    print(f"  {key}: {val}")

## Deep learning example with LTSFLinearForecaster

For neural network based forecasters, pretraining trains the network weights
on panel data. The subsequent ``fit`` call fine-tunes those weights on the
target series rather than initialising from random weights.

This requires ``torch`` to be installed.

In [None]:
from sktime.forecasting.ltsf import LTSFLinearForecaster

nn_forecaster = LTSFLinearForecaster(
    seq_len=24,
    pred_len=6,
    num_epochs=3,
    batch_size=16,
    lr=1e-3,
)

# Pretrain on the full panel
nn_forecaster.pretrain(y_panel)
print(f"State: {nn_forecaster.state}")
print(f"Instances seen: {nn_forecaster.n_pretrain_instances_}")

In [None]:
# Fine-tune on a single series and predict
y_target = y_panel.xs(("Food preparation", "Hobs"))["Sales"]
fh = ForecastingHorizon(list(range(1, nn_forecaster.pred_len + 1)), is_relative=True)

nn_forecaster.fit(y_target, fh=fh)
y_pred = nn_forecaster.predict()
print(f"State: {nn_forecaster.state}")
print(f"\nPredictions:\n{y_pred}")

## Incremental pretraining

Calling ``pretrain`` a second time on an already pretrained (or fitted)
forecaster triggers ``_pretrain_update`` instead of ``_pretrain``. This enables
incremental learning from additional data batches without rebuilding the model
from scratch.

In [None]:
from sktime.utils._testing.hierarchical import _make_hierarchical

incremental = DummyGlobalForecaster(strategy="mean")

# First batch
batch_1 = _make_hierarchical(
    hierarchy_levels=(3,), min_timepoints=24, max_timepoints=24, random_state=0,
)
incremental.pretrain(batch_1)
print(f"After batch 1: mean={incremental.global_mean_:.4f}, "
      f"instances={incremental.n_pretrain_instances_}")

# Second batch -- pretrain is called again, triggers _pretrain_update
batch_2 = _make_hierarchical(
    hierarchy_levels=(2,), min_timepoints=24, max_timepoints=24, random_state=42,
)
incremental.pretrain(batch_2)
print(f"After batch 2: mean={incremental.global_mean_:.4f}, "
      f"instances={incremental.n_pretrain_instances_}")

## Cloning preserves pretrained state

When a pretrained forecaster is cloned via sktime's ``.clone()`` method, the
pretrained attributes are copied to the clone. This is important for
cross-validation and tuning, where the estimator is cloned internally.

Note: ``sklearn.base.clone`` does **not** preserve pretrained state. Always
use the sktime ``.clone()`` method (which is what sktime's CV and tuning
tools use internally).

In [None]:
original = DummyGlobalForecaster()
original.pretrain(y_panel)
print(f"Original: state={original.state}, mean={original.global_mean_:.2f}")

cloned = original.clone()
print(f"Clone:    state={cloned.state}, mean={cloned.global_mean_:.2f}")

## Discovering pretrainable forecasters

Forecasters that support pretraining declare the ``capability:pretrain`` tag.
You can find all of them with ``all_estimators``.

In [None]:
from sktime.registry import all_estimators

all_estimators(
    "forecaster",
    filter_tags={"capability:pretrain": True},
    as_dataframe=True,
)

## Summary

- ``pretrain(y_panel)`` learns from panel data and sets state to ``"pretrained"``
- ``fit(y_series)`` fine-tunes on a single series, preserving pretrained weights
- ``get_pretrained_params()`` inspects what was learned during pretraining
- Calling ``pretrain`` again triggers incremental updates
- sktime's ``.clone()`` preserves pretrained state (important for CV/tuning)
- Use ``all_estimators`` with ``filter_tags={"capability:pretrain": True}``
  to discover pretrainable forecasters