# Visualize Data

In [None]:
%load_ext autoreload
%autoreload 2
from pathlib import Path
from data_loader_h5 import H5Dataset
from data_loader_jsonl import JSONLDataset
from data_augmentations import RandomizeBackgrounds, augment_image_rgb, complexify_text

dataset_location = "/tmp/clevr-act-7-depth"
#train_dataset = H5Dataset(dataset_location, augment_text=complexify_text)
randomize_background = RandomizeBackgrounds(p=0.2, background_images = "/tmp/indoorCVPR/Images")
train_dataset = H5Dataset(dataset_location, augment_rgbds=randomize_background, augment_rgb=augment_image_rgb, augment_text=complexify_text)



#train_dataset = JSONLDataset(jsonl_file_path="/data/lmbraid19/argusm/datasets/clevr-real-block-v1")

#dataset_location = Path("/tmp/clevr-act-7-depth")
#randomize_background = RandomizeBackgrounds(p=0.2, background_images = "/tmp/indoorCVPR/Images")
#train_dataset = H5Dataset(dataset_location, augment_rgbds=randomize_background, augment_rgb=augment_image_rgb)

print(len(train_dataset))

In [None]:
from IPython.display import display, HTML
from tqdm.notebook import tqdm
from utils_vis import render_example

print(len(train_dataset))
num_samples = 25
html_imgs = ""
for i in tqdm(range(num_samples)):
    image, sample = train_dataset[i]
    html_imgs += render_example(image[1], label=sample["suffix"], text=sample["prefix"], camera=sample["camera"])    
display(HTML(html_imgs))

# Depth Stuff

In [6]:
dataset_location = "/tmp/clevr-act-7-depth"
train_dataset = H5Dataset(dataset_location, return_depth=True)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.cm import viridis

class DepthColorToNorm:
    def __init__(self, depth_min=0, depth_max=1023):
        viridis_values = np.linspace(0, 1, len(viridis.colors))
        # Create LUT dictionary mapping RGB -> value
        self.viridis_lut = {tuple(rgb[:3]): value for rgb, value in zip(viridis.colors, viridis_values)}  # Disel: only the brave.
        #self.viridis_lut = {rgb_to_key(rgb[:3]): value for rgb, value in zip(viridis.colors, viridis_values)}
        self.depth_min = depth_min
        self.depth_max = depth_max
        
    def __call__(self, array):
        if isinstance(array, np.ndarray):
            old_shape = array.shape
            recovered_flat = np.array([self.viridis_lut.get(tuple(color), 0) for color in array.reshape(-1, 3)])
            return recovered_flat.reshape(old_shape[:-1])
        elif isinstance(array, tuple):
            return self.viridis_lut.get(array, 0)
        else:
            raise ValueError
        
    # def rgb_to_key(rgb, precision=3):
    #         return tuple(np.round(rgb, precision))  # Round for consistent lookup

# Step 1: Create a random 4x4 array with values between 0 and 1
random_array = np.random.rand(4, 4)
viridis_mapped = viridis(random_array)[..., :3]  # Take only RGB, ignore alpha
color_to_norm = DepthColorToNorm()
recovered_array = color_to_norm(viridis_mapped)
np.all(np.isclose(random_array, recovered_array, atol=.01))

In [None]:
from utils_traj_tokens import decode_caption_xyzrotvec2
from tqdm.notebook import tqdm
all_tcp_zs = []
all_depths = []
for i in tqdm(range(200)):
    (depth, image), sample = train_dataset[i]
    curve_25d, quat_c = decode_caption_xyzrotvec2(sample["suffix"], sample["camera"])
    #suffix_int = [int(x) for x in re.findall(r"<(?:loc|seg)(\d+)>", sample["suffix"])]
    x, y = curve_25d[0][:2].round().numpy().astype(int)
    tcp_z_m = curve_25d[0][2].numpy()

    image_depth_color = tuple(depth[y, x])
    image_depth_norm = color_to_norm(image_depth_color)
    image_depth_m = image_depth_norm * 1023 / 1000

    all_tcp_zs.append(tcp_z_m)
    all_depths.append(image_depth_m)

In [None]:
print(np.mean(all_tcp_zs), np.mean(all_depths))

fig, ax = plt.subplots(1)
ax.scatter(all_tcp_zs, all_depths)
ax.set_xlabel("TCP zs [m]")
ax.set_ylabel("image depth [m]")
ax.plot((.2,.8),(.2,.8), color='k')
plt.show()

bad_idx = np.argmax(np.abs(np.array(all_tcp_zs) - np.array(all_depths)))
plt.imshow(train_dataset[bad_idx][0][1])
print(train_dataset[bad_idx][1]["prefix"])
plt.show()


In [None]:
rows, cols = 2, 3
fig, axes = plt.subplots(rows, cols, figsize=(12, 12*2/3))  # 3 rows x 4 columns of histograms
for c in range(cols):
    (depth, image), sample = train_dataset[c]

    curve_25d, quat_c = decode_caption_xyzrotvec2(sample["suffix"], sample["camera"])

    #suffix_int = [int(x) for x in re.findall(r"<(?:loc|seg)(\d+)>", sample["suffix"])]

    x,y = curve_25d[0][:2].round().numpy().astype(int)
    depth_val = curve_25d[0][2].numpy()
    depth_color = tuple(depth[y, x])
    print(depth_color)
    depth_norm_float = color_to_norm(depth_color)
    depth_float = depth_norm_float * 1023 / 10
    
    print(depth_float, depth_val)
    
    d = curve_25d[0][2]
    #image[y:y+10,x:x+10] = [255,0,0]

    for r in range(rows//2):
        axes[r][c].imshow(depth)
        axes[r][c].set_title("depth")
        axes[r+1][c].imshow(image)
        axes[r+1][c].set_title(sample["prefix"].split("<")[0])


# ManiSkill H5 Datasets

In [None]:
import glob
from pathlib import Path
import h5py

def print_h5_structure(name, obj):
    """
    Prints the path, type, and shape of each member in an h5 file.
    """
    # Check if the object is a dataset
    if isinstance(obj, h5py.Dataset):
        print(f"Dataset: {name} | Shape: {obj.shape} | Type: {obj.dtype}")
    # If it's a group, just print its name
    elif isinstance(obj, h5py.Group):
        print(f"Group: {name}")

traj_path = Path("/data/lmbraid19/argusm/datasets/clevr-act-9-ms-small/20250205_182607.h5")
idx = 0

with h5py.File(traj_path, "r") as h5_file:
    print(h5_file['traj_0/obs/sensor_param/render_camera'].keys())
    print_h5_structure(f"traj_{idx}", h5_file)

In [None]:
def h5_tree(val, pre=''):
    items = len(val)
    for key, val in val.items():
        items -= 1
        if items == 0:
            # the last item
            if type(val) == h5py._hl.group.Group:
                print(pre + '└── ' + key)
                h5_tree(val, pre+'    ')
            else:
                try:
                    print(pre + '└── ' + key + ' (%d)' % len(val))
                except TypeError:
                    print(pre + '└── ' + key + ' (scalar)')
        else:
            if type(val) == h5py._hl.group.Group:
                print(pre + '├── ' + key)
                h5_tree(val, pre+'│   ')
            else:
                try:
                    print(pre + '├── ' + key + ' (%d)' % len(val))
                except TypeError:
                    print(pre + '├── ' + key + ' (scalar)')

h5_tree(train_dataset.h5_file)