In [1]:
# This file is modified from https://github.com/haotian-liu/LLaVA/

import argparse
import re
from io import BytesIO
import os, os.path as osp

import requests
import torch
from PIL import Image

def load_image(image_file):
    if image_file.startswith("http") or image_file.startswith("https"):
        print("downloading image from url", args.video_file)
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert("RGB")
    else:
        image = Image.open(image_file).convert("RGB")
    return image


def load_images(image_files):
    out = []
    for image_file in image_files:
        image = load_image(image_file)
        out.append(image)
    return out

def center_crop_and_resize(image: Image.Image, crop_size: int, resize_size: int) -> Image.Image:
    """Center crops an image to `crop_size` and resizes it to `resize_size`."""
    width, height = image.size
    left = (width - crop_size) // 2
    top = (height - crop_size) // 2
    right = left + crop_size
    bottom = top + crop_size
    cropped = image.crop((left, top, right, bottom))
    return cropped.resize((resize_size, resize_size), Image.LANCZOS)


In [62]:
import cv2

def add_path_2d_to_img(
    image, points, line_size=1, circle_size=0, plot_lines=True, color="red"
):
    img_out = image.copy()

    if np.all(points <= 1):
        points = points * image.shape[1]

    points = points.astype(int)
    path_len = len(points)

    # Generate gradient from dark red to bright red
    if color == "red":
        color_choice = np.linspace(25, 255, path_len).astype(int)
        colors = [tuple(int(r) for r in (r_val, 0, 0)) for r_val in color_choice]
    # Generate gradient from dark blue to bright blue
    elif color == "blue":
        color_choice = np.linspace(25, 255, path_len).astype(int)
        colors = [tuple(int(r) for r in (0, 0, r_val)) for r_val in color_choice]

    for i in range(path_len - 1):
        color = colors[i]
        if plot_lines:
            cv2.line(img_out, tuple(points[i]), tuple(points[i + 1]), color, line_size)
        if circle_size > 0:
            cv2.circle(
                img_out,
                tuple(points[i]),
                max(1, circle_size),
                color,
                -1,
                lineType=cv2.LINE_AA,
            )

    # Draw last point
    if circle_size > 0:
        cv2.circle(
            img_out,
            tuple(points[-1]),
            max(1, circle_size),
            colors[-1],
            -1,
            lineType=cv2.LINE_AA,
        )

    return img_out

In [2]:
# sft example
model_name = "vila_3b_no_hamster"
model_path = "/lustre/fs12/portfolios/nvr/users/mmemmel/projects/vila/checkpoints/finetuned/vila/"
base_name = None
base_path = None

# LoRA example -> set base_name (and path)
model_name = "vila_3b_sft_all_lora_amazon"
model_path = "/lustre/fs12/portfolios/nvr/users/mmemmel/projects/vila/checkpoints/finetuned/vila/"
base_name = "Efficient-Large-Model/VILA1.5-3b"
base_path = None

############################################

model_name = "vila_3b_all_path_mask"
model_path = "/lustre/fs12/portfolios/nvr/users/mmemmel/projects/vila/checkpoints/finetuned/nvila/"
base_name = None # "Efficient-Large-Model/VILA1.5-3b"
base_path = None
prompt_type = "path_mask" # "mask", "path_mask"

model_name = "nvila_lite_2b_oxe_robopoint"
model_path = "/lustre/fs12/portfolios/nvr/users/mmemmel/projects/vila/checkpoints/finetuned/nvila/"
base_name = None # "Efficient-Large-Model/VILA1.5-3b"
base_path = None
prompt_type = "path" # "mask", "path_mask"

args_dict = {
    
    # replace later
    # "query": None,
    # "image_file": None,

    # "video_file": None,
    "model_path": model_name if model_path is None else os.path.join(model_path, model_name),
    "conv_mode": "vicuna_v1", # "llava_v0", # "vicuna_v1",
    
    "model_base": base_name if base_path is None else os.path.join(base_path, base_name),
    # "num_video_frames": 6,
    # "sep": ",",
    "temperature": 0.2,
    "top_p": None,
    "num_beams": 1,
    "max_new_tokens": 1024,
}
args = argparse.Namespace(**args_dict)

In [None]:
version = "nvila" if "nvila" in model_name else "vila"
# load_model(version, args)
if version == "vila":
    from llava.constants import IMAGE_TOKEN_INDEX
    from llava.model.builder import load_pretrained_model
elif version == "nvila":
    import llava

from llava.mm_utils import (get_model_name_from_path, process_images, tokenizer_image_token)
from llava.conversation import conv_templates
from llava.utils import disable_torch_init

global tokenizer, model, image_processor, context_len

# standard model
if args.model_base is None:
    disable_torch_init()
    model_name = get_model_name_from_path(args.model_path)
    if version == "nvila":
        model = llava.load(args.model_path)
    elif version == "vila":
        tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, model_name, args.model_base)
    print("standard", args.model_path, args.model_base)

# LoRA
else:
    disable_torch_init()
    
    from llava.model.builder import load_pretrained_model
    from peft import PeftModel
    tokenizer, base_model, image_processor, context_len = load_pretrained_model(
        args.model_base, get_model_name_from_path(args.model_base), model_base=None
    )

    model = PeftModel.from_pretrained(base_model, args.model_path)

    model = model.merge_and_unload()
    print("LoRA", args.model_path, args.model_base)


  from .autonotebook import tqdm as notebook_tqdm


[2025-05-19 18:54:25,550] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)


You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


standard /lustre/fs12/portfolios/nvr/users/mmemmel/projects/vila/checkpoints/finetuned/nvila/nvila_lite_2b_oxe_robopoint None


In [42]:
def inference_nvila(message, args):
    outputs = model.generate_content(message)
    return outputs.strip()

def inference_vila(message, args):
    quest = message[0]
    images = message[1:]

    # conversation template
    conv = conv_templates[args.conv_mode].copy()
    conv.append_message(conv.roles[0], quest)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()

    # preprocessing
    images_tensor = process_images(images, image_processor, model.config).to(model.device, dtype=torch.float16)
    input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()

    # inference
    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=[
                images_tensor,
            ],
            do_sample=True if args.temperature > 0 else False,
            temperature=args.temperature,
            top_p=args.top_p,
            num_beams=args.num_beams,
            max_new_tokens=args.max_new_tokens,
            use_cache=True,
        )

    # postprocess
    outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
    return outputs.strip()

In [4]:
from vila_utils.utils.prompts import get_prompt

quest = "close the drawer"
image = load_image("/lustre/fs12/portfolios/nvr/users/mmemmel/projects/vila/example_imgs/drawer_scene.png")

quest = "put the carrot in the sink"
image = load_image("/lustre/fs12/portfolios/nvr/users/mmemmel/projects/vila/example_imgs/widowx.png")


image = center_crop_and_resize(image, min(image.size), 384)
message = [get_prompt(quest, prompt_type), image]
if version == "vila":
    outputs = inference_vila(message, args)
elif version == "nvila":
    outputs = inference_nvila(message, args)
outputs

In the image, please execute the command described in <quest>put the carrot in the sink</quest>.
Provide a sequence of points denoting the trajectory of a robot gripper to achieve the goal.
Format your answer as a list of tuples enclosed by <ans> and </ans> tags. For example:
<ans>[(0.25, 0.32), (0.32, 0.17), (0.13, 0.24), (0.74, 0.21), ...]</ans>
The tuple denotes point x and y location of the end effector of the gripper in the image.
The coordinates should be integers ranging between 0.0 and 1.0, indicating the relative locations of the points in the image.
'. Removed.[0m


'<ans>[(0.59, 0.2), (0.52, 0.2), (0.57, 0.25), (0.57, 0.18), (0.5, 0.21), (0.54, 0.32), (0.57, 0.25)]</ans>'

In [49]:
import numpy as np
from vila_utils.utils.encode import scale_path
from vila_utils.utils.decode import get_path_from_answer, add_mask_2d_to_img

def add_answer_to_img(img, answer, prompt_type, color="red"):
    
    out = get_path_from_answer(answer, prompt_type)

    h, w, c = img.shape

    # scale path to image size
    scaled_mask = None
    if "mask" in prompt_type:
        min_in, max_in = np.zeros(2), np.array([w,h])
        min_out, max_out = np.zeros(2), np.ones(2)
        mask = out[1] if len(out) == 2 else out
        scaled_mask = scale_path(mask, min_in=min_out, max_in=max_out, min_out=min_in, max_out=max_in)

    path = None
    scaled_path = None
    if "path" in prompt_type:
        min_in, max_in = np.zeros(2), np.array([w,h])
        min_out, max_out = np.zeros(2), np.ones(2)
        path = out[0] if len(out) == 2 else out
        scaled_path = scale_path(path, min_in=min_out, max_in=max_out, min_out=min_in, max_out=max_in)

    if "mask" in prompt_type and scaled_mask is not None:
        img = add_mask_2d_to_img(img, scaled_mask, mask_pixels=int(h*0.15))

    if "path" in prompt_type and scaled_path is not None:
        img = add_path_2d_to_img(img, scaled_path, line_size=3, circle_size=0, plot_lines=True, color=color)

    return img, path

In [5]:
import json
import cv2
from tqdm import tqdm
from fastdtw import fastdtw
from scipy.spatial.distance import euclidean

out_path = "/lustre/fs12/portfolios/nvr/users/mmemmel/projects/vila/results"
save_path = os.path.join(out_path, model_name)
os.makedirs(save_path, exist_ok=True)

json_path = "/lustre/fs12/portfolios/nvr/users/mmemmel/projects/vila/data/oxe_processed_path_subtraj/bridge_v2_primary_path/train_bridge_v2_primary_path_conv.json"
with open(json_path, "r") as f:
    data = json.load(f)

data = data[:10]

logs = {
    "img_path": [],
    "path_pred": [],
    "path_gt": [],
    "dtw_euclidian": [],
    "first_l2": [],
    "last_l2": [],
}
for qa in tqdm(data):

    img_path = qa["image"]
    answer_gt = qa["conversations"][1]["value"]
    prompt = qa["conversations"][0]["value"]
    image = load_image(img_path)

    image = center_crop_and_resize(image, min(image.size), 384)
    message = [prompt, image]
    if version == "vila":
        answer_pred = inference_vila(message, args)
    elif version == "nvila":
        answer_pred = inference_nvila(message, args)

    image = np.array(image)
    image, path_pred = add_answer_to_img(image, answer_pred, "path", color="red")
    image, path_gt = add_answer_to_img(image, answer_gt, "path", color="blue")

    # compute metrics
    dtw_euclidian, _ = fastdtw(path_gt, path_pred, dist=euclidean)
    first_l2 = np.linalg.norm(path_gt[0] - path_pred[0])
    last_l2 = np.linalg.norm(path_gt[-1] - path_pred[-1])

    # add metrics to image
    text = f"DTW {dtw_euclidian:.2f} first {first_l2:.2f} last {last_l2:.2f}"
    (text_width, text_height), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.75, 2)
    cv2.rectangle(image, (8, image.shape[0]-40), (8 + text_width + 4, image.shape[0]-10), (255,255,255), -1)
    cv2.putText(image, text, (10, image.shape[0]-20), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (0,0,0), 2)

    # save image
    img_path = os.path.join(save_path, f"{img_path.split('/')[-1]}.png")
    Image.fromarray(image).save(img_path)

    logs["img_path"].append(img_path)
    logs["path_pred"].append(path_pred)
    logs["path_gt"].append(path_gt)
    logs["dtw_euclidian"].append(dtw_euclidian)
    logs["first_l2"].append(first_l2)
    logs["last_l2"].append(last_l2)

# dump logs to json
with open(os.path.join(save_path, "results.json"), "w") as f:
    results = []
    for i in range(len(logs["img_path"])):
        results.append({
            "img_path": logs["img_path"][i],
            "path_pred": logs["path_pred"][i].tolist(),
            "path_gt": logs["path_gt"][i].tolist(),
            "dtw_euclidian": logs["dtw_euclidian"][i],
            "first_l2": logs["first_l2"][i],
            "last_l2": logs["last_l2"][i],
        })
    json.dump(results, f)

# dump summary to json
summary = {
    "dtw_euclidian": np.mean(logs["dtw_euclidian"]),
    "first_l2": np.mean(logs["first_l2"]),
    "last_l2": np.mean(logs["last_l2"]),
}
with open(os.path.join(save_path, "summary.json"), "w") as f:
    json.dump(summary, f)


In the image, please execute the command described in <quest>put small spoon from basket to tray</quest>.
Provide a sequence of points denoting the trajectory of a robot gripper to achieve the goal.
Format your answer as a list of tuples enclosed by <ans> and </ans> tags. For example:
<ans>[(0.25, 0.32), (0.32, 0.17), (0.13, 0.24), (0.74, 0.21), ...]</ans>
The tuple denotes point x and y location of the end effector of the gripper in the image.
The coordinates should be integers ranging between 0.0 and 1.0, indicating the relative locations of the points in the image.
'. Removed.[0m


In the image, please execute the command described in <quest>put the blue cube on the right side of the table on top of the rectangular block</quest>.
Provide a sequence of points denoting the trajectory of a robot gripper to achieve the goal.
Format your answer as a list of tuples enclosed by <ans> and </ans> tags. For example:
<ans>[(0.25, 0.32), (0.32, 0.17), (0.13, 0.24), (0.74, 0.21), ...]</ans>
The tuple denotes point x and y location of the end effector of the gripper in the image.
The coordinates should be integers ranging between 0.0 and 1.0, indicating the relative locations of the points in the image.
'. Removed.[0m
In the image, please execute the command described in <quest>put the red object into the pot</quest>.
Provide a sequence of points denoting the trajectory of a robot gripper to achieve the goal.
Format your answer as a list of tuples enclosed by <ans> and </ans> tags. For example:
<ans>[(0.25, 0.32), (0.32, 0.17), (0.13, 0.24), (0.74, 0.21), ...]</ans>
The tuple