In [None]:
%load_ext autoreload
%autoreload 2
import os
import json
from pathlib import Path
from cvla.utils_traj_tokens import getActionEncInstance

os.environ["CUDA_VISIBLE_DEVICES"]="0"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

dataset_location = "/tmp/clevr-act-7-depth"
#dataset_location = "/data/lmbraid19/argusm/datasets/cvla-droid-block-simple-v3"
#dataset_location = "/data/lmbraid19/argusm/datasets/cvla-droid-block-v3"
#dataset_location = "/data/lmbraid19/argusm/datasets/cvla-droid-1of5c-v1"
dataset_location = Path(dataset_location)

model_location = Path("/data/lmbraid19/argusm/models/")
#model_path = model_location / "clevr-act-7-depth_text_aug" / "checkpoint-4687"
#model_path = model_location / "clevr-act-7-depth_e512s" / "checkpoint-4687"
model_path = model_location / "mix30obj_text_debug" / "checkpoint-4687"
#model_path = model_location / "mix30obj_depth" / "checkpoint-4687"

#model_path = model_location / "clevr-act-7-depth_e512s_depth" / "checkpoint-4687"
#model_path = model_location / "clevr-act-7-depth_depthaug" / "checkpoint-4687"

# some processing
info_file = model_path.parent / "cvla_info.json"
try:
    with open(info_file, "r") as f:
        model_info = json.load(f)
except FileNotFoundError:
    model_info = None

if model_info is not None:
    action_encoder = model_info["action_encoder"]
    return_depth = model_info["return_depth"]
else:
    action_encoder = "xyzrotvec-cam-1024xy"
    return_depth = False
    if "_depth" in str(model_path):
        return_depth = True

enc_model = getActionEncInstance(action_encoder)
dataset_name = dataset_location.name
model_name = model_path.parent.name

print()
print("dataset:".ljust(10), dataset_name, dataset_location)
if model_path.is_dir():
    print("model:".ljust(10), model_name,"\t", model_path)
    print("encoder".ljust(10), action_encoder)
    print("depth:".ljust(10), return_depth)

In [None]:
#!nvidia-smi

In [None]:
from cvla.data_loader_jsonl import JSONLDataset
from cvla.data_loader_h5 import H5Dataset
from cvla.data_augmentations import CropMiddle


if "real" in str(dataset_location):
    crop_augmentation = CropMiddle(crop_size=600, object_size=100, valid=True)
    test_dataset = JSONLDataset(
        jsonl_file_path=f"{dataset_location}/_annotations.valid.jsonl",
        image_directory_path=f"{dataset_location}/dataset",
        clean_prompt=True,
        return_depth=return_depth,
        limit_samples=200,
        augment_crop=crop_augmentation,
    )
    test_dataset.action_encoder = getActionEncInstance("xyzrotvec-cam-1024xy")
else:
    test_dataset = H5Dataset(dataset_location, return_depth=return_depth, action_encoder="xyzrotvec-cam-512xy")
    
print("dataset len", len(test_dataset))
enc_model = enc_model
enc_data = test_dataset.action_encoder

In [None]:
from IPython.display import display, HTML
from tqdm.notebook import tqdm
from cvla.utils_vis import render_example

def get_image(images):
    if isinstance(images, (list, tuple)):
        return images[-1]
    else:
        return images
    
def get_depth(images):
    if isinstance(images, (list, tuple)):
        return images[0]
    else:
        return None
    
print(len(test_dataset))
num_samples = min(3*2, len(test_dataset))
html_imgs = ""
for i in tqdm(range(num_samples)):
    images, sample = test_dataset[i]
    image = get_depth(images) if return_depth else get_image(images)
    html_imgs += render_example(image, label=sample["suffix"], text=sample["prefix"], camera=sample["camera"], enc=enc_data, enc_pred=enc_model)

display(HTML(html_imgs))

In [None]:
sample = test_dataset[i]
prefix = sample[1]["prefix"]
robot_state = prefix.split(" ")[-1]
enc_data.decode_trajectory(robot_state, sample[1]["camera"])


In [None]:
import torch
from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
TORCH_DTYPE = torch.bfloat16
print('Using device:', DEVICE)

MODEL_ID ="google/paligemma2-3b-pt-224"
processor = PaliGemmaProcessor.from_pretrained(MODEL_ID)
print("loaded processor.")
model = PaliGemmaForConditionalGeneration.from_pretrained(model_path, torch_dtype=TORCH_DTYPE, device_map="auto")

In [None]:
def collate_fn(batch):
    images, labels = zip(*batch)
    prefixes = ["<image>" + label["prefix"] for label in labels]
    inputs = processor(
        text=prefixes,
        images=images,
        return_tensors="pt",
        padding="longest"
    ).to(TORCH_DTYPE).to(DEVICE)
    return inputs

if return_depth:
    def collate_fn(batch):
        images, labels = zip(*batch)
        prefixes = ["<image><image>" + label["prefix"] for label in labels]
        images_flat = [img for img_list_x in images for img in img_list_x]
        inputs = processor(
            text=prefixes,
            images=images_flat,
            return_tensors="pt",
            padding="longest"
        ).to(TORCH_DTYPE).to(DEVICE)
        return inputs

In [None]:
from math import ceil
from tqdm.notebook import tqdm
from cvla.utils_vis import render_example

eval_batch_size = 8

#test_samples = eval_batch_size*3
test_samples = min(len(test_dataset), 10)
predictions = {}
for start_idx in tqdm(range(0, test_samples, eval_batch_size), total=ceil(test_samples / eval_batch_size)):
    batch_i = range(start_idx, min(start_idx + eval_batch_size, test_samples))
    batch = [test_dataset[i] for i in batch_i]
    inputs = collate_fn(batch)
    prefix_length = inputs["input_ids"].shape[-1]    

    with torch.inference_mode():
        generation = model.generate(**inputs, max_new_tokens=12, do_sample=False, use_cache=False)
        decoded = [processor.decode(x, skip_special_tokens=True) for x in generation[:, prefix_length:]]
        
    for k,v in zip(batch_i, decoded):
        predictions[k] = v 

In [None]:
plot_images = 160

html_imgs = ""
for _, (dataset_index, decoded_str) in zip(range(plot_images),predictions.items()):
    batch_entry = test_dataset[dataset_index]
    if return_depth:
        (depth, image), sample = batch_entry
    else:
        image, sample = batch_entry
    html_img = render_example(image, text=sample["prefix"], label=sample["suffix"], prediction=decoded_str, camera=sample["camera"],
                              enc=enc_data, enc_pred=enc_model)
    html_imgs += html_img

from IPython.display import display, HTML
display(HTML(html_imgs))

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial.transform import Rotation as R

dect_data = test_dataset.action_encoder.decode_trajectory
dect_model = enc_model.decode_trajectory
decc_data = test_dataset.action_encoder.decode_caption
decc_model = enc_model.decode_caption

all_data = dict(cam=dict(pred=dict(orn=[], pos=[]),
                         data=dict(orn=[], pos=[])),
                cart=dict(pred=dict(orn=[], pos=[]),
                          data=dict(orn=[], pos=[])))

for i, pred in predictions.items():
    entry = test_dataset[i][1] 
    if i == 0:
        print("image size:", entry["camera"].width, entry["camera"].height)
    suffix = entry["suffix"]
    camera = entry["camera"]
        
    for j, mode in enumerate(("cam", "cart")):    
        if mode == "cam":
            dec_func_data, dec_func_model = decc_data, decc_model
        elif mode == "cart":
            dec_func_data, dec_func_model = dect_data, dect_model

        try:
            pos_data, orn_data = dec_func_data(suffix, camera=camera)
            pos_pred, orn_pred = dec_func_model(pred, camera=camera)
        except ValueError:
            if j == 0:
                print("skipping", i, pred, len(pred))
            continue

        if mode == "cart":
            pos_data, orn_data = pos_data[0], orn_data[0]
            pos_pred, orn_pred = pos_pred[0], orn_pred[0]
            
        all_data[mode]["data"]["pos"].append(pos_data.numpy())
        all_data[mode]["pred"]["pos"].append(pos_pred.numpy())
        all_data[mode]["data"]["orn"].append(R.from_quat(orn_data.numpy(), scalar_first=True))
        all_data[mode]["pred"]["orn"].append(R.from_quat(orn_pred.numpy(), scalar_first=True))

for mode in all_data:
    for split in all_data[mode]:
        all_data[mode][split]["pos"] = np.array(all_data[mode][split]["pos"])

valid_diffs = dict()
for mode in ("cam", "cart"):
    valid_diff = all_data[mode]["data"]["pos"] - all_data[mode]["pred"]["pos"] # m to cm
    if mode == "cart":
        valid_diff = valid_diff*100  # m to cm
    if mode == "cam":
        valid_diff[:,:,2] = valid_diff[:,:,2]*100  # m to cm
    valid_orn_diffs = [(R.inv(r1)*r2) for r1, r2 in zip(all_data[mode]["data"]["orn"], all_data[mode]["pred"]["orn"])]
    valid_orn_diffs_deg = np.array([r1.magnitude() for r1 in valid_orn_diffs])*180/np.pi
    valid_orn_diffs_r = [r1.as_rotvec() for r1 in valid_orn_diffs]
    valid_diffs[mode] = np.concatenate((valid_diff, valid_orn_diffs_deg[:,:,np.newaxis]),axis=-1)

keypoint = ["object","container"]
action_labels = ["x","y","d","orn"]*2
units = dict(cam=["px","px","cm","deg"]*2, cart=["cm","cm","cm","deg"]*2)

plot_hist = True
if plot_hist:
    for mode, valid_diff in valid_diffs.items():
        valid_diff = valid_diffs[mode]
        fig, axes = plt.subplots(1, 4, figsize=(12, 12*1/4))  # 3 rows x 4 columns of histograms
        axes = axes.flatten()  # Flatten the 2D array of axes for easy iteration
        for i in range(2):
            for j in range(4):  # x, y, d
                axes[j].hist(valid_diff[:, i,j], bins=20, alpha=0.7,  edgecolor='black', label=keypoint[i])
                axes[j].set_title(f'Hist. {action_labels[j]} err.')
                axes[j].set_xlabel(f"{action_labels[j]} err. [{units[mode][j]}]")
                axes[j].set_ylabel('Frequency')
                axes[j].legend()
    plt.tight_layout()
    plt.show()

depth_str = "depth" if return_depth else "rgb   "
print(f"model {model_name} ({depth_str})\ndataset {dataset_name}")
print(f"valid_samples", len(valid_diffs["cam"]), "samples", len(predictions), "valid_rate", f"{len(valid_diffs["cam"])/len(predictions):0.2f}")
for mode, valid_diff in valid_diffs.items():
    for i, action_label in enumerate(action_labels[:4]):
        print(f"{action_label.ljust(2)} l2: {np.linalg.norm(valid_diff[:,:,i]):0.2f} {units[mode][i]} l1: {np.mean(np.abs(valid_diff[:,:,i])):0.2f} {units[mode][i]}")
    l1 = np.mean(np.abs(valid_diff))
    l1_2d = np.mean(np.abs(valid_diff[:,:,:2]))
    l1_depth = np.mean(np.abs(valid_diff[:,:,2]))
    l1_depth_obj = np.mean(np.abs(valid_diff[:,0,2]))
    out_str = f"{model_name} ({depth_str}) L1: {l1:0.3f} L1_2d {l1_2d:0.3f} "
    out_str += f"{units[mode][0]} L1_depth: {l1_depth:0.3f} {units[mode][2]} L1_depth_obj: {l1_depth_obj:0.3f} {units[mode][2]} {mode} "
    out_str += f"{dataset_name}-len={len(test_dataset)}"
    print(out_str)
    print()

In [None]:
import re

MOVE_REGEX = re.compile(r"^move\s+([\w',.-]+(?:\s+[\w',.-]+)*)\s+onto\s+([\w',.-]+(?:\s+[\w',.-]+)*)$")

text = "move mr. potohead onto blue block"
match = MOVE_REGEX.match(text.strip())
print(match.groups())


In [None]:
mix30obj_text (rgb   ) L1: 75.920 L1_2d 60.521 px L1_depth: 14.226 cm L1_depth_obj: 13.250 cm cam cvla-droid-1of5c-v1-len=200
clevr-act-7-depth_e512s (rgb   ) L1: 77.966 L1_2d 67.927 px L1_depth: 12.432 cm L1_depth_obj: 11.803 cm cam cvla-droid-1of5c-v1-len=200
clevr-act-7-depth_text_aug (rgb   ) L1: 92.865 L1_2d 95.115 px L1_depth: 16.793 cm L1_depth_obj: 13.647 cm cam cvla-droid-1of5c-v1-len=200


#clevr-act-7-depth_depthaug (depth) L1: 39.406 L1_depth: 8.897 cm L1_depth_obj: 5.356 cm cam
#clevr-act-7-depth_depthaug (depth) L1: 16.043 L1_depth: 3.531 cm L1_depth_obj: 2.930 cm cart

#clevr-act-7-depth_text_aug (rgb   ) L1: 35.151 L1_depth: 11.916 cm L1_depth_obj: 7.575 cm cam
#clevr-act-7-depth_text_aug (rgb   ) L1: 17.952 L1_depth: 4.256 cm L1_depth_obj: 3.442 cm cart



# RGB Networks
# Valid Samples: 157 L1: 35.67 (no-agug)
# Valid Samples: 160 L1: 34.35 (prompt cleaning) 33.89
# Valid Samples: 160 L1: 34.62 (prompt cleaning + simplify text)  -- similar

# Valid Samples: 160 L1: 24.96 (rgb20 = augmentation + random background (CVPR09-dataset))
# Valid Samples: 160 L1: 26.61 (rgb20 + simplify text) -- worse
# Valid Samples: 150 L1: 40.55 (augmentation + DROID background)  -- way worse

# With depth (only 31 samples for real data)
# Valid Samples: 31 L1: 29.24 (no text aug, NaN-to-Max eval)
# Valid Samples: 31 L1: 28.58  L1_depth 6.29375 (no text aug, Nan-to-Min eval)  
# Valid Samples: 160 L1: 27.93 6.29375 (no text aug, all-zero)

# clevr-act-7-depth_text_aug (rgb  ) L1: 13.994 L1_depth: 11.916 L1_depth_obj: 7.575
# clevr-act-7-depth_depthaug (depth) L1: 14.677 L1_depth: 8.897 L1_depth_obj: 5.356 



In [None]:
import numpy as np
import matplotlib.pyplot as plt
np.array([13.68125, 48.0125 ,  8.025  ,  8.475  , 20.1875 , 29.4    ,  # without masking v3
        5.2875 ,  4.00625, 15.20625,  8.475  , 20.1875 , 29.4    ])

np.array([ 5.76875, 14.11875,  7.275  , 36.9875 , 53.45   , 18.25625,  # with masking v2 (wrong orientation)
        8.94375,  7.0375 , 16.825  , 36.9875 , 53.45   , 18.25625])



eval_on_blocks = {"baseline":35.67, "+clean prompt": 33.89, "(simplify text)":34.62,
                  "rbg-20%": 24.96, "(droid bg)": 40.55, "text aug": 25.93}


# Extract labels and values
labels, values = zip(*eval_on_blocks.items())

# Plot
plt.figure(figsize=(10, 5))
plt.bar(labels, values, color='skyblue', edgecolor='black')

# Formatting
plt.ylabel("L1 Error Mean")
plt.title("Evaluation over Training Data")
plt.xticks(rotation=45, ha="right")  # Rotate x-axis labels for better readability
plt.xlabel("Experiment Runs")
plt.grid(axis="y", linestyle="--", alpha=0.7)  # Add grid lines for readability

# Show plot
plt.show()

In [None]:
plt.imshow(np.abs(valid_diffs[:,:,0:3].reshape(-1,6)))
worst_position = np.argsort(-np.max(np.abs(valid_diffs[:,:,2:3]),axis=(1,2)))

html_imgs = ""
for i in range(5):
    dataset_i = worst_position[i]
    batch_entry = test_dataset[dataset_i]
    if return_depth:
        (depth, image), sample = batch_entry
    else:
        image, sample = batch_entry

    decoded_str = predictions[dataset_i]
    html_img = render_example(image, text=sample["prefix"], label=sample["suffix"], prediction=decoded_str, camera=sample["camera"])
    html_imgs += html_img


display(HTML(html_imgs))
    

# Simulation Eval

In [None]:
%reload_ext autoreload
%autoreload 2
from pathlib import Path
from cvla.hf_model_class import cVLA_wrapped

model_location = Path("/data/lmbraid19/argusm/models/")
model_path = model_location / "clevr-act-7-depth_e512s" / "checkpoint-4687"
model_wrapped = cVLA_wrapped(model_path=model_path)

In [None]:
from mani_skill.examples.run_env import Args, iterate_env

parsed_args = Args()
parsed_args.env_id = "ClevrMove-v1"
parsed_args.render_mode = "rgb_array"
parsed_args.control_mode = "pd_joint_pos"
parsed_args.action_encoder = model_wrapped.enc_model.NAME
env_iter = iterate_env(parsed_args, vis=False, model=model_wrapped)

In [None]:
for i in range(10):
    _ = next(env_iter)