In [None]:
import tsdm

In [None]:
tsdm.load_model("ODE-RNN")

In [None]:
from tsdm.datasets import Electricity

x = Electricity.dataset
n_data, n_dim = x.shape

In [None]:
import importlib
import sys
from contextlib import contextmanager
from pathlib import Path
from types import ModuleType


@contextmanager
def add_to_path(p: Path) -> None:
    """Source: https://stackoverflow.com/a/41904558/9318372"""
    old_path = sys.path
    sys.path = sys.path[:]
    sys.path.insert(0, str(p))
    try:
        yield
    finally:
        sys.path = old_path


def path_import(module_path: Path, module_name: str = None) -> ModuleType:
    """
    implementation taken from https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
    Source: https://stackoverflow.com/a/41904558/9318372

    Parameters
    ----------
    module_path: Path
        Path to the folder where the module is located
    module_name: str, optional

    Returns
    -------
    """

    module_name = module_name or module_path.parts[-1]
    module_init = module_path.joinpath("__init__.py")
    assert module_init.exists(), f"Module {module_path} has no __init__ file !!!"

    with add_to_path(module_path):
        spec = importlib.util.spec_from_file_location(module_name, str(module_init))
        module = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(module)
        return module

In [None]:
module = path_import(Path("/home/rscholz/.tsdm/models/ODE-RNN"))

create_net = module.lib.utils.create_net
ODEFunc = module.lib.ode_func.ODEFunc
DiffeqSolver = module.lib.diffeq_solver.DiffeqSolver
ODE_RNN = module.lib.ode_rnn.ODE_RNN

In [None]:
import torch
from torch import nn, optim

In [None]:
HP = {
    # Size of the latent state
    "n_ode_gru_dims": 6,
    # Number of layers in ODE func in recognition ODE
    "n_layers": 1,
    # Number of units per layer in ODE func
    "n_units": 100,
    # nonlinearity used
    "nonlinear": nn.Tanh,
    #
    "concat_mask": True,
    # dimensionality of input
    "input_dim": n_dim,
    # device: 'cpu' or 'cuda'
    "device": torch.device("cpu"),
    # Number of units per layer in each of GRU update networks
    "n_gru_units": 100,
    # measurement error
    "obsrv_std": 0.01,
    #
    "use_binary_classif": False,
    #
    "train_classif_w_reconstr": False,
    #
    "classif_per_tp": False,
    # number of outputs
    "n_labels": 1,
    # relative tolerance of ODE solver
    "odeint_rtol": 1e-3,
    # absolute tolereance of ODE solver
    "odeint_atol": 1e-4,
    # batch_size
    "batch-size": 50,
    # learn-rate
    "lr": 1e-2,
}

In [None]:
ode_func_net = create_net(
    HP["n_ode_gru_dims"],
    HP["n_ode_gru_dims"],
    n_layers=HP["n_layers"],
    n_units=HP["n_units"],
    nonlinear=HP["nonlinear"],
)

In [None]:
rec_ode_func = ODEFunc(
    ode_func_net=ode_func_net,
    input_dim=HP["input_dim"],
    latent_dim=HP["n_ode_gru_dims"],
    device=HP["device"],
).to(HP["device"])

In [None]:
z0_diffeq_solver = DiffeqSolver(
    HP["input_dim"],
    rec_ode_func,
    "euler",
    HP["n_ode_gru_dims"],
    odeint_rtol=HP["odeint_rtol"],
    odeint_atol=HP["odeint_atol"],
    device=HP["device"],
)

In [None]:
model = ODE_RNN(
    HP["input_dim"],
    HP["n_ode_gru_dims"],
    device=HP["device"],
    z0_diffeq_solver=z0_diffeq_solver,
    n_gru_units=HP["n_gru_units"],
    concat_mask=HP["concat_mask"],
    obsrv_std=HP["obsrv_std"],
    use_binary_classif=HP["use_binary_classif"],
    classif_per_tp=HP["classif_per_tp"],
    n_labels=HP["n_labels"],
    train_classif_w_reconstr=HP["train_classif_w_reconstr"],
).to(HP["device"])

model

In [None]:
n_steps = 100

relative_time = (x.index[:n_steps] - x.index[0]) / (x.index[1] - x.index[0])
T = torch.from_numpy(relative_time.values).float().to(HP["device"])
X = torch.from_numpy(x[:n_steps].values).float().to(HP["device"]).unsqueeze(0)

batch_dict = {
    "observed_tp": T,
    "tp_to_predict": T,
    "mask_predicted_data": torch.ones_like(X),
    "data_to_predict": X,
    "observed_data": X,
    "observed_mask": torch.ones_like(X),
    "labels": None,
    "mode": "interp",
}

In [None]:
for key, val in batch_dict.items():
    if type(val) == torch.Tensor:
        print(key, val.shape)
    else:
        print(key, val)

In [None]:
pred_y, info = model.get_reconstruction(
    batch_dict["tp_to_predict"],
    batch_dict["observed_data"],
    truth_time_steps=batch_dict["observed_tp"],
    mask=batch_dict["observed_mask"],
    n_traj_samples=1,
    mode=batch_dict["mode"],
)
info

In [None]:
pred_y.shape

In [None]:
info

In [None]:
batch_dict["tp_to_predict"]

In [None]:
batch_dict["observed_data"]

In [None]:
model.compute_all_losses(batch_dict, n_traj_samples=20, kl_coef=0)

In [None]:
from tqdm.auto import trange

In [None]:
optimizer = optim.Adamax(model.parameters(), lr=HP["lr"])

for n in (pbar := trange(1000)):
    optimizer.zero_grad()
    train_res = model.compute_all_losses(batch_dict, n_traj_samples=3, kl_coef=0)
    pbar.set_postfix({key: float(val) for key, val in train_res.items()})
    train_res["loss"].backward()
    optimizer.step()

In [None]:
train_res

In [None]:
x.index

In [None]:
tsdm.make_dense_triplets(x)

In [None]:
from typing import TypedDict


class ODE_RNN_HP(TypedDict):
    param1: int = 3
    param2: str = "gaga"

In [None]:
isinstance(ODE_RNN_HP(), dict)

In [None]:
ODE_RNN_HP.keys()