<a href="https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/VisionTransformer/Quick_demo_of_HuggingFace_version_of_Vision_Transformer_inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Adapted from Quick demo: Vision Transformer (ViT) by Google Brain

In [1]:
import logging
import sys
logger = logging.getLogger(__name__)
logging.basicConfig(format='%(asctime)s %(levelname)s:%(message)s', level=logging.DEBUG, datefmt='%I:%M:%S')
logger.info("Logger set up")

02:54:52 INFO:Logger set up


# Part 1: Preprocess Data
Mark all pixels that belongs to the bounding boxes of positive candidates as targets

### Preprocess data

In [2]:
from transformers import set_seed
from datasets import load_dataset

set_seed(123)
# logger.info(f"Use model {cfg.model.pretrained_model_name_or_path}")
# output_dir = HydraConfig.get().runtime.output_dir
train_dataset = load_dataset("osunlp/Multimodal-Mind2Web", split="train").select(range(1000))
# train_dataset = train_dataset.map(
#     flatten_actions,
#     batched=True,
#     remove_columns=train_dataset.column_names, # remove all original columns?
#     batch_size=10,
#     num_proc=4,
# )
print(train_dataset)
train_dataset = train_dataset.remove_columns(["neg_candidates", "raw_html", "cleaned_html"])

# Add column for previous_actions
previous_actions = []
curr_actions = None
num_actions = 0
step = 0
for i in range(len(train_dataset)):    
    if step == num_actions:
        step = 0
        curr_actions = train_dataset[i]["action_reprs"]
        num_actions = len(curr_actions)
    previous_actions.append(curr_actions[:step]) 
    step += 1

train_dataset = train_dataset.add_column("previous_actions", previous_actions)

# filter out those without pos_candidates
train_dataset = train_dataset.filter(lambda example: len(example["pos_candidates"]) == 1, num_proc=20) #TODO
train_dataset = train_dataset.remove_columns('action_reprs')
print(train_dataset)

02:54:54 INFO:PyTorch version 2.0.1 available.
02:54:54 DEBUG:Starting new HTTPS connection (1): huggingface.co:443
02:54:54 DEBUG:https://huggingface.co:443 "GET /api/datasets/osunlp/Multimodal-Mind2Web HTTP/1.1" 200 5343
02:54:54 DEBUG:Starting new HTTPS connection (1): s3.amazonaws.com:443
02:54:54 DEBUG:https://s3.amazonaws.com:443 "HEAD /datasets.huggingface.co/datasets/datasets/osunlp/Multimodal-Mind2Web/osunlp/Multimodal-Mind2Web.py HTTP/1.1" 404 0
02:54:54 DEBUG:https://huggingface.co:443 "GET /api/datasets/osunlp/Multimodal-Mind2Web HTTP/1.1" 200 5343
02:54:54 DEBUG:Starting new HTTPS connection (1): huggingface.co:443
02:54:55 DEBUG:https://huggingface.co:443 "HEAD /datasets/osunlp/Multimodal-Mind2Web/resolve/f27b6362acc6efe0e97289620307ca42cb177e5b/README.md HTTP/1.1" 200 0
02:54:55 DEBUG:https://huggingface.co:443 "GET /api/datasets/osunlp/Multimodal-Mind2Web/revision/f27b6362acc6efe0e97289620307ca42cb177e5b HTTP/1.1" 200 5343
02:54:55 DEBUG:https://huggingface.co:443 "POST

Resolving data files:   0%|          | 0/27 [00:00<?, ?it/s]

02:54:56 DEBUG:https://huggingface.co:443 "GET /api/datasets/osunlp/Multimodal-Mind2Web/revision/f27b6362acc6efe0e97289620307ca42cb177e5b HTTP/1.1" 200 5343
02:54:56 DEBUG:https://huggingface.co:443 "GET /api/datasets/osunlp/Multimodal-Mind2Web/revision/f27b6362acc6efe0e97289620307ca42cb177e5b HTTP/1.1" 200 5343
02:54:56 DEBUG:https://huggingface.co:443 "GET /api/datasets/osunlp/Multimodal-Mind2Web/revision/f27b6362acc6efe0e97289620307ca42cb177e5b HTTP/1.1" 200 5343
02:54:56 DEBUG:https://huggingface.co:443 "GET /api/datasets/osunlp/Multimodal-Mind2Web/revision/f27b6362acc6efe0e97289620307ca42cb177e5b HTTP/1.1" 200 5343
02:54:56 DEBUG:https://huggingface.co:443 "GET /api/datasets/osunlp/Multimodal-Mind2Web/revision/f27b6362acc6efe0e97289620307ca42cb177e5b HTTP/1.1" 200 5343
02:54:56 DEBUG:https://huggingface.co:443 "GET /api/datasets/osunlp/Multimodal-Mind2Web/revision/f27b6362acc6efe0e97289620307ca42cb177e5b HTTP/1.1" 200 5343
02:54:56 DEBUG:https://huggingface.co:443 "GET /api/datase

Dataset({
    features: ['action_uid', 'raw_html', 'cleaned_html', 'operation', 'pos_candidates', 'neg_candidates', 'website', 'domain', 'subdomain', 'annotation_id', 'confirmed_task', 'screenshot', 'action_reprs'],
    num_rows: 1000
})
Dataset({
    features: ['action_uid', 'operation', 'pos_candidates', 'website', 'domain', 'subdomain', 'annotation_id', 'confirmed_task', 'screenshot', 'previous_actions'],
    num_rows: 892
})


### Generate prompt and label
The full prompt is:

[patch embeddings] \n Based on the webpage screenshot, try to complete the following task:\n Task: [task] \n Previous actions:\n [actions] \n Which image patch contains the element to interact with next?"

In [3]:
import json
def get_prompt_target(example):
    """
    Use the bounding boxes of pos_candidates (as list of lists, [left, bottom, width, height]
    """
    boxes = []
    for cand in example["pos_candidates"]:
        json_data = json.loads(cand)
        attributes = json.loads(json_data['attributes'])
        bounding_box_rect_str = attributes['bounding_box_rect']
        boxes.append(list(map(float, bounding_box_rect_str.split(','))))

    # NOTE: Don't prune, just include the whole webpage
    seq_input = (
        "Based on the HTML webpage, try to complete the following task:\n"
        f"Task: {example['confirmed_task']}\n"
        f"Previous actions:\n"
    )
    # TODO: hard-coded
    previous_k = 5
    if len(example["previous_actions"]) > 0:
        for action in example["previous_actions"][-previous_k:]:
            seq_input += f"{action}\n"
    else:
        seq_input += "None\n"
        
    seq_input += (
        "What should be the element to interact with next?"
    )

    example["question"] = seq_input
    example["boxes"] = boxes

    return example


In [4]:
cols_to_remove = set(train_dataset.column_names)
cols_to_remove.remove("screenshot")
train_dataset = train_dataset.map(
    get_prompt_target,
    batched=False,
    remove_columns=list(cols_to_remove)
)
train_dataset[2]

Map:   0%|          | 0/892 [00:00<?, ? examples/s]

02:55:01 DEBUG:open file: /scr/wychow/.cache/huggingface/datasets/osunlp___multimodal-mind2_web/default/0.0.0/f27b6362acc6efe0e97289620307ca42cb177e5b/tmpx5aoovv6


{'screenshot': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1280x5429>,
 'question': 'Based on the HTML webpage, try to complete the following task:\nTask: rent a car in Brooklyn - Central, NY on from April 9 to April 15.\nPrevious actions:\n[heading]  CAR -> CLICK\n[combobox]  Enter pick up city, airport name, or airport code. -> TYPE: Brooklyn Central\nWhat should be the element to interact with next?',
 'boxes': [[114.59375, 365.1875, 306.8125, 25.6875]]}

### Prepare Model

In [5]:
from transformers import PreTrainedModel
import torch.nn as nn

class MultimodalAgent(PreTrainedModel):
    def __init__(self, config, image_encoder, lm):
        super().__init__(config)
        self.config = config
        self.supports_gradient_checkpointing = True
        self.image_encoder = image_encoder
        self.projector = nn.Linear(image_encoder.config.hidden_size, lm.config.hidden_size) 
        self.lm = lm

    def forward(self, pixel_values, input_ids, attention_mask, labels=None):
        # embed pixel_values with image_encoder
        # h_image = self.image_encoder(flattened_patches, attention_mask_image).last_hidden_state
        h_image = self.image_encoder(pixel_values, interpolate_pos_encoding=True).last_hidden_state
        # linear layer to project hidden states to lm's input dimension
        h_image = self.projector(h_image)
        # look up token embedding for text
        h_text = self.lm.model.embed_tokens(input_ids)
        # concatenate image represenation with question
        inputs_embeds = torch.cat([h_image, h_text], dim=1)
        # also concat attention mask
        # attention_mask = torch.cat([torch.ones(h_image.shape), attention_mask], dim=-1)
        # TODO: need to add some sort of separator, like \n?
        return self.lm(inputs_embeds=inputs_embeds, output_hidden_states=True).hidden_states[-1] # Not passing attention mask, no need for now since batch size is 1
        

In [6]:
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel, AutoConfig
import torch
# from transformers import Pix2StructVisionModel, ViTImageProcessor, Pix2StructVisionConfig

### Config for notebook
from transformers import AutoConfig
config = AutoConfig.from_pretrained("mistralai/Mistral-7B-v0.1")
config.return_dict = True
config.use_cache = False
config.low_cpu_mem_usage = True
config.rope_theta = 10000.0
config.attn_implementation = "flash_attention_2"
###

# TODO: Move config to somewhere else

# image_encoder_config = Pix2StructVisionConfig.from_pretrained("google/pix2struct-base")
# TODO: try different hidden size?
# image_encoder_config.seq_len = 27145
# image_encoder_config.patch_size = 16

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# image_encoder = Pix2StructVisionModel.from_pretrained("google/pix2struct-textcaps-base", config=image_encoder_config, torch_dtype=torch.bfloat16)
# image_encoder.to(device)

image_encoder_path = "google/vit-base-patch16-224"
image_encoder_config = AutoConfig.from_pretrained(image_encoder_path)
image_encoder = AutoModel.from_pretrained(image_encoder_path, config=image_encoder_config)
image_encoder.to(device)

lm_path = "mistralai/Mistral-7B-v0.1"
lm = AutoModelForCausalLM.from_pretrained(lm_path, config=config, torch_dtype=torch.bfloat16)
lm.to(device)

model = MultimodalAgent(config, image_encoder, lm)
model.to(device)
print(torch.cuda.memory_allocated())

tokenizer = AutoTokenizer.from_pretrained(lm_path)
tokenizer.pad_token = tokenizer.eos_token # should be ok for casual LM

02:55:03 DEBUG:https://huggingface.co:443 "HEAD /mistralai/Mistral-7B-v0.1/resolve/main/config.json HTTP/1.1" 200 0
02:55:03 DEBUG:https://huggingface.co:443 "HEAD /google/vit-base-patch16-224/resolve/main/config.json HTTP/1.1" 200 0
Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

02:55:12 DEBUG:https://huggingface.co:443 "HEAD /mistralai/Mistral-7B-v0.1/resolve/main/generation_config.json HTTP/1.1" 200 0
02:55:14 DEBUG:https://huggingface.co:443 "HEAD /mistralai/Mistral-7B-v0.1/resolve/main/tokenizer_config.json HTTP/1.1" 200 0


15378507776


### Tokenize Train Data

In [7]:
from transformers import AutoImageProcessor
import math
max_patches = 2000
# max_patches = 200
patch_height, patch_width = 16, 16
processor = AutoImageProcessor.from_pretrained(image_encoder_path) # TODO: define this somewhere else

def preprocess_training_examples(examples, tokenizer):
    """
    Tokenize and map char index of the target to token index
    """
    inputs = tokenizer(examples["question"] + " [ACT]")
    inputs["labels"] = examples["boxes"]

    return inputs

def preprocess_training_examples_with_tokenizer(tokenizer):
    return lambda examples: preprocess_training_examples(examples, tokenizer)

def preprocess_image(example):
    """ 
    Aspect ratio preserving, fixed size patches 
    reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/pix2struct/image_processing_pix2struct.py
    """
    
    image_width, image_height = example["screenshot"][0].size
    # maximize scale s.t.
    scale = math.sqrt(max_patches * (patch_height / image_height) * (patch_width / image_width))
    num_feasible_rows = max(min(math.floor(scale * image_height / patch_height), max_patches), 1)
    num_feasible_cols = max(min(math.floor(scale * image_width / patch_width), max_patches), 1)
    resized_height = max(num_feasible_rows * patch_height, 1)
    resized_width = max(num_feasible_cols * patch_width, 1)
    
    processor.size = {"height":resized_height, "width":resized_width}
    inputs = processor(images=example["screenshot"], return_tensors="pt")
    # example["screenshot"] = inputs["flattened_patches"]
    example["screenshot"] = inputs["pixel_values"]
    all_scaled_boxes = []
    x_scale = image_width / resized_width
    y_scale = image_height / resized_height
    for boxes in example["labels"]:
        scaled_boxes = []
        for box in boxes:
            scaled_boxes.append([box[0]/x_scale, box[1]/y_scale, box[2]/x_scale, box[3]/y_scale])
        all_scaled_boxes.append(scaled_boxes)
    example["labels"] = all_scaled_boxes
    # example["attention_mask_image"] = inputs["attention_mask"]
    return example
    # return {"pixel_values": processor(images=example["screenshot"], return_tensors="pt").pixel_values} #[1, 3, 224, 224]

02:55:14 DEBUG:https://huggingface.co:443 "HEAD /google/vit-base-patch16-224/resolve/main/preprocessor_config.json HTTP/1.1" 200 0
02:55:14 DEBUG:https://huggingface.co:443 "HEAD /google/vit-base-patch16-224/resolve/main/config.json HTTP/1.1" 200 0


In [8]:
cols = train_dataset.column_names
cols.remove("screenshot")
train_dataset = train_dataset.map(
    preprocess_training_examples_with_tokenizer(tokenizer),
    remove_columns=cols,
    )
# processor = ViTImageProcessor(size={"height": 5429, "width": 1280})
# train_dataset["pixel_values"] = train_dataset.map(preprocess_image, remove_columns=train_dataset.column_names,
#     )

# train_dataset.set_format("pt", columns=["input_ids", "attention_mask", "label"], output_all_columns=True)
print(train_dataset[0])
train_dataset.set_transform(preprocess_image, output_all_columns=True) # process images on the fly
# split the train_dataset into train and validation
dataset = train_dataset.train_test_split(test_size=0.1) 
train_dataset, eval_dataset = dataset["train"], dataset["test"]
print(train_dataset[0])
logger.info(f"Use device {'gpu' if torch.cuda.is_available() else 'cpu'}")
# logger.info(f"Use batch size {cfg.train.batch_size}")
logger.info(f"Training data size {len(train_dataset)}")
logger.info(f"Eval data size {len(eval_dataset)}")

Map:   0%|          | 0/892 [00:00<?, ? examples/s]

02:55:14 DEBUG:open file: /scr/wychow/.cache/huggingface/datasets/osunlp___multimodal-mind2_web/default/0.0.0/f27b6362acc6efe0e97289620307ca42cb177e5b/tmp7n5tfm1v
02:55:15 DEBUG:open file: /scr/wychow/.cache/huggingface/datasets/osunlp___multimodal-mind2_web/default/0.0.0/f27b6362acc6efe0e97289620307ca42cb177e5b/tmp67uiu604
02:55:15 DEBUG:open file: /scr/wychow/.cache/huggingface/datasets/osunlp___multimodal-mind2_web/default/0.0.0/f27b6362acc6efe0e97289620307ca42cb177e5b/tmprks_l0i2
02:55:16 INFO:Use device gpu
02:55:16 INFO:Training data size 802
02:55:16 INFO:Eval data size 90


{'screenshot': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1280x5429 at 0x7F94AC6DA350>, 'input_ids': [1, 17158, 356, 272, 13987, 4686, 3005, 28725, 1464, 298, 4160, 272, 2296, 3638, 28747, 13, 4818, 28747, 7358, 264, 1253, 297, 21491, 387, 7993, 28725, 11800, 356, 477, 3999, 28705, 28774, 298, 3999, 28705, 28740, 28782, 28723, 13, 28284, 6768, 28747, 13, 5364, 13, 3195, 1023, 347, 272, 2442, 298, 14113, 395, 1679, 28804, 733, 7637, 28793], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'labels': [[283.1875, 220.390625, 93.59375, 33.0]]}
{'screenshot': tensor([[[-0.9922, -0.9922, -0.9922,  ..., -0.9922, -0.9922, -0.9922],
         [-0.9922, -0.9922, -0.9922,  ..., -0.9922, -0.9922, -0.9922],
         [-0.9922, -0.9922, -0.9922,  ..., -0.9922, -0.9922, -0.9922],
         ...,
         [ 0.7961,  0.7961,  0.7961,  ...,  0.7961,  0.7

### Set up LoRA

In [9]:
from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_int8_training

lora_config = LoraConfig(
    # task_type=TaskType.CAUSAL_LM, # task type is not necessary, but this is needed to get the label
    inference_mode=False,
    r=16,
    lora_alpha=32, 
    lora_dropout=0.05,
    target_modules = "all-linear"
)

# model.lm.enable_input_require_grads()
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

for name, param in model.named_parameters():
    if param.requires_grad:
        print(name)

trainable params: 45,277,184 || all params: 7,376,548,352 || trainable%: 0.6137990539670769
base_model.model.image_encoder.encoder.layer.0.attention.attention.query.lora_A.default.weight
base_model.model.image_encoder.encoder.layer.0.attention.attention.query.lora_B.default.weight
base_model.model.image_encoder.encoder.layer.0.attention.attention.key.lora_A.default.weight
base_model.model.image_encoder.encoder.layer.0.attention.attention.key.lora_B.default.weight
base_model.model.image_encoder.encoder.layer.0.attention.attention.value.lora_A.default.weight
base_model.model.image_encoder.encoder.layer.0.attention.attention.value.lora_B.default.weight
base_model.model.image_encoder.encoder.layer.0.attention.output.dense.lora_A.default.weight
base_model.model.image_encoder.encoder.layer.0.attention.output.dense.lora_B.default.weight
base_model.model.image_encoder.encoder.layer.0.intermediate.dense.lora_A.default.weight
base_model.model.image_encoder.encoder.layer.0.intermediate.dense.lora

### Set up Trainer

In [10]:
from transformers import TrainingArguments, Trainer

#         'lora_config': lora_config,
#         'learning_rate': cfg.train.learning_rate,
#         'num_train_epochs': cfg.train.epoch,
#         'gradient_accumulation_steps': cfg.train.gradient_accumulation_steps,
#         'per_device_train_batch_size': cfg.train.batch_size,
#         'per_device_eval_batch_size': cfg.eval.eval_batch_size,
#         'eval_accumulation_steps': cfg.eval.eval_accumulation_steps,
#         'gradient_checkpointing': True,
# }
config = {
        'lora_config': lora_config,
        'learning_rate': 3e-5,
        'num_train_epochs': 1,
        'gradient_accumulation_steps': 8,
        'per_device_train_batch_size': 1,
        'per_device_eval_batch_size': 1,
        'eval_accumulation_steps': 32,
        'gradient_checkpointing': True,
}

# class CustomTrainer(Trainer):
    
#     def compute_loss(self, model, inputs, return_outputs=False):

#         hidden_states = model(inputs["input_ids"], inputs["attention_mask"], output_hidden_states=True).hidden_states[-1]
#         # compute cosine simularity between last token and every token before
#         temperature = 0.1 # TODO: hard coded
#         sim = torch.nn.functional.cosine_similarity(hidden_states[:,:-3,:], hidden_states[:,-1:,:], dim=2) # Last 3 tokens are "[", "ACT", "]"
#         target_idx = inputs["labels"]

#         loss = torch.nn.functional.cross_entropy(sim / temperature, target_idx)

#         if return_outputs:
#             # instead of returning all hidden_states which would be too much memory,
#             # return the similarity scores as "logits"
#             # but different than sim because sin only calculates for 
#             # scores = torch.nn.functional.cosine_similarity(hidden_states[:,:-1,:], hidden_states[:,-1:,:], dim=2)
#             return loss, {"similarity": sim}
#         return loss

In [11]:
import math
# def boxes_to_patch_idx_multitarget(box, num_cols):
#     """ box is a tensor. Returns a list """
#     # pos_idxs = set()
#     l, b, w, h = box[0], box[1], box[2], box[3]
#     # unscaled 2d idx -> scaled 2d idx
#     x1, x2 = l//downscale_factor, (l+w)//downscale_factor
#     y1, y2 = b//downscale_factor, (b+h)//downscale_factor
#     # scaled 2d idx -> patch 2d idx
#     x1, x2 = math.floor(x1/16), math.ceil(x2/16)
#     y1, y2 = math.floor(y1/16), math.ceil(y2/16)
#     # 2d -> 1d
#     return [num_cols*r + c for c in range(x1, x2) for r in range(y1, y2)]

def boxes_to_patch_idx(box, num_cols):
    """ returns the patch closest to the center of the element """
    # pos_idxs = set()
    l, b, w, h = box[0], box[1], box[2], box[3]
    # scaled 2d coordinate -> patch 2d coordinate
    x1, x2 = l/patch_width, (l+w)/patch_width
    y1, y2 = b/patch_height, (b+h)/patch_height
    # patch 2d coordinate -> 1d idx
    c = math.floor((x1+x2)/2)
    r = math.floor((y1+y2)/2)
    # if x2 - x1 >= 2: # element at least contains 1 whole patch
    # else: # element within 2 patches
    return [num_cols*r + c]

def patch_idx_to_click(patch_idx, num_cols):
    """ (x, y), default to clicking the centre of the patch"""
    r, c = patch_idx // num_cols, patch_idx % num_cols
    return patch_width * (c+0.5), patch_height * (r+0.5)
    
class MultimodalTrainer(Trainer):
    
    def compute_loss(self, model, inputs, return_outputs=False):
        
        # hidden_states = model(flattened_patches=inputs["flattened_patches"], attention_mask_image=inputs["attention_mask_image"], input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
        hidden_states = model(pixel_values=inputs["pixel_values"], input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
        # compute cosine simularity between last token and every token before
        temperature = 0.1 # TODO: hard coded
        # sim = torch.nn.functional.cosine_similarity(hidden_states[:,:-3,:], hidden_states[:,-1:,:], dim=2) # Last 3 tokens are "[", "ACT", "]"
        num_cols = inputs["pixel_values"].shape[-1] // patch_width
        num_rows = inputs["pixel_values"].shape[-2] // patch_width
        num_patches = num_cols * num_rows
        sim = torch.nn.functional.cosine_similarity(hidden_states[:,1:num_patches+1,:], hidden_states[:,-1:,:], dim=2) # Last 3 tokens are "[", "ACT", "]"
        pos_idxs = set()
        
        for box in inputs["labels"][0]: # TODO: only for batch size 
            pos_idxs.update(boxes_to_patch_idx(box, num_cols))
        # +1 because first idx is CLS
        # target_idx = torch.tensor([idx + 1 for idx in pos_idxs]).to(device)
        
        target_idx = torch.tensor(list(pos_idxs)).to(device)
        
        # print("box", inputs["labels"][0])
        # print("click coordinate", patch_idx_to_click(target_idx, num_cols))
        loss = torch.nn.functional.cross_entropy(sim / temperature, target_idx) # TODO: use BCE for multitarget?
        # print(loss)
        # print("prediction", torch.argmax(sim).item(), "actual", target_idx.item())
        # print(torch.max(sim), sim[0,target_idx])
        if return_outputs:
            # instead of returning all hidden_states which would be too much memory,
            # return the similarity scores as "logits"
            # but different than sim because sin only calculates for 
            # scores = torch.nn.functional.cosine_similarity(hidden_states[:,:-1,:], hidden_states[:,-1:,:], dim=2)
            return loss, {"sim":sim, "target_idx":target_idx}
        return loss

### Set up Evaluation

In [12]:
import numpy as np
def custom_collate(data):
    # flattened_patches = torch.stack([d['screenshot'] for d in data])
    pixel_values = torch.stack([d['screenshot'] for d in data])
    # input_ids = torch.stack([d['input_ids'] for d in data])
    input_ids = torch.tensor([d['input_ids'] for d in data]) # set_transform resets set_format :(
    attention_mask = torch.tensor([d['attention_mask'] for d in data])
    # attention_mask_image = torch.stack([d['attention_mask_image'] for d in data])
    labels = torch.tensor([d['labels'] for d in data]) # todo: only uses first positive
    return { 
        'pixel_values': pixel_values,
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        # 'attention_mask_image': attention_mask_image,
        'labels': labels,
    }

def compute_metrics(pred):
    sims, target_idxs = pred.predictions[0], pred.predictions[1]
    accuracy = []
    # Need to use a for loop because sequence length is different for each input
    preds = sims.argmax(axis=1)
    print(preds)
    print(target_idxs)
    accuracy = preds == target_idxs # TODO: use information from bounding box to get more metrics
    # bounding box stored in pred.label_ids
    return {
        'accuracy': np.array(accuracy).mean(),
    }

### Run Training

In [13]:
# training_args = TrainingArguments(
#     output_dir=output_dir,
#     overwrite_output_dir=True,
#     optim="adamw_torch_fused",
#     bf16=True,  # Use BF16 for flash attention
#     # evlaution
#     evaluation_strategy="steps",
#     eval_steps=cfg.eval.eval_steps,
#     include_inputs_for_metrics=True,
#     # logging strategies
#     logging_dir=f"{output_dir}/logs",
#     logging_strategy="steps",
#     logging_steps=10,
#     save_strategy="no",
#     **{k:v for k,v in config.items() if k != 'lora_config'}
# ) # TODO: move train arguments to config

training_args = TrainingArguments(
    output_dir="output",
    overwrite_output_dir=True,
    optim="adamw_torch_fused",
    bf16=True,  # Use BF16 for flash attention
    # evlaution
    label_names=["labels"], # so that trainer will call compute_loss
    evaluation_strategy="steps",
    eval_steps=8,
    include_inputs_for_metrics=True,
    log_level="info",
    # logging strategies
    logging_dir=f"output/logs",
    logging_strategy="steps",
    logging_steps=8,
    save_strategy="no",
    remove_unused_columns=False,
    **{k:v for k,v in config.items() if k != 'lora_config'}
) # TODO: move train arguments to config
trainer = MultimodalTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics,
    tokenizer=tokenizer,
    data_collator=custom_collate,
)
trainer.train()

Using auto half precision backend
***** Running training *****
  Num examples = 802
  Num Epochs = 1
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 8
  Total optimization steps = 100
  Number of trainable parameters = 45,277,184


Step,Training Loss,Validation Loss,Accuracy
8,7.528,7.290866,0.022222
16,7.0946,6.914609,0.022222
24,6.5816,6.47703,0.0
32,6.524,6.376817,0.0
40,6.0073,6.511876,0.0
48,6.1718,6.313993,0.0
56,6.0196,6.255442,0.0
64,5.8871,6.292933,0.011111
72,5.8389,6.28259,0.011111
80,5.9123,6.199437,0.022222


***** Running Evaluation *****
  Num examples = 90
  Batch size = 1


[   0  227  366   90   69  383  572    6  201  319  137   83 1238  105
  523 1175   57   91  149  276   13   57   63 1605  129  340   76  108
  373   98  179   58  597   35   10   19    8  224  862  157  296   14
   48  294   73  327  183   46 1083   48 1352    8   19  554   45    9
  617   58   63  214   55  373   42   16  204   10   64  374 1192 1518
   31  171 1719   43    6  298   99  143  352   31   70  244   80   30
   95  758  334  236    3 1518]
[ 165  430  581   57 1771 1145 1264  298   87  430   27  702  535  155
   86  533   57  197   33  501   58  379  168 1299   90  126  108  703
  155  853  179  133  214   52  183  194  515  580   21  407   94  146
  111  216   27  648  226  297  112   40  149  315  358  639  396 1763
   36   20  167  275  122  247  822  203  736  198  118  125 1306  598
  104  152   82  129  403  437  193  151  149  417  217  158  177  180
 1141  107  581  351  621  598]


***** Running Evaluation *****
  Num examples = 90
  Batch size = 1


[ 116   20  334   97   69  383    5    6   88  176  497  169  516    4
   82   36   57  268  149  632   13   88    8  446  129   70  117   44
  155   66   99   61   71   47   10   19    8   44  293   40  296   14
   48  294   14  327  183   50 1083   25  113    8   19  459  122  275
 1921   60    8   99   55  234   46   16  202   10   64   27  525   53
  266   13 1584   60    8  352   44  155   85   29   99  244   31   44
   95   27  263  236    3   53]
[ 165  430  581   57 1771 1145 1264  298   87  430   27  702  535  155
   86  533   57  197   33  501   58  379  168 1299   90  126  108  703
  155  853  179  133  214   52  183  194  515  580   21  407   94  146
  111  216   27  648  226  297  112   40  149  315  358  639  396 1763
   36   20  167  275  122  247  822  203  736  198  118  125 1306  598
  104  152   82  129  403  437  193  151  149  417  217  158  177  180
 1141  107  581  351  621  598]


***** Running Evaluation *****
  Num examples = 90
  Batch size = 1


[ 116   20   29   96   69  383    5    6  138  176  497  168  516    4
   81    6   64    5  149  121   13   57    8  446  129   70  117   44
  151   66  166   25   70   47    9   19    8   84  293   40  296   14
   48  215   14  327  183   50 1082   25   95   91   19   93   61  275
    1   79    8  129   55  234   46    4  204    9   64  178  436   48
   19   13   32   60    8   14   44  155   63   29   99  244   31   37
   99   51   29   92    3   48]
[ 165  430  581   57 1771 1145 1264  298   87  430   27  702  535  155
   86  533   57  197   33  501   58  379  168 1299   90  126  108  703
  155  853  179  133  214   52  183  194  515  580   21  407   94  146
  111  216   27  648  226  297  112   40  149  315  358  639  396 1763
   36   20  167  275  122  247  822  203  736  198  118  125 1306  598
  104  152   82  129  403  437  193  151  149  417  217  158  177  180
 1141  107  581  351  621  598]


***** Running Evaluation *****
  Num examples = 90
  Batch size = 1


[ 116   20   29   96   56  383   76    6  138  166    2  168  516   22
   81    6   64   24   21  105   13   88    2  131   93   70  113   44
  151   81  166   25   72   47    9    9    8   84  293   38   69   14
   48   26   14  394  183   50 1082   25   31   91   11   94   61  275
    1   58    2  216   55  234   65    4  204    9   64  178  436   48
   19   13   32   60    8   14   44  146   82   29   99  244   31   37
  214  178   62   91   16   48]
[ 165  430  581   57 1771 1145 1264  298   87  430   27  702  535  155
   86  533   57  197   33  501   58  379  168 1299   90  126  108  703
  155  853  179  133  214   52  183  194  515  580   21  407   94  146
  111  216   27  648  226  297  112   40  149  315  358  639  396 1763
   36   20  167  275  122  247  822  203  736  198  118  125 1306  598
  104  152   82  129  403  437  193  151  149  417  217  158  177  180
 1141  107  581  351  621  598]


***** Running Evaluation *****
  Num examples = 90
  Batch size = 1


[ 116   20   29   94   56  158    5    6  138  176    5   83  516    4
   77    6   64   16   21  105   13   57    8 1842   93  125  117   44
  151   81   91   25   70   47    9   19    8   84  293   38   69   14
   48    9   14  394  183   50 1082   25   28    5   19   94   61  275
    1   58    8  130   55  234   46    4   51    9   10  177  401   48
   19   13   32   60    8   14   44  146  139   29  129  244   31   18
   98   51   29   91    3   48]
[ 165  430  581   57 1771 1145 1264  298   87  430   27  702  535  155
   86  533   57  197   33  501   58  379  168 1299   90  126  108  703
  155  853  179  133  214   52  183  194  515  580   21  407   94  146
  111  216   27  648  226  297  112   40  149  315  358  639  396 1763
   36   20  167  275  122  247  822  203  736  198  118  125 1306  598
  104  152   82  129  403  437  193  151  149  417  217  158  177  180
 1141  107  581  351  621  598]


***** Running Evaluation *****
  Num examples = 90
  Batch size = 1


[ 116   20  263   96   56  383    5    6  138  166    2  169  516    4
   77   87   64   16   21  105   13   88   87  386   93  125  113   44
  151  252  166   28   72   47    9   10    8   84  293   38   69   14
   48    9   14  394  183   74 1082   35   28  256   19   94   61  275
   25   58   87  130   55  234   46    4   51    9   64  292  436   48
   33   11   32   60    8   14   44  146  139   29  129  244   31   37
  214  172  263   91    3   48]
[ 165  430  581   57 1771 1145 1264  298   87  430   27  702  535  155
   86  533   57  197   33  501   58  379  168 1299   90  126  108  703
  155  853  179  133  214   52  183  194  515  580   21  407   94  146
  111  216   27  648  226  297  112   40  149  315  358  639  396 1763
   36   20  167  275  122  247  822  203  736  198  118  125 1306  598
  104  152   82  129  403  437  193  151  149  417  217  158  177  180
 1141  107  581  351  621  598]


***** Running Evaluation *****
  Num examples = 90
  Batch size = 1


[116  20 334  96  56 383   5   6 138 284   5 198 516   4  77  87  64  16
  21 105  13  57 128 446  93 125 117  44 151 252 178  28  72  47   9  12
   8  84 293  38  30  14  48  26   5 394 183  74 570  14  28 256  11  94
  61 275  25  58 124 130  55 234  42   4 204   9 170 292 436  48  33 159
  32  60 120  14  44 146  87  29 129 244  31  16 214 172 275 140   3  48]
[ 165  430  581   57 1771 1145 1264  298   87  430   27  702  535  155
   86  533   57  197   33  501   58  379  168 1299   90  126  108  703
  155  853  179  133  214   52  183  194  515  580   21  407   94  146
  111  216   27  648  226  297  112   40  149  315  358  639  396 1763
   36   20  167  275  122  247  822  203  736  198  118  125 1306  598
  104  152   82  129  403  437  193  151  149  417  217  158  177  180
 1141  107  581  351  621  598]


***** Running Evaluation *****
  Num examples = 90
  Batch size = 1


[116  20 275  96  56 383   5   6 136 311   5  80 516   4  77  87  64  16
  21 105  13   4   1 446  93 125 127   3  85 252 178  28  72  47   9  12
   8  84 317  38  30  14  48  26   5 394 183  77 570  14  28   5  19  94
  22 276  25  79 124 130  55 234  46   4 204   9  17 126 466  48 104  14
  32  60 120  14  44   7  87  29 130 244  31  17  98 126 275 236   3  48]
[ 165  430  581   57 1771 1145 1264  298   87  430   27  702  535  155
   86  533   57  197   33  501   58  379  168 1299   90  126  108  703
  155  853  179  133  214   52  183  194  515  580   21  407   94  146
  111  216   27  648  226  297  112   40  149  315  358  639  396 1763
   36   20  167  275  122  247  822  203  736  198  118  125 1306  598
  104  152   82  129  403  437  193  151  149  417  217  158  177  180
 1141  107  581  351  621  598]


***** Running Evaluation *****
  Num examples = 90
  Batch size = 1


[116  20 275  92  56 383  88   6 136 311   5  80 516   2 137 145  64  16
   5 105  13   4   1 446  93  96 127   3  85 252 202  28  72  47   9  12
   8  84 317  38  28  14  48  26   5 394 183  77 570  25  28   5  19  94
  22 276  25  60   1 130  55 214  42   4  44   9  17 126 466  48 104  14
  32  60   8  14  21   7  87  29 130 252  31  17  99 126 275 236   3  48]
[ 165  430  581   57 1771 1145 1264  298   87  430   27  702  535  155
   86  533   57  197   33  501   58  379  168 1299   90  126  108  703
  155  853  179  133  214   52  183  194  515  580   21  407   94  146
  111  216   27  648  226  297  112   40  149  315  358  639  396 1763
   36   20  167  275  122  247  822  203  736  198  118  125 1306  598
  104  152   82  129  403  437  193  151  149  417  217  158  177  180
 1141  107  581  351  621  598]


***** Running Evaluation *****
  Num examples = 90
  Batch size = 1


[116  20 275  92  56 383  88   6 136 311   5  80 516 155 137 145  64  16
   5 105  13   4   1 386  93  89 110   3 183 252 202  25  70  47   9  12
   8 433 317  38  28  14  48  26   5 394 183  77 570  25  28 372  11  94
  22 276  25  60   1 130  55 657  42   4 170   9  90 126 466  48 104 159
  32  60 120  14  21   7  87  29 130 252  31  17  99 126 275 236   3  48]
[ 165  430  581   57 1771 1145 1264  298   87  430   27  702  535  155
   86  533   57  197   33  501   58  379  168 1299   90  126  108  703
  155  853  179  133  214   52  183  194  515  580   21  407   94  146
  111  216   27  648  226  297  112   40  149  315  358  639  396 1763
   36   20  167  275  122  247  822  203  736  198  118  125 1306  598
  104  152   82  129  403  437  193  151  149  417  217  158  177  180
 1141  107  581  351  621  598]


***** Running Evaluation *****
  Num examples = 90
  Batch size = 1


[116  20 275  92  56 383  88   6 136 339   5  80 516 155 137 145  64  16
   5 105  13   4   1 446  74  89 117   3 151 252 202  25  70  78   9  12
  18 433 317  38  28  14  48  26   5 394 183  77 570  39  28 372  11 460
  22 276   1  60   1 130  55 657  42   4 170   9  90 126 466  48 104 159
  32  56 120  14  21   7  87  29 130 252  31  16  99 126 275 236   3  48]
[ 165  430  581   57 1771 1145 1264  298   87  430   27  702  535  155
   86  533   57  197   33  501   58  379  168 1299   90  126  108  703
  155  853  179  133  214   52  183  194  515  580   21  407   94  146
  111  216   27  648  226  297  112   40  149  315  358  639  396 1763
   36   20  167  275  122  247  822  203  736  198  118  125 1306  598
  104  152   82  129  403  437  193  151  149  417  217  158  177  180
 1141  107  581  351  621  598]


***** Running Evaluation *****
  Num examples = 90
  Batch size = 1


[116  20 275  92  56 383  88   6 136 311   5  80 516 155 137 145  64  16
   5 105  13   4   1 446  74  89 110   3 183 252 202  25  70  78   9  12
  18 433 317  38  28  14  48  26   5 394 183  77 570  39  28 372  11 460
  22 276   1  60   1 130  55 657  42   4 170   9  90 126 466  48 104 159
  32  56 120  14  21   7  87  29 130 252  31  16  99 126 275 236   3  48]
[ 165  430  581   57 1771 1145 1264  298   87  430   27  702  535  155
   86  533   57  197   33  501   58  379  168 1299   90  126  108  703
  155  853  179  133  214   52  183  194  515  580   21  407   94  146
  111  216   27  648  226  297  112   40  149  315  358  639  396 1763
   36   20  167  275  122  247  822  203  736  198  118  125 1306  598
  104  152   82  129  403  437  193  151  149  417  217  158  177  180
 1141  107  581  351  621  598]




Training completed. Do not forget to share your model on huggingface.co/models =)




TrainOutput(global_step=100, training_loss=6.274795722961426, metrics={'train_runtime': 1838.2028, 'train_samples_per_second': 0.436, 'train_steps_per_second': 0.054, 'total_flos': 3369929015126016.0, 'train_loss': 6.274795722961426, 'epoch': 1.0})

# Sanity check

Pix2Struct, reference: https://github.com/huggingface/notebooks/blob/main/examples/image_captioning_pix2struct.ipynb

### Load model and processor

In [14]:
from transformers import AutoProcessor, Pix2StructVisionModel

# processor = AutoProcessor.from_pretrained("google/pix2struct-base")
# model = Pix2StructVisionModel.from_pretrained("google/pix2struct-base")
train_dataset[997]

IndexError: Invalid key: 997 is out of bounds for size 802

In [None]:
import requests
from PIL import Image
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor, Pix2StructVisionModel


train_dataset = load_dataset("osunlp/Multimodal-Mind2Web", split="train").select(range(10))
image = train_dataset[3]["screenshot"]

text = "A picture of"

model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
# processor = Pix2StructImageProcessor.from_pretrained("google/pix2struct-textcaps-base")

# image only
inputs = processor(images=train_dataset["screenshot"], text=text, return_tensors="pt")
print(inputs.keys())
predictions = model.generate(**inputs)
print(processor.decode(predictions[0], skip_special_tokens=True))


### Using Pix2Struct

In [None]:
from transformers import Pix2StructImageProcessor, Pix2StructVisionModel, Pix2StructConfig, Pix2StructForConditionalGeneration
from datasets import load_dataset
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_dataset = load_dataset("osunlp/Multimodal-Mind2Web", split="train").select(range(10))
image = train_dataset[3]["screenshot"]

# TODO: Move config to somewhere else
image_encoder_path = "google/pix2struct-textcaps-base"
image_encoder_config = Pix2StructConfig.from_pretrained(image_encoder_path)
# TODO: try different hidden size?
max_patches = 2000
patch_size = 16
# image_encoder_config.vision_config.seq_len = max_patches
# image_encoder_config.vision_config.patch_size = patch_size
print(image_encoder_config)

image_encoder = Pix2StructForConditionalGeneration.from_pretrained(image_encoder_path, config=image_encoder_config).encoder
print(image_encoder)
image_encoder.to(device)

processor = Pix2StructImageProcessor.from_pretrained(image_encoder_path) # TODO: define this somewhere else
processor.max_patches = max_patches
processor.patch_size = {"height":patch_size, "width":patch_size}
inputs = processor(images=image, return_tensors="pt").to(device)
print(image_encoder(**inputs))
print(torch.cuda.memory_summary())
# 2000 -> 7G
# 3000 -> 14G
# 4000 -> 25G
# 5000 -> 37G

### Using VIT

In [None]:
from transformers import AutoConfig, AutoImageProcessor, AutoModel
from datasets import load_dataset
import torch
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_dataset = load_dataset("osunlp/Multimodal-Mind2Web", split="train").select(range(10))
image = train_dataset[3]["screenshot"]

# TODO: Move config to somewhere else
image_encoder_path = "google/vit-base-patch16-224"
image_encoder_config = AutoConfig.from_pretrained(image_encoder_path)
# TODO: try different hidden size?
# print(image_encoder_config)

image_encoder = AutoModel.from_pretrained(image_encoder_path, config=image_encoder_config)
# print(image_encoder)
image_encoder.to(device)

downscale_factor = 4
processor = AutoImageProcessor.from_pretrained(image_encoder_path) # TODO: define this somewhere else
processor.size = {"height":5429//downscale_factor, "width":1280//downscale_factor}
inputs = processor(images=image, return_tensors="pt").to(device)
print(inputs.pixel_values.shape)
plt.figure(figsize=(12, 40))
plt.imshow(inputs.pixel_values.cpu()[0].permute((1,2,0)))
plt.show()
h = image_encoder(inputs["pixel_values"], interpolate_pos_encoding=True).last_hidden_state
# print(torch.cuda.memory_summary())
h.shape

### Match downscaled image patch index to target

In [None]:
train_dataset[3]["pos_candidates"]

### Match target index to patch index


bounding_box_rect is in the format of (left, bottom, width, height), so pixel_values[:,bottom:bottom+height,left:left+width] should be marked as positive

unscaled index 2d -> scaled index 2d -> patch index 2d -> patch index 1d

Shortest width / height:

In [None]:
train_dataset = load_dataset("osunlp/Multimodal-Mind2Web", split="train")
cands = train_dataset["pos_candidates"]
shortest = 100
widths = []
heights = []
import json
for cand_list in cands:
    for cand in cand_list:
        json_data = json.loads(cand)
        attributes = json.loads(json_data['attributes'])
        bounding_box_rect_str = attributes['bounding_box_rect']
        lbwh = tuple(map(float, bounding_box_rect_str.split(',')))
        widths.append(lbwh[2])
        heights.append(lbwh[3])
        # if lbwh[2] <= 0 or lbwh[3] <= 0:
        #     print(cand_list)
        #     print(shortest)

        # shortest = min(shortest, lbwh[2], lbwh[3])
        
import matplotlib.pyplot as plt
# plt.hist(widths, bins=100)
# plt.show()
heights = np.array(heights)
plt.hist(heights[heights < 200], bins=100)
plt.axvline(x=32, color='r', linestyle='--')
plt.title("Pos candidates height")


In [None]:
from transformers import ViTImageProcessor
import torch
sample = train_dataset[3]
print(sample["pos_candidates"])

image = sample["screenshot"]
print(image.size)
processor = ViTImageProcessor(size={"height": 5429, "width": 1280})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
inputs = processor(images=image, return_tensors="pt").to(device)
pixel_values = inputs.pixel_values # [1, 3, 224, 224]
print(pixel_values.shape)

import matplotlib.pyplot as plt
import numpy as np
plt.figure()
plt.imshow(np.transpose(pixel_values[0,:,410:410+46,96:96+106].cpu(), (1,2,0)))


processor2 = ViTImageProcessor(size={"height": 5429//2, "width": 1280//2})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
inputs2 = processor2(images=image, return_tensors="pt").to(device)
pixel_values2 = inputs2.pixel_values # [1, 3, 224, 224]
print(pixel_values2.shape)

import matplotlib.pyplot as plt
import numpy as np
plt.figure()
plt.imshow(np.transpose(pixel_values2[0,:,410//2:(410+46)//2,96//2:(96+106)//2].cpu(), (1,2,0)))

# plt.figure(figsize=(15, 15))
# plt.imshow(np.transpose(pixel_values[0].cpu(), (1,2,0)))
# for i in range(0, 1000, 100):
#     plt.figure()
#     plt.imshow(np.transpose(pixel_values[0,:,i:i+160,i:i+160].cpu(), (1,2,0)))

In [None]:
import math
def boxes_to_patch_idx_multitarget(box, num_cols):
    """ box is a tensor. Returns a list """
    # pos_idxs = set()
    l, b, w, h = box[0], box[1], box[2], box[3]
    # unscaled 2d idx -> scaled 2d idx
    x1, x2 = l//2, (l+w)//2
    y1, y2 = b//2, (b+h)//2
    # scaled 2d idx -> patch 2d idx
    x1, x2 = math.floor(x1/16), math.ceil(x2/16)
    y1, y2 = math.floor(y1/16), math.ceil(y2/16)
    # 2d -> 1d
    return [num_cols*r + c for c in range(x1, x2) for r in range(y1, y2)]

def boxes_to_patch_idx(box, num_cols):
    """ returns the patch closest to the center of the element """
    # pos_idxs = set()
    l, b, w, h = box[0], box[1], box[2], box[3]
    # unscaled 2d idx -> scaled 2d idx
    x1, x2 = l//2, (l+w)//2
    y1, y2 = b//2, (b+h)//2
    # scaled 2d idx -> patch 2d idx
    x1, x2 = x1/16, x2/16
    y1, y2 = y1/16, y2/16
    # 2d -> 1d
    c = math.floor((x1+x2)/2)
    r = math.floor((y1+y2)/2)
    # if x2 - x1 >= 2: # element at least contains 1 whole patch
    # else: # element within 2 patches
    return num_cols*r + c

# for i in range(16):
#     print([i*16+j for j in range(16)])
print(boxes_to_patch_idx_multitarget([96,410.390625,106,46], 640//16))
boxes_to_patch_idx([96,410.390625,106,46], 640//16)

In [None]:
from transformers import AutoConfig, AutoImageProcessor, AutoModel
from datasets import load_dataset
import torch
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_dataset = load_dataset("osunlp/Multimodal-Mind2Web", split="train").select(range(20))
train_dataset["screenshot"]