# Create a Forecast plot given a stored model

### Notebook Configuration

In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

import logging
import pickle
from copy import deepcopy
from pathlib import Path

import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import numpy as np
import pandas
import pandas as pd
import torch
import torchinfo
from pandas import DataFrame, Index, Series, Timedelta, Timestamp
from torch import Tensor, jit, tensor
from torch.utils.data import DataLoader

rng = np.random.default_rng()
np.set_printoptions()
plt.style.use("bmh")
plt.rcParams["axes.axisbelow"] = True
# logging.basicConfig(level=logging.INFO)

### Enter Path and Name

In [None]:
Fold = 0
Partition = "test"
RunID = 449
ExpID = 15653
SplitID = (Fold, Partition)
KEY = (RunID, ExpID)

In [None]:
# files:  2021-11-12T00:52:10 2021-11-12T00:51:55 2021-11-12T00:51:48
# "checkpoints/2021-11-15T12:05:00/LinODEnet-0"
# "checkpoints/LinODEnet/KIWI_RUNS/skew_allways/2021-11-15T16:05:41/LinODEnet-0"
# "adam/2021-11-15T20:38:52/LinODEnet-0"
PATH = Path("/home/rscholz/Projects/KIWI/tsdm/experiments/evaluation/checkpoints/")
NAME = "RecursiveScriptModule-30"
# the model checkpoint, should be a zip-archive created by torch.save / torch.jit.save
MODEL_FILE = PATH / NAME
DEVICE = torch.device("cpu")
DTYPE = torch.float32
NAN = tensor(float("nan"), dtype=DTYPE, device=DEVICE)
PRD_HORIZON = "2h"
OBS_HORIZON = "2h"
# HORIZON = "3h"

## Intialize the task

In [None]:
from tsdm.tasks import KiwiTask

task = KiwiTask(
    sampler_kwargs=dict(
        forecasting_horizon=PRD_HORIZON,
        observation_horizon=OBS_HORIZON,
    )
)

## Load the Model

In [None]:
model = torch.jit.load(MODEL_FILE, torch.device("cpu"))
torchinfo.summary(model, depth=1)

## Load the Encoder

In [None]:
with open(PATH / "encoder.pickle", "rb") as file:
    encoder = pickle.load(file)
encoder

## Import Task

In [None]:
split = task.splits[SplitID][KEY]
# encoder = task.encoders[SplitID]
sampler = task.samplers[SplitID][KEY]
sampler.shuffle = False
TS = split.timeseries

from tsdm.tasks import TimeSeriesSampleGenerator

generator = TimeSeriesSampleGenerator(
    split,
    observables=task.observables,
    targets=task.targets,
    covariates=task.covariates,
)

samples = [generator[key] for key in sampler];

In [None]:
sample = samples[0]
tx, x = encoder.encode(sample.inputs.x).values()
tx, x = encoder.encode(sample.inputs.x).values()
ty, y = encoder.encode(sample.targets.y).values()
yhat = model(tx, x)
reconstructed = encoder.decode({"T": tx, "X": yhat})

In [None]:
reconstructed.index  # .loc[: task.observation_horizon]

### Helper Function for Batch post-processing

In [None]:
def prep_batch(batch: tuple[Tensor, Tensor]):
    """Get batch and create model inputs and targets."""
    T, X = batch
    targets = X[..., task.observation_horizon :, task.targets.index].clone()
    # assert targets.shape == (BATCH_SIZE, PRD_HORIZON, len(TASK.targets))
    originals = X.clone()
    inputs = X.clone()
    inputs[:, task.observation_horizon :, task.targets.index] = NAN
    inputs[:, task.observation_horizon :, task.observables.index] = NAN
    # assert inputs.shape == (BATCH_SIZE, HORIZON, NUM_DIM)
    return T, inputs, targets, originals

## Helper function to create the plot

In [None]:
def make_plot(axes, task, sample):
    """Helper function to create plot automatically."""
    # batch = slices[0]
    tx, x = encoder.encode(sample.inputs.x).values()
    ty, y = encoder.encode(sample.targets.y).values()
    yhat = model(tx, x)
    reconstructed = encoder.decode({"T": tx, "X": yhat})
    # mask = sample.inputs.x.notna() | sample.targets.y.notna()

    time = reconstructed.index.to_series()
    t0 = time[0]
    t1 = t0 + pd.Timedelta(task.observation_horizon)
    t2 = t1 + pd.Timedelta(task.forecasting_horizon)

    # # convert to float
    # time = time / np.timedelta64(1, "h")
    # t0 = t0 / np.timedelta64(1, "h")
    # t1 = t1 / np.timedelta64(1, "h")
    # t2 = t2 / np.timedelta64(1, "h")
    # reconstructed.index = time
    # time = time.to_series()
    # display(t0, t1,  time.loc[t0:t1])

    for ax, target in zip(axes.flatten(), task.targets):
        # color = next(ax._get_lines.prop_cycler)["color"]
        ax.axvspan(t0.to_numpy(), t1.to_numpy(), facecolor="grey", alpha=0.3)
        ax.axvspan(t1.to_numpy(), t2.to_numpy(), facecolor="green", alpha=0.3)
        ax.plot(
            time.loc[t0:t1], reconstructed.loc[t0:t1, target], ls=":", lw=2, color="y"
        )
        ax.plot(
            time.loc[t1:t2], reconstructed.loc[t1:t2, target], ls="-", lw=2, color="r"
        )

### Create the Raw data plot

In [None]:
%matplotlib widget

fig, axes = plt.subplots(
    nrows=3, ncols=3, sharex=True, figsize=(16, 8), constrained_layout=True
)

for ax, target in zip(axes.flatten(), task.targets):
    data = TS[target]
    times = TS.index.values
    mask = ~np.isnan(data)
    ax.plot(
        times[mask],
        data[mask],
        ls=":",
        lw=0.5,
        marker=".",
        ms=6,
    )
    ax.legend([f"{target} - observations"])

fig.suptitle(f"{Fold=}  {Partition=}  {RunID=}  {ExpID=}");

### Add the model Forecast Plots

In [None]:
grid = np.linspace(0, len(samples) - 1, num=3, dtype=int)
# grid = [0]
batch = [samples[key] for key in grid]

for sample in batch:
    make_plot(axes, task, sample)

fig.savefig(f"{NAME.replace(r'/', r'_')}_{PRD_HORIZON}.pdf")

In [None]:
raise

In [None]:
from torch.optim import AdamW

In [None]:
hparams = {
    "optimizer": {
        "__name__": "AdamW",
        "__module__": "torch.optim",
        "lr": 0.001,
        "betas": (0.9, 0.999),
    }
}

In [None]:
import yaml

In [None]:
with open("hparams.yaml", "w") as file:
    yaml.safe_dump(hparams, file)