In [None]:
from pathlib import Path

from fusiontimeseries.lib.get_next_path import get_next_path

base_dir = Path("./results")
base_dir.mkdir(parents=True, exist_ok=True)
output_dir = get_next_path(base_fname="chronos2-rss-bilinear", base_dir=base_dir)
output_dir.mkdir(parents=True, exist_ok=False)
print(f"Output directory created at: {output_dir}")

In [None]:
import torch

torch.cuda.empty_cache()

In [None]:
from fusiontimeseries.lib.config import FTSConfig

fts_config = FTSConfig(
    op_embedding_dim=512,
    num_ops=4,
    context_length=512,
    pred_tail_timestamps=80,
    batch_size=128,
    stratification_bins=5,
    sampling_bins=5,
    val_size=0.1,
    padding_value=torch.nan,
    padding_mask_default=0.0,
    padding_mask_indicator=1.0,
    stratification="opc_pca",
    sampling_strategy="linear",
    data_augmentation="white_noise",
    learning_rate=1e-4,
    lr_scheduler_type="linear",
    lr_scheduler_warmup_ratio=0.0,
    optimizer_type="adamw_torch_fused",
    max_grad_norm=1.0,
    max_steps=4000,
    eval_steps=200,
    gradient_accumulation_steps=1,
)

In [None]:
from chronos import Chronos2Model

model = Chronos2Model.from_pretrained("amazon/chronos-2")
model.chronos_config.context_length = fts_config.context_length
model = model.to(fts_config.device)  # type: ignore
model.device

In [None]:
from fusiontimeseries.modules import ContinuousConditionEmbed

shared_p_projection = ContinuousConditionEmbed(
    embedding_dim=fts_config.op_embedding_dim,
    n_cond=fts_config.num_ops,
    max_wavelength=10_000,
    init_weights="kaiming_uniform",
)

In [None]:
import math
from torch import nn
import torch.nn.functional as F

from fusiontimeseries.lib.conditioning import ConditionRegistry
from fusiontimeseries.loralib.layers import LoRALayer, layer_dict
from fusiontimeseries.loralib.utils import expand_like

OP_PARAM_KEY: str = "op_params"


class RSSBilinearLoRA(nn.Linear, LoRALayer):
    # Class-level shared projection (will be set once for all instances)
    _shared_p_projection: ContinuousConditionEmbed | None = None

    def __init__(
        self,
        in_features: int,
        out_features: int,
        r: int = 0,
        lora_alpha: int = 1,
        lora_dropout: float = 0.0,
        fan_in_fan_out: bool = False,
        merge_weights: bool = False,  # keep false due to conditioning
        post_layer_norm: bool = False,
        pre_batch_norm: bool = False,
        **kwargs,
    ):
        nn.Linear.__init__(self, in_features, out_features, **kwargs)
        LoRALayer.__init__(
            self,
            r=r,
            lora_alpha=lora_alpha,
            lora_dropout=lora_dropout,
            merge_weights=merge_weights,
        )

        self.fan_in_fan_out = fan_in_fan_out
        self.in_features = in_features
        self.out_features = out_features
        self.post_layer_norm = post_layer_norm
        self.pre_batch_norm = pre_batch_norm

        self.lora_scale = self.lora_alpha / r
        self.p_dim = fts_config.op_embedding_dim  # conditioning dimension

        self._init_lora(r)
        self.reset_parameters()
        if fan_in_fan_out:
            self.weight.data = self.weight.data.transpose(0, 1)
        if self.post_layer_norm:
            self.post_ln = nn.LayerNorm(out_features)
            self.merge_weights = False
        if self.pre_batch_norm:
            self.pre_bn = nn.BatchNorm1d(in_features, affine=False)
            self.merge_weights = False

    def _init_lora(self, r):
        # Actual trainable parameters
        if r > 0:
            self.lora_A = nn.Parameter(self.weight.new_zeros((r, self.in_features)))
            self.lora_B = nn.Parameter(self.weight.new_zeros((self.out_features, r)))
            self.lora_C = nn.Parameter(self.weight.new_zeros((r, self.p_dim)))
            self.lora_S = nn.Parameter(self.weight.new_zeros((r, self.p_dim)))
        else:
            try:
                # ensure parameters do not exist if they are zero
                delattr(self, "lora_A")
                delattr(self, "lora_B")
                delattr(self, "lora_C")
                delattr(self, "lora_S")
                delattr(self, "lora_scale")
            except AttributeError:
                pass
        self.weight.requires_grad = False
        self.r = r

    def reset_parameters(self):
        nn.Linear.reset_parameters(self)
        # initialize A the same way as the default for nn.Linear and B to zero
        # adapt to initialization via PCA on pretrained weights
        if hasattr(self, "lora_C"):
            nn.init.kaiming_uniform_(self.lora_C, a=math.sqrt(5))
        if hasattr(self, "lora_S"):
            nn.init.zeros_(self.lora_S)
        if hasattr(self, "lora_A"):
            nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        if hasattr(self, "lora_B"):
            nn.init.zeros_(self.lora_B)

    def change_lora_rank(self, new_rank):
        if new_rank != self.r:
            self._init_lora(new_rank)

    def train(self, mode: bool = True):
        def T(w):
            return w.transpose(0, 1) if self.fan_in_fan_out else w

        nn.Linear.train(self, mode)
        if mode:
            if self.merge_weights and self.merged:
                # Make sure that the weights are not merged
                if self.r > 0:
                    self.weight.data -= T(self.lora_B @ self.lora_A) * self.lora_scale
                self.merged = False
        else:
            if self.merge_weights and not self.merged:
                # Merge the weights and mark it
                if self.r > 0:
                    self.weight.data += T(self.lora_B @ self.lora_A) * self.lora_scale
                self.merged = True

    def forward(self, x: torch.Tensor):
        """
        Args:
            x (torch.Tensor): (B, ..., in_features)
        """

        def T(w):
            return w.transpose(0, 1) if self.fan_in_fan_out else w

        # Base linear projection
        y_base = F.linear(x, T(self.weight), bias=self.bias)  # (B, ..., out_features)
        if self.pre_batch_norm:
            x = self.pre_bn(x)

        if self.r > 0 and not self.merged:
            p: torch.Tensor | None = ConditionRegistry.get(OP_PARAM_KEY)
            if p is not None and self._shared_p_projection is not None:
                p = p.to(self.lora_C.device)

                # Learnable representation of p: (B, p_dim)
                p_repr: torch.Tensor = self._shared_p_projection(p)

                # condition projection: (B, r) = (B, p_dim) @ (r, p_dim).T
                c = F.linear(p_repr, self.lora_C)

                # rank_space_shift: (B, r) = (B, p_dim) @ (r, p_dim).T
                s = F.linear(p_repr, self.lora_S)

                # input projection: (B, ..., r) = ( (B, ..., in_features) @ (r, in_features).T )
                h = F.linear(self.lora_dropout(x), self.lora_A)

                # introduce shift gate
                # only activate gate when x activates that rank direction
                s_gate = torch.sigmoid(h.norm(dim=-1, keepdim=True))

                # Bilinear FiLM-inspired modulation in rank space
                # (B, ..., r)
                h_mod = h * (
                    1.0 + expand_like(target=c, like=h)
                ) + s_gate * expand_like(target=s, like=h)

                # Project back to output space
                # (B, ..., out_features) = ( (B, ..., r) @ (out_features, r).T )
                delta_y = F.linear(h_mod, self.lora_B) * self.lora_scale
                y_base += delta_y

            else:
                raise RuntimeError(
                    "Operating parameters not found or _shared_p_projection is None during forward pass."
                )

            if self.post_layer_norm:
                y_base = self.post_ln(y_base)

        return y_base


# Register the custom layer
layer_dict["RSSBilinearLoRA"] = RSSBilinearLoRA

# Register shared projection
RSSBilinearLoRA._shared_p_projection = shared_p_projection

In [None]:
model = RSSBilinearLoRA.convert(
    module=model,
    kind="RSSBilinearLoRA",
    lora_rank=8,
    lora_alpha=16,
    target_module_names=[
        "self_attention.q",
        "self_attention.k",
        "self_attention.v",
        "self_attention.o",
        "output_patch_embedding.output_layer",
    ],
)

# Register the shared projection as a submodule so it gets saved/loaded properly
model.shared_condition_projection = shared_p_projection
model.shared_condition_projection.to(fts_config.device)

In [None]:
from fusiontimeseries.loralib.utils import mark_only_lora_as_trainable


mark_only_lora_as_trainable(model=model, bias="none")

In [None]:
from fusiontimeseries.loralib.utils import print_trainable_parameters


print_trainable_parameters(model, save_path=output_dir / "trainable_params.json")

In [None]:
from fusiontimeseries.finetuning.chronos2.dataset import Chronos2Dataset

train_dataset, val_dataset = Chronos2Dataset.train_val_split(fts_config)

In [None]:
from transformers.training_args import TrainingArguments

training_arguments = TrainingArguments(
    output_dir=str(output_dir),
    per_device_train_batch_size=fts_config.batch_size,
    per_device_eval_batch_size=fts_config.batch_size,
    learning_rate=fts_config.learning_rate,
    lr_scheduler_type=fts_config.lr_scheduler_type,
    warmup_ratio=fts_config.lr_scheduler_warmup_ratio,
    optim=fts_config.optimizer_type,
    logging_strategy="steps",
    logging_steps=fts_config.eval_steps,
    disable_tqdm=False,
    report_to="none",
    max_steps=fts_config.max_steps,
    gradient_accumulation_steps=fts_config.gradient_accumulation_steps,
    dataloader_num_workers=0,
    tf32=False,
    bf16=False,
    save_only_model=True,
    prediction_loss_only=True,
    save_total_limit=2,
    save_strategy="steps",
    save_steps=fts_config.eval_steps,
    eval_strategy="steps",
    eval_steps=fts_config.eval_steps,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    use_cpu=False,
    label_names=["future_target"],
    remove_unused_columns=False,
    max_grad_norm=fts_config.max_grad_norm,
)
training_arguments._n_gpu = 1

In [None]:
from transformers import Trainer

from chronos.chronos2.model import Chronos2Model, Chronos2Output


class ConditionedTrainer(Trainer):
    def compute_loss(
        self,
        model: Chronos2Model,
        inputs: dict[str, torch.Tensor],
        *args,
        return_outputs=False,
        **kwargs,
    ):
        # Tensor[B, N]
        p_raw: torch.Tensor | None = inputs.pop(
            "operating_parameters", None
        )  # remove before forward, otherwise TypeError in Trainer
        assert p_raw is not None, "operating_parameters key is missing in inputs"

        with ConditionRegistry.patch(op_params=p_raw):
            outputs: Chronos2Output = model(**inputs)

        loss = outputs.loss if hasattr(outputs, "loss") else outputs[0]
        return (loss, outputs) if return_outputs else loss

In [None]:
import json

trainer = ConditionedTrainer(
    model=model,
    args=training_arguments,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)
with open(output_dir / "training_args.json", "w") as f:
    json.dump(trainer.args.to_dict(), f, indent=4)

In [None]:
import json

train_output = trainer.train()

with open(output_dir / "train_summary.json", "w") as f:
    json.dump(train_output._asdict(), f, indent=4)

In [None]:
from fusiontimeseries.loralib.utils import lora_state_dict

lora_weights = lora_state_dict(model)
torch.save(lora_weights, output_dir / "lora_weights.pt")

In [None]:
benchmark_data = Chronos2Dataset.get_benchmark_flux_traces(fts_config)
model = model.eval()

In [None]:
from fusiontimeseries.lib.benchmarking import rmse_with_standard_error
from fusiontimeseries.lib.dataset import FluxData
import numpy as np

START_IDX: int = 80

id_benchmark_data = benchmark_data["id"]
id_benchmark_forecasts: dict[int, list[float]] = {}
for flux_id, flux_data in id_benchmark_data.items():
    flux_data: FluxData
    energy_flux = np.array(flux_data.energy_flux)
    op_params = (
        torch.Tensor(flux_data.operating_parameters).unsqueeze(0).to(fts_config.device)
    )

    ctx: np.ndarray = energy_flux[:START_IDX]
    while len(ctx) < len(energy_flux):
        with torch.no_grad():
            tctx = torch.full(
                size=(1, fts_config.context_length), fill_value=fts_config.padding_value
            )  # NaN
            tctx[0, -len(ctx) :] = torch.tensor(ctx)
            context_mask = torch.full_like(
                tctx, fill_value=fts_config.padding_mask_default
            )  # 0.0
            context_mask[0, -len(ctx) :] = fts_config.padding_mask_indicator  # 1.0

            with ConditionRegistry.patch(op_params=op_params):
                output: Chronos2Output = model(
                    context=tctx.to(fts_config.device),
                    context_mask=context_mask.to(fts_config.device),
                )
        if output.quantile_preds is not None:
            quantiles: torch.Tensor = output.quantile_preds  # (B, Qs=21, pred_len)
            median_quantile: int = quantiles.shape[1] // 2
            forecast: np.ndarray = (
                quantiles[:, median_quantile, :].cpu().numpy().flatten()
            )  # (pred_len,)
            ctx = np.concatenate([ctx, forecast])

    id_benchmark_forecasts[flux_id] = ctx[
        : len(energy_flux)
    ].tolist()  # Trim to original length


id_benchmark_means: list[np.floating] = [
    np.mean(flux_data.energy_flux[:-80]) for _id, flux_data in id_benchmark_data.items()
]
id_forecast_means: list[np.floating] = [
    np.mean(forecast[:-80]) for forecast in id_benchmark_forecasts.values()
]
id_rmse, id_rmse_se = rmse_with_standard_error(
    y_true=np.array(id_benchmark_means), y_pred=np.array(id_forecast_means)
)
id_rmse, id_rmse_se, id_benchmark_means, id_forecast_means

In [None]:
ood_benchmark_data = benchmark_data["ood"]
ood_benchmark_forecasts: dict[int, list[float]] = {}
for flux_id, flux_data in ood_benchmark_data.items():
    flux_data: FluxData
    energy_flux = np.array(flux_data.energy_flux)
    op_params = (
        torch.Tensor(flux_data.operating_parameters).unsqueeze(0).to(fts_config.device)
    )

    ctx: np.ndarray = energy_flux[:START_IDX]
    while len(ctx) < len(energy_flux):
        with torch.no_grad():
            tctx = torch.full(
                size=(1, fts_config.context_length), fill_value=fts_config.padding_value
            )  # NaN
            tctx[0, -len(ctx) :] = torch.tensor(ctx)
            context_mask = torch.full_like(
                tctx, fill_value=fts_config.padding_mask_default
            )  # 0.0
            context_mask[0, -len(ctx) :] = fts_config.padding_mask_indicator  # 1.0

            with ConditionRegistry.patch(op_params=op_params):
                output: Chronos2Output = model(
                    context=tctx.to(fts_config.device),
                    context_mask=context_mask.to(fts_config.device),
                )
        if output.quantile_preds is not None:
            quantiles: torch.Tensor = output.quantile_preds  # (B, Qs=21, pred_len)
            median_quantile: int = quantiles.shape[1] // 2
            forecast: np.ndarray = (
                quantiles[:, median_quantile, :].cpu().numpy().flatten()
            )  # (pred_len,)
            ctx = np.concatenate([ctx, forecast])

    ood_benchmark_forecasts[flux_id] = ctx[
        : len(energy_flux)
    ].tolist()  # Trim to original length

ood_benchmark_means: list[np.floating] = [
    np.mean(flux_data.energy_flux[:-80])
    for _id, flux_data in ood_benchmark_data.items()
]
ood_forecast_means: list[np.floating] = [
    np.mean(forecast[:-80]) for forecast in ood_benchmark_forecasts.values()
]
ood_rmse, ood_rmse_se = rmse_with_standard_error(
    y_true=np.array(ood_benchmark_means), y_pred=np.array(ood_forecast_means)
)
ood_rmse, ood_rmse_se

In [None]:
# save forecast means for each time series id and the total rmse and standard error for id and ood to the load_dir in a json file
import json

results = {
    "metrics": {
        "id": {
            "rmse": float(id_rmse),
            "standard_error": float(id_rmse_se),
        },
        "ood": {
            "rmse": float(ood_rmse),
            "standard_error": float(ood_rmse_se),
        },
    },
    "forecasts": {
        "id": {**id_benchmark_forecasts},
        "ood": {**ood_benchmark_forecasts},
    },
}

with open(output_dir / "benchmark_results.json", "w") as f:
    json.dump(results, f, indent=4)

In [None]:
from matplotlib import pyplot as plt

fig, axes = plt.subplots(3, 2, figsize=(3 * 3, 2 * 2), sharex="col")
axes = axes.flatten()
for ax_id, flux_id in enumerate(id_benchmark_data.keys()):
    ax = axes[ax_id]
    flux_data = id_benchmark_data[flux_id]
    energy_flux = np.array(flux_data.energy_flux)

    ax.plot(energy_flux)
    ax.plot(id_benchmark_forecasts[flux_id])
    ax.text(
        0.5,
        0.9,
        f"OPs: {flux_data.operating_parameters}",
        horizontalalignment="center",
        verticalalignment="center",
        transform=ax.transAxes,
        fontsize=7,
    )
    ax.set_title(f"Flux ID: {flux_id}", fontdict={"fontsize": 5})
    ax.set_xlabel("Time Step", fontdict={"fontsize": 5})
    ax.set_ylabel("Energy Flux", fontdict={"fontsize": 5})
fig.savefig(output_dir / "id_benchmark_forecasts.png")
fig.legend()
fig.show()

In [None]:
from matplotlib import pyplot as plt

fig, axes = plt.subplots(3, 2, figsize=(3 * 3, 2 * 2), sharex="col")
axes = axes.flatten()
for ax_id, flux_id in enumerate(ood_benchmark_data.keys()):
    ax = axes[ax_id]
    flux_data = ood_benchmark_data[flux_id]
    energy_flux = np.array(flux_data.energy_flux)

    ax.plot(energy_flux)
    ax.plot(ood_benchmark_forecasts[flux_id])
    ax.text(
        0.5,
        0.9,
        f"OPs: {flux_data.operating_parameters}",
        horizontalalignment="center",
        verticalalignment="center",
        transform=ax.transAxes,
        fontsize=7,
    )
    ax.set_title(f"Flux ID: {flux_id}", fontdict={"fontsize": 5})
    ax.set_xlabel("Time Step", fontdict={"fontsize": 5})
    ax.set_ylabel("Energy Flux", fontdict={"fontsize": 5})
fig.savefig(output_dir / "ood_benchmark_forecasts.png")
fig.legend()
fig.show()