## Introduction

This notebook evaluates the TimesFM model with ambient dataset for benchmarking purpose

Github: https://github.com/google-research/timesfm

arxiv: https://arxiv.org/abs/2310.10688

Frequency definitions

0: T, MIN, H, D, B, U

1: W, M

2: Q, Y

In [8]:
from typing import Optional, Tuple
from os import path
import numpy as np
import pandas as pd

import torch
from torch.utils.data import DataLoader
from dataloader.dataloader import UnivariateMethaneHourly

from finetuning.finetuning_torch import FinetuningConfig, TimesFMFinetuner
from huggingface_hub import snapshot_download
import wandb

from timesfm import TimesFm, TimesFmCheckpoint, TimesFmHparams
from timesfm.pytorch_patched_decoder import PatchedTimeSeriesDecoder
import plotly.graph_objects as go
import argparse
wandb.login()

torch.cuda.is_available()

wandb: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
wandb: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: No netrc file found, creating one.
wandb: Appending key for api.wandb.ai to your netrc file: C:\Users\Windows\_netrc
wandb: Currently logged in as: ranluo87 (ranluo87-university-of-calgary-in-alberta) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin


False

In [9]:
data_dir = 'C:\Python Projects\TimeSeries_Benchmarking\datasets\select'
data_file = 'Anzac.csv'

parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, default=data_dir)
parser.add_argument('--data_file', type=str, default=data_file)
# TimesFM configurations
parser.add_argument('--timesfm', type=bool, default=True)
parser.add_argument('--freq_type', type=int, default=0)

parser.add_argument('--seq_len', type=int, default=512)
parser.add_argument('--pred_len', type=int, default=128)
# Optimization Hyperparams
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--learning_rate', type=float, default=1e-4)

args = parser.parse_args('')

In [10]:
repo_id = "google/timesfm-2.0-500m-pytorch"

hparams = TimesFmHparams(
    backend='cpu',
    per_core_batch_size=32,
    num_layers=50,
    horizon_len=args.pred_len,
    context_len=args.seq_len,
    use_positional_embedding=False,
    output_patch_len=128
)

tfm = TimesFm(
    hparams=hparams,
    checkpoint=TimesFmCheckpoint(
        huggingface_repo_id=repo_id
    )
)

model = PatchedTimeSeriesDecoder(tfm._model_config)

checkpoint_path = path.join(snapshot_download(repo_id), 'torch_model.ckpt')
loaded_checkpoint = torch.load(checkpoint_path, weights_only=True)
model.load_state_dict(loaded_checkpoint)

Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]

Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]

<All keys matched successfully>

In [12]:
config = FinetuningConfig(
    batch_size=args.batch_size,
    num_epochs=args.epochs,
    learning_rate=args.learning_rate,
    freq_type=args.freq_type,
    log_every_n_steps=10,
    val_check_interval=0.2,
    use_quantile_loss=False,
    use_wandb=True
)

train_dataset = UnivariateMethaneHourly(args, flag='train')
val_dataset = UnivariateMethaneHourly(args, flag='val')

finetuner = TimesFMFinetuner(model, config)
finetuner.finetune(train_dataset=train_dataset, val_dataset=val_dataset)

100%|██████████| 48024/48024 [00:00<00:00, 72042.17it/s] 
100%|██████████| 9789/9789 [00:00<00:00, 64537.59it/s]


{'history': {'train_loss': [], 'val_loss': [], 'learning_rate': []}}

In [None]:
raw_df = pd.read_csv(str(path.join(data_dir, data_file)), parse_dates=True)
raw_df.reset_index(inplace=True)

test_df = raw_df[int(len(raw_df) * 0.8):]
test_df.columns = ['unique_id', 'ds', 'values']

test_df['ds'] = pd.to_datetime(test_df['ds'])

forecast_df = tfm.forecast_on_df(
    inputs=test_df,
    freq='1H'
)

forecast_df = forecast_df[['ds', 'timesfm']]
forecast_df = forecast_df.groupby(['ds']).mean()
forecast_df.reset_index(inplace=True)

fig = go.Figure()

fig.add_trace(go.Scatter(x=forecast_df['ds'], y=forecast_df['timesfm'], mode='lines+markers', name='Forecast'))
fig.add_trace(go.Scatter(x=test_df['ds'], y=test_df['values'], mode='lines+markers', name='True'))

fig.write_html("./timesfm.html")