# Calculate Baseline Metrics

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# Set working directory to '/n/groups/patel/shakson/aiready/'
import os
os.chdir("/home/shaksonisaac/CGM/mambatf/")

#LOAD Datasets
import pandas as pd
import io
from google.cloud import storage

_BUCKET_NAME = "cgmproject2025"

# Download dataset from GCS
client = storage.Client()
bucket = client.bucket(_BUCKET_NAME)
blob = bucket.blob('ai-ready/data/train_finaltimeseries_meal.feather')
data_bytes = blob.download_as_bytes()
train = pd.read_feather(io.BytesIO(data_bytes))


# Download test set:
client = storage.Client()
bucket = client.bucket(_BUCKET_NAME)
blob = bucket.blob('ai-ready/data/test_finaltimeseries_meal.feather')
data_bytes = blob.download_as_bytes()
test = pd.read_feather(io.BytesIO(data_bytes))

In [3]:
import numpy as np
import pandas as pd
import scipy.stats as stats

def smape(y_true, y_pred):
    return 100 * np.mean(
        2 * np.abs(y_pred - y_true) / (np.abs(y_pred) + np.abs(y_true) + 1e-8)
    )

def quantile_loss(y_true, y_pred, q=0.5):
    return np.mean(np.maximum(q * (y_true - y_pred), (q - 1) * (y_true - y_pred)))

def calculate_metrics(t_grouped, train_df):
    static_vars = train_df[
        [
            "participant_id",
        ]
    ].drop_duplicates()
    t_grouped = t_grouped.merge(static_vars, on="participant_id", how="left")
    metrics_df = (
        t_grouped.groupby("participant_id")
        .apply(
            lambda df: pd.Series(
                {
                    "SMAPE": smape(df["target"], df["prediction"]),
                    "Quantile_Loss": quantile_loss(df["target"], df["prediction"]),
                    "MAE": np.mean(np.abs(df["prediction"] - df["target"])),
                    "RMSE": np.sqrt(np.mean(np.square(df["prediction"] - df["target"]))),
                }
            )
        )
        .reset_index()
    )
    return metrics_df

def get_confidence_intervals(df, metric, confidence=0.95):
    mean = df[metric].mean()
    sem = stats.sem(df[metric])  # Standard Error of the Mean
    margin_of_error = sem * stats.t.ppf((1 + confidence) / 2, len(df) - 1)
    return mean, mean - margin_of_error, mean + margin_of_error

def calculate_metrics_CI(records, train):
    """
    Calculate metrics for each participant in the records DataFrame.
    """
    metrics_df = calculate_metrics(records, train)

    confidence_intervals = {}
    for metric in ["SMAPE", "MAE", "RMSE", "Quantile_Loss"]:
        mean, lower, upper = get_confidence_intervals(metrics_df, metric)
        confidence_intervals[metric] = {
            "mean": mean,
            "lower": lower,
            "upper": upper
        }
    # Convert to DataFrame for better visualization
    confidence_df = pd.DataFrame(confidence_intervals).T.reset_index()
    confidence_df.columns = ["Metric", "Mean", "Lower CI", "Upper CI"]
    return confidence_df

In [4]:
import gc
def clear_memory():
    gc.collect()
    import torch
    torch.cuda.empty_cache()
clear_memory()

## Checkpointing Method

In [5]:
import os, re
from google.cloud import storage

def download_best_ckpt_by_filename(local_dir: str, bucket_name: str, gcs_prefix: str,
                                   run_prefix: str = "tft576-", metric_key: str = "val_loss"):
    """
    Finds the .ckpt with the smallest {metric_key} in its filename under gcs_prefix,
    e.g., 'tft576-epoch=17-val_loss=3.61.ckpt', downloads it, and returns the local path.
    """
    os.makedirs(local_dir, exist_ok=True)
    client = storage.Client()
    bucket = client.bucket(bucket_name)

    blobs = [b for b in bucket.list_blobs(prefix=gcs_prefix)
             if b.name.endswith(".ckpt")
             and "last.ckpt" not in b.name
             and run_prefix in os.path.basename(b.name)
             and f"{metric_key}=" in os.path.basename(b.name)]

    if not blobs:
        raise FileNotFoundError(f"No epoch checkpoints with {metric_key}=... under gs://{bucket_name}/{gcs_prefix}")

    rx = re.compile(rf"{metric_key}=([0-9]+\.[0-9]+)")
    def score(b):
        m = rx.search(os.path.basename(b.name))
        return float(m.group(1)) if m else float("inf")

    best_blob = min(blobs, key=score)
    local_path = os.path.join(local_dir, os.path.basename(best_blob.name))
    best_blob.download_to_filename(local_path)
    print(f"Downloaded best checkpoint: gs://{bucket_name}/{best_blob.name} -> {local_path}")
    return local_path

# NHITS with checkpoint

In [9]:
# NHITS older version

# Load Data
import os
import sys
import torch

# Fetch best ckpt from GCS (use your actual bucket/prefix)
BUCKET = "cgmproject2025"
GCS_PREFIX = "checkpoints_nhits_288v4"   # <- change if your run used a different folder
best_ckpt_path = download_best_ckpt_by_filename("checkpoints", BUCKET, GCS_PREFIX,
                                                run_prefix="nhits_288-", metric_key="val_loss")

#from TFT_pytorch import log_memory, create_tft_dataloaders, TFT_train
from scripts.NHITS288v4 import create_nhits_dataloaders, NHiTS, load_nhits_from_gcs

# Rebuild the training dataset (same context_length, horizon, etc.)
training, val_dataloader, train_dataloader, validation = create_nhits_dataloaders(train, horizon=12, context_length=288, batchsize=32)

# Load model from the checkpoint
device = "cuda" if torch.cuda.is_available() else "cpu"
model = NHiTS.load_from_checkpoint(best_ckpt_path, map_location=device).to(device)
model.eval()  # Put model in evaluation mode

Downloaded best checkpoint: gs://cgmproject2025/checkpoints_nhits_288v4/nhits_288-epoch=04-val_loss=7.92.ckpt -> checkpoints/nhits_288-epoch=04-val_loss=7.92.ckpt
[2025-09-02 23:11:52.216827] 🚀 Start of Dataloader Creation
GPU Mem allocated: 0.00 GB | reserved: 0.01 GB


/home/shaksonisaac/miniconda3/envs/cgmall/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss'])`.
/home/shaksonisaac/miniconda3/envs/cgmall/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'logging_metrics' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['logging_metrics'])`.


NHiTS(
  	"activation":                        ReLU
  	"backcast_loss_ratio":               0.0
  	"batch_normalization":               False
  	"categorical_groups":                {}
  	"context_length":                    288
  	"dataset_parameters":                {'time_idx': 'ds', 'target': 'cgm_glucose', 'group_ids': ['participant_id'], 'weight': None, 'max_encoder_length': 288, 'min_encoder_length': 288, 'min_prediction_idx': np.int64(11), 'min_prediction_length': 12, 'max_prediction_length': 12, 'static_categoricals': ['participant_id', 'clinical_site', 'study_group'], 'static_reals': ['age'], 'time_varying_known_categoricals': ['sleep_stage'], 'time_varying_known_reals': ['ds', 'minute_of_day', 'tod_sin', 'tod_cos', 'activity_steps', 'calories_value', 'heartrate', 'oxygen_saturation', 'respiration_rate', 'stress_level', 'predmeal_flag'], 'time_varying_unknown_categoricals': None, 'time_varying_unknown_reals': ['cgm_glucose', 'cgm_lag_1', 'cgm_lag_3', 'cgm_lag_6', 'cgm_diff_la

In [10]:
# Get the parameter count of tft
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters in NHITS model: {total_params}")

Total parameters in NHITS model: 301492


In [12]:
# Get global metrics:
raw_preds = model.predict(val_dataloader, mode="raw", return_x=True, return_index=True)
print("Shape of raw_preds:", raw_preds.output["prediction"].shape)
y_pred = raw_preds.output["prediction"] #[:, :, 1] #To get median quantile.
y_true = raw_preds.x["decoder_target"]
index_df = raw_preds.index
records = []
for i in range(len(index_df)):
    uid = index_df.iloc[i]["participant_id"]
    time_start = index_df.iloc[i]["ds"]
    for t in range(y_pred.shape[1]):
        records.append({
            "participant_id": uid,
            "ds": int(time_start + t),
            "target": float(y_true[i, t]),
            "prediction": float(y_pred[i, t]),
        })
records = pd.DataFrame(records)

metricsCI = calculate_metrics_CI(records, train)
print(metricsCI)

# Save metrics locally
metricsCI.to_csv("./figures/NHITS_288v4.csv", index=False)

💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Shape of raw_preds: torch.Size([741, 12, 1])
          Metric       Mean  Lower CI   Upper CI
0          SMAPE   6.404618  6.030764   6.778471
1            MAE   8.657538  8.060709   9.254367
2           RMSE  10.155661  9.484127  10.827196
3  Quantile_Loss   4.328769  4.030355   4.627183


  .apply(


## Best NHITS val loss

In [6]:
# NHITS older version

# Load Data
import os
import sys
import torch

# Fetch best ckpt from GCS (use your actual bucket/prefix)
BUCKET = "cgmproject2025"
GCS_PREFIX = "checkpoints_nhits_288v10"   # <- change if your run used a different folder
best_ckpt_path = download_best_ckpt_by_filename("checkpoints", BUCKET, GCS_PREFIX,
                                                run_prefix="nhits_288-", metric_key="val_loss")

#from TFT_pytorch import log_memory, create_tft_dataloaders, TFT_train
from scripts.NHITS288v10 import create_nhits_dataloaders, NHiTS, load_nhits_from_gcs

# Rebuild the training dataset (same context_length, horizon, etc.)
training, val_dataloader, train_dataloader, validation = create_nhits_dataloaders(train, horizon=12, context_length=288, batchsize=32)

# Load model from the checkpoint
device = "cuda" if torch.cuda.is_available() else "cpu"
model = NHiTS.load_from_checkpoint(best_ckpt_path, map_location=device).to(device)
model.eval()  # Put model in evaluation mode

Downloaded best checkpoint: gs://cgmproject2025/checkpoints_nhits_288v10/nhits_288-epoch=06-val_loss=7.75.ckpt -> checkpoints/nhits_288-epoch=06-val_loss=7.75.ckpt


  from .autonotebook import tqdm as notebook_tqdm


[GPU] NVIDIA A100-SXM4-40GB  CC=(8, 0), BF16=OK
[2025-09-03 02:52:43.019302] 🚀 Start of Dataloader Creation
GPU Mem allocated: 0.00 GB | reserved: 0.00 GB


/home/shaksonisaac/miniconda3/envs/cgmall/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss'])`.
/home/shaksonisaac/miniconda3/envs/cgmall/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'logging_metrics' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['logging_metrics'])`.


NHiTS(
  	"activation":                        ReLU
  	"backcast_loss_ratio":               0.2
  	"batch_normalization":               False
  	"categorical_groups":                {}
  	"context_length":                    288
  	"dataset_parameters":                {'time_idx': 'ds', 'target': 'cgm_glucose', 'group_ids': ['participant_id'], 'weight': None, 'max_encoder_length': 288, 'min_encoder_length': 288, 'min_prediction_idx': np.int64(11), 'min_prediction_length': 12, 'max_prediction_length': 12, 'static_categoricals': ['participant_id', 'clinical_site', 'study_group'], 'static_reals': ['age'], 'time_varying_known_categoricals': ['sleep_stage'], 'time_varying_known_reals': ['ds', 'minute_of_day', 'tod_sin', 'tod_cos', 'activity_steps', 'calories_value', 'heartrate', 'oxygen_saturation', 'respiration_rate', 'stress_level', 'predmeal_flag'], 'time_varying_unknown_categoricals': None, 'time_varying_unknown_reals': ['cgm_glucose', 'cgm_lag_1', 'cgm_lag_3', 'cgm_lag_6', 'cgm_diff_la

In [7]:
# Get the parameter count of tft
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters in NHITS model: {total_params}")

Total parameters in NHITS model: 807098


In [8]:
# Get global metrics:
raw_preds = model.predict(val_dataloader, mode="raw", return_x=True, return_index=True)
print("Shape of raw_preds:", raw_preds.output["prediction"].shape)
y_pred = raw_preds.output["prediction"] #[:, :, 1] #To get median quantile.
y_true = raw_preds.x["decoder_target"]
index_df = raw_preds.index
records = []
for i in range(len(index_df)):
    uid = index_df.iloc[i]["participant_id"]
    time_start = index_df.iloc[i]["ds"]
    for t in range(y_pred.shape[1]):
        records.append({
            "participant_id": uid,
            "ds": int(time_start + t),
            "target": float(y_true[i, t]),
            "prediction": float(y_pred[i, t]),
        })
records = pd.DataFrame(records)

metricsCI = calculate_metrics_CI(records, train)
print(metricsCI)

# Save metrics locally
metricsCI.to_csv("./figures/NHITS_288best.csv", index=False)

💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/shaksonisaac/miniconda3/envs/cgmall/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Shape of raw_preds: torch.Size([741, 12, 1])
          Metric      Mean  Lower CI   Upper CI
0          SMAPE  6.311799  5.930420   6.693177
1            MAE  8.466096  7.914444   9.017749
2           RMSE  9.933878  9.313917  10.553839
3  Quantile_Loss  4.233048  3.957222   4.508874


  .apply(


# NHITS v2

In [10]:
# NHITS older version

# Load Data
import os
import sys
import torch

# Fetch best ckpt from GCS (use your actual bucket/prefix)
BUCKET = "cgmproject2025"
GCS_PREFIX = "checkpoints_nhits_288v2"   # <- change if your run used a different folder
best_ckpt_path = download_best_ckpt_by_filename("checkpoints", BUCKET, GCS_PREFIX,
                                                run_prefix="nhits_288-", metric_key="val_loss")

#from TFT_pytorch import log_memory, create_tft_dataloaders, TFT_train
from scripts.NHITS288v2 import create_nhits_dataloaders, NHiTS, load_nhits_from_gcs

# Rebuild the training dataset (same context_length, horizon, etc.)
training, val_dataloader, train_dataloader, validation = create_nhits_dataloaders(train, horizon=12, context_length=288, batchsize=32)

# Load model from the checkpoint
device = "cuda" if torch.cuda.is_available() else "cpu"
model = NHiTS.load_from_checkpoint(best_ckpt_path, map_location=device).to(device)
model.eval()  # Put model in evaluation mode

Downloaded best checkpoint: gs://cgmproject2025/checkpoints_nhits_288v2/nhits_288-epoch=01-val_loss=8.53.ckpt -> checkpoints/nhits_288-epoch=01-val_loss=8.53.ckpt
[2025-09-02 03:29:25.459498] 🚀 Start of Dataloader Creation
GPU Mem allocated: 0.00 GB | reserved: 0.03 GB


/home/shaksonisaac/miniconda3/envs/cgmall/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss'])`.
/home/shaksonisaac/miniconda3/envs/cgmall/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'logging_metrics' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['logging_metrics'])`.


NHiTS(
  	"activation":                        ReLU
  	"backcast_loss_ratio":               0.0
  	"batch_normalization":               False
  	"categorical_groups":                {}
  	"context_length":                    288
  	"dataset_parameters":                {'time_idx': 'ds', 'target': 'cgm_glucose', 'group_ids': ['participant_id'], 'weight': None, 'max_encoder_length': 288, 'min_encoder_length': 288, 'min_prediction_idx': np.int64(11), 'min_prediction_length': 12, 'max_prediction_length': 12, 'static_categoricals': ['participant_id', 'clinical_site', 'study_group'], 'static_reals': ['age'], 'time_varying_known_categoricals': ['sleep_stage'], 'time_varying_known_reals': ['ds', 'minute_of_day', 'tod_sin', 'tod_cos', 'activity_steps', 'calories_value', 'heartrate', 'oxygen_saturation', 'respiration_rate', 'stress_level', 'predmeal_flag'], 'time_varying_unknown_categoricals': None, 'time_varying_unknown_reals': ['cgm_glucose', 'cgm_lag_1', 'cgm_lag_3', 'cgm_lag_6', 'cgm_diff_la

In [12]:
# Get global metrics:
raw_preds = model.predict(val_dataloader, mode="raw", return_x=True, return_index=True)
print("Shape of raw_preds:", raw_preds.output["prediction"].shape)
y_pred = raw_preds.output["prediction"]#[:, :, 1] #To get median quantile.
y_true = raw_preds.x["decoder_target"]
index_df = raw_preds.index
records = []
for i in range(len(index_df)):
    uid = index_df.iloc[i]["participant_id"]
    time_start = index_df.iloc[i]["ds"]
    for t in range(y_pred.shape[1]):
        records.append({
            "participant_id": uid,
            "ds": int(time_start + t),
            "target": float(y_true[i, t]),
            "prediction": float(y_pred[i, t]),
        })
records = pd.DataFrame(records)

metricsCI = calculate_metrics_CI(records, train)
print(metricsCI)

# Save metrics locally
metricsCI.to_csv("./figures/NHITS_288v2.csv", index=False)

💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Shape of raw_preds: torch.Size([741, 12, 1])
          Metric       Mean   Lower CI   Upper CI
0          SMAPE   6.769904   6.368118   7.171690
1            MAE   9.193850   8.573059   9.814642
2           RMSE  10.776200  10.082285  11.470115
3  Quantile_Loss   4.596925   4.286529   4.907321


  .apply(
