In [None]:
import os
from pathlib import Path
os.environ["CUDA_VISIBLE_DEVICES"]="2"
dataset_location = Path("/tmp/clevr-act-6-var-cam")


model_location = Path("/data/lmbraid19/argusm/models")
model_path = Path(str(model_location / Path(dataset_location).stem) + "_hf_af_lr5e5")/ "checkpoint-4687"
print(model_path)
assert(model_path.is_dir())

print_head = lambda file_path, n=1: print("\n".join([line.strip() for _, line in zip(range(n), open(file_path))]))

print("dataset_location", dataset_location)
print()
print_head(dataset_location / "dataset" /"_annotations.train.jsonl", 5)
print_head(dataset_location / "info.json")

In [2]:
import os
import random
from PIL import Image
from data_loader import JSONLDataset

train_dataset = JSONLDataset(
    jsonl_file_path=f"{dataset_location}/dataset/_annotations.train.jsonl",
    image_directory_path=f"{dataset_location}/dataset",
)
valid_dataset = JSONLDataset(
    jsonl_file_path=f"{dataset_location}/dataset/_annotations.valid.jsonl",
    image_directory_path=f"{dataset_location}/dataset",
)
test_dataset = JSONLDataset(
    jsonl_file_path=f"{dataset_location}/dataset/_annotations.valid.jsonl",
    image_directory_path=f"{dataset_location}/dataset",
)

In [None]:
import torch

from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', DEVICE)

MODEL_ID ="google/paligemma2-3b-pt-224"
processor = PaliGemmaProcessor.from_pretrained(MODEL_ID)

In [4]:
TORCH_DTYPE = torch.bfloat16

def augment_suffix(suffix):
    parts = suffix.split(' ; ')
    random.shuffle(parts)
    return ' ; '.join(parts)

def collate_fn(batch):
    images, labels = zip(*batch)

    paths = [label["image"] for label in labels]
    prefixes = ["<image>" + label["prefix"] for label in labels]
    suffixes = [augment_suffix(label["suffix"]) for label in labels]

    inputs = processor(
        text=prefixes,
        images=images,
        return_tensors="pt",
        suffix=suffixes,
        padding="longest"
    ).to(TORCH_DTYPE).to(DEVICE)

    return inputs

In [None]:
model = PaliGemmaForConditionalGeneration.from_pretrained(model_path, torch_dtype=TORCH_DTYPE, device_map="auto")

In [None]:
from tqdm.notebook import tqdm

def augment_suffix(suffix):
    parts = suffix.split(' ; ')
    random.shuffle(parts)
    return ' ; '.join(parts)

test_samples = 25
decode_dataset = [None, ]*test_samples
for i in tqdm(range(test_samples), total=test_samples):
    image, label = test_dataset[i]
    prefix = "<image>" + label["prefix"]
    suffix = label["suffix"]
    inputs = processor(
        text=prefix,
        images=image,
        return_tensors="pt",
        suffix = [augment_suffix(suffix)]
    ).to(TORCH_DTYPE).to(DEVICE)
    prefix_length = inputs["input_ids"].shape[-1]

    with torch.inference_mode():
        generation = model.generate(**inputs, max_new_tokens=12, do_sample=False, use_cache=False)
        generation = generation[0][prefix_length:]
        decoded = processor.decode(generation, skip_special_tokens=True)
    decode_dataset[i] = decoded

print(decode_dataset)

In [None]:
from matplotlib import pyplot as plt
import re
import numpy as np
results = []
suffix_nums = []
predi_nums = []
for i in range(test_samples):
    suffix = test_dataset[i][1]["suffix"]
    prefix = test_dataset[i][1]["prefix"]
    decoded = decode_dataset[i]
    #print("prefix:", prefix)
    #print(suffix)
    #print(decoded)
    #print()
    suffix_p = [int(x) for x in re.findall(r"<loc(\d{4})>", suffix)]
    try:
        decode_p = [int(x) for x in re.findall(r"<loc(\d{4})>", decoded)]
    except ValueError:
        continue
    if len(decode_p) != 12:
        continue
    pred_diff = np.array(suffix_p) - np.array(decode_p)
    results.append(pred_diff)
    suffix_nums.append(suffix_p)
    predi_nums.append(decode_p)
results = np.array(results)
suffix_nums = np.array(suffix_nums) / 1024 * 448
predi_nums = np.array(predi_nums) / 1024 * 448

print(results)
print(np.abs(results).mean(axis=0).round())

plot_histogram = False
if plot_histogram:
    fig, axes = plt.subplots(4, 3, figsize=(10, 12))  # 3 rows x 4 columns of histograms
    axes = axes.flatten()  # Flatten the 2D array of axes for easy iteration
    for i in range(12):
        axes[i].hist(results[:, i], bins=20, alpha=0.7, color='blue', edgecolor='black')
        axes[i].set_title(f'Histogram for Column {i + 1}')
        axes[i].set_xlabel('Value')
        axes[i].set_ylabel('Frequency')

    plt.tight_layout()  # Adjust layout for better spacing
    plt.show()

In [None]:
import json
import matplotlib.pyplot as plt
from typing import List
from utils_traj_tokens import decode_caption_xyzrotvec
from utils_trajectory import DummyCamera

def read_n_lines(file_path: str, n: int) -> List[str]:
    with open(file_path, 'r') as file:
        lines = [next(file).strip() for _ in range(n)]
    return lines

def plot_tokens(image, tokens, ax,color='green'):
    image_height, image_width, _ = image.shape
    camera_extrinsic = [[[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]]]
    camera_intrinsic = [[[410.029, 0.0, 224.0], [0.0, 410.029, 224.0], [0.0, 0.0, 1.0]]]
    camera = DummyCamera(camera_intrinsic, camera_extrinsic, width=image_width, height=image_height)
    curve_25d, _ =  decode_caption_xyzrotvec(tokens, camera)
    curve_2d = curve_25d[:, :2]
    ax.plot(curve_2d[:, 0], curve_2d[:, 1],'.--', color=color)


lines = read_n_lines(f"{dataset_location}/dataset/_annotations.train.jsonl", 25)
n_rows, n_cols = 5, 5
fig, axes = plt.subplots(n_rows, n_cols, figsize=(10, 10))
for i, line in enumerate(lines):
    ax = axes[i // n_cols][i % n_cols]
    data = json.loads(line)
    image_path = os.path.join(dataset_location, "dataset", data.get('image'))
    image = np.asarray(test_dataset[i][0])
    ax.imshow(image)
    gt = test_dataset[i][1]["suffix"]
    pd = decode_dataset[i]
    line = plot_tokens(image, gt, ax, color='green')
    line = plot_tokens(image, pd, ax, color='lime')
    ax.axis('off')
plt.tight_layout()
plt.show()

# Simulation Eval

In [None]:
import numpy as np
class ModelWrapper:
    def __init__(self, transformers_model=model):
        self.model = transformers_model
    
    def make_predictions(self, image, prefix):
        prefix = "<image>" + prefix
        image = Image.fromarray(image)
        inputs = processor(text=prefix,
                           images=image,
                           return_tensors="pt").to(TORCH_DTYPE).to(DEVICE)
        prefix_length = inputs["input_ids"].shape[-1]
        with torch.inference_mode():
            generation = model.generate(**inputs, max_new_tokens=12, do_sample=False, use_cache=False)
            generation = generation[0][prefix_length:]
            decoded = processor.decode(generation, skip_special_tokens=True)
        return None, None, None, decoded
model_wrapped = ModelWrapper(model)

i = 0
image, label = test_dataset[i]
print(image)
print(label["prefix"])
res = model_wrapped.make_predictions(np.asarray(image), label["prefix"])
print(res)


In [10]:
%reload_ext autoreload
%autoreload 2
import json
from PIL import Image
from mani_skill.examples.run_env import Args, iterate_env, save_dataset

        
parsed_args = Args()
parsed_args.env_id = "ClevrMove-v1"
parsed_args.render_mode = "rgb_array"
parsed_args.control_mode = "pd_joint_pos"

env_iter = iterate_env(parsed_args, vis=False, model=model_wrapped)

In [None]:
for i in range(25):
    next(env_iter)