# Visualize Datasets

Compare two datasets to make sure they are reasonably similar.

## Load Data

In [None]:
import re
from pathlib import Path
from PIL import Image
from tqdm.notebook import tqdm
import numpy as np
from cvla.data_loader import JSONLDataset

ACTION_ENCODER = "xyzrotvec-cam-proj"

real_dataset_location = "/data/lmbraid19/argusm/datasets/clevr-real-block-v1"
real_dataset_location = Path(real_dataset_location)
real_dataset = JSONLDataset(
    jsonl_file_path=f"{real_dataset_location}/dataset/_annotations.valid.jsonl",
    image_directory_path=f"{real_dataset_location}/dataset",
)
sim_dataset_location = "/data/lmbraid19/argusm/datasets/clevr-act-6-var-cam"
sim_dataset_location = Path(sim_dataset_location)
sim_dataset = JSONLDataset(
    jsonl_file_path=f"{sim_dataset_location}/dataset/_annotations.valid.jsonl",
    image_directory_path=f"{sim_dataset_location}/dataset",
)

pred_list = []
valid_sample = []

num_samples = len(real_dataset)
suffix_rs = []
suffix_ss = []
for i in tqdm(range(num_samples)):
    image_r, label_real = real_dataset[i]
    image_s, label_sim = sim_dataset[i]
    suffix_real = [int(x) for x in re.findall(r"<loc(\d{4})>", label_real["suffix"])]
    suffix_sim = [int(x) for x in re.findall(r"<loc(\d{4})>", label_sim["suffix"])]
    suffix_rs.append(suffix_real)
    suffix_ss.append(suffix_sim)

suffix_rs = np.array(suffix_rs)
suffix_ss = np.array(suffix_ss)


## Plot Raw Value Histograms

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(12, 12*1/3))  # 3 rows x 4 columns of histograms

names = "real", "sim"
data = (suffix_rs, suffix_ss)
for i in range(2):
    i=i+0
    axes[0].scatter(data[i][:,0], data[i][:,1], alpha=0.7, label=names[i])
    axes[0].scatter(data[i][:,6], data[i][:,7], alpha=0.7, label=names[i]+"-dst")
    axes[0].legend()
    axes[1].hist(data[i][:,2],  bins=10, alpha=0.5,  edgecolor='black',label=names[i])
    axes[1].hist(data[i][:,2+6],  bins=10, alpha=0.5,  edgecolor='black',label=names[i]+"-dst")
    axes[1].legend()
    
#axes[0].set_title(f'Hist {action_labels[i]}')
#axes[0].set_xlabel(units[i])
#axes[0].set_ylabel('Frequency')



## Analyze Text

In [None]:
#!pip install spacy
#!python -m spacy download en_core_web_sm

In [None]:
import spacy
nlp = spacy.load("en_core_web_sm")

def extract_components(sentence):
    doc = nlp(sentence)
    verb, obj, prep, prep_obj = None, None, None, None

    for token in doc:
        if token.pos_ == "VERB":  # Extract the main verb
            verb = token.text
        elif token.dep_ == "dobj":  # Extract direct object
            obj = token.text
        elif token.dep_ == "prep":  # Extract preposition
            prep = token.text
        elif token.dep_ == "pobj":  # Extract prepositional object
            prep_obj = token.text

    return dict(verb=verb, obj=obj, prep=prep, prep_obj=prep_obj)

list_of_comp_real = []
list_of_comp_sim = []
texts = []
num_samples = len(real_dataset)
for i in tqdm(range(num_samples)):
    image_r, label_real = real_dataset[i]
    text_real = label_real["prefix"].split("<")[0].replace("\n","").lower()
    components_real = extract_components(text_real)
    list_of_comp_real.append(components_real)
    texts.append(text_real)

    image_s, label_sim = sim_dataset[i]
    text_sim = label_sim["prefix"].split("<")[0]
    components_sim = extract_components(text_sim)
    list_of_comp_sim.append(components_sim)

In [None]:
from collections import Counter, defaultdict

def count_value_frequencies(list_of_dicts):
    key_value_counter = defaultdict(Counter)
    
    # Iterate through each dictionary
    for d in list_of_dicts:
        for key, value in d.items():
            key_value_counter[key][value] += 1  # Count value occurrences per key

    return key_value_counter

key_value_counts_real = count_value_frequencies(list_of_comp_real)
key_value_counts_sim = count_value_frequencies(list_of_comp_sim)

# Print results
for key, counter in key_value_counts_real.items():
    print(f"Key: {key}")
    for value, freq in counter.most_common():  # Sorted by frequency
        freq_sim = key_value_counts_sim[key][value]
        print(f"  Value: {value}\t freq-real: {freq} freq_sim: {freq_sim}")
    for value, freq_sim in key_value_counts_sim[key].items():
        if value not in counter:
            print(f"  Value: {value}\t freq-real: {0} freq_sim: {freq_sim}")
    print("-" * 40)


fig, axes = plt.subplots(1, 4, figsize=(18, 12 * 1/4))  # 1x3 grid of histograms
axes = axes.flatten()  # Flatten axes for easy iteration
# Plotting histograms for each POS
for i, (pos, real_counts) in enumerate(key_value_counts_real.items()):
    # Extract values and their frequencies
    all_keys = list(set(real_counts.keys()).union(set(key_value_counts_sim[pos].keys())))


    real_freq = [key_value_counts_real[pos].get(value, 0) for value in all_keys]
    sim_freq = [key_value_counts_sim[pos].get(value, 0) for value in all_keys]
    
    # Get the current axis for plotting
    ax = axes[i]
    
    # Bar width
    width = 0.35  
    x = range(len(all_keys))
    
    # Create bars for real frequencies and simulated frequencies
    ax.bar(x, real_freq, width, label='Real Freq.',  alpha=0.7)
    ax.bar(x, sim_freq, width, label='Sim. Freq.',  alpha=0.7)
    
    # Labels and Title
    ax.set_xlabel('Text')
    ax.set_ylabel('Frequency')
    #ax.set_title(f'{pos.capitalize()} Frequency')
    ax.set_ylabel(f'{pos.capitalize()} Frequency')
    ax.set_xticks([i for i in x])
    ax.set_xticklabels(all_keys, rotation=90)
    ax.legend()
plt.tight_layout()  # Adjust layout for better spacing
plt.show()

# Generate Similar Texts

TOOD(maxim): Make a function to generate texts that mirror the original dataset

(We want to simplify the original dataset .lower(), no . at the end, no space at the end)

1. Option 1: define some templates + frequency
2. Option 2: cout word freqeuncy and follow `/ManiSkill/mani_skill/examples/utils_env_interventions.py`

In [None]:

print(len(texts))
print(len(set(texts)))
print(set(texts))

## Visualize Angles

TODO(maxim): The angle distribution should be similar. Find out why they are different, is it because of the data or (hopelfully) because of a shift between the coordinate systems.


In [None]:
import matplotlib.pyplot as plt

# fig = plt.figure()
# axes = fig.add_subplot(111, projection='3d')

# names = "real", "sim"
# data = (suffix_rs, suffix_ss)
# for i in range(2):
#     axes.scatter(data[i][:,3], data[i][:,4],data[i][:,5], alpha=0.5, label=names[i])
#     #axes.scatter(data[i][:,9], data[i][:,10],data[i][:,11],alpha=0.7, label=names[i]+"-dst")
#     axes.legend()
# plt.show()

from math import pi
fig, axes = plt.subplots(1, 4, figsize=(12, 12*1/4))
names = "real", "sim"
data = (suffix_rs, suffix_ss)
for i in range(2):
    for j in range(3):
        axes[j].hist(data[i][:,3+j], label=names[i], alpha=.7)
        #axes[j].hist(data[i][:,3+j]/100-pi, alpha=0.7, label=names[i])
        #axes[0].scatter(data[i][:,6], data[i][:,7], alpha=0.7, label=names[i]+"-dst")
        axes[j].legend()
    j = 3
    norm = np.linalg.norm(data[i][:,3:6]/100-pi, axis=1)
    axes[j].hist(norm, label=names[i], alpha=.7)
    axes[j].legend()
    
    


In [None]:
# Plot angles to gripper
# https://stackoverflow.com/questions/31768031/plotting-points-on-the-surface-of-a-sphere

from cvla.utils_traj_tokens import getActionEncDecFunction, decode_caption_xyzrotvec
from cvla.utils_trajectory import DummyCamera
from scipy.spatial.transform import Rotation as R
import torch

enc, dec = getActionEncDecFunction('xyzrotvec-cam-proj')
camera = DummyCamera(intrinsic_matrix=[], extrinsic_matrix=[], width=224, height=224)

def rotation_to_spherical(rot: R):
    """
    Convert a 3D rotation (given as a scipy Rotation object) into spherical coordinates.
    Returns:
    - azimuth (longitude) φ in radians
    - elevation (latitude) θ in radians
    - rotation angle α in radians
    """
    
    # Convert rotation to axis-angle representation
    axis, angle = rot.as_rotvec(), np.linalg.norm(rot.as_rotvec())
    # Normalize the rotation axis to ensure it's a unit vector
    if np.isclose(angle, 0):  # Handle zero rotation case
        return 0, 0, 0
    unit_axis = axis / angle  # Normalize axis to lie on unit sphere
    x, y, z = unit_axis
    # Compute spherical coordinates
    azimuth = np.arctan2(y, x)  # Longitude φ
    elevation = np.arcsin(z)     # Latitude θ
    return azimuth, elevation, angle  # (φ, θ, α)

datasets = (real_dataset, sim_dataset)
dataset_names = ("real", "sim")
sphericals = ([], [])

limit = min([len(x) for x in datasets])
for i in range(len(datasets)):
    for j in range(len(datasets[i])):    
        suffix = datasets[i][j][1]["suffix"]
        dec_gt = decode_caption_xyzrotvec(suffix, camera)
        orns_R = R.from_quat(dec_gt[1][:2], scalar_first=True)
        sphericals[i].append(rotation_to_spherical(orns_R[0]))
        if j == limit -1:
            break
    
sphericals = np.array(sphericals)

In [None]:
# TODO(maxim): ideally this plot looks similar to the xy image position ones where sim is a superset of real.

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
for i in range(len(datasets)):
    spherical_sim = sphericals[i]
    ax.scatter(spherical_sim[:,0],spherical_sim[:,1], spherical_sim[:,2],s=20, label=dataset_names[i])
plt.legend()
plt.show()

# Visualize Images

In [None]:
from cvla.utils_vis import render_example

html_imgs = ""
for i in tqdm(range(num_samples)):
    image_r, label_real = real_dataset[i]
    image_s, label_sim = sim_dataset[i]
    html_imgs += render_example(image_r, label=label_real["suffix"], prediction=None, text=label_real["prefix"])
    html_imgs += render_example(image_s, label=label_sim["suffix"], prediction=None, text=label_sim["prefix"])

    if i > 10:
        break
plot_images = True
if plot_images:
    from IPython.display import display, HTML
    display(HTML(html_imgs))