In [None]:
%load_ext autoreload
%autoreload 2
import os
import json
import getpass
from pathlib import Path
from cvla.utils_traj_tokens import getActionEncInstance

os.environ["CUDA_VISIBLE_DEVICES"]="0"
os.environ["TOKENIZERS_PARALLELISM"] = "false"


username = getpass.getuser()
if username == "argusm":
    dataset_location = Path("/data/lmbraid19/argusm/datasets")
    model_location = Path("/data/lmbraid19/argusm/models")
else:
    dataset_location = Path("/home/houman/cVLA_test/")
    model_location = Path("/home/houman/cVLA_test/models")


#dataset_location = dataset_location / "cvla-droid-1of5c-v1"

#model_path = model_location / "clevr-act-7-depth_text_aug" / "checkpoint-4687"
#model_path = model_location / "clevr-act-7-depth_e512s" / "checkpoint-4687"
#model_path = model_location / "mix30obj_text_debug" / "checkpoint-4687"
#model_path = model_location / "mix30obj_depth" / "checkpoint-4687"
model_path = model_location / "mix30obj_mask" / "checkpoint-4687"


# some processing
info_file = model_path.parent / "cvla_info.json"
try:
    with open(info_file, "r") as f:
        model_info = json.load(f)
except FileNotFoundError:
    model_info = None

if model_info is not None:
    action_encoder = model_info["action_encoder"]
    return_depth = model_info["return_depth"]
else:
    action_encoder = "xyzrotvec-cam-1024xy"
    return_depth = False
    if "_depth" in str(model_path):
        return_depth = True

enc_model = getActionEncInstance(action_encoder)
dataset_name = dataset_location.name
model_name = model_path.parent.name

print()
print("dataset:".ljust(10), dataset_name, dataset_location)
if model_path.is_dir():
    print("model:".ljust(10), model_name,"\t", model_path)
    print("encoder".ljust(10), action_encoder)
    print("depth:".ljust(10), return_depth)

In [None]:
import numpy as np
from tqdm.notebook import tqdm
from cvla.data_loader_images import ImageFolderDataset
from cvla.utils_trajectory import DummyCamera
from torchvision.transforms import v2

test_dataset_no_crop = ImageFolderDataset(dataset_location / "cvla-imbit-1", startswith="rgb_")
image_width_no_crop, image_height_no_crop = test_dataset_no_crop[0][0].size
print("original image size", image_width_no_crop, image_height_no_crop)

camera_extrinsic = [[[1, 0, 0.0, 0.0], [0, 1, 0, 0], [0, 0, 1, 0]]]
camera_intrinsic = [[[260.78692626953125, 0.0, 322.3820495605469],[ 0.0, 260.78692626953125, 180.76370239257812],[0.0, 0.0, 1.0]]]
camera_no_crop = DummyCamera(camera_intrinsic, camera_extrinsic, width=image_width_no_crop, height=image_height_no_crop)

center_crop = v2.CenterCrop(360)
test_dataset = ImageFolderDataset(dataset_location / "cvla-imbit-1", startswith="rgb_", transform=center_crop)
image_width, image_height = test_dataset[0][0].size
print("new image size", image_width, image_height)

# compute intrinsic matrix for cropped camera
dx = int((image_width_no_crop - image_width) / 2)
dy = int((image_height_no_crop - image_height) / 2)
K = np.array(camera_intrinsic[0])  # shape (3,3)
K_cropped = K.copy()
K_cropped[0,2] -= dx
K_cropped[1,2] -= dy
camera = DummyCamera([K_cropped.tolist()], camera_extrinsic, width=image_width, height=image_height)

test_dataset.labels = [dict(prefix="", suffix="")]*len(test_dataset)

In [None]:
from IPython.display import display, HTML
from cvla.utils_vis import render_example
import torch

base_to_tcp_pos = torch.tensor([[[-0.7487, -0.3278,  0.7750]]])
base_to_tcp_orn = torch.tensor([[[ 0,  0, 0, 1]]])  # quaternion w, x, y, z 
_, _, robot_state = enc_model.encode_trajectory(base_to_tcp_pos, base_to_tcp_orn, camera)

print("robot_state", robot_state)

def get_image(images):
    if isinstance(images, (list, tuple)):
        return images[-1]
    else:
        return images
    
def get_depth(images):
    if isinstance(images, (list, tuple)):
        return images[0]
    else:
        return None
    
print(len(test_dataset))
num_samples = min(3*1, len(test_dataset))
html_imgs = ""
for i in tqdm(range(num_samples)):
    images, sample = test_dataset[i]
    image = get_depth(images) if return_depth else get_image(images)
    html_imgs += render_example(image, text="{i} text"+robot_state, camera=camera, enc=enc_model)
    
display(HTML(html_imgs))

In [None]:

import torch
from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
TORCH_DTYPE = torch.bfloat16
print('Using device:', DEVICE)

MODEL_ID ="google/paligemma2-3b-pt-224"
processor = PaliGemmaProcessor.from_pretrained(MODEL_ID)
print("loaded processor.")
model = PaliGemmaForConditionalGeneration.from_pretrained(model_path, torch_dtype=TORCH_DTYPE, device_map="auto")

def collate_fn(batch):
    images, labels = zip(*batch)
    prefixes = ["<image>" + label["prefix"] for label in labels]
    inputs = processor(
        text=prefixes,
        images=images,
        return_tensors="pt",
        padding="longest"
    )
    return inputs

# Own Data Experiments

In [None]:
from tqdm.notebook import tqdm
from math import ceil
from cvla.utils_vis import render_example

def collate_fn(batch):
    images, labels = zip(*batch)
    prefixes = ["<image>" + label["prefix"] for label in labels]
    inputs = processor(
        text=prefixes,
        images=images,
        return_tensors="pt",
        padding="longest"
    ).to(DEVICE)
    return inputs

base_to_tcp_pos = torch.tensor([[[-0.7487, -0.3278,  0.7750]]])
base_to_tcp_orn = torch.tensor([[[ 1,  0, 0, 0]]])  # quaternion w, x, y, z 
test_dataset.labels = [dict(prefix="", suffix="")]*len(test_dataset)
_, _, robot_state = enc_model.encode_trajectory(base_to_tcp_pos, base_to_tcp_orn, camera)

x = "yellow cup"
y = "plate"
action_text = "put the {} inside the {}".format(x, y)
prefix = action_text + " " + robot_state
test_dataset.labels = [dict(prefix=prefix, suffix="")]*len(test_dataset)

eval_batch_size = 8
test_samples = len(test_dataset)
pred_list = []
html_imgs = ""
for start_idx in tqdm(range(0, test_samples, eval_batch_size), total=ceil(test_samples / eval_batch_size)):
    batch = [test_dataset[i] for i in range(start_idx, min(start_idx + eval_batch_size, test_samples))]
    #for e in batch:
    #    augment_sample(e)
    inputs = collate_fn(batch)
    prefix_length = inputs["input_ids"].shape[-1]    

    with torch.inference_mode():
        generation = model.generate(**inputs, max_new_tokens=12, do_sample=False, use_cache=False)
        decoded = [processor.decode(x, skip_special_tokens=True) for x in generation[:, prefix_length:]]
    pred_list.extend(decoded)
    
    for batch_entry, decoded_str in zip(batch, decoded):
        if return_depth:
            (depth, image), sample = batch_entry
        else:
            image, sample = batch_entry
        
        html_img = render_example(image, text=sample["prefix"], prediction=decoded_str, camera=camera, enc=enc_model)
        html_imgs += html_img

plot_images = True
if plot_images:
    from IPython.display import display, HTML
    display(HTML(html_imgs))

In [None]:
import torch
dec_result = enc_model.decode_trajectory(robot_state, camera)
base_to_tcp_pos = torch.tensor([[[-0.7487, -0.3278,  0.7750]]])
#tcp_orn_in_camera = torch.tensor([[[ 0.5494,  0.4075, -0.5746,  0.4493]]])  # quaternion w, x, y, z 
base_to_tcp_orn = torch.tensor([[[ 0,  0, 0, 1]]])  # quaternion w, x, y, z 
_, _, robot_state_str = enc_model.encode_trajectory(base_to_tcp_pos, base_to_tcp_orn, camera)
text = " "+robot_state_str

html_img = render_example(np.ones_like(image)*255, text=text, label=None, prediction=None, camera=camera, enc=enc_model)
display(HTML(html_img))

print(robot_state_str)