In [1]:
#@markdown ### **Imports**
# diffusion policy import
from typing import Tuple, Sequence, Dict, Union, Optional, Callable
import numpy as np
import math
import torch
import torch.nn as nn
import zarr
from tqdm import tqdm
from PIL import Image
from skvideo.io import vwrite
from IPython.display import Video

import cv2
import pandas as pd
import os
import sys

import glob
from abc import ABC, abstractmethod

In [2]:
# General failure detector interface --> ideally implement this downstream
class FailureDetector(ABC):
    def __init__(self, calibration_dataset_dir: str, config: Dict):
        pass
    @abstractmethod
    def detect_episode(self, episode_dataset_dir) -> bool:
        """
        Detects if a failure has occurred in the environment.
        :param obs: The observation from the environment.
        :return: True, detection time if a failure is detected, False, None otherwise.
        """
        pass
    @abstractmethod
    def calibrate(self):
        """
        Calibrates the failure detector.
        :return: None
        """
        pass

## Episode Dataset Code

In [3]:
from typing import Any, Dict, Union, List, Generator, Optional

import os
import glob
import pickle
import pathlib
import numpy as np
import pandas as pd
from torch.utils.data import IterableDataset
from copy import deepcopy


def load_pickle(path: Union[str, pathlib.Path]) -> pd.DataFrame:
    with open(path, "rb") as f:
        data = pickle.load(f)
    return data


class EpisodeDataset(IterableDataset):

    def __init__(
        self,
        dataset_path: Union[str, pathlib.Path],
        exec_horizon: int,
        sample_history: int,
        filter_success: bool = False,
        filter_failure: bool = False,
        filter_episodes: Optional[List[int]] = None,
        max_episode_length: Optional[int] = None,
        max_num_episodes: Optional[int] = None,
    ) -> None:
        """Construct EpisodeDataset."""
        super().__init__()
        assert exec_horizon >= 1 and sample_history >= 0
        self._dataset_path = dataset_path
        self._episode_files = sorted(glob.glob(os.path.join(dataset_path, "*.pkl")))
        self._exec_horizon = exec_horizon
        self._sample_history = sample_history
        self._filter_success = filter_success
        self._filter_failure = filter_failure
        self._filter_episodes = filter_episodes
        self._max_episode_length = max_episode_length
        self._max_num_episodes = max_num_episodes

    def __iter__(
        self,
    ) -> Generator[Union[Dict[str, Any], List[Dict[str, Any]]], None, None]:
        """Return sample."""
        num_episodes = 0
        for i, file_path in enumerate(self._episode_files):
            # if self._max_num_episodes is not None and num_episodes >= self._max_num_episodes:
            if self._max_num_episodes is not None and i >= self._max_num_episodes:
                continue

            episode = load_pickle(file_path)
            success = episode.iloc[0].to_dict().get("success", True)
            if (
                (self._filter_success and success)
                or (self._filter_failure and not success)
                or (
                    self._filter_episodes is not None
                    and not isinstance(self._filter_episodes, str)
                    and i in self._filter_episodes
                )
            ):
                continue
            else:
                num_episodes += 1

            for idx in range(
                self._exec_horizon * self._sample_history,
                len(episode), # rows of data in episode df
                self._exec_horizon,
            ):
                if (
                    self._max_episode_length is not None
                    and episode.iloc[idx].to_dict()["timestep"]
                    >= self._max_episode_length
                ):
                    continue

                sample = [
                    episode.iloc[j].to_dict()
                    for j in range(
                        idx - self._exec_horizon * self._sample_history,
                        idx + 1,
                        self._exec_horizon,
                    )
                ]

                # if len(sample) < 2:
                #     continue
                # assert all(x["episode"] == i for x in sample), not relevant for our dataset
                yield sample[0] if len(sample) == 1 else sample

In [6]:
test_dataset = EpisodeDataset(
    "logs/datasets/no_domain_randomization_v5",
    exec_horizon=1,
    sample_history=1,
    filter_success=False,
    filter_failure=False,
)

In [7]:
count = 0
try:
    for sample in test_dataset:
        # Optional sanity checks on each sample:
        # assert isinstance(sample, dict) or list
        # if list: assert all("timestep" in s for s in sample)
        assert len(sample) == 2
        count += 1
    print(f"✅ Completed iteration over dataset: {count} samples seen.")
except Exception as e:
    print(f"❌ Error after {count} samples: {e!r}")
    raise

✅ Completed iteration over dataset: 8654 samples seen.


In [8]:
it = iter(test_dataset)

prev_data, cur_data = next(it)

### Testing STAC Pipeline Code

error_fns in eval script:
- "mmd_rbf_all_median"
- "kde_kl_all_for_eig"
- "kde_kl_all_rev_eig"

In [9]:
from typing import Union, Dict, Callable, Any, List, Optional
import os
import sys
import torch
import hydra
import random
import pickle
import imageio
import pathlib
import omegaconf
import numpy as np
import pandas as pd
from collections import defaultdict

from src.stac import utils
import src.stac.dataset_utils as data_utils
from src.stac import error_utils, metric_utils, action_utils

import omegaconf


In [10]:
CONSISTENCY_AGGR_FNS: Dict[str, Callable[[np.ndarray], float]] = {
    "min": np.min,
    "max": np.max,
    "mean": np.mean,
    "std_dev": np.std,
    "var": np.var,
}

CONSISTENCY_ERROR_FNS: Dict[str, Dict[str, Any]] = {
    "mse_all": {
        "error_fn": "mse",
        "ignore_gripper": True,
        "ignore_rotation": False,
    },
    "mse_pos": {
        "error_fn": "mse",
        "ignore_gripper": True,
        "ignore_rotation": True,
    },
    "ate_pos": {
        "error_fn": "ate",
        "ignore_gripper": True,
        "ignore_rotation": True,
    },
}

CONSISTENCY_DIST_ERROR_FNS = {
    # MMD. - Maximum Mean Discrepancy (MMD)
    "mmd_rbf_pos": {
        "error_fn": "mmd_rbf",
        "ignore_gripper": True,
        "ignore_rotation": True,
        "gamma": 1.0,
    },
    "mmd_rbf_all": {
        "error_fn": "mmd_rbf",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "gamma": None,
    },
    "mmd_rbf_all_1.0": {
        "error_fn": "mmd_rbf",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "gamma": 1.0,
    },
    "mmd_rbf_all_median": {
        "error_fn": "mmd_rbf",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "gamma": "median",
    },
    "mmd_rbf_all_eig": {
        "error_fn": "mmd_rbf",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "gamma": "max_eig",
    },
    "mmd_rbf_all_0.1": {
        "error_fn": "mmd_rbf",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "gamma": 0.1,
    },
    "mmd_rbf_all_0.5": {
        "error_fn": "mmd_rbf",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "gamma": 0.5,
    },
    "mmd_rbf_all_5.0": {
        "error_fn": "mmd_rbf",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "gamma": 5.0,
    },
    "mmd_rbf_all_10.0": {
        "error_fn": "mmd_rbf",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "gamma": 10.0,
    },
    "mmd_rbf_all_100.0": {
        "error_fn": "mmd_rbf",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "gamma": 100.0,
    },
    # KDE For. KL.
    "kde_kl_all_for": {
        "error_fn": "kde_kl",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "forward": True,
        "bandwidth": 1.0,
    },
    "kde_kl_all_for_eig": {
        "error_fn": "kde_kl",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "forward": True,
        "bandwidth": "max_eig",
    },
    "kde_kl_all_for_0.1": {
        "error_fn": "kde_kl",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "forward": True,
        "bandwidth": 0.1,
    },
    "kde_kl_all_for_0.5": {
        "error_fn": "kde_kl",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "forward": True,
        "bandwidth": 0.5,
    },
    "kde_kl_all_for_5.0": {
        "error_fn": "kde_kl",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "forward": True,
        "bandwidth": 5.0,
    },
    "kde_kl_all_for_10.0": {
        "error_fn": "kde_kl",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "forward": True,
        "bandwidth": 10.0,
    },
    "kde_kl_all_for_100.0": {
        "error_fn": "kde_kl",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "forward": True,
        "bandwidth": 100.0,
    },
    # KDE Rev. KL.
    "kde_kl_all_rev": {
        "error_fn": "kde_kl",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "forward": False,
        "bandwidth": 1.0,
    },
    "kde_kl_all_rev_eig": {
        "error_fn": "kde_kl",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "forward": False,
        "bandwidth": "max_eig",
    },
    "kde_kl_all_rev_0.1": {
        "error_fn": "kde_kl",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "forward": False,
        "bandwidth": 0.1,
    },
    "kde_kl_all_rev_0.5": {
        "error_fn": "kde_kl",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "forward": False,
        "bandwidth": 0.5,
    },
    "kde_kl_all_rev_5.0": {
        "error_fn": "kde_kl",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "forward": False,
        "bandwidth": 5.0,
    },
    "kde_kl_all_rev_10.0": {
        "error_fn": "kde_kl",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "forward": False,
        "bandwidth": 10.0,
    },
    "kde_kl_all_rev_100.0": {
        "error_fn": "kde_kl",
        "ignore_gripper": True,
        "ignore_rotation": False,
        "forward": False,
        "bandwidth": 100.0,
    },
}

# Experiment keys.
def temporal_consistency_exp_key(
    pred_horizon: int,
    sample_size: int,
    error_fn: str,
    aggr_fn: Optional[str] = None,
) -> str:
    if error_fn in CONSISTENCY_DIST_ERROR_FNS:
        return (
            f"pred_horizon_{pred_horizon}_sample_size_{sample_size}_error_fn_{error_fn}"
        )
    else:
        return f"pred_horizon_{pred_horizon}_sample_size_{sample_size}_error_fn_{error_fn}_aggr_fn_{aggr_fn}"
    
def quantile_exp_key(exp_key: str, quantile: float = 0.95) -> str:
    return f"{exp_key}_quantile_{quantile}"

def get_consistency_aggr_fns(
    cfg: omegaconf.DictConfig,
    error_fn: str,
) -> List[Optional[str]]:
    if error_fn not in CONSISTENCY_ERROR_FNS:
        return [None]
    return cfg.eval.consistency.aggr_fns

def compute_cum_scores(
    results_frame: pd.DataFrame,
    exp_keys: List[str],
) -> pd.DataFrame:
    for exp_key in exp_keys:
        cum_scores = pd.Series(
            data_utils.aggr_episode_key_data(
                results_frame,
                f"{exp_key}_score",
                np.cumsum,
            ),
            name=f"{exp_key}_cum_score",
        )
        results_frame = pd.concat([results_frame, cum_scores], axis=1)

    return results_frame

# get RGB image from data
def get_rgb(data: Dict[str, Any]) -> Optional[np.ndarray]:
    return data['rgb'][0]

In [14]:
def compute_temporal_consistency_errors(
    cfg: omegaconf.DictConfig, dataset: EpisodeDataset
) -> pd.DataFrame:
    """Compute temporal consistency errors over dataset."""
    results = []
    exp_keys = []

    for prev_data, curr_data in iter(dataset):
        assert curr_data["timestep"] - prev_data["timestep"] == cfg.model.ac_horizon

        results.append(
            {
                "episode": curr_data["episode"],
                "timestep": curr_data["timestep"],
                "success": curr_data.get("success", True),
            }
        )
        rgb = get_rgb(curr_data)
        if isinstance(rgb, np.ndarray):
            results[-1]["rgb"] = rgb

        for sample_size in cfg.eval.consistency.sample_sizes:

            # Subsample current and previous actions.
            curr_actions = action_utils.subsample_actions(
                curr_data["sampled_actions"],
                sample_size,
            )
            curr_skip_steps = curr_data.get("skip_steps", None) # Should be None

            prev_actions = action_utils.subsample_actions(
                prev_data["sampled_actions"],
                sample_size,
            )
            prev_skip_steps = prev_data.get("skip_steps", None) # Should be None

            for error_fn in cfg.eval.consistency.error_fns:

                if error_fn in CONSISTENCY_ERROR_FNS:
                    prev_selected_actions = prev_data["executed_action"]
                    error_fn_kwargs = CONSISTENCY_ERROR_FNS[error_fn]
                elif error_fn in CONSISTENCY_DIST_ERROR_FNS:
                    prev_selected_actions = prev_actions
                    error_fn_kwargs = CONSISTENCY_DIST_ERROR_FNS[error_fn]
                else:
                    raise ValueError(f"Error function {error_fn} not supported.")

                for aggr_fn in get_consistency_aggr_fns(cfg, error_fn):
                    for pred_horizon in cfg.eval.consistency.pred_horizons:
                        if cfg.model.ac_horizon >= pred_horizon:
                            continue

                        exp_key = temporal_consistency_exp_key(
                            pred_horizon=pred_horizon,
                            sample_size=sample_size,
                            error_fn=error_fn,
                            aggr_fn=aggr_fn,
                        )
                        if exp_key not in exp_keys:
                            exp_keys.append(exp_key)

                        error = error_utils.compute_temporal_error(
                            curr_action=curr_actions,
                            prev_action=prev_selected_actions,
                            pred_horizon=pred_horizon,
                            exec_horizon=cfg.model.ac_horizon,
                            sim_freq=cfg.env.args.freq,
                            num_robots=cfg.env.num_eef,
                            action_dim=cfg.env.dof,
                            skip_steps=False,
                            curr_skip_steps=curr_skip_steps,
                            prev_skip_steps=prev_skip_steps,
                            **error_fn_kwargs,
                        )
                        if error_fn in CONSISTENCY_ERROR_FNS:
                            error = CONSISTENCY_AGGR_FNS[aggr_fn](error)
                        results[-1][f"{exp_key}_score"] = error

    results_frame = compute_cum_scores(pd.DataFrame(results), exp_keys)
    return results_frame


def evaluate_temporal_consistency(
    cfg: omegaconf.DictConfig,
    demo_dataset_path: Union[str, pathlib.Path],
    test_dataset_path: Union[str, pathlib.Path],
) -> Dict[str, Union[Dict[str, Any], pd.DataFrame]]:
    """Compute temporal consistency results."""
    # Construct episode iterable datasets.
    demo_dataset = EpisodeDataset(
        dataset_path=demo_dataset_path,
        exec_horizon=1,
        sample_history=1,
        filter_success=getattr(cfg.eval, "filter_demo_success", False),
        filter_failure=getattr(cfg.eval, "filter_demo_failure", True),
        filter_episodes=getattr(cfg.eval, "filter_demo_episodes", None),
        max_episode_length=getattr(cfg.eval, "max_demo_episode_length", None),
        max_num_episodes=getattr(cfg.eval, "max_num_episodes", None),
    )
    test_dataset = EpisodeDataset(
        dataset_path=test_dataset_path,
        exec_horizon=1,
        sample_history=1,
        filter_success=getattr(cfg.eval, "filter_test_success", False),
        filter_failure=getattr(cfg.eval, "filter_test_failure", False),
        filter_episodes=getattr(cfg.eval, "filter_test_episodes", None),
        max_episode_length=getattr(cfg.eval, "max_test_episode_length", None),
        max_num_episodes=getattr(cfg.eval, "max_num_episodes", None),
    )

    # Compute scores for specified parameter sets.
    results_dict = defaultdict(dict)
    demo_results_frame = compute_temporal_consistency_errors(cfg, demo_dataset)
    test_results_frame = compute_temporal_consistency_errors(cfg, test_dataset)

    # Compute metrics for specified parameter sets.
    for sample_size in cfg.eval.consistency.sample_sizes:
        for error_fn in cfg.eval.consistency.error_fns:
            for aggr_fn in get_consistency_aggr_fns(cfg, error_fn):
                for pred_horizon in cfg.eval.consistency.pred_horizons:
                    if cfg.model.ac_horizon >= pred_horizon:
                        continue

                    exp_key = temporal_consistency_exp_key(
                        pred_horizon=pred_horizon,
                        sample_size=sample_size,
                        error_fn=error_fn,
                        aggr_fn=aggr_fn,
                    )

                    for quantile in cfg.eval.quantiles:

                        test_results_frame = metric_utils.compute_detection_results(
                            exp_key=exp_key,
                            quantile_key=quantile_exp_key(exp_key, quantile),
                            results_dict=results_dict,
                            demo_results_frame=demo_results_frame,
                            test_results_frame=test_results_frame,
                            detector=getattr(cfg.eval, "detector", "quantile"),
                            detector_kwargs={
                                "quantile": quantile,
                                **getattr(cfg.eval, "detector_kwargs", {}),
                            },
                        )

    return {
        "results_dict": results_dict,
        "test_results_frame": test_results_frame,
        "demo_results_frame": demo_results_frame,
    }

In [None]:
conf_dict = {
    "env": {
        "args": {
            "freq": 1,
            "max_episode_length": 300
        },
        "dof": 2,
        "num_eef": 1
    },
    "model": {
        "ac_horizon": 8,
    },
    "eval": {
        "consistency": {
            "sample_sizes": [128, 256],
            "error_fns": ["mse_all", "mmd_rbf_all_median", "kde_kl_all_for_eig"], 
            "pred_horizons": [16],
            "aggr_fns": ["min"]
        },
        "quantiles": [0.95],
    }
}

cfg = omegaconf.OmegaConf.create(conf_dict)

dataset_path = "logs/datasets/no_domain_randomization_v5"

In [18]:
# Run eval
output_data = evaluate_temporal_consistency(cfg, dataset_path, dataset_path)


Episode Results: ep_iid_cum | pred_horizon_16_sample_size_128_error_fn_mse_all_aggr_fn_min_quantile_0.95
TPR: 0.46 | TNR: 0.65 | Acc: 0.55 | Bal. Acc: 0.55
TP Time 108.78 (87.47)

Episode Results: ep_iid_cum | pred_horizon_16_sample_size_128_error_fn_mmd_rbf_all_median_quantile_0.95
TPR: 0.69 | TNR: 0.55 | Acc: 0.62 | Bal. Acc: 0.62
TP Time 127.33 (83.51)

Episode Results: ep_iid_cum | pred_horizon_16_sample_size_128_error_fn_kde_kl_all_for_eig_quantile_0.95
TPR: 0.70 | TNR: 0.55 | Acc: 0.63 | Bal. Acc: 0.62
TP Time 136.07 (84.02)
