In [1]:
import torch
import numpy as np
import argparse
import os
import sys
import time
import datetime
from ts2vec import TS2Vec
import tasks
import datautils
from tasks import _eval_protocols as eval_protocols
from utils import init_dl_program, name_with_datetime, pkl_save, data_dropout

In [2]:
def save_checkpoint_callback(
    save_every=1,
    unit='epoch'
):
    assert unit in ('epoch', 'iter')
    def callback(model, loss):
        n = model.n_epochs if unit == 'epoch' else model.n_iters
        if n % save_every == 0:
            model.save(f'{run_dir}/model_{n}.pkl')
    return callback

In [3]:
dataset='electricity'
run_name='forecast_multivar'
loader='forecast_csv'
gpu=0
batch_size=8
lr=0.001
repr_dims=320
max_train_length=3000
iters=None
epochs=None
save_every=None
seed=42
max_threads=30
eval=True
irregular=0

In [4]:
device = init_dl_program(0, seed=42, max_threads=max_threads)

print('Loading data... ')
task_type = 'forecasting'
data, train_slice, valid_slice, test_slice, scaler, pred_lens, n_covariate_cols = datautils.load_forecast_csv(dataset)
train_data = data[:, train_slice]
print("Train data shape:", train_data.shape)

config = dict(
    batch_size=batch_size,
    lr=lr,
    output_dims=repr_dims,
    max_train_length=max_train_length
)

if save_every is not None:
    unit = 'epoch' if epochs is not None else 'iter'
    config[f'after_{unit}_callback'] = save_checkpoint_callback(save_every, unit)

run_dir = 'training/' + dataset + '__' + name_with_datetime(run_name)
os.makedirs(run_dir, exist_ok=True)

t = time.time()

model = TS2Vec(
    input_dims=train_data.shape[-1],
    device=device,
    **config
)

Loading data... 
Train data shape: (321, 15782, 8)


In [5]:
loss_log = model.fit(
    train_data,
    n_epochs=epochs,
    n_iters=iters,
    verbose=True
)
model.save(f'{run_dir}/model.pkl')
t = time.time() - t
print(f"\nTraining time: {datetime.timedelta(seconds=t)}\n")

Epoch #0: loss=1.4719260054826737
Epoch #1: loss=0.550676117092371
Epoch #2: loss=0.40362279050052163

Training time: 0:02:14.329828



In [6]:
model.load('training/electricity__forecast_multivar_20241024_132036/model.pkl')

  state_dict = torch.load(fn, map_location=self.device)


In [6]:
padding = 200
    
t = time.time()
all_repr = model.encode(
    data,
    causal=True,
    sliding_length=1,
    sliding_padding=padding,
    batch_size=256
)
ts2vec_infer_time = time.time() - t
with open('training/electricity__forecast_multivar_20241024_132036/all_repr.npy', 'wb') as f:
    np.save(f, all_repr)

100%|██████████| 26304/26304 [11:16<00:00, 38.91it/s]
100%|██████████| 26304/26304 [04:20<00:00, 101.13it/s]


In [17]:
all_repr = np.load('training/electricity__forecast_multivar_20241024_132036/all_repr.npy')

In [7]:
train_repr = all_repr[:, train_slice]
valid_repr = all_repr[:, valid_slice]
test_repr = all_repr[:, test_slice]

train_data = data[:, train_slice, n_covariate_cols:]
valid_data = data[:, valid_slice, n_covariate_cols:]
test_data = data[:, test_slice, n_covariate_cols:]

In [8]:
def generate_pred_samples(features, data, pred_len, drop=0):
    n = data.shape[1]
    features = features[:, :-pred_len]
    labels = np.stack([ data[:, i:1+n+i-pred_len] for i in range(pred_len)], axis=2)[:, 1:]
    features = features[:, drop:]
    labels = labels[:, drop:]
    return features.reshape(-1, features.shape[-1]), \
            labels.reshape(-1, labels.shape[2]*labels.shape[3])

def cal_metrics(pred, target):
    return {
        'MSE': ((pred - target) ** 2).mean(),
        'MAE': np.abs(pred - target).mean()
    }

In [9]:
ours_result = {}
lr_train_time = {}
lr_infer_time = {}
out_log = {}
for pred_len in pred_lens[0:2]:
    print("Predicting for length:", pred_len)
    train_features, train_labels = generate_pred_samples(train_repr, train_data, pred_len, drop=padding)
    valid_features, valid_labels = generate_pred_samples(valid_repr, valid_data, pred_len)
    test_features, test_labels = generate_pred_samples(test_repr, test_data, pred_len)
    
    print("Fitting Ridge Regression")
    t = time.time()
    lr = eval_protocols.fit_ridge(train_features, train_labels, valid_features, valid_labels)
    lr_train_time[pred_len] = time.time() - t
    
    print("Predicting with Ridge")
    t = time.time()
    test_pred = lr.predict(test_features)
    lr_infer_time[pred_len] = time.time() - t

    ori_shape = test_data.shape[0], -1, pred_len, test_data.shape[2]
    test_pred = test_pred.reshape(ori_shape)
    test_labels = test_labels.reshape(ori_shape)
    
        
    out_log[pred_len] = {
        'norm': test_pred,
        'norm_gt': test_labels,
    }
    ours_result[pred_len] = {
        'norm': cal_metrics(test_pred, test_labels),
    }
    
eval_res = {
    'ours': ours_result,
    'ts2vec_infer_time': ts2vec_infer_time,
    'lr_train_time': lr_train_time,
    'lr_infer_time': lr_infer_time
}

Predicting for length: 24
Fitting Ridge Regression


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Predicting with Ridge
Predicting for length: 48
Fitting Ridge Regression


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Predicting with Ridge


In [10]:
# pkl_save(f'{run_dir}/out.pkl', out)
# pkl_save(f'{run_dir}/eval_res.pkl', eval_res)
print('Evaluation result:', eval_res)

Evaluation result: {'ours': {24: {'norm': {'MSE': 0.28812696385408204, 'MAE': 0.3758995626158909}}, 48: {'norm': {'MSE': 0.31168743772722285, 'MAE': 0.3922654431728807}}}, 'ts2vec_infer_time': 938.0712258815765, 'lr_train_time': {24: 16.403483629226685, 48: 16.90790867805481}, 'lr_infer_time': {24: 0.5094263553619385, 48: 0.8072500228881836}}


In [11]:
eval_res['ours']

{24: {'norm': {'MSE': 0.28812696385408204, 'MAE': 0.3758995626158909}},
 48: {'norm': {'MSE': 0.31168743772722285, 'MAE': 0.3922654431728807}}}