# Time Series Imputation Benchmark

Compare **simulated Annealing** to SOTA methods (BRITS, SAITS, KNN, MICE) on standard datasets loaded via **TSDB**.

In [25]:
# 2. Imports
import tsdb
import torch
import numpy as np
import pandas as pd
from neal import SimulatedAnnealingSampler
from sklearn.preprocessing import RobustScaler
from sklearn.linear_model import Ridge
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import KNNImputer, IterativeImputer
from sklearn.model_selection import train_test_split
from pypots.imputation import SAITS, BRITS
import matplotlib.pyplot as plt
import seaborn as sns
import time

# reproducibility
np.random.seed(42)

## 3. Load datasets via TSDB

We load four benchmarks: Air Quality (Italy), Electricity Load, PeMS Traffic, PhysioNet 2012 ICU.

In [2]:
# 3. Load datasets via TSDB (robust to missing 'X' key)
# map our friendly names to TSDB dataset keys
dataset_map = {
    "AirQuality": "italy_air_quality",
    "Electricity": "electricity_load_diagrams",
    "Traffic": "pems_traffic",
    "PhysioNet": "physionet_2012"
}

data_dict = {}
for display_name, tsdb_name in dataset_map.items():
    raw = tsdb.load(tsdb_name, use_cache=True)
    # Debug print to see what keys we got:
    print(f"{display_name} raw keys:", list(raw.keys()))
    if "X" in raw:
        df = raw["X"]
    else:
        # fallback: pick the first pd.DataFrame in raw.values()
        dfs = [v for v in raw.values() if isinstance(v, pd.DataFrame)]
        if not dfs:
            raise ValueError(f"No DataFrame found in TSDB output for '{tsdb_name}'. Keys: {list(raw.keys())}")
        df = dfs[0]
    # ensure datetime index if possible
    if isinstance(df.index, pd.DatetimeIndex):
        df = df.sort_index()
    data_dict[display_name] = df
    print(f"  → Loaded {display_name} as DataFrame {df.shape}\n")

2025-05-22 10:06:02 [INFO]: You're using dataset italy_air_quality, please cite it properly in your work. You can find its reference information at the below link: 
https://github.com/WenjieDu/TSDB/tree/main/dataset_profiles/italy_air_quality
2025-05-22 10:06:02 [INFO]: Dataset italy_air_quality has already been downloaded. Processing directly...
2025-05-22 10:06:02 [INFO]: Dataset italy_air_quality has already been cached. Loading from cache directly...
2025-05-22 10:06:02 [INFO]: Loaded successfully!
2025-05-22 10:06:02 [INFO]: You're using dataset electricity_load_diagrams, please cite it properly in your work. You can find its reference information at the below link: 
https://github.com/WenjieDu/TSDB/tree/main/dataset_profiles/electricity_load_diagrams
2025-05-22 10:06:02 [INFO]: Dataset electricity_load_diagrams has already been downloaded. Processing directly...
2025-05-22 10:06:02 [INFO]: Dataset electricity_load_diagrams has already been cached. Loading from cache directly...
2

AirQuality raw keys: ['X']
  → Loaded AirQuality as DataFrame (9357, 15)

Electricity raw keys: ['X']
  → Loaded Electricity as DataFrame (140256, 370)

Traffic raw keys: ['X']
  → Loaded Traffic as DataFrame (17544, 863)

PhysioNet raw keys: ['set-a', 'set-b', 'set-c', 'outcomes-a', 'outcomes-b', 'outcomes-c', 'static_features']
  → Loaded PhysioNet as DataFrame (180552, 43)



## 4. Simulate missingness

- **Random MCAR**: 10% missing uniformly at random.
- **Block gap**: a 7-day contiguous gap where applicable.

In [3]:
# 4. Simulate missingness (contiguous 10% gap)
obs = {}
mask = {}

for name, df in data_dict.items():
    # --- only keep numeric cols ---
    df_num = df.select_dtypes(include='number')
    arr = df_num.values.astype(float)
    n, d = arr.shape

    # determine gap length = 10% of time‐steps
    gap_len = int(np.floor(n * 0.10))
    # choose a start index (you can fix or randomize; here we fix seed for reproducibility)
    rng = np.random.RandomState(42)
    start = rng.randint(0, n - gap_len + 1)

    # build the contiguous block mask
    m_block = np.zeros_like(arr, dtype=bool)
    m_block[start : start + gap_len, :] = True

    # apply mask
    arr_obs = arr.copy()
    arr_obs[m_block] = np.nan

    # store
    obs[name]  = arr_obs
    mask[name] = m_block

    print(f"{name}: rows {start}–{start+gap_len-1} set to NaN (missing={m_block.mean()*100:.1f}%)")

AirQuality: rows 7270–8204 set to NaN (missing=10.0%)
Electricity: rows 121958–135982 set to NaN (missing=10.0%)
Traffic: rows 7270–9023 set to NaN (missing=10.0%)
PhysioNet: rows 121958–140012 set to NaN (missing=10.0%)


## 5. simulated Annealing Imputer

In [4]:
def impute_with_annealing(ts, mask, num_reads=100, prior=None, prior_weight=1.0):
    import numpy as _np
    from neal import SimulatedAnnealingSampler as Sampler

    imp = ts.copy()
    idx = _np.where(mask)[0]
    if idx.size == 0:
        return imp
    s, e = idx[0], idx[-1]
    y0 = ts[s-1] if s>0 else ts[e+1] if e+1<len(ts) else 0.0
    yN = ts[e+1] if e+1<len(ts) else ts[s-1] if s>0 else 0.0
    N = e-s+1
    d = (yN-y0)/(N+1)
    exp = _np.linspace(y0+d, yN-d, N)
    known = ts[~mask]
    minv, maxv = known.min(), known.max()

    Q, var_idx, inv_map, c = {}, {}, {}, 0
    for i,v in enumerate(exp):
        opts = {int(_np.floor(v)), int(_np.ceil(v))}
        if len(opts)==1:
            x0=opts.pop(); opts={x0-1,x0,x0+1}
        for x in sorted(opts):
            var_idx[(i,x)] = c
            inv_map[c] = (i,x)
            c+=1
    P=1e6
    def add(a,b,val):
        key=(min(a,b),max(a,b))
        Q[key]=Q.get(key,0)+val

    # (a) seasonality / range penalties
    for (i,x),qi in var_idx.items():
        add(qi,qi,(x-exp[i])**2)
        if x>maxv: add(qi,qi,1e6*(x-maxv)**4)
        if x<minv: add(qi,qi,1e4*(minv-x)**2)

    # (b) smoothness
    for (ii,x1),qi in var_idx.items():
        if ii>=N-1: continue
        for (jj,x2),qj in var_idx.items():
            if jj!=ii+1: continue
            add(qi,qj,((x2-x1)-d)**2)

    # (c) prior-bias for hybrid (only if provided)
    if prior is not None:
        for (i,x),qi in var_idx.items():
            p_val = prior[s+i]
            add(qi,qi, prior_weight*(x-p_val)**2)

    # (d) one-hot constraint
    for i in range(N):
        opts=[x for (ii,x) in var_idx if ii==i]
        for a,x1 in enumerate(opts):
            qa=var_idx[(i,x1)]
            add(qa,qa,-2*P)
            for x2 in opts[a+1:]:
                qb=var_idx[(i,x2)]
                add(qa,qb,2*P)

    # Solve
    sampler=Sampler()
    sampleset=sampler.sample_qubo(Q,num_reads=num_reads)
    sol=sampleset.first.sample
    for var,bit in sol.items():
        if bit:
            i,x=inv_map[var]
            imp[s+i]=x
    return imp

## 6. SAITS→QA Hybrid Imputer

In [5]:
def impute_saits_then_anneal(arr, mask, saits_epochs=20, qa_reads=100, prior_weight=1.0):
    # first do SAITS
    imp_saits = impute_saits(arr.copy(), mask, epochs=saits_epochs)
    hybrid = imp_saits.copy()
    # then refine ONLY the gap via QA, biasing toward SAITS
    for col in range(arr.shape[1]):
        hybrid[:,col] = impute_with_annealing(
            hybrid[:,col], mask[:,col],
            num_reads=qa_reads,
            prior=imp_saits[:,col],
            prior_weight=prior_weight
        )
    return hybrid

def impute_brits_then_anneal(arr, mask, brits_epochs=20, qa_reads=100, prior_weight=1.0):
    # first do BRITS
    imp_brits = impute_brits(arr.copy(), mask, epochs=brits_epochs)
    hybrid = imp_brits.copy()
    # then refine ONLY the gap via QA, biasing toward BRITS
    for col in range(arr.shape[1]):
        hybrid[:,col] = impute_with_annealing(
            hybrid[:,col], mask[:,col],
            num_reads=qa_reads,
            prior=imp_brits[:,col],
            prior_weight=prior_weight
        )
    return hybrid

## 6b. Baseline Imputers

In [26]:
def impute_brits(arr, mask, epochs=20):
    """
    BRITS imputer running entirely on CPU.
    """
    arr_masked = arr.copy()
    arr_masked[mask] = np.nan
    X = torch.tensor(arr_masked[np.newaxis], dtype=torch.float32)
    data = {"X": X}
    model = BRITS(
        n_steps=arr.shape[0],
        n_features=arr.shape[1],
        rnn_hidden_size=64,
        epochs=epochs
    )
    model.fit(data)
    imp_t = model.impute(data)    # this is already a numpy.ndarray
    return imp_t[0]                # not .numpy()

def impute_saits(arr, mask, epochs=20):
    """
    SAITS imputer running entirely on CPU.
    """
    arr_masked = arr.copy()
    arr_masked[mask] = np.nan
    X = torch.tensor(arr_masked[np.newaxis], dtype=torch.float32)
    data = {"X": X}
    model = SAITS(
        n_steps=arr.shape[0],
        n_features=arr.shape[1],
        n_layers=2,
        d_model=64,
        n_heads=4,
        d_k=16,
        d_v=16,
        d_ffn=256,
        epochs=epochs
    )
    model.fit(data)
    imp_t = model.impute(data)    # already numpy.ndarray
    return imp_t[0]                # not .numpy()

knn_imp = lambda arr, mask: KNNImputer(n_neighbors=5).fit_transform(
    np.where(mask, np.nan, arr)
)

scaler = RobustScaler()
mice_imp = lambda arr, mask: scaler.inverse_transform(
    IterativeImputer(
        estimator=Ridge(alpha=1.0),
        max_iter=20,
        tol=1e-3,
        initial_strategy='median',
        random_state=0
    ).fit_transform(
        scaler.fit_transform(np.where(mask, np.nan, arr))
    )
)

## 7. Evaluation & Comparison

In [28]:
def metrics(true, imp, mask, elapsed_time):
    d = imp - true
    rmse = np.sqrt(np.nanmean((d[mask])**2))
    mae  = np.nanmean(np.abs(d[mask]))
    nz   = true[mask] != 0
    mape = np.nanmean(np.abs(d[mask][nz] / true[mask][nz])) * 100
    return rmse, mae, mape, elapsed_time

results = []
for name, df in data_dict.items():
    if name=="AirQuality":
        arr = df.select_dtypes(include='number').values.astype(float)
        m   = mask[name]
        start_time = time.time()
        # simulated Annealing only
        qa_imp = np.vstack([
            impute_with_annealing(arr[:,c], m[:,c])
            for c in range(arr.shape[1])
        ]).T
        end_time = time.time()
        elapsed_time = end_time - start_time
        results.append((*metrics(arr, qa_imp, m,  elapsed_time), name, 'simulatedAnneal'))
        # BRITS
        start_time = time.time()
        br_imp = impute_brits(arr, m, 150)
        end_time = time.time()
        elapsed_time = end_time - start_time
        results.append((*metrics(arr, br_imp, m, elapsed_time), name, 'BRITS'))
        # SAITS
        start_time = time.time()
        sa_imp = impute_saits(arr, m, 150)
        end_time = time.time()
        elapsed_time = end_time - start_time
        results.append((*metrics(arr, sa_imp, m,  elapsed_time), name, 'SAITS'))
        """
        # SAITS→QA hybrid
        sq_imp = impute_saits_then_anneal(arr, m, saits_epochs=20, qa_reads=100)
        results.append((*metrics(arr, sq_imp, m), name, 'SAITS+QA'))
        # BRITS→QA hybrid
        bq_imp = impute_brits_then_anneal(arr, m, brits_epochs=20, qa_reads=100)
        results.append((*metrics(arr, bq_imp, m), name, 'BRITS+QA'))
        """

df_res = pd.DataFrame(
    results,
    columns=['RMSE','MAE','MAPE','Dataset','Method']
)

2025-05-22 11:50:18 [INFO]: No given device, using default device: cpu
2025-05-22 11:50:18 [INFO]: Using customized MAE as the training loss function.
2025-05-22 11:50:18 [INFO]: Using customized MSE as the validation metric function.
2025-05-22 11:50:18 [INFO]: BRITS initialized with the given hyperparameters, the number of trainable parameters: 52,016
2025-05-22 11:50:28 [INFO]: Epoch 001 - training loss (MAE): 1003.5520
2025-05-22 11:50:37 [INFO]: Epoch 002 - training loss (MAE): 998.2177
2025-05-22 11:50:46 [INFO]: Epoch 003 - training loss (MAE): 992.9183
2025-05-22 11:50:55 [INFO]: Epoch 004 - training loss (MAE): 987.6596
2025-05-22 11:51:04 [INFO]: Epoch 005 - training loss (MAE): 982.4376
2025-05-22 11:51:13 [INFO]: Epoch 006 - training loss (MAE): 977.2526
2025-05-22 11:51:22 [INFO]: Epoch 007 - training loss (MAE): 972.1053
2025-05-22 11:51:32 [INFO]: Epoch 008 - training loss (MAE): 966.9975
2025-05-22 11:51:41 [INFO]: Epoch 009 - training loss (MAE): 961.9261
2025-05-22 11

ValueError: 5 columns passed, passed data had 6 columns

In [None]:
# 8. Plot RMSE comparison
plt.figure(figsize=(10,6))
sns.barplot(data=df_res, x='Dataset', y='RMSE', hue='Method')
plt.title('RMSE by Method and Dataset')
plt.show()

In [29]:
results

[(np.float64(294.5503261751803),
  np.float64(156.3438307610037),
  np.float64(85.08588158095988),
  20.947827100753784,
  'AirQuality',
  'simulatedAnneal'),
 (np.float64(645.5327736062553),
  np.float64(423.10142732354876),
  np.float64(97.36174119679492),
  1407.4743611812592,
  'AirQuality',
  'BRITS'),
 (np.float64(627.9780744529519),
  np.float64(405.6969662733688),
  np.float64(106.14751045906563),
  538.4755589962006,
  'AirQuality',
  'SAITS')]