In [3]:
#@markdown ### **Imports**
# diffusion policy import
from typing import Tuple, Sequence, Dict, Union, Optional, Callable
import numpy as np
import math
import torch
import torch.nn as nn
import zarr
from tqdm import tqdm
from PIL import Image
from skvideo.io import vwrite
from IPython.display import Video

import cv2
import pandas as pd
import os
import sys

import glob
from abc import ABC, abstractmethod

In [6]:
# General failure detector interface --> ideally implement this downstream
class FailureDetector(ABC):
    def __init__(self, calibration_dataset_dir: str, config: Dict):
        pass
    @abstractmethod
    def detect_episode(self, episode_dataset_dir) -> bool:
        """
        Detects if a failure has occurred in the environment.
        :param obs: The observation from the environment.
        :return: True, detection time if a failure is detected, False, None otherwise.
        """
        pass
    @abstractmethod
    def calibrate(self):
        """
        Calibrates the failure detector.
        :return: None
        """
        pass

## Episode Dataset Code

In [7]:
from typing import Any, Dict, Union, List, Generator, Optional

import os
import glob
import pickle
import pathlib
import numpy as np
import pandas as pd
from torch.utils.data import IterableDataset
from copy import deepcopy


def load_pickle(path: Union[str, pathlib.Path]) -> pd.DataFrame:
    with open(path, "rb") as f:
        data = pickle.load(f)
    return data


class EpisodeDataset(IterableDataset):

    def __init__(
        self,
        dataset_path: Union[str, pathlib.Path],
        exec_horizon: int,
        sample_history: int,
        filter_success: bool = False,
        filter_failure: bool = False,
        filter_episodes: Optional[List[int]] = None,
        max_episode_length: Optional[int] = None,
        max_num_episodes: Optional[int] = None,
    ) -> None:
        """Construct EpisodeDataset."""
        super().__init__()
        assert exec_horizon >= 1 and sample_history >= 0
        self._dataset_path = dataset_path
        self._episode_files = sorted(glob.glob(os.path.join(dataset_path, "*.pkl")))
        self._exec_horizon = exec_horizon
        self._sample_history = sample_history
        self._filter_success = filter_success
        self._filter_failure = filter_failure
        self._filter_episodes = filter_episodes
        self._max_episode_length = max_episode_length
        self._max_num_episodes = max_num_episodes

    def __iter__(
        self,
    ) -> Generator[Union[Dict[str, Any], List[Dict[str, Any]]], None, None]:
        """Return sample."""
        num_episodes = 0
        for i, file_path in enumerate(self._episode_files):
            # if self._max_num_episodes is not None and num_episodes >= self._max_num_episodes:
            if self._max_num_episodes is not None and i >= self._max_num_episodes:
                continue

            episode = load_pickle(file_path)
            success = episode.iloc[0].to_dict().get("success", True)
            if (
                (self._filter_success and success)
                or (self._filter_failure and not success)
                or (
                    self._filter_episodes is not None
                    and not isinstance(self._filter_episodes, str)
                    and i in self._filter_episodes
                )
            ):
                continue
            else:
                num_episodes += 1

            for idx in range(
                self._exec_horizon * self._sample_history,
                len(episode), # rows of data in episode df
                self._exec_horizon,
            ):
                if (
                    self._max_episode_length is not None
                    and episode.iloc[idx].to_dict()["timestep"]
                    >= self._max_episode_length
                ):
                    continue

                sample = [
                    episode.iloc[j].to_dict()
                    for j in range(
                        idx - self._exec_horizon * self._sample_history,
                        idx + 1,
                        self._exec_horizon,
                    )
                ]

                # if len(sample) < 2:
                #     continue
                # assert all(x["episode"] == i for x in sample), not relevant for our dataset
                yield sample[0] if len(sample) == 1 else sample

In [11]:
test_dataset = EpisodeDataset(
    "logs/datasets/domain_randomization_v2",
    exec_horizon=1,
    sample_history=0,
    filter_success=False,
    filter_failure=False,
)

In [13]:
count = 0
try:
    for sample in test_dataset:
        # Optional sanity checks on each sample:
        # assert isinstance(sample, dict) or list
        # if list: assert all("timestep" in s for s in sample)
        print(sample["timestep"])
        count += 1
    print(f"✅ Completed iteration over dataset: {count} samples seen.")
except Exception as e:
    print(f"❌ Error after {count} samples: {e!r}")
    raise

0
8
16
24
32
40
48
56
64
72
80
88
96
104
112
120
128
136
144
152
160
168
176
184
192
200
208
216
224
232
240
248
256
264
272
280
288
296
0
8
16
24
32
40
48
56
64
72
80
88
96
104
112
120
128
136
144
152
160
168
176
184
192
200
208
216
224
232
240
248
256
264
272
280
288
296
0
8
16
24
32
40
48
56
64
72
80
88
96
104
112
120
128
136
144
152
160
168
176
184
192
200
208
216
224
232
240
248
256
264
272
280
288
296
0
8
16
24
32
40
48
56
64
72
80
88
96
104
112
120
128
136
144
152
160
168
176
184
192
200
208
216
224
232
240
248
256
264
272
280
288
296
0
8
16
24
32
40
48
56
64
72
80
88
96
104
112
120
128
136
144
152
160
168
176
184
192
200
208
216
224
232
240
248
256
264
272
280
288
296
0
8
16
24
32
40
48
56
64
72
80
88
96
104
112
120
128
136
144
152
160
168
176
184
192
200
208
216
224
232
240
248
256
264
272
280
288
296
0
8
16
24
32
40
48
56
64
72
80
88
96
104
112
120
128
136
144
152
160
168
176
184
192
200
208
216
224
232
240
248
256
264
272
280
288
296
0
8
16
24
32
40
48
56
64
72
80
88
96
104


In [14]:
it = iter(test_dataset)

sample = next(it) # does return the dict with just 

In [15]:
sample.keys()

dict_keys(['timestep', 'rgb', 'reward', 'sampled_actions', 'executed_action', 'action_index', 'agent_positions', 'agent_velocities', 'block_poses', 'goal_poses', 'step_image_features', 'step_agent_poses', 'success', 'episode'])

In [16]:
sample['step_image_features'].shape

(2, 512)

In [17]:
sample['agent_positions'][0]

array([ 96., 326.])

### Testing STAC Pipeline Code

error_fns in eval script:
- "mmd_rbf_all_median"
- "kde_kl_all_for_eig"
- "kde_kl_all_rev_eig"

In [46]:
from typing import Union, Dict, Callable, Any, List, Optional
import os
import sys
import torch
import hydra
import random
import pickle
import imageio
import pathlib
import omegaconf
import numpy as np
import pandas as pd
from collections import defaultdict

from src.stac import utils
import src.stac.dataset_utils as data_utils
from src.stac import error_utils, metric_utils, action_utils

import omegaconf


In [47]:
CONSISTENCY_AGGR_FNS: Dict[str, Callable[[np.ndarray], float]] = {
    "min": np.min,
    "max": np.max,
    "mean": np.mean,
    "std_dev": np.std,
    "var": np.var,
}

CONSISTENCY_ERROR_FNS: Dict[str, Dict[str, Any]] = {
    "mse_all": {
        "error_fn": "mse",
        "ignore_gripper": True,
        "ignore_rotation": False,
    },
    "mse_pos": {
        "error_fn": "mse",
        "ignore_gripper": True,
        "ignore_rotation": True,
    },
    "ate_pos": {
        "error_fn": "ate",
        "ignore_gripper": True,
        "ignore_rotation": True,
    },
}

CONSISTENCY_DIST_ERROR_FNS = {
    # MMD. - Maximum Mean Discrepancy (MMD)
    "mmd_rbf_pos": {
        "error_fn": "mmd_rbf",
        "ignore_gripper": True,
        "ignore_rotation": True,
        "gamma": 1.0,
    },
    "mmd_rbf_all": {
        "error_fn": "mmd_rbf",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "gamma": None,
    },
    "mmd_rbf_all_1.0": {
        "error_fn": "mmd_rbf",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "gamma": 1.0,
    },
    "mmd_rbf_all_median": {
        "error_fn": "mmd_rbf",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "gamma": "median",
    },
    "mmd_rbf_all_eig": {
        "error_fn": "mmd_rbf",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "gamma": "max_eig",
    },
    "mmd_rbf_all_0.1": {
        "error_fn": "mmd_rbf",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "gamma": 0.1,
    },
    "mmd_rbf_all_0.5": {
        "error_fn": "mmd_rbf",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "gamma": 0.5,
    },
    "mmd_rbf_all_5.0": {
        "error_fn": "mmd_rbf",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "gamma": 5.0,
    },
    "mmd_rbf_all_10.0": {
        "error_fn": "mmd_rbf",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "gamma": 10.0,
    },
    "mmd_rbf_all_100.0": {
        "error_fn": "mmd_rbf",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "gamma": 100.0,
    },
    # KDE For. KL.
    "kde_kl_all_for": {
        "error_fn": "kde_kl",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "forward": True,
        "bandwidth": 1.0,
    },
    "kde_kl_all_for_eig": {
        "error_fn": "kde_kl",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "forward": True,
        "bandwidth": "max_eig",
    },
    "kde_kl_all_for_0.1": {
        "error_fn": "kde_kl",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "forward": True,
        "bandwidth": 0.1,
    },
    "kde_kl_all_for_0.5": {
        "error_fn": "kde_kl",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "forward": True,
        "bandwidth": 0.5,
    },
    "kde_kl_all_for_5.0": {
        "error_fn": "kde_kl",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "forward": True,
        "bandwidth": 5.0,
    },
    "kde_kl_all_for_10.0": {
        "error_fn": "kde_kl",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "forward": True,
        "bandwidth": 10.0,
    },
    "kde_kl_all_for_100.0": {
        "error_fn": "kde_kl",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "forward": True,
        "bandwidth": 100.0,
    },
    # KDE Rev. KL.
    "kde_kl_all_rev": {
        "error_fn": "kde_kl",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "forward": False,
        "bandwidth": 1.0,
    },
    "kde_kl_all_rev_eig": {
        "error_fn": "kde_kl",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "forward": False,
        "bandwidth": "max_eig",
    },
    "kde_kl_all_rev_0.1": {
        "error_fn": "kde_kl",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "forward": False,
        "bandwidth": 0.1,
    },
    "kde_kl_all_rev_0.5": {
        "error_fn": "kde_kl",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "forward": False,
        "bandwidth": 0.5,
    },
    "kde_kl_all_rev_5.0": {
        "error_fn": "kde_kl",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "forward": False,
        "bandwidth": 5.0,
    },
    "kde_kl_all_rev_10.0": {
        "error_fn": "kde_kl",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "forward": False,
        "bandwidth": 10.0,
    },
    "kde_kl_all_rev_100.0": {
        "error_fn": "kde_kl",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "forward": False,
        "bandwidth": 100.0,
    },
    "wasserstein": {
        "error_fn": "wass",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "gamma": None,
    }
}

# Experiment keys.
def temporal_consistency_exp_key(
    pred_horizon: int,
    sample_size: int,
    error_fn: str,
    aggr_fn: Optional[str] = None,
) -> str:
    if error_fn in CONSISTENCY_DIST_ERROR_FNS:
        return (
            f"pred_horizon_{pred_horizon}_sample_size_{sample_size}_error_fn_{error_fn}"
        )
    else:
        return f"pred_horizon_{pred_horizon}_sample_size_{sample_size}_error_fn_{error_fn}_aggr_fn_{aggr_fn}"
    
def quantile_exp_key(exp_key: str, quantile: float = 0.95) -> str:
    return f"{exp_key}_quantile_{quantile}"

def get_consistency_aggr_fns(
    cfg: omegaconf.DictConfig,
    error_fn: str,
) -> List[Optional[str]]:
    if error_fn not in CONSISTENCY_ERROR_FNS:
        return [None]
    return cfg.eval.consistency.aggr_fns

def compute_cum_scores(
    results_frame: pd.DataFrame,
    exp_keys: List[str],
) -> pd.DataFrame:
    for exp_key in exp_keys:
        cum_scores = pd.Series(
            data_utils.aggr_episode_key_data(
                results_frame,
                f"{exp_key}_score",
                np.cumsum,
            ),
            name=f"{exp_key}_cum_score",
        )
        results_frame = pd.concat([results_frame, cum_scores], axis=1)

    return results_frame

# get RGB image from data
def get_rgb(data: Dict[str, Any]) -> Optional[np.ndarray]:
    return data['rgb'][0]

In [49]:
def compute_temporal_consistency_errors(
    cfg: omegaconf.DictConfig, dataset: EpisodeDataset
) -> pd.DataFrame:
    """Compute temporal consistency errors over dataset."""
    results = []
    exp_keys = []

    for prev_data, curr_data in iter(dataset):
        assert curr_data["timestep"] - prev_data["timestep"] == cfg.model.ac_horizon

        results.append(
            {
                "episode": curr_data["episode"],
                "timestep": curr_data["timestep"],
                "success": curr_data.get("success", True),
            }
        )
        rgb = get_rgb(curr_data)
        if isinstance(rgb, np.ndarray):
            results[-1]["rgb"] = rgb

        for sample_size in cfg.eval.consistency.sample_sizes:

            # Subsample current and previous actions.
            curr_actions = action_utils.subsample_actions(
                curr_data["sampled_actions"],
                sample_size,
            )
            curr_skip_steps = curr_data.get("skip_steps", None) # Should be None

            prev_actions = action_utils.subsample_actions(
                prev_data["sampled_actions"],
                sample_size,
            )
            prev_skip_steps = prev_data.get("skip_steps", None) # Should be None

            for error_fn in cfg.eval.consistency.error_fns:

                if error_fn in CONSISTENCY_ERROR_FNS:
                    prev_selected_actions = prev_data["executed_action"]
                    error_fn_kwargs = CONSISTENCY_ERROR_FNS[error_fn]
                elif error_fn in CONSISTENCY_DIST_ERROR_FNS:
                    prev_selected_actions = prev_actions
                    error_fn_kwargs = CONSISTENCY_DIST_ERROR_FNS[error_fn]
                else:
                    raise ValueError(f"Error function {error_fn} not supported.")

                for aggr_fn in get_consistency_aggr_fns(cfg, error_fn):
                    for pred_horizon in cfg.eval.consistency.pred_horizons:
                        if cfg.model.ac_horizon >= pred_horizon:
                            continue

                        exp_key = temporal_consistency_exp_key(
                            pred_horizon=pred_horizon,
                            sample_size=sample_size,
                            error_fn=error_fn,
                            aggr_fn=aggr_fn,
                        )
                        if exp_key not in exp_keys:
                            exp_keys.append(exp_key)

                        error = error_utils.compute_temporal_error(
                            curr_action=curr_actions,
                            prev_action=prev_selected_actions,
                            pred_horizon=pred_horizon,
                            exec_horizon=cfg.model.ac_horizon,
                            sim_freq=cfg.env.args.freq,
                            num_robots=cfg.env.num_eef,
                            action_dim=cfg.env.dof,
                            skip_steps=False,
                            curr_skip_steps=curr_skip_steps,
                            prev_skip_steps=prev_skip_steps,
                            **error_fn_kwargs,
                        )
                        if error_fn in CONSISTENCY_ERROR_FNS:
                            error = CONSISTENCY_AGGR_FNS[aggr_fn](error)
                        results[-1][f"{exp_key}_score"] = error

    results_frame = compute_cum_scores(pd.DataFrame(results), exp_keys)
    return results_frame


def evaluate_temporal_consistency(
    cfg: omegaconf.DictConfig,
    demo_dataset_path: Union[str, pathlib.Path],
    test_dataset_path: Union[str, pathlib.Path],
) -> Dict[str, Union[Dict[str, Any], pd.DataFrame]]:
    """Compute temporal consistency results."""
    # Construct episode iterable datasets.
    demo_dataset = EpisodeDataset(
        dataset_path=demo_dataset_path,
        exec_horizon=1,
        sample_history=1,
        filter_success=getattr(cfg.eval, "filter_demo_success", False),
        filter_failure=getattr(cfg.eval, "filter_demo_failure", True),
        filter_episodes=getattr(cfg.eval, "filter_demo_episodes", None),
        max_episode_length=getattr(cfg.eval, "max_demo_episode_length", None),
        max_num_episodes=getattr(cfg.eval, "max_num_episodes", None),
    )
    test_dataset = EpisodeDataset(
        dataset_path=test_dataset_path,
        exec_horizon=1,
        sample_history=1,
        filter_success=getattr(cfg.eval, "filter_test_success", False),
        filter_failure=getattr(cfg.eval, "filter_test_failure", False),
        filter_episodes=getattr(cfg.eval, "filter_test_episodes", None),
        max_episode_length=getattr(cfg.eval, "max_test_episode_length", None),
        max_num_episodes=getattr(cfg.eval, "max_num_episodes", None),
    )

    # Compute scores for specified parameter sets.
    results_dict = defaultdict(dict)
    demo_results_frame = compute_temporal_consistency_errors(cfg, demo_dataset)
    test_results_frame = compute_temporal_consistency_errors(cfg, test_dataset)

    # Compute metrics for specified parameter sets.
    for sample_size in cfg.eval.consistency.sample_sizes:
        for error_fn in cfg.eval.consistency.error_fns:
            for aggr_fn in get_consistency_aggr_fns(cfg, error_fn):
                for pred_horizon in cfg.eval.consistency.pred_horizons:
                    if cfg.model.ac_horizon >= pred_horizon:
                        continue

                    exp_key = temporal_consistency_exp_key(
                        pred_horizon=pred_horizon,
                        sample_size=sample_size,
                        error_fn=error_fn,
                        aggr_fn=aggr_fn,
                    )

                    for quantile in cfg.eval.quantiles:

                        test_results_frame = metric_utils.compute_detection_results(
                            exp_key=exp_key,
                            quantile_key=quantile_exp_key(exp_key, quantile),
                            results_dict=results_dict,
                            demo_results_frame=demo_results_frame,
                            test_results_frame=test_results_frame,
                            detector=getattr(cfg.eval, "detector", "quantile"),
                            detector_kwargs={
                                "quantile": quantile,
                                **getattr(cfg.eval, "detector_kwargs", {}),
                            },
                        )

    return {
        "results_dict": results_dict,
        "test_results_frame": test_results_frame,
        "demo_results_frame": demo_results_frame,
    }

In [50]:
conf_dict = {
    "env": {
        "args": {
            "freq": 1,
            "max_episode_length": 300
        },
        "dof": 2,
        "num_eef": 1
    },
    "model": {
        "ac_horizon": 8,
    },
    "eval": {
        "consistency": {
            "sample_sizes": [128, 256],
            "error_fns": ["mse_all", "mmd_rbf_all_median", "kde_kl_all_for_eig"], 
            "pred_horizons": [16],
            "aggr_fns": ["min"]
        },
        "quantiles": [0.95],
    }
}

cfg = omegaconf.OmegaConf.create(conf_dict)

dataset_path = "logs/datasets/domain_randomization_v2"

In [51]:
# Run eval
output_data = evaluate_temporal_consistency(cfg, dataset_path, dataset_path)


Episode Results: ep_iid_cum | pred_horizon_16_sample_size_128_error_fn_mse_all_aggr_fn_min_quantile_0.95
TPR: 0.88 | TNR: 0.90 | Acc: 0.88 | Bal. Acc: 0.89
TP Time 197.71 (57.96)

Episode Results: ep_iid_cum | pred_horizon_16_sample_size_128_error_fn_mmd_rbf_all_median_quantile_0.95
TPR: 1.00 | TNR: 0.90 | Acc: 0.98 | Bal. Acc: 0.95
TP Time 211.60 (13.85)

Episode Results: ep_iid_cum | pred_horizon_16_sample_size_128_error_fn_kde_kl_all_for_eig_quantile_0.95
TPR: 1.00 | TNR: 0.90 | Acc: 0.98 | Bal. Acc: 0.95
TP Time 216.40 (11.59)

Episode Results: ep_iid_cum | pred_horizon_16_sample_size_256_error_fn_mse_all_aggr_fn_min_quantile_0.95
TPR: 0.88 | TNR: 0.90 | Acc: 0.88 | Bal. Acc: 0.89
TP Time 196.57 (58.23)

Episode Results: ep_iid_cum | pred_horizon_16_sample_size_256_error_fn_mmd_rbf_all_median_quantile_0.95
TPR: 1.00 | TNR: 0.90 | Acc: 0.98 | Bal. Acc: 0.95
TP Time 211.00 (14.50)

Episode Results: ep_iid_cum | pred_horizon_16_sample_size_256_error_fn_kde_kl_all_for_eig_quantile_0.9

In [64]:
# Raw scores from compute_temporal_consistency_errors on the test dataset.
# For each combination of quantile_key and calib_key evaluated, a new prediction column is added.
from tabulate import tabulate

print(output_data['test_results_frame'].shape)

print(tabulate(output_data['test_results_frame'][900:905], headers='keys', tablefmt='psql'))


(1667, 46)
+-----+-----------+------------+-----------+-------------------+----------------------------------------------------------------------+---------------------------------------------------------------------+---------------------------------------------------------------------+----------------------------------------------------------------------+---------------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------+-------------------------------------------------------------------------+-------------------------------------------------------------------------+--------------------------------------------------------------------------+-------------------------------------------------------------------------+-------------------------------------------------------------------------+---------------------------------------------------------

In [62]:
# Raw scores from compute_temporal_consistency_errors on the demo dataset.
# For each combination of quantile_key and calib_key evaluated, a new prediction column is added.
print(output_data['demo_results_frame'].shape)

print(tabulate(output_data['demo_results_frame'][0:5], headers='keys', tablefmt='psql'))

(187, 16)
+----+-----------+------------+-----------+-------------------+----------------------------------------------------------------------+---------------------------------------------------------------------+---------------------------------------------------------------------+----------------------------------------------------------------------+---------------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------+-------------------------------------------------------------------------+-------------------------------------------------------------------------+--------------------------------------------------------------------------+-------------------------------------------------------------------------+-------------------------------------------------------------------------+
|    |   episode |   timestep | success   | rgb           

In [None]:
#r: Dict[str, Any] = results_dict[quantile_key][calib_key]["episode"]["metrics"]

# exp_key = "pred_horizon_{value}_sample_size_{value}_error_fn_{value}_aggr_fn_{value}"
# quantile_key = "{exp_key}_quantile_{value}"
# calib_key = "t_iid, ep_iid_max, ep_iid_per_t, ep_iid_cum, ep_iid_cum_per_t" 
# combinations = "{exp_key}_score or {exp_key}_cum_score" (done already in metric_utils)

r: Dict[str, Any] = output_data['results_dict']["pred_horizon_16_sample_size_128_error_fn_mse_all_aggr_fn_min_quantile_0.95"]["ep_iid_per_t"]["episode"]["metrics"]

print(r)

{'TP Time Mean': 46.72, 'TP Time STD': 29.991358755481553, 'FP Time Mean': 45.333333333333336, 'FP Time STD': 27.19477073916152, 'AUROC Score': 0.7825000000000001, 'TP': 25, 'TN': 7, 'FP': 3, 'FN': 15, 'TPR': 0.625, 'TNR': 0.7, 'FPR': 0.3, 'FNR': 0.375, 'Accuracy': 0.64, 'Balanced Accuracy': 0.6625, 'F1 Score': 0.7352941176470589}


# Wasserstein metric demonstration

In [13]:
import ot
import numpy as np

def compute_wasserstein_ot(x: np.ndarray, y: np.ndarray, p: int = 1) -> float:
    """
    Exact p-Wasserstein distance between two empirical measures on R^D.

    Args:
      x: [N, D] array 
      y: [M, D] array
      p: ground‐metric exponent (1 for 1-Wasserstein, 2 for 2-Wasserstein)

    Returns:
      W_p(x, y)
    """
    n, m = x.shape[0], y.shape[0]
    a = np.ones(n) / n   # uniform weights over x
    b = np.ones(m) / m   # uniform weights over y

    # cost matrix: ||x_i - y_j||_2^p
    M = ot.dist(x, y, metric='euclidean')**p  # shape [128,128]

    # emd2 returns the p-th power of W_p
    Wp_p = ot.emd2(a, b, M)
    return float(Wp_p**(1.0/p))

In [15]:
x = np.random.randn(128,16)
y = np.random.randn(128,16) * 10 + 50

compute_wasserstein_ot(x, y)

203.59145381588658

## Embedding Detector Scripts

In [8]:
EMBEDDING_SCORE_FNS: Dict[str, Dict[str, Any]] = {
    "top1_l2": {"method": "topk", "method_kwargs": {"error_fn": "l2", "k": 1}},
    "top5_l2": {"method": "topk", "method_kwargs": {"error_fn": "l2", "k": 5}},
    "top10_l2": {"method": "topk", "method_kwargs": {"error_fn": "l2", "k": 10}},
    "top1_cosine": {"method": "topk", "method_kwargs": {"error_fn": "cosine", "k": 1}},
    "top5_cosine": {"method": "topk", "method_kwargs": {"error_fn": "cosine", "k": 5}},
    "top10_cosine": {
        "method": "topk",
        "method_kwargs": {"error_fn": "cosine", "k": 10},
    },
    "mahal": {
        "method": "mahal",
    },
}

def embedding_similarity_exp_key(
    embedding: str,
    score_fn: int,
) -> str:
    return f"embedding_{embedding}_score_fn_{score_fn}"

def get_rgb(data: Dict[str, Any]) -> Optional[np.ndarray]:
    return data['rgb'][0]

# Custom function that returns the average latent embedding for one time step
# visual conditioning has two in the obs-history w current setup
def get_latent_embedding(data, average=True):
    image_feats = data['step_image_features']
    agent_pos = data['step_agent_poses']

    comb = np.concatenate([image_feats, agent_pos], axis=-1)

    if average:
        return comb.mean(axis=0)

    return comb[-1]

def compute_embedding_similarity_scores(
    cfg: omegaconf.DictConfig,
    test_dataset: Optional[EpisodeDataset] = None,
    demo_dataset: Optional[EpisodeDataset] = None,
    demo_frame: Optional[pd.DataFrame] = None,
    leave_timestep_out: bool = False,
    leave_episode_out: bool = False,
    demo_as_test: bool = False,
) -> pd.DataFrame:
    """Compute embedding similarity scores over dataset."""
    assert not (leave_timestep_out and leave_episode_out)
    assert (demo_dataset is not None) ^ (demo_frame is not None)
    assert (test_dataset is not None) ^ demo_as_test

    # Extract demo embeddings.
    if demo_dataset is not None:
        demo_frame = []
        for data in iter(demo_dataset):
            demo_frame.append(
                {
                    "episode": data["episode"],
                    "timestep": data["timestep"],
                    "success": data.get("success", True),
                }
            )
            rgb = get_rgb(data)
            if isinstance(rgb, np.ndarray):
                demo_frame[-1]["rgb"] = rgb

            for embedding in cfg.eval.embedding.embeddings:
                if embedding == "step_obs":
                    demo_frame[-1][embedding] = get_latent_embedding(data)
                else:
                    demo_frame[-1][embedding] = data[embedding].flatten()

        demo_frame = pd.DataFrame(demo_frame)
    assert isinstance(demo_frame, pd.DataFrame)

    # Extract test embeddings.
    if demo_as_test:
        test_frame = demo_frame.copy()
    else:
        test_frame = []
        for data in iter(test_dataset):
            test_frame.append(
                {
                    "episode": data["episode"],
                    "timestep": data["timestep"],
                    "success": data.get("success", True),
                }
            )
            rgb = get_rgb(data)
            if isinstance(rgb, np.ndarray):
                test_frame[-1]["rgb"] = rgb

            for embedding in cfg.eval.embedding.embeddings:
                if embedding == "step_obs":
                    test_frame[-1][embedding] = get_latent_embedding(data)
                else:
                    test_frame[-1][embedding] = data[embedding].flatten()

        test_frame = pd.DataFrame(test_frame)
    assert isinstance(test_frame, pd.DataFrame)

    # Compute embedding scores.
    exp_keys = []
    for embedding in cfg.eval.embedding.embeddings:
        for score_fn in cfg.eval.embedding.score_fns:

            exp_key = embedding_similarity_exp_key(
                embedding=embedding,
                score_fn=score_fn,
            )
            if exp_key not in exp_keys:
                exp_keys.append(exp_key)

            if leave_episode_out:
                test_frame = pd.concat(
                    [
                        test_frame,
                        pd.Series(np.zeros(len(test_frame)), name=f"{exp_key}_score"),
                    ],
                    axis=1,
                )
                for i in range(data_utils.num_episodes(test_frame)):
                    episode_frame = data_utils.get_episode(
                        test_frame, i, use_index=True
                    )
                    episode = episode_frame.iloc[0].to_dict()["episode"]
                    non_episode_frame: pd.DataFrame = demo_frame[
                        demo_frame["episode"] != episode
                    ]

                    episode_scores = error_utils.compute_embedding_scores(
                        data_embeddings=non_episode_frame[embedding].values,
                        test_embeddings=episode_frame[embedding].values,
                        **EMBEDDING_SCORE_FNS[score_fn],
                    )
                    test_frame.loc[
                        test_frame["episode"] == episode, f"{exp_key}_score"
                    ] = episode_scores
            else:
                test_scores = error_utils.compute_embedding_scores(
                    data_embeddings=demo_frame[embedding].values,
                    test_embeddings=test_frame[embedding].values,
                    leave_one_out=leave_timestep_out,
                    **EMBEDDING_SCORE_FNS[score_fn],
                )
                test_frame = pd.concat(
                    [test_frame, pd.Series(test_scores, name=f"{exp_key}_score")],
                    axis=1,
                )

    test_frame = compute_cum_scores(test_frame, exp_keys)
    return test_frame


def evaluate_embedding_similarity(
    cfg: omegaconf.DictConfig,
    demo_dataset_path: Union[str, pathlib.Path],
    test_dataset_path: Union[str, pathlib.Path],
) -> Dict[str, Union[Dict[str, Any], pd.DataFrame]]:
    """Compute embedding similarity results."""
    # Construct episode iterable datasets.
    demo_dataset = EpisodeDataset(
        dataset_path=demo_dataset_path,
        exec_horizon=1,
        sample_history=0,
        filter_success=getattr(cfg.eval, "filter_demo_success", False),
        filter_failure=getattr(cfg.eval, "filter_demo_failure", True),
        filter_episodes=getattr(cfg.eval, "filter_demo_episodes", None),
        max_episode_length=getattr(cfg.eval, "max_demo_episode_length", None),
        max_num_episodes=getattr(cfg.eval, "max_num_episodes", None),
    )
    test_dataset = EpisodeDataset(
        dataset_path=test_dataset_path,
        exec_horizon=1,
        sample_history=0,
        filter_success=getattr(cfg.eval, "filter_test_success", False),
        filter_failure=getattr(cfg.eval, "filter_test_failure", False),
        filter_episodes=getattr(cfg.eval, "filter_test_episodes", None),
        max_episode_length=getattr(cfg.eval, "max_test_episode_length", None),
        max_num_episodes=getattr(cfg.eval, "max_num_episodes", None),
    )

    # Compute scores for specified parameter sets.
    results_dict = defaultdict(dict)
    demo_results_frame = compute_embedding_similarity_scores(
        cfg,
        demo_dataset=demo_dataset,
        demo_as_test=True,
        leave_episode_out=getattr(cfg.eval.embedding, "leave_episode_out", True),
        leave_timestep_out=getattr(cfg.eval.embedding, "leave_timestep_out", False),
    )
    test_results_frame = compute_embedding_similarity_scores(
        cfg,
        test_dataset=test_dataset,
        demo_frame=demo_results_frame,
    )

    # Compute metrics for specified parameter sets.
    for embedding in cfg.eval.embedding.embeddings:
        for score_fn in cfg.eval.embedding.score_fns:

            exp_key = embedding_similarity_exp_key(
                embedding=embedding,
                score_fn=score_fn,
            )

            for quantile in cfg.eval.quantiles:

                test_results_frame = metric_utils.compute_detection_results(
                    exp_key=exp_key,
                    quantile_key=quantile_exp_key(exp_key, quantile),
                    results_dict=results_dict,
                    demo_results_frame=demo_results_frame,
                    test_results_frame=test_results_frame,
                    detector=getattr(cfg.eval, "detector", "quantile"),
                    detector_kwargs={
                        "quantile": quantile,
                        **getattr(cfg.eval, "detector_kwargs", {}),
                    },
                )

    return {
        "results_dict": results_dict,
        "test_results_frame": test_results_frame,
        "demo_results_frame": demo_results_frame,
    }

In [15]:
conf_dict = {
    "env": {
        "args": {
            "freq": 1,
            "max_episode_length": 300
        },
        "dof": 2,
        "num_eef": 1
    },
    "model": {
        "ac_horizon": 8,
    },
    "eval": {
        "consistency": {
            "sample_sizes": [128, 256],
            "error_fns": ["mse_all", "mmd_rbf_all_median", "kde_kl_all_for_eig"], 
            "pred_horizons": [16],
            "aggr_fns": ["min"]
        },
        "quantiles": [0.95],
        "embedding": {
            "embeddings": ["step_obs"],
            "score_fns": ["top5_cosine", "top10_cosine", "mahal"]
        }
    }
}

cfg = omegaconf.OmegaConf.create(conf_dict)

demo_dataset_path = "logs/datasets/no_domain_randomization_v8_simple_env"
test_dataset_path = "logs/datasets/domain_randomization_v2"


In [16]:
# Run eval
output_data = evaluate_embedding_similarity(cfg, demo_dataset_path, test_dataset_path)


Episode Results: ep_iid_cum | embedding_step_obs_score_fn_top5_cosine_quantile_0.95
TPR: 0.03 | TNR: 0.90 | Acc: 0.20 | Bal. Acc: 0.46
TP Time 0.00 (0.00)

Episode Results: ep_iid_cum | embedding_step_obs_score_fn_top10_cosine_quantile_0.95
TPR: 0.03 | TNR: 1.00 | Acc: 0.22 | Bal. Acc: 0.51
TP Time 0.00 (0.00)

Episode Results: ep_iid_cum | embedding_step_obs_score_fn_mahal_quantile_0.95
TPR: 1.00 | TNR: 1.00 | Acc: 1.00 | Bal. Acc: 1.00
TP Time 220.60 (17.05)


In [None]:
df = pd.read_pickle("logs/datasets/no_domain_randomization_v8_simple_env/episode_000_failure.pkl")

img_feats = df.step_image_features[0]
agent_pos = df.step_agent_poses[0]

## Testing the embedding variance code

In [23]:
ENSEMBLE_ACTION_SPACES: Dict[str, Dict[str, Any]] = {
    "all": {
        "ignore_gripper": True,
        "ignore_rotation": False,
    },
    "pos": {
        "ignore_gripper": True,
        "ignore_rotation": True,
    },
    "traj": {"ignore_gripper": True, "ignore_rotation": True, "use_trajectory": True},
}

def diffusion_ensemble_exp_key(
    pred_horizon: int,
    sample_size: int,
    action_space: str,
) -> str:
    return f"pred_horizon_{pred_horizon}_sample_size_{sample_size}_action_space_{action_space}"


def compute_diffusion_ensemble_variances(
    cfg: omegaconf.DictConfig,
    dataset: EpisodeDataset,
) -> pd.DataFrame:
    """Compute diffusion ensemble variances over dataset."""
    results = []
    exp_keys = []

    for data in iter(dataset):

        results.append(
            {
                "episode": data["episode"],
                "timestep": data["timestep"],
                "success": data.get("success", True),
            }
        )
        rgb = get_rgb(data)
        if isinstance(rgb, np.ndarray):
            results[-1]["rgb"] = rgb

        for sample_size in cfg.eval.ensemble.sample_sizes:

            # Subsample current actions.
            actions = action_utils.subsample_actions(
                data["sampled_actions"],
                sample_size,
            )

            for pred_horizon in cfg.eval.ensemble.pred_horizons:
                for action_space in cfg.eval.ensemble.action_spaces:

                    exp_key = diffusion_ensemble_exp_key(
                        pred_horizon=pred_horizon,
                        sample_size=sample_size,
                        action_space=action_space,
                    )
                    if exp_key not in exp_keys:
                        exp_keys.append(exp_key)

                    # Compute variance (vectorized).
                    variance = error_utils.compute_action_variance(
                        actions=actions,
                        pred_horizon=pred_horizon,
                        sim_freq=cfg.env.args.freq,
                        num_robots=cfg.env.num_eef,
                        action_dim=cfg.env.dof,
                        **ENSEMBLE_ACTION_SPACES[action_space],
                    )
                    results[-1][f"{exp_key}_score"] = variance

    results_frame = compute_cum_scores(pd.DataFrame(results), exp_keys)
    return results_frame


def evaluate_diffusion_ensemble(
    cfg: omegaconf.DictConfig,
    demo_dataset_path: Union[str, pathlib.Path],
    test_dataset_path: Union[str, pathlib.Path],
) -> Dict[str, Union[Dict[str, Any], pd.DataFrame]]:
    """Compute diffusion ensemble results."""
    # Construct episode iterable datasets.
    demo_dataset = EpisodeDataset(
        dataset_path=demo_dataset_path,
        exec_horizon=1,
        sample_history=0,
        filter_success=getattr(cfg.eval, "filter_demo_success", False),
        filter_failure=getattr(cfg.eval, "filter_demo_failure", True),
        filter_episodes=getattr(cfg.eval, "filter_demo_episodes", None),
        max_episode_length=getattr(cfg.eval, "max_demo_episode_length", None),
        max_num_episodes=getattr(cfg.eval, "max_num_episodes", None),
    )
    test_dataset = EpisodeDataset(
        dataset_path=test_dataset_path,
        exec_horizon=1,
        sample_history=0,
        filter_success=getattr(cfg.eval, "filter_test_success", False),
        filter_failure=getattr(cfg.eval, "filter_test_failure", False),
        filter_episodes=getattr(cfg.eval, "filter_test_episodes", None),
        max_episode_length=getattr(cfg.eval, "max_test_episode_length", None),
        max_num_episodes=getattr(cfg.eval, "max_num_episodes", None),
    )

    # Compute scores for specified parameter sets.
    results_dict = defaultdict(dict)
    demo_results_frame = compute_diffusion_ensemble_variances(cfg, demo_dataset)
    test_results_frame = compute_diffusion_ensemble_variances(cfg, test_dataset)

    # Compute metrics for specified parameter sets.
    for sample_size in cfg.eval.ensemble.sample_sizes:
        for pred_horizon in cfg.eval.ensemble.pred_horizons:
            for action_space in cfg.eval.ensemble.action_spaces:

                exp_key = diffusion_ensemble_exp_key(
                    pred_horizon=pred_horizon,
                    sample_size=sample_size,
                    action_space=action_space,
                )

                for quantile in cfg.eval.quantiles:

                    test_results_frame = metric_utils.compute_detection_results(
                        exp_key=exp_key,
                        quantile_key=quantile_exp_key(exp_key, quantile),
                        results_dict=results_dict,
                        demo_results_frame=demo_results_frame,
                        test_results_frame=test_results_frame,
                        detector=getattr(cfg.eval, "detector", "quantile"),
                        detector_kwargs={
                            "quantile": quantile,
                            **getattr(cfg.eval, "detector_kwargs", {}),
                        },
                    )

    return {
        "results_dict": results_dict,
        "test_results_frame": test_results_frame,
        "demo_results_frame": demo_results_frame,
    }

In [24]:
conf_dict = {
    "env": {
        "args": {
            "freq": 1,
            "max_episode_length": 300
        },
        "dof": 2,
        "num_eef": 1
    },
    "model": {
        "ac_horizon": 8,
    },
    "eval": {
        "consistency": {
            "sample_sizes": [128, 256],
            "error_fns": ["mse_all", "mmd_rbf_all_median", "kde_kl_all_for_eig"], 
            "pred_horizons": [16],
            "aggr_fns": ["min"]
        },
        "quantiles": [0.95],
        "embedding": {
            "embeddings": ["step_obs"],
            "score_fns": ["top5_cosine", "top10_cosine", "mahal", ""]
        },
        "ensemble": {
            "sample_sizes": [128, 256],
            "pred_horizons": [16],
            "action_spaces": ["traj"]
        }
    }
}


cfg = omegaconf.OmegaConf.create(conf_dict)

demo_dataset_path = "logs/datasets/no_domain_randomization_v8_simple_env"
test_dataset_path = "logs/datasets/domain_randomization_v2"

In [None]:
# Eval code
output_data = evaluate_diffusion_ensemble(cfg, demo_dataset_path, test_dataset_path)


Episode Results: ep_iid_cum | pred_horizon_16_sample_size_128_action_space_traj_quantile_0.95
TPR: 0.00 | TNR: 0.90 | Acc: 0.18 | Bal. Acc: 0.45
TP Time -1.00 (-1.00)

Episode Results: ep_iid_cum | pred_horizon_16_sample_size_256_action_space_traj_quantile_0.95
TPR: 0.00 | TNR: 0.90 | Acc: 0.18 | Bal. Acc: 0.45
TP Time -1.00 (-1.00)


{'results_dict': defaultdict(dict,
             {'pred_horizon_16_sample_size_128_action_space_traj_quantile_0.95': {'t_iid': {'timestep': {'metrics': {'AUROC Score': 0.4275881645738713,
                  'TP': 36,
                  'TN': 189,
                  'FP': 8,
                  'FN': 1484,
                  'TPR': 0.02368421052631579,
                  'TNR': 0.9593908629441624,
                  'FPR': 0.04060913705583756,
                  'FNR': 0.9763157894736842,
                  'Accuracy': 0.1310425160163075,
                  'Balanced Accuracy': 0.4915375367352391,
                  'F1 Score': 0.04603580562659847},
                 'data': {'calib_scores': array([ 36.16705322,   9.51525116,   4.91235924, ..., 411.98022461,
                          11.48771191,   3.42585588]),
                  'test_scores': array([1102.70385742,   57.31099701,   12.93008232, ...,   15.6844635 ,
                            7.15046215,  176.56411743]),
                  'test_preds