In [1]:
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.metrics import RMSE
from pytorch_lightning import Trainer
import torch

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import glob
from concurrent.futures import ThreadPoolExecutor

In [2]:
df3 = pd.read_parquet("/home/data/processed/FeatureEng.parquet")

In [3]:
# ✅ Drop NA values for all required target variables
tft_df = df3.dropna(subset=['cpu_utilization_ratio', 'memory_utilization_ratio', 'disk_total_throughput', 'network_total_throughput'])
df3['vm_id'] = df3['VM']

In [None]:
# Define target variables
# targets = ['cpu_utilization_ratio', 'memory_utilization_ratio', 'disk_total_throughput', 'network_total_throughput']
targets = ['cpu_utilization_ratio']
time_varying_known_reals = ['time_idx', 'hour_sin', 'hour_cos', 'dayofweek_sin', 'dayofweek_cos', 'month_sin', 'month_cos']

In [None]:
import os
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import shutil

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.metrics import RMSE

# 🔧 Unified config: change here only
train_config = {
    "targets": ['cpu_utilization_ratio'],  # change to full list if needed
    "time_varying_known_reals": ['time_idx', 'hour_sin', 'hour_cos', 'dayofweek_sin', 'dayofweek_cos', 'month_sin', 'month_cos'],
    "group_ids": ['vm_id'],
    "max_encoder_length": 8,
    "max_prediction_length": 2,
    "hidden_size": 4,
    "dropout": 0.1,
    "learning_rate": 0.03,
    "batch_size": 4,
    "epochs": 1,
    "loss_fn": RMSE(),
    "output_base_dir": "/home/output",
    "log_dir": "/home/output/logs",
    "accelerator_opt": "cpu",
}

# 🚀 Run for each target
for target in train_config["targets"]:
    print(f"\n🔁 Training for target: {target}")

    print(f"💾 Torch using device: {'cuda' if torch.cuda.is_available() else 'cpu'}")
    print(f"🧠 Total rows in full tft_df: {len(tft_df)}")


    run_dir = os.path.join(train_config["output_base_dir"], f"{target}_run")
    os.makedirs(run_dir, exist_ok=True)

    train_df = tft_df[tft_df.time_idx <= tft_df['time_idx'].max() * 0.8].copy()
    train_df = train_df[np.isfinite(train_df[target])]
    train_df = train_df.dropna(subset=[target])
    bad_rows = tft_df[~np.isfinite(tft_df[target]) | tft_df[target].isna()]
    print(f"⚠️ {len(bad_rows)} bad rows removed for target: {target}")

    # Dataset
    dataset = TimeSeriesDataSet(
        train_df,
        time_idx='time_idx',
        target=target,
        group_ids=train_config["group_ids"],
        max_encoder_length=train_config["max_encoder_length"],
        max_prediction_length=train_config["max_prediction_length"],
        time_varying_known_reals=train_config["time_varying_known_reals"],
        time_varying_unknown_reals=train_config["targets"],
        target_normalizer=GroupNormalizer(groups=train_config["group_ids"]),
        add_relative_time_idx=True,
        add_target_scales=True,
        add_encoder_length=True,
        allow_missing_timesteps=True
    )

    val_dataset = TimeSeriesDataSet.from_dataset(dataset, tft_df, predict=True, stop_randomization=True, allow_missing_timesteps=True)

    train_dataloader = dataset.to_dataloader(train=True, batch_size=train_config["batch_size"], num_workers=0)
    val_dataloader = val_dataset.to_dataloader(train=False, batch_size=train_config["batch_size"], num_workers=0)

    # Logger and checkpoint
    logger = CSVLogger(save_dir=train_config["log_dir"], name=f"{target}_log")
    checkpoint_callback = ModelCheckpoint(
        monitor="val_loss",
        dirpath=run_dir,
        filename="tft-{epoch:02d}-{val_loss:.2f}",
        save_top_k=1,
        save_last=True,
        mode="min"
    )

    # Load or create model
    ckpt_path = os.path.join(run_dir, "tft-last.ckpt")
    if os.path.exists(ckpt_path):
        print(f"📦 Resuming from checkpoint: {ckpt_path}")
        model = TemporalFusionTransformer.load_from_checkpoint(
            checkpoint_path=ckpt_path,
            dataset=dataset,
            loss=train_config["loss_fn"]
        )
    else:
        print("🆕 Starting new model")
        model = TemporalFusionTransformer.from_dataset(
            dataset,
            learning_rate=train_config["learning_rate"],
            hidden_size=train_config["hidden_size"],
            dropout=train_config["dropout"],
            loss=train_config["loss_fn"],
            log_interval=10,
            reduce_on_plateau_patience=4,
        )

    # Trainer
    trainer = Trainer(
        max_epochs=train_config["epochs"],
        accelerator=["accelerator_opt"],
        devices=1 if torch.cuda.is_available() else None,
        logger=logger,
        callbacks=[checkpoint_callback],
        enable_checkpointing=True
    )

    # Train
    trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

    # Predict
    predictions, x = model.predict(val_dataloader, mode='raw', return_x=True)
    forecast = predictions['prediction'][0].detach().cpu().numpy()

    # Plot and save
    plt.figure(figsize=(10, 6))
    model.plot_prediction(x, predictions, idx=0, show_future_observed=True)
    plt.title(f"Prediction Plot for {target}")
    plt.savefig(f"{run_dir}/plot.png")
    plt.close()

    # Save predictions
    pd.DataFrame(forecast, columns=[f'{target}_forecast']).to_csv(f"{run_dir}/predictions.csv", index=False)

    # Save loss log
    log_csv_path = os.path.join(logger.log_dir, "metrics.csv")
    if os.path.exists(log_csv_path):
        shutil.copy(log_csv_path, f"{run_dir}/loss_log.csv")

    # Save parameters
    with open(f"{run_dir}/params.json", "w") as f:
        json.dump(train_config, f, indent=2)

    # Save notes with spikes info
    spikes = forecast > np.percentile(forecast, 95)
    with open(f"{run_dir}/notes.txt", "w") as f:
        f.write(f"Target: {target}\n")
        f.write(f"Spikes > 95th percentile: {int(spikes.sum())}\n")
        f.write("Review plot.png and predictions.csv for further insights.\n")

    print(f"✅ Run complete — outputs saved at: {run_dir}")