## Import the data

Create the datapoint

In [6]:
import base64
import json
from io import BytesIO
from pathlib import Path
from typing import Any, Dict, Generator, List, Optional

from datasets import Dataset, Features
from datasets import Image as DSImage
from datasets import Sequence, Value, load_dataset
from PIL import Image
from tqdm import tqdm

In [None]:
ds = load_dataset("osunlp/Multimodal-Mind2Web")

In [3]:
# Declare classes
import dataclasses
from dataclasses import dataclass
from typing import List, Literal, Tuple
import json


@dataclass
class Coordinate:
    x: int
    y: int


@dataclass
class ScrollBar:
    offset: float
    height: float


@dataclass
class BrowserState:
    screenshot: str
    height: int
    width: int
    scrollbar: ScrollBar
    url: str
    mouse: Coordinate


@dataclass
class BrowserAction:
    action: Literal[
        "success",
        "failure",
        "key",
        "type",
        "mouse_move",
        "left_click",
        "left_click_drag",
        "right_click",
        "middle_click",
        "double_click",
        "screenshot",
        "cursor_position",
        "scroll_up",
        "scroll_down",
    ]
    # TODO: Do we want to use Coordinate class here, or easier to just construct with tuple
    coordinate: tuple[int, int] | None
    text: str | None
    reasoning: str
    id: str


@dataclass
class BrowserStep:
    state: BrowserState
    action: BrowserAction

In [4]:
# Declare functions

import random


def generate_tool_id() -> str:
    prefix = "toolu_01"
    characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
    id_length = 22
    result = prefix

    for _ in range(id_length):
        result += random.choice(characters)

    return result


def is_in_viewport(viewport, point):
    x1, y1, x2, y2 = viewport
    x, y = point
    return x1 <= x <= x2 and y1 <= y <= y2


def scroll_viewport(direction, viewport, y_max):
    x1, y1, x2, y2 = viewport
    height = y2 - y1
    scroll_amount = 0.75 * height

    if direction == "up":
        new_y1 = max(1, y1 - scroll_amount)
        new_y2 = new_y1 + height
    elif direction == "down":
        new_y2 = min(y_max, y2 + scroll_amount)
        new_y1 = new_y2 - height
    else:
        raise ValueError("Direction must be 'up' or 'down'")

    # Adjust if the new viewport exceeds bounds while preserving height
    if new_y1 < 1:
        new_y1 = 1
        new_y2 = new_y1 + height
    if new_y2 > y_max:
        new_y2 = y_max
        new_y1 = new_y2 - height

    return (x1, new_y1, x2, new_y2)


def viewport_screenshot(screenshot, viewport):
    import base64
    from io import BytesIO

    x1, y1, x2, y2 = map(int, viewport)
    cropped_image = screenshot.copy().crop((x1, y1, x2, y2))

    buffered = BytesIO()
    cropped_image.save(buffered, format="JPEG", quality=85)
    encoded_string = base64.b64encode(buffered.getvalue()).decode("utf-8")

    return encoded_string
    # return ""


def process_step(
    step, mouse_coordinates: Coordinate
) -> Tuple[List[BrowserStep], Coordinate]:
    cerebellum_steps: List[BrowserStep] = []

    # Initialize the viewport to the top 16:10 ratio part of the screenshot
    screenshot = step["screenshot"]
    width, height = screenshot.size
    viewport_height = width * 10 / 16
    viewport = (0, 0, width, viewport_height)

    # Find the bounding box of the first pos_candidates
    if len(step["pos_candidates"]) == 0:
        return ([], mouse_coordinates)

    candidate = json.loads(step["pos_candidates"][0])
    attributes = json.loads(candidate["attributes"])
    bounding_box_rect = attributes["bounding_box_rect"]
    x, y, box_width, box_height = map(float, bounding_box_rect.split(","))
    center_x = x + box_width / 2
    center_y = y + box_height / 2

    if not (0 <= center_x <= width and 0 <= center_y <= height):
        print("Bounding box coordinates outside of provided screenshot, skipping step")
        return ([], mouse_coordinates)

    # Scroll the viewport until the center of the bounding box is in view
    y_max = float(height)
    while not is_in_viewport(viewport, (center_x, center_y)):
        if center_y < viewport[1]:
            browser_state = BrowserState(
                url="",
                screenshot=viewport_screenshot(screenshot, viewport),
                height=viewport_height,
                width=width,
                scrollbar=ScrollBar(
                    offset=float(viewport[1]) / y_max,
                    height=float(viewport_height) / y_max,
                ),
                mouse=mouse_coordinates,
            )
            page_up_action = BrowserAction(
                action="key",
                coordinate=None,
                text="PAGE_UP",
                reasoning="Press the Page Up key to scroll up",
                id=generate_tool_id(),
            )
            cerebellum_steps.append(
                BrowserStep(state=browser_state, action=page_up_action)
            )

            viewport = scroll_viewport("up", viewport, y_max)
        elif center_y > viewport[3]:

            browser_state = BrowserState(
                url="",
                screenshot=viewport_screenshot(screenshot, viewport),
                height=viewport_height,
                width=width,
                scrollbar=ScrollBar(
                    offset=float(viewport[1]) / y_max,
                    height=float(viewport_height) / y_max,
                ),
                mouse=mouse_coordinates,
            )
            page_down_action = BrowserAction(
                action="key",
                coordinate=None,
                text="PAGE_DOWN",
                reasoning="Press the Page Down key to scroll down",
                id=generate_tool_id(),
            )
            cerebellum_steps.append(
                BrowserStep(state=browser_state, action=page_down_action)
            )
            viewport = scroll_viewport("down", viewport, y_max)

    # Create a mouse movement action to position the mouse into the center of the bounding box
    # Remap center_x and center_y relative to the current viewport
    center_x_relative = center_x - viewport[0]
    center_y_relative = center_y - viewport[1]
    mouse_move_action = BrowserAction(
        action="mouse_move",
        coordinate=(center_x_relative, center_y_relative),
        text=None,
        reasoning="Move mouse to the center of the element",
        id=generate_tool_id(),
    )
    browser_state = BrowserState(
        url="",
        screenshot=viewport_screenshot(screenshot, viewport),
        height=viewport_height,
        width=width,
        scrollbar=ScrollBar(
            offset=float(viewport[1]) / y_max, height=float(viewport_height) / y_max
        ),
        mouse=mouse_coordinates,
    )
    move_step = BrowserStep(state=browser_state, action=mouse_move_action)
    cerebellum_steps.append(move_step)

    # Pretend now the mouse was moved
    mouse_coordinates = Coordinate(x=center_x_relative, y=center_y_relative)

    # Perform a left click action
    left_click_action = BrowserAction(
        action="left_click",
        coordinate=None,
        text=None,
        reasoning="Perform a left click on element",
        id=generate_tool_id(),
    )
    browser_state = BrowserState(
        url="",
        screenshot=viewport_screenshot(screenshot, viewport),
        height=viewport_height,
        width=width,
        scrollbar=ScrollBar(
            offset=float(viewport[1]) / y_max, height=float(viewport_height) / y_max
        ),
        mouse=mouse_coordinates,
    )
    left_click_step = BrowserStep(state=browser_state, action=left_click_action)
    cerebellum_steps.append(left_click_step)

    # Create corresponding key actions if the action is "type" or "select"
    operation = json.loads(step["operation"])
    if operation["op"] in ["TYPE", "SELECT"]:
        text = operation["value"]
        type_action = BrowserAction(
            action="type",
            coordinate=None,
            text=text,
            reasoning=f"Typing text set to desired value",
            id=generate_tool_id(),
        )
        browser_state = BrowserState(
            url="",
            screenshot=viewport_screenshot(screenshot, viewport),
            height=viewport_height,
            width=width,
            scrollbar=ScrollBar(
                offset=float(viewport[1]) / y_max, height=float(viewport_height) / y_max
            ),
            mouse=mouse_coordinates,
        )
        type_step = BrowserStep(state=browser_state, action=type_action)
        cerebellum_steps.append(type_step)

    # Return an array of BrowserStep[]
    return (cerebellum_steps, mouse_coordinates)

In [None]:
train = ds.get("train")

train_iterator = iter(train)

print(list(train[0].keys()))

data_point = next(train_iterator)
i = 1
while train_iterator is not None:
    goal = data_point["confirmed_task"]
    task_id = data_point["annotation_id"]
    action_id = data_point["action_uid"]

    print("Grabbing steps for:", goal, task_id)

    steps = [data_point]

    # Keep on pulling on the iterator until we get all the steps in this task
    while True:
        # i+=1
        # print(i, data_point["action_uid"])
        try:
            data_point = next(train_iterator)
        except StopIteration:
            train_iterator = None
            break

        # if data_point["action_uid"] == "aea31efd-c391-4099-a13e-3a9417cca68f":
        #     print(data_point)

        if data_point["confirmed_task"] != goal:
            break

        steps.append(data_point)

    cerebellum_steps: List[BrowserStep] = []

    mouse = Coordinate(x=1, y=1)
    for raw_step in steps:

        decomposed_steps, mouse = process_step(raw_step, mouse)

        cerebellum_steps += decomposed_steps

    # Define the output file path
    output_file_path = f"mind2web/{task_id}.jsonl"

    # Open the file in write mode
    with open(output_file_path, "w") as outfile:
        goal_json = json.dumps({"goal": goal})
        outfile.write(goal_json)
        outfile.write("\n")
        # Iterate over each step in cerebellum_steps
        for this_step in cerebellum_steps:
            # Write the dictionary as a JSON line
            step_str = json.dumps(dataclasses.asdict(this_step))
            outfile.write(step_str)
            outfile.write("\n")

In [6]:
import base64

CURSOR_64 = "iVBORw0KGgoAAAANSUhEUgAAAAoAAAAQCAYAAAAvf+5AAAAAw3pUWHRSYXcgcHJvZmlsZSB0eXBlIGV4aWYAAHjabVBRDsMgCP33FDuC8ijF49i1S3aDHX9YcLFLX+ITeOSJpOPzfqVHBxVOvKwqVSQbuHKlZoFmRzu5ZD55rvX8Uk9Dz2Ql2A1PVaJ/1MvPwK9m0TIZ6TOE7SpUDn/9M4qH0CciC/YwqmEEcqGEQYsvSNV1/sJ25CvUTxqBjzGJU86rbW9f7B0QHSjIxoD6AOiHE1oXjAlqjQVyxmTMkJjEFnK3p4H0BSRiWUv/cuYLAAABhWlDQ1BJQ0MgcHJvZmlsZQAAeJx9kT1Iw0AYht+2SqVUHCwo0iFD1cWCqIijVqEIFUKt0KqDyaV/0KQhSXFxFFwLDv4sVh1cnHV1cBUEwR8QZwcnRRcp8buk0CLGg7t7eO97X+6+A/yNClPNrnFA1SwjnUwI2dyqEHxFCFEM0DoqMVOfE8UUPMfXPXx8v4vzLO+6P0evkjcZ4BOIZ5luWMQbxNObls55nzjCSpJCfE48ZtAFiR+5Lrv8xrnosJ9nRoxMep44QiwUO1juYFYyVOIp4piiapTvz7qscN7irFZqrHVP/sJwXltZ5jrNKJJYxBJECJBRQxkVWIjTrpFiIk3nCQ//kOMXySWTqwxGjgVUoUJy/OB/8Lu3ZmFywk0KJ4DuF9v+GAaCu0Czbtvfx7bdPAECz8CV1vZXG8DMJ+n1thY7Avq2gYvrtibvAZc7wOCTLhmSIwVo+gsF4P2MvikH9N8CoTW3b61znD4AGepV6gY4OARGipS97vHuns6+/VvT6t8Ph1lyr0hzlCAAAA14aVRYdFhNTDpjb20uYWRvYmUueG1wAAAAAAA8P3hwYWNrZXQgYmVnaW49Iu+7vyIgaWQ9Ilc1TTBNcENlaGlIenJlU3pOVGN6a2M5ZCI/Pgo8eDp4bXBtZXRhIHhtbG5zOng9ImFkb2JlOm5zOm1ldGEvIiB4OnhtcHRrPSJYTVAgQ29yZSA0LjQuMC1FeGl2MiI+CiA8cmRmOlJERiB4bWxuczpyZGY9Imh0dHA6Ly93d3cudzMub3JnLzE5OTkvMDIvMjItcmRmLXN5bnRheC1ucyMiPgogIDxyZGY6RGVzY3JpcHRpb24gcmRmOmFib3V0PSIiCiAgICB4bWxuczp4bXBNTT0iaHR0cDovL25zLmFkb2JlLmNvbS94YXAvMS4wL21tLyIKICAgIHhtbG5zOnN0RXZ0PSJodHRwOi8vbnMuYWRvYmUuY29tL3hhcC8xLjAvc1R5cGUvUmVzb3VyY2VFdmVudCMiCiAgICB4bWxuczpkYz0iaHR0cDovL3B1cmwub3JnL2RjL2VsZW1lbnRzLzEuMS8iCiAgICB4bWxuczpHSU1QPSJodHRwOi8vd3d3LmdpbXAub3JnL3htcC8iCiAgICB4bWxuczp0aWZmPSJodHRwOi8vbnMuYWRvYmUuY29tL3RpZmYvMS4wLyIKICAgIHhtbG5zOnhtcD0iaHR0cDovL25zLmFkb2JlLmNvbS94YXAvMS4wLyIKICAgeG1wTU06RG9jdW1lbnRJRD0iZ2ltcDpkb2NpZDpnaW1wOjFiYzFkZjE3LWM5YmMtNGYzZi1hMmEzLTlmODkyNWNiZjY4OSIKICAgeG1wTU06SW5zdGFuY2VJRD0ieG1wLmlpZDo4YTUyMWJhMC00YmNlLTQzZWEtYjgyYS04ZGM2MTBjYmZlOTgiCiAgIHhtcE1NOk9yaWdpbmFsRG9jdW1lbnRJRD0ieG1wLmRpZDplODQ3ZjUxNC00MWVlLTQ2ZjYtOTllNC1kNjI3MjMxMjhlZTIiCiAgIGRjOkZvcm1hdD0iaW1hZ2UvcG5nIgogICBHSU1QOkFQST0iMi4wIgogICBHSU1QOlBsYXRmb3JtPSJMaW51eCIKICAgR0lNUDpUaW1lU3RhbXA9IjE3MzAxNTc3NjY5MTI3ODciCiAgIEdJTVA6VmVyc2lvbj0iMi4xMC4zOCIKICAgdGlmZjpPcmllbnRhdGlvbj0iMSIKICAgeG1wOkNyZWF0b3JUb29sPSJHSU1QIDIuMTAiCiAgIHhtcDpNZXRhZGF0YURhdGU9IjIwMjQ6MTA6MjhUMTY6MjI6NDYtMDc6MDAiCiAgIHhtcDpNb2RpZnlEYXRlPSIyMDI0OjEwOjI4VDE2OjIyOjQ2LTA3OjAwIj4KICAgPHhtcE1NOkhpc3Rvcnk+CiAgICA8cmRmOlNlcT4KICAgICA8cmRmOmxpCiAgICAgIHN0RXZ0OmFjdGlvbj0ic2F2ZWQiCiAgICAgIHN0RXZ0OmNoYW5nZWQ9Ii8iCiAgICAgIHN0RXZ0Omluc3RhbmNlSUQ9InhtcC5paWQ6ZTVjOTM2ZDYtYjMzYi00NzM4LTlhNWUtYjM3YTA5MzdjZDAxIgogICAgICBzdEV2dDpzb2Z0d2FyZUFnZW50PSJHaW1wIDIuMTAgKExpbnV4KSIKICAgICAgc3RFdnQ6d2hlbj0iMjAyNC0xMC0yOFQxNjoyMjo0Ni0wNzowMCIvPgogICAgPC9yZGY6U2VxPgogICA8L3htcE1NOkhpc3Rvcnk+CiAgPC9yZGY6RGVzY3JpcHRpb24+CiA8L3JkZjpSREY+CjwveDp4bXBtZXRhPgogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIAogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIAogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIAogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIAogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIAogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIAogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAKICAgICAgICAgICAgICAgICAgICAgICAgICAgCjw/eHBhY2tldCBlbmQ9InciPz5/5aQ8AAAABmJLR0QAcgByAAAtJLTuAAAACXBIWXMAAABZAAAAWQGqnamGAAAAB3RJTUUH6AocFxYuv5vOJAAAAHhJREFUKM+NzzEOQXEMB+DPYDY5iEVMIpzDfRxC3mZyBK7gChZnELGohaR58f7a7dd8bVq4YaVQgTvWFVjCUcXxA28qcBBHFUcVRwWPPuFfXVsbt0PPnLBL+dKHL+wxxhSPhBcZznuDXYKH1uGzBJ+YtPAZRyy/jTd7qEoydWUQ7QAAAABJRU5ErkJggg=="
CURSOR_BYTES = base64.b64decode(CURSOR_64)

In [None]:
# Post process
system_prompt = """You are an intelligent web browsing agent operating in fullscreen mode to accomplish a specified user goal, detailed in <USER_TASK>.
* Use only the Page Down or Page Up keys for scrolling.
* If the webpage is scrollable, a gray rectangular scrollbar will appear on the right edge of the screenshot.
* Adhere strictly to the instructions in the <IMPORTANT> section below.
</SYSTEM_CAPABILITY>

Your task is to execute user requests using their browser. After each action, capture a screenshot and thoroughly assess whether the desired outcome has been achieved. Clearly articulate your reasoning for each function call: "I have evaluated step X..." If the result is incorrect, attempt the step again. Proceed to the next step only after confirming successful execution. Always utilize a tool for actions and ensure to return a tool call. Remember to invoke the stop_browsing tool upon achieving the task's goal. Prioritize keyboard shortcuts for navigation whenever feasible.

<IMPORTANT>
* Utilize the user's <USER_DATA> to complete forms as you progress towards the goal.
* Ensure a UI element is fully visible before interacting with it.
</IMPORTANT>"""


import os
import json
from PIL import Image
import io
import math

cursor_img = Image.open(io.BytesIO(CURSOR_BYTES))


def mark_screenshot(
    img_buffer: bytes, mouse_position: Coordinate, scrollbar: ScrollBar
) -> str:
    """Adds scrollbar and cursor overlays to a screenshot.
    Args:
        img_buffer: Raw bytes of the screenshot image
        mouse_position: Coordinate object containing x,y position of mouse cursor
        scrollbar: ScrollBar object containing scrollbar dimensions and position
    Returns:
        Base64 string of the modified screenshot with overlays added
    Raises:
        IOError: If there are issues manipulating the image
    """
    with Image.open(io.BytesIO(img_buffer)) as img:
        width, height = img.size

        # Create scrollbar overlay
        scrollbar_width = 10
        scrollbar_height = int(height * scrollbar.height)
        scrollbar_top = int(height * scrollbar.offset)

        # Create gray rectangle for scrollbar with 70% opacity
        scrollbar_img = Image.new(
            "RGBA", (scrollbar_width, scrollbar_height), (128, 128, 128, int(255 * 0.8))
        )

        # Create composite image
        composite = img.copy()
        composite.paste(
            scrollbar_img, (width - scrollbar_width, scrollbar_top), scrollbar_img
        )

        # Add cursor
        composite.paste(
            cursor_img,
            (max(0, mouse_position.x), max(0, mouse_position.y)),
            cursor_img,
        )

        # Calculate the aspect ratio
        aspect_ratio = composite.width / composite.height

        # Determine the new dimensions while preserving aspect ratio
        if composite.width > 640 or composite.height > 400:
            if aspect_ratio > 640 / 400:
                new_width = 640
                new_height = int(640 / aspect_ratio)
            else:
                new_height = 400
                new_width = int(400 * aspect_ratio)
        else:
            new_width, new_height = composite.width, composite.height

        # Resize the composite image
        composite = composite.resize((new_width, new_height), Image.LANCZOS)

        # Convert back to base64 string
        output_buffer = io.BytesIO()
        composite.save(output_buffer, "JPEG", quality=85)
        return base64.b64encode(output_buffer.getvalue()).decode("utf-8")


# Define the directory containing the jsonl files
directory = "mind2web"

# Loop through each file in the directory
for filename in os.listdir(directory):
    if filename.endswith(".jsonl"):
        file_path = os.path.join(directory, filename)

        print("Processing", filename)
        # Open and read each jsonl file
        with open(file_path, "r") as file:
            lines = []
            for line in file:
                # Process each line as a JSON object
                json_object = json.loads(line)
                lines.append(json_object)

            goal = lines.pop(0)["goal"]

            starting_directions = []

            starting_directions.append({"role": "system", "content": system_prompt})
            starting_directions.append(
                {
                    "role": "user",
                    "content": f"<USER_TASK>{goal}</USER_TASK>\n<USER_DATA>NONE</USER_DATA>",
                }
            )

            # Play with this for best behavior
            starting_directions.append(
                {
                    "role": "assistant",
                    "content": "",
                    "tool_calls": [
                        {
                            "name": "screenshot",
                            "arguments": {
                                "reason": "Take a screenshot of the browser to understand the current webpage"
                            },
                        }
                    ],
                }
            )

            examples = []

            # lines now contains the steps
            for i in range(len(lines)):

                last_function_name = "screenshot"

                subset = lines[: i + 1]

                training_example = starting_directions.copy()

                for step_idx in range(len(subset)):
                    step = subset[step_idx]
                    state = step["state"]

                    scrollbar = ScrollBar(
                        offset=step["state"]["scrollbar"]["offset"],
                        height=step["state"]["scrollbar"]["height"],
                    )
                    mouse = Coordinate(
                        x=math.floor(step["state"]["mouse"]["x"]),
                        y=math.floor(step["state"]["mouse"]["y"]),
                    )

                    normalized_mouse_x = mouse.x / float(step["state"]["width"])
                    normalized_mouse_y = mouse.y / float(step["state"]["height"])

                    training_example.append(
                        {
                            "role": "tool",
                            "name": last_function_name,
                            "content": {"result": "Action completed successfully"},
                        }
                    )

                    msg_content = []
                    msg_content.append(
                        {
                            "type": "text",
                            "value": f"After action mouse cursor is at X: {normalized_mouse_x}, Y: {normalized_mouse_y}",
                        }
                    )

                    if step_idx == len(subset) - 1:
                        decoded_img = base64.b64decode(step["state"]["screenshot"])
                        marked_image = mark_screenshot(decoded_img, mouse, scrollbar)
                        msg_content.append(
                            {
                                "type": "image_url",
                                "image_url": {
                                    "url": f"data:image/jpeg;base64,{marked_image}"
                                },
                            }
                        )

                    training_example.append({"role": "user", "content": msg_content})

                    action_arg = {
                        "reason": step["action"]["reasoning"],
                    }

                    if step["action"]["coordinate"]:
                        [x, y] = step["action"]["coordinate"]
                        norm_x = float(x) / float(step["state"]["width"])
                        norm_y = float(y) / float(step["state"]["height"])
                        action_arg["coordinate"] = (norm_x, norm_y)

                    if step["action"]["text"]:
                        action_arg["text"] = step["action"]["text"]

                    training_example.append(
                        {
                            "role": "assistant",
                            "content": "",
                            "tool_calls": [
                                {
                                    "name": step["action"]["action"],
                                    "arguments": action_arg,
                                }
                            ],
                        }
                    )

                    last_function_name = step["action"]["action"]

                examples.append(training_example)

            with open(f"molmo/{filename}", "w") as jsonl_file:
                for example in examples:
                    jsonl_entry = json.dumps(example)
                    jsonl_file.write(jsonl_entry + "\n")

In [12]:
"""Data processing utilities for Molmo training data."""


def convert_base64_to_image(base64_string: str) -> Image.Image:
    """Converts a base64 string to a PIL Image."""
    base64_string = (
        base64_string.split(",")[1] if "," in base64_string else base64_string
    )
    image_data = base64.b64decode(base64_string)
    return Image.open(BytesIO(image_data))


def yield_jsonl_records(
    mind2web_dir: Path, max_files: Optional[int] = None, max_lines: Optional[int] = None
) -> Generator:
    """Generator function that yields records from JSONL files.

    Args:
        mind2web_dir: Directory containing JSONL files
        max_files: Maximum number of files to process
        max_lines: Maximum number of lines to process per file

    Yields:
        Dict containing the parsed JSON record
    """
    jsonl_files = [f for f in mind2web_dir.glob("*.jsonl")]

    if max_files is not None:
        jsonl_files = jsonl_files[:max_files]

    print(f"Processing {len(jsonl_files)} files...")

    for file_path in tqdm(jsonl_files, desc="Processing files"):
        line_count = 0

        with open(file_path, "r") as file:
            for line in file:
                if max_lines is not None and line_count >= max_lines:
                    break

                try:
                    record = json.loads(line)
                    yield record
                    line_count += 1
                except json.JSONDecodeError as e:
                    print(f"Error in {file_path}: {e}")
                    continue


def process_record(record: List[Dict[str, Any]]) -> Dict[str, Any]:
    """Process a single JSONL record.

    Args:
        record: List of messages from the JSONL record

    Returns:
        Dict containing processed messages and images
    """
    empty_content = {"index": None, "text": None}
    messages = []
    images = []

    for msg in record:
        initial_content = (
            msg["content"]
            if isinstance(msg["content"], list)
            else [{"type": "text", "text": msg["content"]}]
        )

        empty_filled_content = [
            {**empty_content, **content_entry} for content_entry in initial_content
        ]

        for entry in empty_filled_content:
            if entry["type"] == "image_url":
                base64_url = entry["image_url"]["url"]
                image = convert_base64_to_image(base64_url)
                images.append(image)

                entry["type"] = "image"
                entry["image_index"] = len(images) - 1
                entry["text"] = None
                entry.pop("image_url", None)
            elif "value" in entry:
                entry["text"] = entry.pop("value")

        messages.append({**msg, "content": empty_filled_content})

    return {"messages": messages, "images": images}


def create_dataset_features() -> Features:
    """Create the feature schema for the dataset."""
    return Features(
        {
            "messages": Sequence(
                {
                    "role": Value("string"),
                    "content": Sequence(
                        {
                            "index": Value("null"),
                            "text": Value("string"),
                            "type": Value("string"),
                            "image_index": Value("int64"),
                        }
                    ),
                    "tool_calls": Sequence(
                        {"name": Value("string"), "arguments": Value("string")},
                        length=-1,
                    ),
                }
            ),
            "images": Sequence(DSImage()),
        }
    )


def process_and_save_dataset(
    mind2web_dir: Path,
    save_dir: Path,
    max_files: Optional[int] = None,
    max_lines: Optional[int] = None,
    test_mode: bool = False,
) -> Dataset:
    """Process and save the dataset.

    Args:
        mind2web_dir: Input directory containing JSONL files
        save_dir: Output directory for processed dataset
        max_files: Maximum number of files to process
        max_lines: Maximum number of lines to process per file
        test_mode: Whether to save in test mode

    Returns:
        Processed Dataset
    """
    print("Starting data processing...")

    dataset = Dataset.from_generator(
        lambda: (
            process_record(record)
            for record in yield_jsonl_records(mind2web_dir, max_files, max_lines)
        )
    ).cast_column("images", Sequence(DSImage()))

    print(f"\nProcessed {len(dataset)} examples")

    # Determine save directory based on test mode
    final_save_dir = save_dir / "test" if test_mode else save_dir
    final_save_dir.mkdir(parents=True, exist_ok=True)

    dataset.save_to_disk(str(final_save_dir))
    print(f"\nDataset saved to {final_save_dir}")

    # Verify the saved dataset
    print("\nVerifying saved dataset...")
    loaded_dataset = Dataset.load_from_disk(str(final_save_dir))
    print("Dataset size:", len(loaded_dataset))
    print("\nFirst example:")
    first_example = loaded_dataset[0]
    print("Number of messages:", len(first_example["messages"]))
    print("Number of images:", len(first_example["images"]))

    return loaded_dataset

In [10]:
MIND2WEB_DIR = "<update/to/dir>"
SAVE_DIR = "<update/to/dir>"

In [None]:
# Process the Molmo dataset
processed_dataset = process_and_save_dataset(
    mind2web_dir=Path(MIND2WEB_DIR),
    save_dir=Path(SAVE_DIR),
)

In [None]:
# Assuming simplest conversion of Llava back to Molmo of dumping messages as string.
top_N = 5
top_N_list = []

for idx, item in enumerate(processed_dataset):
    this_len = len(json.dumps(item["messages"]))
    top_N_list.append((this_len, idx))
    top_N_list = sorted(top_N_list, key=lambda tup: tup[0], reverse=True)[:top_N]

In [None]:
for idx, item in enumerate(top_N_list):
    print(f"{idx}. Item {item[1]} has length {item[0]}")

In [None]:
# Test some basic tokenization
from transformers import AutoProcessor

# load the processor
processor = AutoProcessor.from_pretrained(
    "allenai/Molmo-7B-O-0924",
    trust_remote_code=True,
    torch_dtype="auto",
    device_map="auto",
)


for idx, item in enumerate(top_N_list):
    tokens = processor.process(text=json.dumps(processed_dataset[item[1]]["messages"]))
    print(f"{idx}. Item {item[1]} has tokenization length {len(tokens['input_ids'])}")