In [36]:
import os
import sys


def set_environment_variables_from_file(filename):
    with open(filename, "r") as f:
        for line in f:
            name, value = line.strip().split("=", 1)
            os.environ[name] = value

    # After setting environment variables, update sys.path based on PYTHONPATH
    pythonpath = os.getenv("PYTHONPATH")
    if pythonpath:
        # Split PYTHONPATH into paths, add them to sys.path if not already present
        for path in pythonpath.split(":"):
            if path not in sys.path:
                sys.path.insert(0, path)


# Path to the file where you saved the environment variables
environment_variables_filename = "/home/preston/ADAM-ROS/environment_variables.txt"
set_environment_variables_from_file(environment_variables_filename)

import numpy as np
import torch
from torch.utils.data import DataLoader
from safety_filter.learning.cvae_utils import elbo_loss, sample_outputs
from dataclasses import dataclass, asdict
from datetime import datetime
from pathlib import Path
import yaml

import plotly.graph_objects as go

from safety_filter.learning.models import CVAE
from safety_filter.learning.DI.data import DIDataset
from safety_filter.learning.DI.train import VAL_DATASET_PATH, TRAIN_DATASET_PATH, TrainConfig

model_path = Path("models/2024-03-19_15-52-36")
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [37]:
# Load the config from file
config_file = model_path / "config.yaml"
with open(config_file, "r") as f:
    config_dict = yaml.safe_load(f)
    config = TrainConfig(**config_dict)

# Load the latest checkpoint (of all .pth files under model_path)
model_file = max(model_path.glob("*.pth"), key=os.path.getctime)
model = CVAE(
    output_dim=config.output_dim,
    latent_dim=config.latent_dim,
    cond_dim=config.cond_dim,
    encoder_layers=config.encoder_layers,
    decoder_layers=config.decoder_layers,
    prior_layers=config.prior_layers,
)
model.load_state_dict(torch.load(model_file))
model.to(config.device)
model.eval()

train_dataset = DIDataset(TRAIN_DATASET_PATH)
val_dataset = DIDataset(VAL_DATASET_PATH)

In [41]:
# Get a sequential batch of val states + query the model for residuals
batch_size = 500
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Get a batch of data
start_iter = batch_size
batch = val_dataset[start_iter : start_iter + batch_size]


n_samples = 10000

# Query the model for normalized disturbances.
d, cond = batch["d"].to(config.device), batch["cond"].to(config.device)
d = model.output_normalizer.normalize(d)

pred_mean, pred_var = sample_outputs(model, cond, n_samples)
index = 1

# Plot the actual d and predicted mean.
fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=np.arange(batch_size),
        y=batch["d"][:, index].cpu().numpy(),
        name="Actual",
        line=dict(color="blue"),
    )
)

fig.add_trace(
    go.Scatter(
        x=np.arange(batch_size),
        y=pred_mean[:, index].cpu().numpy(),
        mode="lines",
        name="Predicted",
        line=dict(color="red"),
    )
)

# Fill between the 2 std devs.
fig.add_trace(
    go.Scatter(
        x=np.arange(batch_size),
        y=pred_mean[:, index].cpu().numpy()
        - 2 * pred_var[:, index, index].sqrt().cpu().numpy(),
        mode="lines",
        name="Predicted - 2 std",
        line=dict(color="green"),
    )
)

fig.add_trace(
    go.Scatter(
        x=np.arange(batch_size),
        y=pred_mean[:, index].cpu().numpy()
        + 2 * pred_var[:, index, index].sqrt().cpu().numpy(),
        mode="lines",
        name="Predicted + 2 std",
        line=dict(color="green"),
    )
)

fig.show()

In [42]:
# Plot the actual mean and std for each data point.
from safety_filter.learning.DI import gen_di_data

index = 1
mean_true = torch.zeros_like(batch["d"])
var_true = torch.zeros(batch_size, batch["d"].shape[-1], batch["d"].shape[-1])

for ii in range(batch_size):
    mean_true[ii], var_true[ii] = gen_di_data.d_mean(batch["x"][ii]), gen_di_data.d_var(
        batch["x"][ii]
    )


fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=np.arange(batch_size),
        y=mean_true[:, index].numpy(),
        mode="lines",
        name="Actual",
        line=dict(color="blue"),
    )
)

fig.add_trace(
    go.Scatter(
        x=np.arange(batch_size),
        y=mean_true[:, index].numpy() - 2 * var_true[:, index, index].sqrt().numpy(),
        mode="lines",
        name="Actual - 2 std",
        line=dict(color="green"),
    )
)

fig.add_trace(
    go.Scatter(
        x=np.arange(batch_size),
        y=mean_true[:, index].numpy() + 2 * var_true[:, index, index].sqrt().numpy(),
        mode="lines",
        name="Actual + 2 std",
        line=dict(color="green"),
    )
)

# Plot disturbance
fig.add_trace(
    go.Scatter(
        x=np.arange(batch_size),
        y=batch["d"][:, index].numpy(),
        mode="lines",
        name="Disturbance",
        line=dict(color="red"),
    )
)

In [43]:
# Plot covariances vs. each other.
fig = go.Figure()
index = 1
fig.add_trace(
    go.Scatter(
        x=np.arange(batch_size),
        y=var_true[:, index, 1].numpy(),
        mode="lines",
        name="Actual",
        line=dict(color="blue"),
    )
)
fig.add_trace(
    go.Scatter(
        x=np.arange(batch_size),
        y=pred_var[:, index, 1].cpu().numpy(),
        mode="lines",
        name="Predicted",
        line=dict(color="red"),
    )
)
fig.show()

In [114]:
pred_var, var_true

(tensor([[[1.4114e-04, 2.5064e-05],
          [2.5064e-05, 9.6638e-05]],
 
         [[1.4083e-04, 2.4210e-05],
          [2.4210e-05, 9.3666e-05]],
 
         [[1.4308e-04, 2.6840e-05],
          [2.6840e-05, 9.6440e-05]],
 
         ...,
 
         [[6.8519e-05, 9.8204e-06],
          [9.8204e-06, 1.0666e-04]],
 
         [[6.9986e-05, 1.1609e-05],
          [1.1609e-05, 1.0809e-04]],
 
         [[1.4199e-04, 2.5728e-05],
          [2.5728e-05, 9.6249e-05]]], device='cuda:0'),
 tensor([[[1.5000e-04, 5.0000e-05],
          [5.0000e-05, 1.0000e-04]],
 
         [[1.4999e-04, 4.9222e-05],
          [4.9222e-05, 9.9216e-05]],
 
         [[1.4999e-04, 4.8917e-05],
          [4.8917e-05, 9.8905e-05]],
 
         ...,
 
         [[5.0574e-05, 1.8565e-06],
          [1.8565e-06, 9.2443e-05]],
 
         [[5.0684e-05, 1.8309e-06],
          [1.8309e-06, 9.1756e-05]],
 
         [[1.5000e-04, 5.0000e-05],
          [5.0000e-05, 1.0000e-04]]]))

In [143]:
# Plot state data.
fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=np.arange(batch_size),
        y=gen_di_data.f_nom(batch["x"])[:, 0].numpy(),
        mode="lines",
        name="State",
        line=dict(color="blue"),
    )
)

18