In [1]:
import os
import random 


os.environ["NCCL_P2P_DISABLE"] = "1"
os.environ["NCCL_IB_DISABLE"] = "1"

from time import time
import math 
import tempfile 
import torch 
import pickle 
import logging 
import warnings
import json
import torch.nn as nn

import matplotlib.pyplot as plt
from glob import glob
from tqdm import tqdm
from time import time
import numpy as np 
import pandas as pd
from sklearn.metrics import mean_squared_error

import argparse


from transformers import Trainer, TrainingArguments, set_seed, EarlyStoppingCallback
from torch.utils.data import ConcatDataset, Dataset, DataLoader


from tsfm_public.models.tinytimemixer.configuration_tinytimemixer import TinyTimeMixerConfig
# from tinytimemixer.modeling_tinytimemixer import TinyTimeMixerForPrediction
from tsfm_public.models.tinytimemixer import TinyTimeMixerForPrediction
from tsfm_public.toolkit.dataset import PretrainDFDataset, ForecastDFDataset
from tsfm_public.toolkit.time_series_preprocessor import TimeSeriesPreprocessor
from tsfm_public.toolkit.util import select_by_index

warnings.filterwarnings("ignore")
SEED = 42
set_seed(SEED)


In [2]:

# metrics used for evaluation
def cal_cvrmse(pred, true, eps=1e-8):
    pred = np.array(pred)
    true = np.array(true)
    return np.power(np.square(pred - true).sum() / pred.shape[0], 0.5) / (true.sum() / pred.shape[0] + eps)

def cal_mae(pred, true):
    pred = np.array(pred)
    true = np.array(true)
    return np.mean(np.abs(pred - true))

def cal_nrmse(pred, true, eps=1e-8):
    true = np.array(true)
    pred = np.array(pred)

    M = len(true) // 24
    y_bar = np.mean(true)
    NRMSE = 100 * (1/ (y_bar+eps)) * np.sqrt((1 / (24 * M)) * np.sum((true - pred) ** 2))
    return NRMSE


In [3]:


def standardize_series(series, eps=1e-8):
    mean = np.mean(series)
    std = np.std(series)
    standardized_series = (series - mean) / (std+eps)
    return standardized_series, mean, std

def unscale_predictions(predictions, mean, std, eps=1e-8):
    return predictions * (std+eps) + mean


class TimeSeriesDataset(Dataset):
    def __init__(self, data, backcast_length, forecast_length, stride=1):
        # Standardize the time series data
        self.data, self.mean, self.std = standardize_series(data)
        self.backcast_length = backcast_length
        self.forecast_length = forecast_length
        self.stride = stride

    def __len__(self):
        return (len(self.data) - self.backcast_length - self.forecast_length) // self.stride + 1

    def __getitem__(self, index):
        start_index = index * self.stride
        x = self.data[start_index : start_index + self.backcast_length]
        y = self.data[start_index + self.backcast_length : start_index + self.backcast_length + self.forecast_length]
        return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.float32)


In [4]:
def model_config(args):

    config = TinyTimeMixerConfig(
        context_length=args["context_length"],
        patch_length=args["patch_length"],
        num_input_channels=args["num_input_channels"],
        patch_stride=args["patch_stride"],
        d_model=args["d_model"],
        num_layers=args["num_layers"],
        expansion_factor=args["expansion_factor"],
        dropout=args["dropout"],
        head_dropout=args["head_dropout"],
        mode=args["mode"][0],
        scaling=args["scaling"],
        prediction_length=args["prediction_length"],
        is_scaling=args["is_scaling"],
        gated_attn=args["gated_attn"],
        norm_mlp=args["norm_mlp"],
        self_attn=args["self_attn"],
        self_attn_heads=args["self_attn_heads"],
        use_positional_encoding=args["use_positional_encoding"],
        positional_encoding_type=args["positional_encoding_type"],
        loss=args["loss"],
        init_std=args["init_std"],
        post_init=args["post_init"],
        norm_eps=args["norm_eps"],
        adaptive_patching_levels=args["adaptive_patching_levels"],
        resolution_prefix_tuning=args["resolution_prefix_tuning"],
        frequency_token_vocab_size=args["frequency_token_vocab_size"],
        distribution_output=args["distribution_output"],
        num_parallel_samples=args["num_parallel_samples"],
        decoder_num_layers=args["decoder_num_layers"],
        decoder_d_model=args["decoder_d_model"],
        decoder_adaptive_patching_levels=args["decoder_adaptive_patching_levels"],
        decoder_raw_residual=args["decoder_raw_residual"],
        decoder_mode=args["decoder_mode"],
        use_decoder=args["use_decoder"],
        enable_forecast_channel_mixing=args["enable_forecast_channel_mixing"],
        fcm_gated_attn=args["fcm_gated_attn"],
        fcm_context_length=args["fcm_context_length"],
        fcm_use_mixer=args["fcm_use_mixer"],
        fcm_mix_layers=args["fcm_mix_layers"],
        fcm_prepend_past=args["fcm_prepend_past"], 
        init_linear=args["init_linear"],
        init_embed=args["init_embed"],

    )

    pretraining_model = TinyTimeMixerForPrediction(config)
    return pretraining_model

In [5]:
def test(
    args,
    model,
    criterion,
    dataset_path,
    result_path,
    device,
    target_buildings="BR02"   # can be "BR02", ["BR02","BR03"], or "all"
):

    os.makedirs(result_path, exist_ok=True)

    res = []

    # Loop over parquet files
    for file_name in os.listdir(dataset_path):

        if not file_name.endswith(".parquet"):
            continue

        file_id = file_name.replace(".parquet", "")
        file_path = os.path.join(dataset_path, file_name)

        print(f"Testing file: {file_id}")

        # Load parquet
        df = pd.read_parquet(file_path)

        # ---------------------------------------------------
        # Determine target buildings
        # ---------------------------------------------------
        if target_buildings == "all":
            buildings_to_test = list(df.columns)

        elif isinstance(target_buildings, str):
            buildings_to_test = [target_buildings]

        elif isinstance(target_buildings, list):
            buildings_to_test = target_buildings

        else:
            raise ValueError("target_buildings must be 'all', a string, or list")

        # ---------------------------------------------------
        # Test each building column
        # ---------------------------------------------------
        for building_col in buildings_to_test:

            if building_col not in df.columns:
                print(f" {building_col} not found in {file_id}, skipping...")
                continue

            print(f"\n   ▶ Building: {building_col}")

            # Extract series
            energy_data = df[building_col].values.astype(np.float32)

            # ---------------------------------------------------
            # Fill NaNs with Median
            # ---------------------------------------------------
            nan_count = np.isnan(energy_data).sum()

            if nan_count > 0:
                median_val = np.nanmedian(energy_data)
                energy_data = np.where(
                    np.isnan(energy_data),
                    median_val,
                    energy_data
                )

                print(f"  Filled {nan_count} NaNs with median={median_val:.4f}")

            # ---------------------------------------------------
            # Check minimum length
            # ---------------------------------------------------
            min_required = args["context_length"] + args["prediction_length"]

            if len(energy_data) < min_required:
                print("   Too short, skipping...")
                continue

            # Dataset creation
            dataset = TimeSeriesDataset(
                energy_data,
                args["context_length"],
                args["prediction_length"],
                args["patch_stride"]
            )

            if len(dataset) == 0:
                print("   No samples, skipping...")
                continue

            model.eval()

            val_losses = []
            y_true_test = []
            y_pred_test = []

            # ---------------------------------------------------
            # Testing loop
            # ---------------------------------------------------
            for x_test, y_test in DataLoader(dataset, batch_size=1):

                x_test = x_test.unsqueeze(-1).to(device)
                y_test = y_test.to(device)

                with torch.no_grad():
                    output = model(x_test)
                    forecast = output.prediction_outputs.squeeze(-1)

                    loss = criterion(forecast, y_test)

                    if torch.isnan(loss):
                        continue

                    val_losses.append(loss.item())

                    y_true_test.append(y_test.cpu().numpy())
                    y_pred_test.append(forecast.cpu().numpy())

            # ---------------------------------------------------
            #  Skip empty results
            # ---------------------------------------------------
            if len(y_true_test) == 0:
                print("   No predictions collected, skipping...")
                continue

            # Combine predictions
            y_true = np.concatenate(y_true_test, axis=0)
            y_pred = np.concatenate(y_pred_test, axis=0)

            # Unscale
            y_pred_unscaled = unscale_predictions(y_pred, dataset.mean, dataset.std)
            y_true_unscaled = unscale_predictions(y_true, dataset.mean, dataset.std)

            # ---------------------------------------------------
            #  Metrics
            # ---------------------------------------------------
            cvrmse = cal_cvrmse(y_pred_unscaled, y_true_unscaled)
            nrmse  = cal_nrmse(y_pred_unscaled, y_true_unscaled)
            mae    = cal_mae(y_pred_unscaled, y_true_unscaled)

            avg_loss = np.mean(val_losses)

            print(f"  CVRMSE={cvrmse:.4f}, NRMSE={nrmse:.4f}, MAE={mae:.4f}")

            # Save row
            res.append([
                file_id,
                building_col,
                cvrmse,
                nrmse,
                mae,
                avg_loss
            ])

    # ---------------------------------------------------
    # Save Results
    # ---------------------------------------------------
    columns = ["Dataset", "Building", "CVRMSE", "NRMSE", "MAE", "Avg_Test_Loss"]

    result_df = pd.DataFrame(res, columns=columns)

    result_csv = os.path.join(result_path, "test_results.csv")
    result_df.to_csv(result_csv, index=False)

    print("\n Testing complete!")
    print(" Results saved at:", result_csv)


In [8]:
config_file = '../Energy-TTM/config/tinyTimeMixers.json'
with open(config_file, 'r') as f:
    args = json.load(f)

# check device 
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'


# define TTMs model
model = model_config(args).to(device)
model.load_state_dict(torch.load('../Energy-TTM/Weights/energy_ttm.pth'))

# model's parameters
param = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Model's parameter count is:", param)

# Define loss and optimizer
criterion = torch.nn.MSELoss()

# print(args.result_path)
# training the model and save best parameters
test(args=args, model=model, criterion=criterion,dataset_path="../Dataset/Forecasting",result_path="test_results_zeroshot", device=device)




Model's parameter count is: 28858
Testing file: Bareilly-1H

   ▶ Building: BR02
  Filled 9984 NaNs with median=0.1660
  CVRMSE=0.0974, NRMSE=236.3778, MAE=0.0428
Testing file: Mathura-1H
 BR02 not found in Mathura-1H, skipping...

 Testing complete!
 Results saved at: test_results_zeroshot/test_results.csv
