In [1]:
from collections import defaultdict
from typing import Dict

import numpy as np
from sqlalchemy.orm import Session

from ares.configs.base import Rollout
from ares.configs.pydantic_sql_helpers import recreate_model
from ares.databases.embedding_database import (
    BASE_EMBEDDING_DB_PATH,
    TEST_EMBEDDING_DB_PATH,
    FaissIndex,
    IndexManager,
)
from ares.databases.structured_database import (
    TEST_ROBOT_DB_PATH,
    RolloutSQLModel,
    setup_database,
)

In [None]:
!pip install matplotlib

In [3]:
from ares.databases.embedding_database import IndexManager, FaissIndex, TEST_EMBEDDING_DB_PATH

index_manager = IndexManager(TEST_EMBEDDING_DB_PATH, FaissIndex)


In [None]:
all_vecs = index_manager.get_all_matrices()

for name, vecs in all_vecs.items():
    print(name, vecs.shape)

In [5]:

def get_dataset_statistics(arr: np.ndarray):
    # get mean and std for each last dimension
    # each is n_samples x n_timesteps x n_dims
    assert len(arr.shape) == 3
    return {
        'mean': np.mean(arr, axis=(0, 1)),
        'std': np.std(arr, axis=(0, 1)),
        'min': np.min(arr, axis=(0, 1)),
        'max': np.max(arr, axis=(0, 1)),
        'q01': np.percentile(arr, 1, axis=(0, 1)),
        'q99': np.percentile(arr, 99, axis=(0, 1)),
        'num_transitions': arr.shape[0] * arr.shape[1],
        'num_trajectories': arr.shape[0],
    }

In [6]:
norms = {name: get_dataset_statistics(vecs) for name, vecs in all_vecs.items()}

In [7]:
from enum import Enum
import numpy as np
# taken from openvla/prismatic https://github.com/openvla/openvla/blob/main/prismatic/vla/datasets/rlds/utils/data_utils.py

class NormalizationType(str, Enum):
    NORMAL = "normal"               # Normalize to Mean = 0, Stdev = 1
    BOUNDS = "bounds"               # Normalize to Interval = [-1, 1]
    BOUNDS_Q99 = "bounds_q99"       # Normalize [quantile_01, ..., quantile_99] --> [-1, ..., 1]
      

def normalize_dataset(vecs: np.ndarray, stats: Dict[str, float], normalization_type: NormalizationType = NormalizationType.NORMAL):
    if normalization_type == NormalizationType.NORMAL:
        normed_vecs = (vecs - stats['mean']) / stats['std']
    elif normalization_type == NormalizationType.BOUNDS:
        normed_vecs = (vecs - stats['min']) / (stats['max'] - stats['min'])
    elif normalization_type == NormalizationType.BOUNDS_Q99:
        normalized = (vecs - stats['q01']) / (stats['q99'] - stats['q01'])
        normed_vecs = np.clip(normalized, -1, 1)
    return 2*normed_vecs - 1



In [8]:
norm_type = NormalizationType.BOUNDS_Q99
normed_vecs = {name: normalize_dataset(vecs, stats, norm_type) for name, vecs, stats in zip(all_vecs.keys(), all_vecs.values(), norms.values())}


In [None]:
import matplotlib.pyplot as plt

n_graphs = len(all_vecs)
# Create n rows and 2 columns
fig, axs = plt.subplots(
    n_graphs, 2, figsize=(15, 5 * n_graphs),  # Height scales with number of distributions
    gridspec_kw={'hspace': 0.5, 'wspace': 0.3}
)

for i, (name, vecs) in enumerate(all_vecs.items()):
    # Plot unnormalized data (left column)
    for j in range(vecs.shape[-1]):
        axs[i][0].hist(vecs[:, :, j].flatten(), bins=100, label=f"dim {j}")
        axs[i][0].set_title(f"{name} (unnormalized)")
        axs[i][0].legend()
    
    # Plot normalized data (right column)
    # these are in the range [-1, 1]
    normed_vec = normed_vecs[name]
    for j in range(normed_vec.shape[-1]):
        axs[i][1].hist(normed_vec[:, :, j].flatten(), bins=100, label=f"dim {j}", range=(-1, 1))
        axs[i][1].set_title(f"{name} (normalized via {norm_type})")
        axs[i][1].legend()

plt.tight_layout()


In [None]:
# import matplotlib.pyplot as plt

# n_graphs = len(all_vecs)
# # Create n rows and 2 columns
# fig, axs = plt.subplots(
#     n_graphs, 2, figsize=(15, 5 * n_graphs),  # Height scales with number of distributions
#     gridspec_kw={'hspace': 0.5, 'wspace': 0.3}
# )


# for i, (name, vecs) in enumerate(all_vecs.items()):
#     # Plot unnormalized data (left column)
#     for j in range(vecs.shape[-1]):
#         # axs[i][0].hist(vecs[:, :, j].flatten(), bins=100, label=f"dim {j}")
#         axs[i][0].set_title(f"{name} (unnormalized)")
#         axs[i][0].legend()
    
#     # Plot normalized data (right column)
#     # these are in the range [-1, 1]
#     normed_vec = normed_vecs[name]
#     for j in range(normed_vec.shape[-1]):
#         # axs[i][1].hist(normed_vec[:, :, j].flatten(), bins=100, label=f"dim {j}", range=(-1, 1))
#         axs[i][1].set_title(f"{name} (normalized via {norm_type})")
#         axs[i][1].legend()
TRAJS = [0, 1, 2, 3, 4]
THESE_VECS = {k: v[TRAJS] for k, v in all_vecs.items()}
THESE_NORMED_VECS = {k: v[TRAJS] for k, v in normed_vecs.items()}
print({k: v.shape for k, v in THESE_VECS.items()})
# these_vecs are [n_trajectory x n_timesteps x n_dims]


In [20]:
def plot_trajectories(normed_vecs_dict, traj_indices, highlight_idx=None, figsize=(15,15)):
    """Plot trajectories with optional highlighting of a specific trajectory.
    
    Args:
        normed_vecs_dict: Dictionary of normalized vectors to plot
        traj_indices: List of trajectory indices to plot
        highlight_idx: Index of trajectory to highlight (from traj_indices)
        figsize: Figure size tuple (width, height)
    """
    for name, normed_vecs in normed_vecs_dict.items():
        n_dims = normed_vecs.shape[-1]
        
        # Calculate grid dimensions for square layout
        grid_size = int(np.ceil(np.sqrt(n_dims)))
        
        # Create subplot grid in square layout
        fig, axs = plt.subplots(
            grid_size, grid_size, figsize=figsize,
            gridspec_kw={'hspace': 0.4, 'wspace': 0.3}
        )
        
        # Flatten axes array for easier indexing
        axs = axs.flatten()

        # Plot normalized trajectories
        for dim in range(n_dims):
            for i, traj in enumerate(traj_indices):
                if highlight_idx is not None and traj == traj_indices[highlight_idx]:
                    # Highlighted trajectory - bold red line
                    axs[dim].plot(normed_vecs[i, :, dim], 
                                color='red', linewidth=2.5,
                                label=f'Traj {traj} (highlighted)')
                else:
                    # Other trajectories - thin gray lines
                    axs[dim].plot(normed_vecs[i, :, dim],
                                color='gray', alpha=0.5, linewidth=1,
                                label=f'Traj {traj}')
                    
            axs[dim].set_title(f'{name} Dim {dim} (Normalized)')
            axs[dim].set_xlabel('Timestep')
            axs[dim].set_ylabel('Value')
            axs[dim].legend()
        
        # Remove any empty subplots
        for i in range(n_dims, len(axs)):
            fig.delaxes(axs[i])
        
        plt.suptitle(f'{name} Trajectories', y=1.02, fontsize=16)
        plt.tight_layout()
        plt.show()


In [None]:
plot_trajectories(THESE_NORMED_VECS, TRAJS, highlight_idx=2)
