In [1]:
import pickle
import numpy as np
import torch

In [None]:
# Load both pickle files
obs_low, inputs_low, images_low, n_state_low = pickle.load(open("../obs_low.pkl", "rb"))
obs_high, inputs_high, images_high, n_state_high = pickle.load(open("../obs_high.pkl", "rb"))

print("Loaded both pickle files")
print(f"obs_low keys: {list(obs_low.keys())}")
print(f"obs_high keys: {list(obs_high.keys())}")
print(f"inputs_low keys: {list(inputs_low.keys())}")
print(f"inputs_high keys: {list(inputs_high.keys())}")
print(f"n_state_low shape: {n_state_low.shape}")
print(f"n_state_high shape: {n_state_high.shape}")


Loaded both pickle files
obs_low keys: ['timestamp', 'left', 'right', 'top_camera', 'right_camera', 'left_camera', 'gt_action_chunks', 'timestamp_end', 'cur_step']
obs_high keys: ['timestamp', 'left', 'right', 'top_camera', 'right_camera', 'left_camera', 'gt_action_chunks', 'timestamp_end', 'cur_step']
inputs_low keys: ['input_ids', 'attention_mask', 'images', 'pooled_patches_idx', 'image_masks']
inputs_high keys: ['input_ids', 'attention_mask', 'images', 'pooled_patches_idx', 'image_masks']
n_state_low shape: (14,)
n_state_high shape: (14,)


In [4]:
# Compare n_state
print("=" * 80)
print("COMPARING n_state")
print("=" * 80)
print(f"n_state_low shape: {n_state_low.shape}, dtype: {n_state_low.dtype}")
print(f"n_state_high shape: {n_state_high.shape}, dtype: {n_state_high.dtype}")
print(f"Shapes match: {n_state_low.shape == n_state_high.shape}")
if n_state_low.shape == n_state_high.shape:
    diff = np.abs(n_state_low - n_state_high)
    print(f"Max difference: {diff.max()}")
    print(f"Mean difference: {diff.mean()}")
    print(f"Min difference: {diff.min()}")
    print(f"Number of non-zero differences: {np.count_nonzero(diff)}")
    print(f"n_state_low stats: min={n_state_low.min():.6f}, max={n_state_low.max():.6f}, mean={n_state_low.mean():.6f}")
    print(f"n_state_high stats: min={n_state_high.min():.6f}, max={n_state_high.max():.6f}, mean={n_state_high.mean():.6f}")
else:
    print("Shapes don't match, cannot compare directly")


COMPARING n_state
n_state_low shape: (14,), dtype: float32
n_state_high shape: (14,), dtype: float32
Shapes match: True
Max difference: 0.0
Mean difference: 0.0
Min difference: 0.0
Number of non-zero differences: 0
n_state_low stats: min=-0.759144, max=0.998686, mean=-0.092655
n_state_high stats: min=-0.759144, max=0.998686, mean=-0.092655


In [6]:
# Compare inputs dictionary
print("=" * 80)
print("COMPARING inputs")
print("=" * 80)
print(f"inputs_low keys: {list(inputs_low.keys())}")
print(f"inputs_high keys: {list(inputs_high.keys())}")
print(f"Keys match: {set(inputs_low.keys()) == set(inputs_high.keys())}")

for key in inputs_low.keys():
    print(f"\n--- Comparing '{key}' ---")
    val_low = inputs_low[key]
    val_high = inputs_high[key]
    
    if isinstance(val_low, torch.Tensor):
        print(f"  Type: torch.Tensor")
        print(f"  Shape low: {val_low.shape}, dtype: {val_low.dtype}")
        print(f"  Shape high: {val_high.shape}, dtype: {val_high.dtype}")
        if val_low.shape == val_high.shape:
            diff = torch.abs(val_low.float() - val_high.float())
            print(f"  Max difference: {diff.max().item():.6f}")
            print(f"  Mean difference: {diff.mean().item():.6f}")
            print(f"  Number of non-zero differences: {torch.count_nonzero(diff).item()}")
            # For integer tensors, show integer stats; for float tensors, show float stats
            if val_low.dtype in (torch.int64, torch.int32, torch.int16, torch.int8):
                print(f"  Low stats: min={val_low.min().item()}, max={val_low.max().item()}, mean={val_low.float().mean().item():.6f}")
                print(f"  High stats: min={val_high.min().item()}, max={val_high.max().item()}, mean={val_high.float().mean().item():.6f}")
            else:
                print(f"  Low stats: min={val_low.min().item():.6f}, max={val_low.max().item():.6f}, mean={val_low.mean().item():.6f}")
                print(f"  High stats: min={val_high.min().item():.6f}, max={val_high.max().item():.6f}, mean={val_high.mean().item():.6f}")
        else:
            print(f"  Shapes don't match!")
    elif isinstance(val_low, list):
        print(f"  Type: list, length low: {len(val_low)}, length high: {len(val_high)}")
        if len(val_low) == len(val_high):
            for i, (v_low, v_high) in enumerate(zip(val_low, val_high)):
                if isinstance(v_low, torch.Tensor):
                    if v_low.shape == v_high.shape:
                        diff = torch.abs(v_low.float() - v_high.float())
                        print(f"    Item {i}: max_diff={diff.max().item():.6f}, mean_diff={diff.mean().item():.6f}")
                    else:
                        print(f"    Item {i}: shapes don't match ({v_low.shape} vs {v_high.shape})")
                elif isinstance(v_low, np.ndarray):
                    if v_low.shape == v_high.shape:
                        diff = np.abs(v_low - v_high)
                        print(f"    Item {i}: max_diff={diff.max():.6f}, mean_diff={diff.mean():.6f}")
                    else:
                        print(f"    Item {i}: shapes don't match ({v_low.shape} vs {v_high.shape})")
                else:
                    print(f"    Item {i}: types don't match or not comparable")
        else:
            print(f"  List lengths don't match!")
    else:
        print(f"  Type: {type(val_low)}, value_low: {val_low}, value_high: {val_high}")
        print(f"  Values equal: {val_low == val_high}")


COMPARING inputs
inputs_low keys: ['input_ids', 'attention_mask', 'images', 'pooled_patches_idx', 'image_masks']
inputs_high keys: ['input_ids', 'attention_mask', 'images', 'pooled_patches_idx', 'image_masks']
Keys match: True

--- Comparing 'input_ids' ---
  Type: torch.Tensor
  Shape low: torch.Size([1, 665]), dtype: torch.int64
  Shape high: torch.Size([1, 665]), dtype: torch.int64
  Max difference: 0.000000
  Mean difference: 0.000000
  Number of non-zero differences: 0
  Low stats: min=13, max=152067, mean=145751.656250
  High stats: min=13, max=152067, mean=145751.656250

--- Comparing 'attention_mask' ---
  Type: torch.Tensor
  Shape low: torch.Size([1, 665]), dtype: torch.int64
  Shape high: torch.Size([1, 665]), dtype: torch.int64
  Max difference: 0.000000
  Mean difference: 0.000000
  Number of non-zero differences: 0
  Low stats: min=1, max=1, mean=1.000000
  High stats: min=1, max=1, mean=1.000000

--- Comparing 'images' ---
  Type: torch.Tensor
  Shape low: torch.Size([1,

In [7]:
# Compare obs dictionary
print("=" * 80)
print("COMPARING obs")
print("=" * 80)
print(f"obs_low keys: {list(obs_low.keys())}")
print(f"obs_high keys: {list(obs_high.keys())}")
print(f"Keys match: {set(obs_low.keys()) == set(obs_high.keys())}")

def compare_dict_recursive(dict_low, dict_high, prefix=""):
    """Recursively compare two dictionaries."""
    all_keys = set(dict_low.keys()) | set(dict_high.keys())
    for key in sorted(all_keys):
        key_path = f"{prefix}.{key}" if prefix else key
        if key not in dict_low:
            print(f"{key_path}: Only in high")
            continue
        if key not in dict_high:
            print(f"{key_path}: Only in low")
            continue
            
        val_low = dict_low[key]
        val_high = dict_high[key]
        
        if isinstance(val_low, dict):
            print(f"{key_path}: dict (nested)")
            compare_dict_recursive(val_low, val_high, key_path)
        elif isinstance(val_low, np.ndarray):
            print(f"{key_path}:")
            print(f"  Shape low: {val_low.shape}, dtype: {val_low.dtype}")
            print(f"  Shape high: {val_high.shape}, dtype: {val_high.dtype}")
            if val_low.shape == val_high.shape and val_low.dtype == val_high.dtype:
                diff = np.abs(val_low - val_high)
                print(f"  Max difference: {diff.max():.6f}")
                print(f"  Mean difference: {diff.mean():.6f}")
                print(f"  Number of non-zero differences: {np.count_nonzero(diff)}")
                if np.count_nonzero(diff) > 0:
                    print(f"  Low stats: min={val_low.min():.6f}, max={val_low.max():.6f}, mean={val_low.mean():.6f}")
                    print(f"  High stats: min={val_high.min():.6f}, max={val_high.max():.6f}, mean={val_high.mean():.6f}")
            else:
                print(f"  Shapes or dtypes don't match!")
        elif isinstance(val_low, (int, float, str, bool)):
            if val_low != val_high:
                print(f"{key_path}: {val_low} != {val_high}")
            else:
                print(f"{key_path}: {val_low} (equal)")
        else:
            print(f"{key_path}: type {type(val_low)}, low={val_low}, high={val_high}")

compare_dict_recursive(obs_low, obs_high)


COMPARING obs
obs_low keys: ['timestamp', 'left', 'right', 'top_camera', 'right_camera', 'left_camera', 'gt_action_chunks', 'timestamp_end', 'cur_step']
obs_high keys: ['timestamp', 'left', 'right', 'top_camera', 'right_camera', 'left_camera', 'gt_action_chunks', 'timestamp_end', 'cur_step']
Keys match: True
cur_step: 0 (equal)
gt_action_chunks:
  Shape low: (16, 14), dtype: float64
  Shape high: (16, 14), dtype: float64
  Max difference: 0.000000
  Mean difference: 0.000000
  Number of non-zero differences: 0
left: dict (nested)
left.gripper_pos:
  Shape low: (1,), dtype: float64
  Shape high: (1,), dtype: float64
  Max difference: 0.000000
  Mean difference: 0.000000
  Number of non-zero differences: 0
left.joint_pos:
  Shape low: (6,), dtype: float64
  Shape high: (6,), dtype: float64
  Max difference: 0.000000
  Mean difference: 0.000000
  Number of non-zero differences: 0
left_camera: dict (nested)
left_camera.images: dict (nested)
left_camera.images.rgb:
  Shape low: (168, 224, 3

In [8]:
# Compare images in detail
print("=" * 80)
print("COMPARING IMAGES")
print("=" * 80)

# Compare images from inputs dictionary
print("\n--- Images from inputs['images'] ---")
img_low = inputs_low['images']
img_high = inputs_high['images']
print(f"Shape low: {img_low.shape}, dtype: {img_low.dtype}")
print(f"Shape high: {img_high.shape}, dtype: {img_high.dtype}")

if img_low.shape == img_high.shape:
    diff = torch.abs(img_low - img_high)
    print(f"Max difference: {diff.max().item():.6f}")
    print(f"Mean difference: {diff.mean().item():.6f}")
    print(f"Number of non-zero differences: {torch.count_nonzero(diff).item()}")
    print(f"Percentage of pixels with differences: {100 * torch.count_nonzero(diff).item() / diff.numel():.2f}%")
    
    # Per-channel statistics
    if len(img_low.shape) >= 3 and img_low.shape[1] == 3:  # Assuming CHW format
        print("\nPer-channel differences:")
        for c in range(3):
            channel_diff = diff[0, c, :, :] if len(diff.shape) == 4 else diff[c, :, :]
            print(f"  Channel {c}: max={channel_diff.max().item():.6f}, mean={channel_diff.mean().item():.6f}")
    
    # Show where differences are largest
    max_diff_per_pixel = diff.max(dim=1)[0] if len(diff.shape) == 4 else diff
    print(f"\nMax difference per pixel (across channels): max={max_diff_per_pixel.max().item():.6f}, mean={max_diff_per_pixel.mean().item():.6f}")
    
    print(f"\nLow image stats: min={img_low.min().item():.6f}, max={img_low.max().item():.6f}, mean={img_low.mean().item():.6f}")
    print(f"High image stats: min={img_high.min().item():.6f}, max={img_high.max().item():.6f}, mean={img_high.mean().item():.6f}")
else:
    print("Shapes don't match!")

# Compare separate images if they exist
if 'images_low' in locals() and 'images_high' in locals():
    print("\n--- Separate images (images_low/images_high) ---")
    if isinstance(images_low, list) and isinstance(images_high, list):
        print(f"Number of images: low={len(images_low)}, high={len(images_high)}")
        for i, (img_l, img_h) in enumerate(zip(images_low, images_high)):
            print(f"\n  Image {i}:")
            if isinstance(img_l, np.ndarray) and isinstance(img_h, np.ndarray):
                print(f"  Shape low: {img_l.shape}, dtype: {img_l.dtype}")
                print(f"  Shape high: {img_h.shape}, dtype: {img_h.dtype}")
                if img_l.shape == img_h.shape and img_l.dtype == img_h.dtype:
                    diff = np.abs(img_l.astype(np.float32) - img_h.astype(np.float32))
                    print(f"  Max difference: {diff.max():.6f}")
                    print(f"  Mean difference: {diff.mean():.6f}")
                    print(f"  Number of non-zero differences: {np.count_nonzero(diff)}")
                    print(f"  Percentage of pixels with differences: {100 * np.count_nonzero(diff) / diff.size:.2f}%")
                    print(f"  Low stats: min={img_l.min()}, max={img_l.max()}, mean={img_l.mean():.6f}")
                    print(f"  High stats: min={img_h.min()}, max={img_h.max()}, mean={img_h.mean():.6f}")
                else:
                    print(f"  Shapes or dtypes don't match!")
            else:
                print(f"  Types: low={type(img_l)}, high={type(img_h)}")
    else:
        print(f"Types: low={type(images_low)}, high={type(images_high)}")


COMPARING IMAGES

--- Images from inputs['images'] ---
Shape low: torch.Size([1, 3, 729, 588]), dtype: torch.float32
Shape high: torch.Size([1, 3, 729, 588]), dtype: torch.float32
Max difference: 0.376471
Mean difference: 0.009732
Number of non-zero differences: 805739
Percentage of pixels with differences: 62.66%

Per-channel differences:
  Channel 0: max=0.376471, mean=0.012190
  Channel 1: max=0.243137, mean=0.007640
  Channel 2: max=0.329412, mean=0.009366

Max difference per pixel (across channels): max=0.376471, mean=0.019236

Low image stats: min=-0.992157, max=0.788235, mean=-0.047408
High image stats: min=-0.976471, max=0.882353, mean=-0.047634


In [None]:
# Compare camera images from obs dictionary
print("=" * 80)
print("COMPARING CAMERA IMAGES FROM OBS")
print("=" * 80)

camera_names = ['top_camera', 'left_camera', 'right_camera']
for camera_name in camera_names:
    if camera_name in obs_low and camera_name in obs_high:
        print(f"\n--- {camera_name} ---")
        if 'images' in obs_low[camera_name] and 'images' in obs_high[camera_name]:
            if 'rgb' in obs_low[camera_name]['images'] and 'rgb' in obs_high[camera_name]['images']:
                img_low = obs_low[camera_name]['images']['rgb']
                img_high = obs_high[camera_name]['images']['rgb']
                print(f"  Shape low: {img_low.shape}, dtype: {img_low.dtype}")
                print(f"  Shape high: {img_high.shape}, dtype: {img_high.dtype}")
                
                if img_low.shape == img_high.shape and img_low.dtype == img_high.dtype:
                    diff = np.abs(img_low.astype(np.float32) - img_high.astype(np.float32))
                    print(f"  Max difference: {diff.max():.6f}")
                    print(f"  Mean difference: {diff.mean():.6f}")
                    print(f"  Number of non-zero differences: {np.count_nonzero(diff)}")
                    print(f"  Percentage of pixels with differences: {100 * np.count_nonzero(diff) / diff.size:.2f}%")
                    
                    # Per-channel statistics
                    if len(img_low.shape) == 3 and img_low.shape[2] == 3:  # HWC format
                        print(f"  Per-channel differences:")
                        for c in range(3):
                            channel_diff = diff[:, :, c]
                            print(f"    Channel {c}: max={channel_diff.max():.6f}, mean={channel_diff.mean():.6f}")
                    
                    print(f"  Low stats: min={img_low.min()}, max={img_low.max()}, mean={img_low.mean():.6f}")
                    print(f"  High stats: min={img_high.min()}, max={img_high.max()}, mean={img_high.mean():.6f}")
                else:
                    print(f"  Shapes or dtypes don't match!")
                    print(f"  Note: Low image is {img_low.shape}, High image is {img_high.shape}")
                    # Try to resize and compare if shapes are different
                    if img_low.dtype == img_high.dtype:
                        from PIL import Image
                        img_low_resized = np.array(Image.fromarray(img_low).resize((img_high.shape[1], img_high.shape[0])))
                        diff = np.abs(img_low_resized.astype(np.float32) - img_high.astype(np.float32))
                        print(f"  After resizing low to high shape:")
                        print(f"    Max difference: {diff.max():.6f}")
                        print(f"    Mean difference: {diff.mean():.6f}")
                        print(f"    Number of non-zero differences: {np.count_nonzero(diff)}")
                        print(f"    Percentage of pixels with differences: {100 * np.count_nonzero(diff) / diff.size:.2f}%")
