In [None]:
!pip install matplotlib

In [None]:

import base64
import os
import typing as t

import numpy as np
import vertexai
from jinja2 import Environment, FileSystemLoader
from litellm import completion, completion_cost
from litellm.utils import ModelResponse
from PIL import Image
from pydantic import BaseModel, Field
from vertexai.generative_models import GenerativeModel, Part

from ares.configs.base import pydantic_to_example_dict, pydantic_to_field_instructions
from ares.utils.image_utils import (
    choose_and_preprocess_frames,
    encode_image,
    split_video_to_frames,
)
from ares.extras.pi_demo_utils import PI_DEMO_PATH, PI_DEMO_TASKS

class Nested(BaseModel):
    the_nested_attr: str
    other_nested_attr: str

class RolloutDescription(BaseModel):
    nested: Nested
    robot_setup: t.Literal["one arm", "two arms"]
    environment: t.Literal["floor", "table", "other"]
    lighting_conditions: t.Literal["normal", "dim", "bright"]
    # task: str = Field(max_length=50, description="Short task description")
    description: str = Field(
        max_length=1000,
        description="A detailed description of the robot's actions over the course of the images. Don't include fluff like 'Let's describe...'. Just describe the episode.",
    )
    success_str: str = Field(
        max_length=1000,
        description="""
    A detailed description of whether or not the robot successfully completes the task. 
    Be very specific and critical about whether or not the robot has met the intended goal state of the task and include lots of details pertaining to partial success.
    In order to be successful, the robot must have completed the task in a way that is consistent with the task description. Any error or deviation from the task description is a failure.
    """.strip(),
    )
    success_score: float = Field(
        description="A float score between 0 and 1, representing the success of the task. A score of 0 means the task was not completed at all, and a score of 1 means the task was completed absolutely perfectly.",
    )

# Build instruction string dynamically from model fields
field_instructions = pydantic_to_field_instructions(RolloutDescription)

# Build instructions string, will go into prompt jinja2 template
instructions = """
Look at the images provided and consider the following task description:
TASK: {task}

Create a response to the task by answering the following questions:
{field_instructions}
""".strip()

# Build example response dict dynamically from model fields
response_format = f"""
For the response, first respond with about 500 words that describe the entire video, focusing on the robot's actions and the task.
Then, respond with a python dict, e.g. {pydantic_to_example_dict(RolloutDescription)} that fulfills the above specifications.
""".strip()


In [None]:
import os
from ares.image_utils import split_video_to_frames, choose_and_preprocess_frames


def get_frames(task, success, n_frames: t.Optional[int] = None):
    video_path = os.path.join(
        PI_DEMO_PATH, f"{PI_DEMO_TASKS[task]['filename_prefix']}_{success}.mp4"
    )
    all_frames = split_video_to_frames(video_path)
    print(f"split video into {len(all_frames)} frames")
    specified_frames: list[int] | None = None
    frames = choose_and_preprocess_frames(
        all_frames, n_frames if n_frames else len(all_frames), specified_frames=specified_frames, resize=(512, 512)
    )
    return frames


In [None]:

# os.environ["LITELLM_LOG"] = "DEBUG"
# litellm.set_verbose=True
# task = "Eggs in carton"
# task = "Grocery Bagging"
# task = "Toast out of toaster"
# task = "Towel fold"
# task = "Stack bowls"
# task = "Tupperware in microwave"
# task = "Items in drawer"
# task = "Laundry fold (shirts)"
# task = "Laundry fold (shorts)"
task = "Paper towel in holder"
# task = "Food in to go box"
success = "fail"
# success = "success"


# provider = "gemini"
# name = f"{provider}/gemini-1.5-flash"

provider = "openai"
name = f"{provider}/gpt-4o"
# name = f"{provider}/gpt-4o-mini"
# name = f"{provider}/gpt-4-turbo"

# provider = "anthropic"
# name = f"{provider}/claude-3-5-sonnet-20240620"

from ares.models.base import VLM

# vlm = GeminiVideoVLM("gemini", "gemini-1.5-flash", dict())
vlm = VLM(provider=provider, name=name)


In [None]:
all_frames = get_frames(task, success, n_frames=None)

In [None]:
# diffs = [np.mean(np.abs(np.array(all_frames[i]) - np.array(all_frames[i+1]))) for i in range(len(all_frames) -1)]

In [None]:
# from transformers import CLIPProcessor, CLIPModel
# model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
# processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

In [None]:
# import torch
# def get_embedding(img):
#     with torch.no_grad():
#         inputs = processor(text=["a picture"], images=img, return_tensors="pt", padding=True)
#         outputs = model(**inputs)
#     return outputs['image_embeds'].detach().numpy()

In [None]:
# embeds = [get_embedding(all_frames[i]) for i in range(len(all_frames)) if i% 5 == 0]

In [None]:
# def cosine_similarity(a, b):
#     """Calculate cosine similarity between two vectors."""
#     dot_product = np.dot(a.flatten(), b.flatten())
#     norm_a = np.linalg.norm(a)
#     norm_b = np.linalg.norm(b)
#     return dot_product / (norm_a * norm_b)

In [None]:
# angles = [cosine_similarity(embeds[i], embeds[i+1]) for i in range(len(embeds)-1)]  # Calculate cosine similarity with other embeddings
# # find the defivative of changes
# # Calculate the derivative (rate of change) of angles
# # angle_changes = np.diff(angles)
# angle_changes = np.gradient(angles)


In [None]:
# plot the angles 
# import matplotlib.pyplot as plt

# fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(10, 4))

# ax1.scatter(range(len(angles)), angles, s=10)
# ax1.set_xlabel('Frame Index')
# ax1.set_ylabel('CLIP Cosine Similarity to Next Frame')
# ax1.set_title('Similarity per Frame')

# ax2.scatter(range(len(angle_changes)), angle_changes, s=10)
# ax2.set_xlabel('Frame Index')
# ax2.set_ylabel('Angles derivative')
# ax2.set_title('Angles derivative')


# ax3.hist(angles, bins=30, edgecolor='black')
# ax3.set_xlabel('CLIP Cosine Similarity')
# ax3.set_ylabel('Frequency')
# ax3.set_title('Distribution of Frame Differences')

# plt.tight_layout()
# plt.show()

In [None]:
# i want a heatmap showing similarity of all to all
# Create similarity matrix of all embeddings compared to all other embeddings
# n = len(embeds)
# similarity_matrix = np.zeros((n, n))
# for i in range(n):
#     for j in range(n):
#         similarity_matrix[i,j] = cosine_similarity(embeds[i], embeds[j])

# # Plot heatmap
# plt.figure(figsize=(8, 8))
# plt.imshow(similarity_matrix, cmap='viridis')
# plt.colorbar(label='Cosine Similarity')
# plt.xlabel('Frame Index')
# plt.ylabel('Frame Index') 
# plt.title('All-to-All Frame Similarity Matrix')
# plt.show()


In [None]:
# angles = np.array(angles)
# angles[angles <.96]

In [None]:
# examples = np.where(np.array(angle_changes) < -.01)[0]

In [None]:
# fig, ax = plt.subplots(len(examples), 2, figsize=(20,20))
# for i, ex in enumerate(examples):
#     ax[i][0].imshow(all_frames[5*(ex)])
#     ax[i][1].imshow(all_frames[5*(ex+1)])
#     # title the row with the similarity
#     ax[i][0].set_title(f'Frame {5*ex}')
#     ax[i][1].set_title(f'Frame {5*(ex+1)}, Similarity: {angles[ex]:.3f}')
# plt.show()


In [None]:
# fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))

# ax1.scatter(range(len(diffs)), diffs, s=10)
# ax1.set_xlabel('Frame Index')
# ax1.set_ylabel('Mean Absolute Difference')
# ax1.set_title('Diff per Frame')

# ax2.hist(diffs, bins=30, edgecolor='black')
# ax2.set_xlabel('Abs Frame Difference')
# ax2.set_ylabel('Frequency')
# ax2.set_title('Distribution of Frame Differences')

# plt.tight_layout()
# plt.show()

In [None]:
import numpy as np
import base64
from PIL import Image
import io
import matplotlib.pyplot as plt

def display_messages(messages):
    images = []
    for m in messages:
        print(m['role'])
        contents = m['content']
        for content in contents: 
            if content['type'] == 'text': 
                print(content['text'])
            elif content['type'] == 'image_url':
                byte_image = content['image_url']['url'][len('data:image/jpeg;base64,'):]
                img_data = base64.b64decode(byte_image)
                img = Image.open(io.BytesIO(img_data))
                images.append(img)
            else:
                continue

    # display grid of images 
    # Calculate grid dimensions
    n = len(images)
    if n == 0:
        return
    
    cols = int(np.ceil(np.sqrt(n)))
    rows = int(np.ceil(n / cols))
    
    # Create subplots
    fig = plt.figure(figsize=(8, 8))
    for i, img in enumerate(images):
        ax = fig.add_subplot(rows, cols, i+1)
        ax.imshow(img)
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()


In [None]:
info_dict = {
    "instructions": instructions.format(task=PI_DEMO_TASKS[task]['task'], field_instructions=chr(10).join(field_instructions)),
    "response_format": response_format,
}
frames = get_frames(task, success, n_frames=10)

messages, res = vlm.ask(
    "extractor_prompt.jinja2",
    info_dict,
    images=frames,
    double_prompt=True,
)

In [None]:
display_messages(messages)

In [None]:
print(res.choices[0].message.content, completion_cost(res))

In [None]:
constraints_list = """
1. paper towel roll is vertically aligned with and fully inserted onto the spindle.
2. roll sits securely within the holder frame.
3. no significant lateral gap between the roll's core and spindle.
4. spindle is not empty; paper towel roll is no longer on the table.
5. no visible collisions with the table or objects during the task.
6. roll remains intact and upright after placement.
"""

info_dict = {
    # "instructions": instructions.format(task=PI_DEMO_TASKS[task]['task'], field_instructions=chr(10).join(field_instructions)),
    # "response_format": response_format,
    "instructions": f"Tell me if the TASK: `{PI_DEMO_TASKS[task]['task']}` has been completed according to the constraints list: {constraints_list}. Tell me why or why not.",
    "response_format": "respond in a single string"
}
frames = get_frames(task, success, n_frames=10)

outputs = []
for i in range(10):
    messages, res = vlm.ask(
        "extractor_prompt.jinja2",
        info_dict,
        images=[frames[i]],
        double_prompt=True,
    )
    outputs.append(res.choices[0].message.content)


In [None]:
outputs

In [None]:
info_dict = {
   "instructions": f"summarize the following text, paying extra attention towards the end. each line represents an answer about a frame of video, in order. {str(outputs)}",
   "response_format": "reply in a single string"
}

messages, res = vlm.ask(
    "extractor_prompt.jinja2",
    info_dict,
    images=[],
    # double_prompt=True,
)

In [None]:
res.choices[0].message.content