In [None]:
%load_ext autoreload
%autoreload 2

import os
from pathlib import Path
os.environ["CUDA_VISIBLE_DEVICES"]="2"
dataset_location = "/tmp/clevr-act-6-var-cam2"
dataset_location = "/data/lmbraid19/argusm/datasets/clevr-real-block-v1"

dataset_location = Path(dataset_location)


model_location = Path("/data/lmbraid19/argusm/models/")
model_path = model_location / "clevr-act-6-var-cam2_hf_af_lr3e5" / "checkpoint-4687"

print_head = lambda file_path, n=1: print("\n".join([line.strip() for _, line in zip(range(n), open(file_path))]))

print("dataset_location", dataset_location)
print()
print_head(dataset_location /"_annotations.train.jsonl", 5)
print_head(dataset_location / "info.json")

In [2]:
#!pip install --upgrade git+https://github.com/huggingface/transformers.git
#!pip install transformers==4.47.1

In [49]:
import random
from data_loader_jsonl import JSONLDataset

train_dataset = JSONLDataset(
    jsonl_file_path=f"{dataset_location}/_annotations.train.jsonl",
    image_directory_path=f"{dataset_location}/dataset",
)

# test_dataset = JSONLDataset(
#     jsonl_file_path=f"{dataset_location}/_annotations.valid.jsonl",
#     image_directory_path=f"{dataset_location}/dataset",
# )
test_dataset = JSONLDataset(
    jsonl_file_path=f"{dataset_location}/_annotations.valid.jsonl",
    image_directory_path=f"{dataset_location}/dataset",
    clean_prompt=True
)

In [None]:
import torch

from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', DEVICE)

MODEL_ID ="google/paligemma2-3b-pt-224"
processor = PaliGemmaProcessor.from_pretrained(MODEL_ID)

In [6]:
TORCH_DTYPE = torch.bfloat16

def augment_suffix(suffix):
    parts = suffix.split(' ; ')
    random.shuffle(parts)
    return ' ; '.join(parts)

def collate_fn(batch):
    images, labels = zip(*batch)

    paths = [label["image"] for label in labels]
    prefixes = ["<image>" + label["prefix"] for label in labels]
    suffixes = [augment_suffix(label["suffix"]) for label in labels]

    inputs = processor(
        text=prefixes,
        images=images,
        return_tensors="pt",
        suffix=suffixes,
        padding="longest"
    ).to(TORCH_DTYPE).to(DEVICE)

    return inputs

In [None]:
model = PaliGemmaForConditionalGeneration.from_pretrained(model_path, torch_dtype=TORCH_DTYPE, device_map="auto")

In [None]:

from tqdm.notebook import tqdm
from utils_vis import render_example
from utils_trajectory import DummyCamera

test_samples = len(test_dataset)
decode_dataset = [None, ]*test_samples
pred_list = []
html_imgs = ""

for i in tqdm(range(test_samples), total=test_samples):
    image, sample = test_dataset[i]
    prefix = "<image>" + sample["prefix"]
    suffix = sample["suffix"]
    inputs = processor(
        text=prefix,
        images=image,
        return_tensors="pt",
    ).to(TORCH_DTYPE).to(DEVICE)

    prefix_length = inputs["input_ids"].shape[-1]

    image_height, image_width = image.height, image.width
    camera_extrinsic = [[[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]]]
    camera_intrinsic = [[[410.029, 0.0, 224.0], [0.0, 410.029, 224.0], [0.0, 0.0, 1.0]]]
    camera = DummyCamera(camera_intrinsic, camera_extrinsic, width=image_width, height=image_height)

    with torch.inference_mode():
        generation = model.generate(**inputs, max_new_tokens=12, do_sample=False, use_cache=False)
        generation = generation[0][prefix_length:]
        decoded = processor.decode(generation, skip_special_tokens=True)
    pred_list.append(decoded)

    html_img = render_example(image, label=sample["suffix"], prediction=decoded, text=prefix, camera=camera)
    html_imgs += html_img
    

plot_images = True
if plot_images:
    from IPython.display import display, HTML
    display(HTML(html_imgs))

In [None]:
import re
import numpy as np
import matplotlib.pyplot as plt

results = []
for i, pred in enumerate(pred_list):
    suffix = test_dataset[i][1]["suffix"]
    suffix_p = [int(x) for x in re.findall(r"<(?:loc|seg)(\d+)>", suffix)]
    decode_p = [int(x) for x in re.findall(r"<(?:loc|seg)(\d+)>", pred)]
    if len(decode_p) != 12:
        continue
    pred_diff = np.array(suffix_p) - np.array(decode_p)
    results.append(pred_diff)
results = np.array(results)

results[:, 0:2] = results[:, 0:2]/1024*224 
results[:, 6:8] = results[:, 6:8]/1024*224 


keypoint = ["object","container"]
action_labels = ["x","y","depth","r1","r2","r3"]*2
units         = ["px","px","cm","rad","rad","rad"]*2
plot_hist=True
if plot_hist:
    fig, axes = plt.subplots(2, 3, figsize=(12, 12*2/3))  # 3 rows x 4 columns of histograms
    axes = axes.flatten()  # Flatten the 2D array of axes for easy iteration
    for i in range(2):
        for j in range(6):
            axes[j].hist(results[:, i*6+j], bins=20, alpha=0.7,  edgecolor='black', label=keypoint[i])
            axes[j].set_title(f'Hist. {action_labels[j]} err.')
            axes[j].set_xlabel(f"{action_labels[j]} err. [{units[j]}]")
            axes[j].set_ylabel('Frequency')
            axes[j].legend()
    plt.tight_layout()  # Adjust layout for better spacing
    plt.show()

print("Valid Samples:", len(results), "L1:",np.mean(np.abs(results)))


In [None]:
# Valid Samples: 157 L1: 35.674097664543524 
# Valid Samples: 160 L1: 34.34583333333333  (prompt cleaning)

In [None]:
from matplotlib import pyplot as plt
import re
import numpy as np
results = []
suffix_nums = []
predi_nums = []
for i in range(test_samples):
    suffix = test_dataset[i][1]["suffix"]
    prefix = test_dataset[i][1]["prefix"]
    decoded = decode_dataset[i]
    suffix_p = [int(x) for x in re.findall(r"<(?:loc|seg)(\d+)>", suffix)]
    try:
        decode_p = [int(x) for x in re.findall(r"<(?:loc|seg)(\d+)>", decoded)]
    except ValueError:
        continue
    if len(decode_p) != 12:
        continue
    pred_diff = np.array(suffix_p) - np.array(decode_p)
    results.append(pred_diff)
    suffix_nums.append(suffix_p)
    predi_nums.append(decode_p)

results = np.array(results)
#suffix_nums = np.array(suffix_nums) / 1024 * 448
#predi_nums = np.array(predi_nums) / 1024 * 448

print(results)
print(np.abs(results).mean(axis=0).round())

plot_histogram = True
if plot_histogram:
    fig, axes = plt.subplots(4, 3, figsize=(10, 12))  # 3 rows x 4 columns of histograms
    axes = axes.flatten()  # Flatten the 2D array of axes for easy iteration
    for i in range(12):
        axes[i].hist(results[:, i], bins=20, alpha=0.7, color='blue', edgecolor='black')
        axes[i].set_title(f'Histogram for Column {i + 1}')
        axes[i].set_xlabel('Value')
        axes[i].set_ylabel('Frequency')

    plt.tight_layout()  # Adjust layout for better spacing
    plt.show()

In [None]:
test_dataset[0]

# Simulation Eval

In [None]:
# import numpy as np
# class ModelWrapper:
#     def __init__(self, transformers_model=model):
#         self.model = transformers_model
    
#     def make_predictions(self, image, prefix):
#         prefix = "<image>" + prefix
#         image = Image.fromarray(image)
#         inputs = processor(text=prefix,
#                            images=image,
#                            return_tensors="pt").to(TORCH_DTYPE).to(DEVICE)
#         prefix_length = inputs["input_ids"].shape[-1]
#         with torch.inference_mode():
#             generation = model.generate(**inputs, max_new_tokens=12, do_sample=False, use_cache=False)
#             generation = generation[0][prefix_length:]
#             decoded = processor.decode(generation, skip_special_tokens=True)
#         return None, None, None, decoded
# model_wrapped = ModelWrapper(model)

# i = 0
# image, label = test_dataset[i]
# print(image)
# print(label["prefix"])
# res = model_wrapped.make_predictions(np.asarray(image), label["prefix"])
# print(res)


In [None]:
# %reload_ext autoreload
# %autoreload 2
# import json
# from PIL import Image
# from mani_skill.examples.run_env import Args, iterate_env, save_dataset

        
# parsed_args = Args()
# parsed_args.env_id = "ClevrMove-v1"
# parsed_args.render_mode = "rgb_array"
# parsed_args.control_mode = "pd_joint_pos"

# env_iter = iterate_env(parsed_args, vis=False, model=model_wrapped)

In [None]:
# for i in range(25):
#     next(env_iter)