In [None]:
from pathlib import Path
from PIL import Image
from tqdm.notebook import tqdm
from jax_eval import VLAModel

model_path= "/data/lmbraid19/argusm/models/clevr-act-6-var-cam_jax-native-long/model/paligemma-3b-pt-224-step005624.f16.npz"
ACTION_ENCODER = "xyzrotvec-cam-proj"
#dataset_location = "/tmp/clevr-act-6-var-cam"
dataset_location = "/data/lmbraid19/argusm/datasets/clevr-real-block-v1"


#model_path = "/data/lmbraid19/argusm/models/clevr-act-6-var-rbt_jax-native-long-rbt/model/paligemma-3b-pt-224-step005624.f16.npz"
#dataset_location = "/tmp/clevr-act-6-var-rbt"
# ACTION_ENCODER = "xyzrotvec-rbt"

model = VLAModel(model_path, gpu_index=3)  # takes a while

# Eval on Validation Dataset

In [3]:
#for i in tqdm(range(len(valid_dataset))):
#    image, sample = valid_dataset[i]
#    suffix = sample["suffix"]
    #print(suffix)
    #suffix_p = [int(x) for x in re.findall(r"<loc(\d{4})>", suffix)]
    #print(suffix_p[3:6] == suffix_p[9:12])

In [None]:
from data_loader import JSONLDataset
from utils_vis import render_example

dataset_location = Path(dataset_location)

valid_dataset = JSONLDataset(
    jsonl_file_path=f"{dataset_location}/dataset/_annotations.valid.jsonl",
    image_directory_path=f"{dataset_location}/dataset",
    clean_prompt=True
)

pred_list = []
valid_sample = []
#print(len(valid_dataset))

#num_samples = 10
num_samples = len(valid_dataset)
html_imgs = ""
for i in tqdm(range(num_samples)):
    image, sample = valid_dataset[i+0]
    prefix = sample["prefix"]
    img_out, text, label, token_pred = model.make_predictions(image, prefix)

    pred_list.append(token_pred)
    valid_sample.append(i)
    
    html_imgs += render_example(image, label=sample["suffix"], prediction=token_pred, text=prefix)
print(pred_list)

plot_images = True
if plot_images:
    from IPython.display import display, HTML
    display(HTML(html_imgs))

In [None]:
import re
import numpy as np
import matplotlib.pyplot as plt
results = []
for j, pred in zip(valid_sample, pred_list):
    suffix = valid_dataset[j][1]["suffix"]
    suffix_p = [int(x) for x in re.findall(r"<loc(\d{4})>", suffix)]
    decode_p = [int(x) for x in re.findall(r"<loc(\d{4})>", pred)]
    if len(decode_p) != 12:
        continue
    pred_diff = np.array(suffix_p) - np.array(decode_p)
    results.append(pred_diff)
results = np.array(results)

results[:,0:2] = results[:, 0:2]/1024*224 
results[:,6:8] = results[:, 6:8]/1024*224 


keypoint = ["object","container"]
action_labels = ["x","y","depth","r1","r2","r3"]*2
units         = ["px","px","cm","rad","rad","rad"]*2
plot_hist=True
if plot_hist:
    fig, axes = plt.subplots(2, 3, figsize=(12, 12*2/3))  # 3 rows x 4 columns of histograms
    axes = axes.flatten()  # Flatten the 2D array of axes for easy iteration
    for i in range(2):
        for j in range(6):
            axes[j].hist(results[:, i*6+j], bins=20, alpha=0.7,  edgecolor='black', label=keypoint[i])
            axes[j].set_title(f'Hist. {action_labels[j]} err.')
            axes[j].set_xlabel(f"{action_labels[j]} err. [{units[j]}]")
            axes[j].set_ylabel('Frequency')
            axes[j].legend()
    plt.tight_layout()  # Adjust layout for better spacing
    plt.show()

print("Valid Samples:", len(results), "L1:",np.mean(np.abs(results)))


Valid Samples: 122 L1: 35.994535519125684 (no propt cleaning)
Valid Samples: 127 L1: 35.69750656167979  (prompt clearning)

In [None]:
from mani_skill.examples.utils_traj_tokens import getActionEncDecFunction, decode_caption_xyzrotvec
from mani_skill.examples.utils_trajectory import DummyCamera
from scipy.spatial.transform import Rotation as R
import torch

enc, dec = getActionEncDecFunction('xyzrotvec-cam-proj')
camera = DummyCamera(intrinsic_matrix=[], extrinsic_matrix=[], width=224, height=224)
decode_caption_xyzrotvec(valid_dataset[j][1]["suffix"], camera=camera)

delta_degrees = []
for j, pred in zip(valid_sample, pred_list):
    suffix = valid_dataset[j][1]["suffix"]
    dec_gt = decode_caption_xyzrotvec(suffix, camera)
    dec_pd = decode_caption_xyzrotvec(pred, camera)
    orns_R = R.from_quat(dec_gt[1][:2], scalar_first=True)
    orns_est_R = R.from_quat(dec_pd[1][:2], scalar_first=True)    
    magnitude_radians = (orns_est_R * orns_R.inv()).magnitude()
    angle_degrees = magnitude_radians * (180.0 / torch.pi)
    delta_degrees.append(angle_degrees)

delta_degrees = np.array(delta_degrees)

action_labels = ["Obj Rot.","Dest. Rot."]
units         = ["deg.", "deg."]
plot_hist=True
if plot_hist:
    fig, axes = plt.subplots(1, 2, figsize=(8,4))  # 3 rows x 4 columns of histograms
    axes = axes.flatten()  # Flatten the 2D array of axes for easy iteration
    for i in range(2):
        axes[i].hist(delta_degrees[:, i], bins=20, alpha=0.7, color='blue', edgecolor='black')
        axes[i].set_title(f'Hist {action_labels[i]}')
        axes[i].set_xlabel(units[i])
        axes[i].set_ylabel('Frequency')
        axes[i].set_xticks(np.arange(0, max(delta_degrees[:, i])+1, 20))
    plt.tight_layout()  # Adjust layout for better spacing
    plt.show()



In [None]:
dec_gt[0][:2]

# Eval on Env Frames

In [5]:
%reload_ext autoreload
%autoreload 2

# Run on Env
import json
from mani_skill.examples.run_env import Args, iterate_env

seed_fn = "/ihome/argusm/lang/ManiSkill/mani_skill/examples/seeds_valid.json"
with open(seed_fn, "r") as f_obj:
    seeds = json.load(f_obj)


env_args = Args()
env_args.env_id = "ClevrMove-v1"
env_args.render_mode = "rgb_array"
env_args.control_mode = "pd_joint_pos"
env_args.obs_mode = "rgb"
#env_args.shader = "rt"
env_args.seed = seeds
env_args.run_mode = "first"
env_iter = iterate_env(env_args, vis=False)

In [None]:
from IPython.display import display, HTML

images, meta_infos = [], []

for _ in tqdm(range(25)):
    image_before, json_dict, rnd_seed = next(env_iter)
    #plt.imshow(image_before)
    images.append(image_before)
    meta_infos.append(json_dict)

plot_images = True
if plot_images:
    html_imgs = ""
    for i,(image, info) in enumerate(zip(images, meta_infos)):
        html_imgs += render_example(Image.fromarray(image), info["suffix"], text=info["prefix"], prediction=None)
        if i == 2:
            break
    display(HTML(html_imgs))

In [7]:
preds_env = []
for image, info in zip(images, meta_infos):
    img_out, text, label, token_pred = model.make_predictions(image, info["prefix"])
    preds_env.append(token_pred)

In [None]:
from IPython.display import display, HTML
html_imgs = ""
for image, info, pred in zip(images, meta_infos, preds_env):
    html_imgs += render_example(Image.fromarray(image), info["suffix"], prediction=pred, text=info["prefix"])
display(HTML(html_imgs))

In [None]:
import re
import numpy as np
import matplotlib.pyplot as plt

results = []
for pred, info in zip(preds_env, meta_infos):
    suffix = info["suffix"]
    suffix_p = [int(x) for x in re.findall(r"<loc(\d{4})>", suffix)]
    decode_p = [int(x) for x in re.findall(r"<loc(\d{4})>", pred)]
    pred_diff = np.array(suffix_p) - np.array(decode_p)
    results.append(pred_diff)
results = np.array(results)

print(results.shape)
print(results)

# Eval on Env Simulation

In [10]:
%reload_ext autoreload
%autoreload 2
import json
from mani_skill.examples.run_env import Args, iterate_env

seed_fn = "/ihome/argusm/lang/ManiSkill/mani_skill/examples/seeds_valid.json"
with open(seed_fn, "r") as f_obj:
    seeds = json.load(f_obj)

env_args = Args()
env_args.env_id = "ClevrMove-v1"
env_args.render_mode = "rgb_array"
env_args.control_mode = "pd_joint_pos"
env_args.obs_mode = "rgb"
#env_args.shader = "rt"
env_args.action_encoder = ACTION_ENCODER
env_args.seed = seeds
env_args.run_mode = "script"
env_iter = iterate_env(env_args, vis=False, model=model)

In [None]:
images, meta_infos = [], []

for _ in tqdm(range(50)):
    image_before, json_dict, rnd_seed = next(env_iter)


In [None]:
text = """reward 0.89 seed 383329927
reward 0.90 seed 3324115916
reward 0.00 seed 2811363264
reward 0.79 seed 404488628
reward 0.59 seed 3160032368
reward 0.00 seed 3081541439
reward 0.83 seed 3376120482
reward 0.84 seed 3606691181
reward 0.66 seed 1934392872
reward 0.02 seed 784044718
reward 0.00 seed 2765379632
reward 0.17 seed 1934709049
reward 0.00 seed 395720737
reward 0.98 seed 1188673835
reward 0.78 seed 2712977935
reward 0.51 seed 709653492
reward 0.00 seed 4169116268
reward 0.18 seed 3343131662
reward 0.00 seed 3263743247
reward 0.06 seed 835966883
reward 0.68 seed 2004551446
reward 0.20 seed 2347498143
reward 0.00 seed 662668321
reward 0.44 seed 3192775210
screw plan failed
reward 0.26 seed 2933672914
reward 0.00 seed 4155422658
reward 0.95 seed 3889246497
reward 0.99 seed 813773289
reward 0.95 seed 1418182452
reward 0.00 seed 4038583929
reward 0.84 seed 3007615710
reward 0.00 seed 417585214
reward 0.83 seed 3297864500
screw plan failed
reward 0.47 seed 3574528627
reward 0.12 seed 3615039049
reward 0.82 seed 1238359775
reward 0.96 seed 1028892848
reward 0.06 seed 858599190
reward 0.89 seed 3457015366
reward 0.80 seed 2855512684
reward 0.41 seed 2025365462
reward 0.59 seed 1190228316
screw plan failed
reward 0.00 seed 3353205654
reward 0.00 seed 2171562312
reward 0.85 seed 2442724835
reward 0.00 seed 160040410
reward 0.92 seed 2870768860
reward 0.00 seed 2428259942
reward 0.89 seed 2377605420
reward 0.48 seed 2401776466"""

lines = text.split("\n")
rewards = []
for line in lines:
    if line.startswith("reward "):
        reward = float(line.split()[1])
        assert reward < 1.0 and reward >= 0
        rewards.append(reward)
print(len)
print(np.mean(rewards))