In [None]:
# !nvidia-smi

In [318]:
import os
import numpy as np
import torch
import os
import re
import json
import argparse
import pandas as pd
import random
from transformers import T5Tokenizer, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, T5ForConditionalGeneration
from rich.table import Column, Table
from rich import box
from rich.console import Console
console = Console(record=True)
from torch import cuda
import nltk
import evaluate
import pdfkit
from pdf2image import convert_from_path
import glob, sys, fitz
from torch.utils.data import Dataset, DataLoader
from collections import namedtuple

from PIL import Image, ImageDraw, ImageOps, ImageFont
import requests
from transformers import AutoProcessor, Pix2StructForConditionalGeneration, Pix2StructConfig, AutoConfig

import time
import textwrap
import math
import string
import pickle
import matplotlib.pyplot as plt
from rouge import Rouge
from typing import Any, Callable, Iterable, List, Optional

# AI2D

In [3]:
def read_json_file(filepath):
    with open(filepath) as json_file:
        data = json.load(json_file)
    return data


In [4]:
# problem_list = read_json_file(os.path.join(os.getcwd(), "data", "scienceqa", "problems.json"))

In [5]:
def load_pickle_dataset(save_dir, source=""):
    pickle_filename = os.path.join(save_dir, f"{source}.pkl")
    with open(pickle_filename, 'rb') as f:
        pickle_data = pickle.load(f)
    return pickle_data

In [6]:
def render_text_on_bounding_box(
    text: str,
    bounding_box: Iterable[Iterable[int]],
    image: Image.Image,
    font_path: str):
    
    """Render text on top of a specific bounding box."""
    draw = ImageDraw.Draw(image)
    (x0, y0), (x1, y1) = bounding_box
    
    draw.rectangle(xy=[(x0, y0), (x1, y1)], fill=(255, 255, 255, 255))
    
    fontsize = 1
    def _can_increment_font(ratio=0.95):
        next_font = ImageFont.truetype(
            font_path, encoding="UTF-8", size=fontsize + 1)
        width, height = next_font.getsize(text)
        return width < ratio * (x1 - x0) and height < ratio * (y1 - y0)

    while _can_increment_font():
        fontsize += 1
    font = ImageFont.truetype(font_path, encoding="UTF-8", size=fontsize)

    draw.text(
        xy=((x0 + x1)/2, (y0 + y1)/2),
        text=text,
        font=font,
        fill="black",
        anchor="mm"
    )
    

In [7]:
def render_text(text: str,
                text_size: int = 36,
                text_color: str = "black",
                background_color: str = "white",
                left_padding: int = 5,
                right_padding: int = 5,
                top_padding: int = 5,
                bottom_padding: int = 5,
                font_path: str = "") -> Image.Image:

    """Render text."""
    # Add new lines so that each line is no more than 80 characters.
    wrapper = textwrap.TextWrapper(width=80)
    lines = wrapper.wrap(text=text)
    wrapped_text = "\n".join(lines)

    font = ImageFont.truetype(font_path, encoding="UTF-8", size=text_size)

    # Use a temporary canvas to determine the width and height in pixels when
    # rendering the text.
    temp_draw = ImageDraw.Draw(Image.new("RGB", (1, 1), background_color))
    _, _, text_width, text_height = temp_draw.textbbox((0, 0), wrapped_text, font)

    # Create the actual image with a bit of padding around the text.
    image_width = text_width + left_padding + right_padding
    image_height = text_height + top_padding + bottom_padding
    image = Image.new("RGB", (image_width, image_height), background_color)
    draw = ImageDraw.Draw(image)
    draw.text(
      xy=(left_padding, top_padding),
      text=wrapped_text,
      fill=text_color,
      font=font)
    
    return image

In [24]:
def render_header(image: Image.Image, header: str, font_path: str) -> Image.Image:
    """Renders a header on a PIL image and returns a new PIL image."""
    header_image = render_text(header, font_path=font_path)
    new_width = max(header_image.width, image.width)

    new_height = int(image.height *  (new_width / image.width))
    new_header_height = int(
        header_image.height * (new_width / header_image.width))

    new_image = Image.new(
        "RGB",
        (new_width, new_height + new_header_height),
        "white")
    new_image.paste(header_image.resize((new_width, new_header_height)), (0, 0))
    new_image.paste(image.resize((new_width, new_height)), (0, new_header_height))

    return new_image

In [25]:
def save_dataset(dataset, save_dir="", filename=""):
    pickle_filename = os.path.join(save_dir, filename)
    with open(pickle_filename, 'wb') as f:
        pickle.dump(dataset, f, protocol=pickle.HIGHEST_PROTOCOL)

In [75]:
def convert_one_question_AI2D(input_path: str, data_dir: str, font_path: str, skip_image_gen: bool):
  
    """Convert example."""
    samples_dict = {}
    
    data = read_json_file(os.path.join(data_dir, input_path)) # till ai2d folder + question path
    
    qid = -1
    if not data["questions"]:
        return samples_dict, int(data["imageName"].split('.')[0])
    
    annotation = read_json_file(os.path.join(data_dir, "annotations", f"{data['imageName']}.json"))
    
    if skip_image_gen == False:
        image_filepath = os.path.join(data_dir, "images", data["imageName"])
        image = Image.open(image_filepath)
        image_with_placeholders = image.copy()

        for v in annotation["text"].values():
            render_text_on_bounding_box(
                text=v["replacementText"],
                bounding_box=v["rectangle"],
                image=image_with_placeholders,
                font_path = font_path)

    for k, v in data["questions"].items():
        
        samples_dict[v["questionId"]] = {}
        # The `image_id` field is only used to ensure correct splitting of the data.
        options = " ".join(
            f"({string.ascii_lowercase[i]}) {a}"
            for i, a in enumerate(v["answerTexts"])
        )
        
        if skip_image_gen == False:
            image_with_header = render_header(
                image=image_with_placeholders if v["abcLabel"] else image,
                header=f"{k} {options}",
                font_path = font_path
            )

            # save new image
            save_dir = os.path.join(data_dir, "new_data","images")
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)

            image_with_header.save(os.path.join(save_dir, f"{v['questionId']}.png"))

        # get output for this sample
        parse = v["answerTexts"][v["correctAnswer"]]
        
        
        # update sample dict with info and meta
        samples_dict[v["questionId"]]["src_image_name"] = data["imageName"]
        samples_dict[v["questionId"]]["raw_output"] = parse
        samples_dict[v["questionId"]]["header_text"] = f"{k} {options}"
        samples_dict[v["questionId"]]["abcLabel"] = v["abcLabel"]
    
    
    return samples_dict, qid

In [76]:
def convert_AI2D(data_dir: str, font_path: str, skip_image_gen: bool):
    
    # list question folder files
    # create sample dict
    samples_dict = {}
    missing_question_list = []
    
    for file in os.listdir(os.path.join(data_dir,"questions")):
        filepath = os.path.join("questions", file)
        one_question_sample_dict, qid = convert_one_question_AI2D(filepath, data_dir, font_path, skip_image_gen)
        samples_dict.update(one_question_sample_dict)
        
        if qid!=-1:
            missing_question_list.append(qid)
        
    # save sample dict
    save_dataset(
        samples_dict,
        save_dir = os.path.join(data_dir,"new_data"),
        filename = "samples_dict.pkl"
    )
    
    save_dataset(
        missing_question_list,
        save_dir = os.path.join(data_dir,"new_data"),
        filename = "missing_question_list.pkl"
    )
    return samples_dict, missing_question_list


In [77]:
data_dir = os.getenv("AI2D_DATA_DIR")
font_path = os.path.join(os.getenv("SCIENCEQA_ASSETS_DIR"), "arial.ttf")
print(data_dir, font_path)

/local1/rwadhawan7/ai2d/data /home/rwadhawan7/eeevqa/eeevqa/assets/arial.ttf


In [78]:
skip_image_gen = True

In [79]:
sample_dict, missing_question_list = convert_AI2D(data_dir, font_path, skip_image_gen)

In [80]:
samples_dict = load_pickle_dataset(os.path.join(data_dir, "new_data"),"samples_dict")

In [81]:
# print(samples_dict.keys())

In [82]:
missing_question_list = load_pickle_dataset(os.path.join(data_dir, "new_data"),"missing_question_list")

In [83]:
print(missing_question_list)
print(len(missing_question_list))

[3742, 2881, 1177, 3070, 3071, 1176, 4477, 2989, 3927, 2459, 4537, 3799, 472, 3798, 1705, 1226, 1040, 1633, 4508, 3859, 4509, 1100, 3006, 1367, 2865, 3095, 1193, 3950, 3965, 4435, 3700, 1134, 570, 2982, 1075, 1074, 165, 3824, 2580, 3748, 3749, 1280, 4020, 3104, 3450, 1731, 2698, 4338, 4339, 3559, 4498, 4230, 4117, 3418, 3419, 4278, 1416, 3511, 3045, 1142, 2642, 4281, 3341, 3472, 4520, 2944, 3309, 1069, 3930, 4862, 3755, 4863, 3067, 525, 3082, 2872, 2873, 950, 3946, 4815, 4484, 3545, 3, 3277, 3011, 2, 553, 2471, 3151, 1230, 3150, 1231, 3336, 3895, 147, 3806, 4264, 4265, 4704, 3391, 3390, 1204, 2994, 1751, 3570, 566, 3025, 2846, 3479, 1400, 3506, 4454, 3594, 3845, 3728, 3620, 3447, 3621, 2879, 3789, 1027, 350, 4865, 828, 4864, 1167, 2666, 2890, 3295, 4600, 4482, 3270, 1182, 3085, 1183, 2874, 3378, 1019, 3801, 1159, 2934, 2750, 3156, 3330, 3892, 3402, 4790, 4564, 4702, 4031, 1290, 4030, 3758, 4699, 4698, 3618, 2840, 1685, 3577, 1471, 1125, 3898, 4361, 4360, 679, 3592, 4851, 2307, 3115, 31

In [281]:
def create_metadata(data_dir):
    
    metadata_dict = {}
    tlist = []
    for file in os.listdir(os.path.join(data_dir, "images")):
        fidx = int(file.split(".")[0])
        tlist.append(fidx)
    
    skipped_image_num_list = []
    for i in range(4907+1):
        if i not in tlist:
            skipped_image_num_list.append(i)
    print(skipped_image_num_list)
    
    metadata_dict["skipped_image_num_list"] = skipped_image_num_list
    
    all_samples_list = []
    for file in os.listdir(os.path.join(data_dir, "images")):
        fidx = int(file.split(".")[0])
        all_samples_list.append(fidx)
        
    print(len(all_samples_list))
    metadata_dict["all_samples_list"] = all_samples_list
    
    all_test_list = list(pd.read_csv(os.path.join(data_dir,"ai2d_test_ids.csv")).to_numpy().reshape(-1))
    print(len(all_test_list))
    metadata_dict["all_test_list"] = all_test_list
    
    missing_from_all_test_list = []
    all_test_present_list = []
    for sidx in all_test_split:
        if sidx not in all_samples_list:
            missing_from_all_test_list.append(sidx)
        else:
            all_test_present_list.append(sidx)
    
    print(missing_from_all_test_list)
    print(len(missing_from_all_test_list))
    print(len(all_test_present_list))
    
    metadata_dict["missing_from_all_test_list"] = missing_from_all_test_list
    metadata_dict["all_test_present_list"] = all_test_present_list
    
    all_samples_question_list = []
    for sidx in all_samples_list:
        if sidx not in missing_question_list:
            all_samples_question_list.append(sidx)
            
    print(len(all_samples_question_list))
    print(len(missing_question_list))
    assert((len(all_samples_question_list) + len(missing_question_list)) == len(all_samples_list))
    metadata_dict["all_samples_question_list"] = all_samples_question_list
    
    
    test_split_src_image = []
    cnt = 0
    for sidx in all_samples_question_list:
        if sidx in all_test_present_list:
            test_split_src_image.append(sidx)
        else:
            cnt+=1

    print(len(test_split_src_image))
    assert((cnt+len(test_split_src_image)) == len(all_samples_question_list))

    missing_test_question_list = []
    cnt2 = 0
    for sidx in all_test_present_list:
        if sidx not in test_split_src_image:
            missing_test_question_list.append(sidx)
        else:
            cnt2+=1

    print(len(missing_test_question_list))
    assert((cnt2+len(missing_test_question_list)) == len(all_test_present_list))
    
    metadata_dict["test_split_src_image"] = test_split_src_image
    metadata_dict["missing_test_question_list"] = missing_test_question_list
    
    train_split_src_image = []

    cnt3 = 0
    for sidx in all_samples_question_list:
        if sidx not in test_split_src_image:
            train_split_src_image.append(sidx)
        else:
            cnt3+=1

    print(len(train_split_src_image))
    assert((cnt3+len(train_split_src_image))==len(all_samples_question_list))

    missing_train_question_list = []
    cnt4 = 0
    for sidx in all_samples_list:
        if sidx not in all_test_present_list and sidx not in train_split_src_image:
            missing_train_question_list.append(sidx)
        else:
            cnt4+=1
    print(len(missing_train_question_list))
    assert((cnt4+len(missing_train_question_list))==(len(all_samples_list)))
    
        
    metadata_dict["train_split_src_image"] = train_split_src_image
    metadata_dict["missing_train_question_list"] = missing_train_question_list
    
    print(len(train_split_src_image))
    print(len(test_split_src_image))
    print(f"total src images used for vqa ai2d: {len(all_samples_question_list)} = {len(train_split_src_image)+len(test_split_src_image)}")
    
    print(len(missing_train_question_list))
    print(len(missing_test_question_list))
    print(f"total src images skipped for vqa ai2d: {len(missing_question_list)} = {len(missing_train_question_list)+len(missing_test_question_list)}")

    test_split_new_images = []
    train_split_new_images = []
    image_dir = os.path.join(data_dir,"new_data","images")
    cnt5=0
    for file in os.listdir(image_dir):
        if int(file.split('.')[0]) in test_split_src_image:
            test_split_new_images.append(os.path.join(image_dir, file))

        else:
            train_split_new_images.append(os.path.join(image_dir, file))

        cnt5+=1
    
    print(len(train_split_new_images))
    print(len(test_split_new_images))
    print(cnt5)
    assert(cnt5==(len(train_split_new_images) + len(test_split_new_images)))
    
    metadata_dict["train_split_new_images"] = train_split_new_images    
    metadata_dict["test_split_new_images"] = test_split_new_images
    
    return metadata_dict


In [282]:
metadata_dict = create_metadata(data_dir)

[4189, 4191, 4325, 4420, 4703]
4903
981
[4325]
1
980
4400
503
889
91
3511
412
3511
889
total src images used for vqa ai2d: 4400 = 4400
412
91
total src images skipped for vqa ai2d: 503 = 503
12425
3076
15501


In [202]:
tlist = []
for file in os.listdir(os.path.join(data_dir, "images")):
    fidx = int(file.split(".")[0])
    tlist.append(fidx)
    
skipped_image_num_list = []
for i in range(4907+1):
    if i not in tlist:
        skipped_image_num_list.append(i)
print(skipped_image_num_list)

[4189, 4191, 4325, 4420, 4703]


In [153]:
all_samples_list = []
for file in os.listdir(os.path.join(data_dir, "images")):
    fidx = int(file.split(".")[0])
    all_samples_list.append(fidx)


In [154]:
# print(all_samples_list)
print(len(all_samples_list))

4903


In [155]:
all_test_list = list(pd.read_csv(os.path.join(data_dir,"ai2d_test_ids.csv")).to_numpy().reshape(-1))

In [171]:
print(type(all_test_list[0]))

<class 'numpy.int64'>


In [156]:
print(len(all_test_list))

981


In [157]:
missing_all_test_list = []
all_test_present_list = []
for sidx in all_test_split:
    if sidx not in all_samples_list:
        missing_all_test_list.append(sidx)
    else:
        all_test_present_list.append(sidx)

In [172]:
print(len(missing_all_test_list))
print(missing_all_test_list)
print(len(all_test_present_list))

1
[4325]
980


In [159]:
all_samples_question_list = []
for sidx in all_samples_list:
    if sidx not in missing_question_list:
        all_samples_question_list.append(sidx)

In [174]:
print(len(all_samples_question_list))
print(len(missing_question_list))
print(len(all_samples_question_list) + len(missing_question_list))

4400
503
4903


In [183]:
# all_test_present_list - test_split_src_image

In [189]:
test_split_src_image = []

cnt = 0
for sidx in all_samples_question_list:
    if sidx in all_test_present_list:
        test_split_src_image.append(sidx)
    else:
        cnt+=1

print(len(test_split_src_image))
print((cnt+len(test_split_src_image)) == len(all_samples_question_list))
        
missing_test_question_list = []
cnt2 = 0
for sidx in all_test_present_list:
    if sidx not in test_split_src_image:
        missing_test_question_list.append(sidx)
    else:
        cnt2+=1

print(len(missing_test_question_list))
print((cnt2+len(missing_test_question_list)) == len(all_test_present_list))
    

889
True
91
True


In [199]:
train_split_src_image = []

cnt3 = 0
for sidx in all_samples_question_list:
    if sidx not in test_split_src_image:
        train_split_src_image.append(sidx)
    else:
        cnt3+=1

print(len(train_split_src_image))
print((cnt3+len(train_split_src_image))==len(all_samples_question_list))

missing_train_question_list = []
cnt4 = 0
for sidx in all_samples_list:
    if sidx not in all_test_present_list and sidx not in train_split_src_image:
        missing_train_question_list.append(sidx)
    else:
        cnt4+=1
print(len(missing_train_question_list))
print((cnt4+len(missing_train_question_list))==(len(all_samples_list)))

3511
True
412
True


In [201]:
print(len(train_split_src_image))
print(len(test_split_src_image))
print(f"total src images used for vqa ai2d: {len(all_samples_question_list)} = {len(train_split_src_image)+len(test_split_src_image)}")

3511
889
total src images used for vqa ai2d: 4400 = 4400


In [200]:
print(len(missing_train_question_list))
print(len(missing_test_question_list))
print(f"total src images skipped for vqa ai2d: {len(missing_question_list)} = {len(missing_train_question_list)+len(missing_test_question_list)}")



412
91
total src images skipped for vqa ai2d: 503 = 503


In [307]:
test_split_new_images = []
train_split_new_images = []
image_dir = os.path.join(data_dir,"new_data","images")
cnt5=0
for file in os.listdir(image_dir):
    if int(file.split('.')[0]) in test_split_src_image:
        test_split_new_images.append(os.path.join(image_dir, file))
    
    elif int(file.split('.')[0]) in train_split_src_image:
        train_split_new_images.append(os.path.join(image_dir, file))
        
    cnt5+=1

In [308]:
print(train_split_new_images[:5])
print(len(train_split_new_images))
print(len(test_split_new_images))
print(cnt5)
print(cnt5==(len(train_split_new_images) + len(test_split_new_images)))

['/local1/rwadhawan7/ai2d/data/new_data/images/839.png-0.png', '/local1/rwadhawan7/ai2d/data/new_data/images/839.png-1.png', '/local1/rwadhawan7/ai2d/data/new_data/images/839.png-2.png', '/local1/rwadhawan7/ai2d/data/new_data/images/3524.png-1.png', '/local1/rwadhawan7/ai2d/data/new_data/images/3524.png-0.png']
12425
3076
15501
True


In [360]:
config = AutoConfig.from_pretrained("google/pix2struct-ai2d-base")

In [361]:
print(config)

Pix2StructConfig {
  "_commit_hash": "a1883ef7c6b2e731e8814bc12e6b49e65da4b60b",
  "_name_or_path": "google/pix2struct-ai2d-base",
  "architectures": [
    "Pix2StructForConditionalGeneration"
  ],
  "decoder_start_token_id": 0,
  "eos_token_id": 1,
  "initializer_factor": 1.0,
  "initializer_range": 0.02,
  "is_encoder_decoder": true,
  "is_vqa": false,
  "model_type": "pix2struct",
  "pad_token_id": 0,
  "text_config": {
    "_name_or_path": "",
    "add_cross_attention": false,
    "architectures": null,
    "bad_words_ids": null,
    "begin_suppress_tokens": null,
    "bos_token_id": null,
    "chunk_size_feed_forward": 0,
    "cross_attention_hidden_size": null,
    "d_ff": 2048,
    "d_kv": 64,
    "decoder_start_token_id": 0,
    "dense_act_fn": "gelu_new",
    "diversity_penalty": 0.0,
    "do_sample": false,
    "dropout_rate": 0.1,
    "early_stopping": false,
    "encoder_hidden_size": 768,
    "encoder_no_repeat_ngram_size": 0,
    "eos_token_id": 1,
    "exponential_decay_

In [362]:
# config_dict = config.to_dict()

In [363]:
# print(config_dict["use_bfloat16"])

False


In [365]:
# config.update({"use_bfloat16":True})

In [366]:
# print(config)

Pix2StructConfig {
  "_commit_hash": "a1883ef7c6b2e731e8814bc12e6b49e65da4b60b",
  "_name_or_path": "google/pix2struct-ai2d-base",
  "architectures": [
    "Pix2StructForConditionalGeneration"
  ],
  "decoder_start_token_id": 0,
  "eos_token_id": 1,
  "initializer_factor": 1.0,
  "initializer_range": 0.02,
  "is_encoder_decoder": true,
  "is_vqa": false,
  "model_type": "pix2struct",
  "pad_token_id": 0,
  "text_config": {
    "_name_or_path": "",
    "add_cross_attention": false,
    "architectures": null,
    "bad_words_ids": null,
    "begin_suppress_tokens": null,
    "bos_token_id": null,
    "chunk_size_feed_forward": 0,
    "cross_attention_hidden_size": null,
    "d_ff": 2048,
    "d_kv": 64,
    "decoder_start_token_id": 0,
    "dense_act_fn": "gelu_new",
    "diversity_penalty": 0.0,
    "do_sample": false,
    "dropout_rate": 0.1,
    "early_stopping": false,
    "encoder_hidden_size": 768,
    "encoder_no_repeat_ngram_size": 0,
    "eos_token_id": 1,
    "exponential_decay_

In [244]:
processor = AutoProcessor.from_pretrained("google/pix2struct-ai2d-base")
model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-ai2d-base")

In [325]:
model.config.use_bfloat16 = True

In [326]:
model.config

Pix2StructConfig {
  "_commit_hash": "a1883ef7c6b2e731e8814bc12e6b49e65da4b60b",
  "_name_or_path": "google/pix2struct-ai2d-base",
  "architectures": [
    "Pix2StructForConditionalGeneration"
  ],
  "decoder_start_token_id": 0,
  "eos_token_id": 1,
  "initializer_factor": 1.0,
  "initializer_range": 0.02,
  "is_encoder_decoder": true,
  "is_vqa": false,
  "model_type": "pix2struct",
  "pad_token_id": 0,
  "text_config": {
    "_name_or_path": "",
    "add_cross_attention": false,
    "architectures": null,
    "bad_words_ids": null,
    "begin_suppress_tokens": null,
    "bos_token_id": null,
    "chunk_size_feed_forward": 0,
    "cross_attention_hidden_size": null,
    "d_ff": 2048,
    "d_kv": 64,
    "decoder_start_token_id": 0,
    "dense_act_fn": "gelu_new",
    "diversity_penalty": 0.0,
    "do_sample": false,
    "dropout_rate": 0.1,
    "early_stopping": false,
    "encoder_hidden_size": 768,
    "encoder_no_repeat_ngram_size": 0,
    "eos_token_id": 1,
    "exponential_decay_

In [311]:
# processor

In [245]:
test_image = Image.open(os.path.join(data_dir,"new_data","images","7.png-0.png"))
# test_image = Image.open(os.path.join(data_dir,"images","7.png")
question_text = ""

In [246]:
inputs = processor(images=test_image, text=question_text, return_tensors="pt")

In [247]:
generated_ids = model.generate(**inputs, max_new_tokens=50)

In [248]:
print(generated_ids)

tensor([[    0,   666,  1585, 48156,  1885,     1]])


In [249]:
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

In [250]:
print(generated_text)

Utricularia


In [251]:
processor

Pix2StructProcessor:
- image_processor: Pix2StructImageProcessor {
  "do_convert_rgb": true,
  "do_normalize": true,
  "image_processor_type": "Pix2StructImageProcessor",
  "is_vqa": true,
  "max_patches": 2048,
  "patch_size": {
    "height": 16,
    "width": 16
  },
  "processor_class": "Pix2StructProcessor"
}

- tokenizer: T5TokenizerFast(name_or_path='google/pix2struct-ai2d-base', vocab_size=50344, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<pad>', 'additional_special_tokens': ['<extra_id_0>', '<extra_id_1>', '<extra_id_2>', '<extra_id_3>', '<extra_id_4>', '<extra_id_5>', '<extra_id_6>', '<extra_id_7>', '<extra_id_8>', '<extra_id_9>', '<extra_id_10>', '<extra_id_11>', '<extra_id_12>', '<extra_id_13>', '<extra_id_14>', '<extra_id_15>', '<extra_id_16>', '<extra_id_17>', '<extra_id_18>', '<extra_id_19>', '<extra_id_20>', '<extra_id_21>', '<extra_

In [256]:
def evaluate_model_AI2D(image_list, samples_dict, model, processor, max_patches):
    question_text = ""
    cnt = 0
    correct_cnt = 0
    start_time = time.time()
    model = model.to('cuda')
    for image_file in image_list:
        
        image_idx = (image_file.split('.')[0]+ "."+ image_file.split('.')[1]).split("/")[-1]
        
        test_image = Image.open(image_file)
        
        inputs = processor(images=test_image, text=question_text, return_tensors="pt", max_patches=max_patches).to('cuda')
        generated_ids = model.generate(**inputs, max_new_tokens=50)
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        
        
#         print(generated_text)
#         print(samples_dict[image_idx]["raw_output"])
        
        if generated_text==samples_dict[image_idx]["raw_output"]:
            correct_cnt+=1
        
        cnt+=1
        if cnt == 20:
            break
    
    end_time = time.time()
    print(end_time-start_time)
    print(f"exact match accuracy {(1.0*correct_cnt)/cnt}")
    
#             break
    

In [257]:
# tstr = "0.png-0.png"
# print(tstr.split('.')[0]+ "."+ tstr.split('.')[1])

In [283]:
evaluate_model_AI2D(metadata_dict["test_split_new_images"], samples_dict, model, processor, max_patches=4096)

12.683918237686157
exact match accuracy 0.2


In [None]:
## cpu
# 217.14571237564087
# exact match accuracy 0.2

In [None]:
## gpu
# 13.092732191085815
# exact match accuracy 0.2

In [264]:
base_processor = AutoProcessor.from_pretrained("google/pix2struct-base")

In [298]:
def create_text_data(samples_dict, max_new_tokens=128, processor=None):
    
    text_data = {
        "targets": [],
        "raw_output": [],
        "sample_list":[]
    }

    sample_keys = list(samples_dict.keys())
    sample_keys.sort()
    
    text_data["sample_list"] = sample_keys

    for key in sample_keys:  
        text_data["raw_output"].append(samples_dict[key]["raw_output"])
        
    text_data["targets"] =  processor(text=text_data["raw_output"], 
                              padding=True, 
                              truncation=True, 
                              return_tensors="pt", 
                              add_special_tokens=True, 
                              max_length=max_new_tokens)
    
    return text_data

In [299]:
text_data = create_text_data(samples_dict, max_new_tokens=52, processor = base_processor)

In [367]:
text_data["sample_list"]

['0.png-0',
 '0.png-1',
 '10.png-0',
 '10.png-1',
 '100.png-0',
 '100.png-1',
 '100.png-2',
 '1000.png-0',
 '1000.png-1',
 '1000.png-2',
 '1001.png-0',
 '1001.png-1',
 '1001.png-2',
 '1002.png-0',
 '1002.png-1',
 '1003.png-0',
 '1003.png-1',
 '1004.png-0',
 '1004.png-1',
 '1005.png-0',
 '1005.png-1',
 '1005.png-2',
 '1005.png-3',
 '1006.png-0',
 '1007.png-0',
 '1007.png-1',
 '1007.png-2',
 '1008.png-0',
 '1009.png-0',
 '101.png-0',
 '101.png-1',
 '101.png-2',
 '101.png-3',
 '101.png-4',
 '1010.png-0',
 '1010.png-1',
 '1010.png-2',
 '1012.png-0',
 '1012.png-1',
 '1013.png-0',
 '1013.png-1',
 '1014.png-0',
 '1014.png-1',
 '1015.png-0',
 '1016.png-0',
 '1017.png-0',
 '1017.png-1',
 '1018.png-0',
 '102.png-0',
 '1020.png-0',
 '1020.png-1',
 '1020.png-2',
 '1021.png-0',
 '1021.png-1',
 '1022.png-0',
 '1022.png-1',
 '1022.png-2',
 '1022.png-3',
 '1023.png-0',
 '1024.png-0',
 '1024.png-1',
 '1025.png-0',
 '1025.png-1',
 '1026.png-0',
 '1026.png-1',
 '1026.png-2',
 '1026.png-3',
 '1028.png-0',

In [368]:
metadata_dict["test_split_new_images"]

['/local1/rwadhawan7/ai2d/data/new_data/images/4875.png-1.png',
 '/local1/rwadhawan7/ai2d/data/new_data/images/4875.png-0.png',
 '/local1/rwadhawan7/ai2d/data/new_data/images/4683.png-2.png',
 '/local1/rwadhawan7/ai2d/data/new_data/images/4683.png-0.png',
 '/local1/rwadhawan7/ai2d/data/new_data/images/4683.png-1.png',
 '/local1/rwadhawan7/ai2d/data/new_data/images/1422.png-1.png',
 '/local1/rwadhawan7/ai2d/data/new_data/images/1422.png-2.png',
 '/local1/rwadhawan7/ai2d/data/new_data/images/1422.png-0.png',
 '/local1/rwadhawan7/ai2d/data/new_data/images/4874.png-1.png',
 '/local1/rwadhawan7/ai2d/data/new_data/images/4874.png-0.png',
 '/local1/rwadhawan7/ai2d/data/new_data/images/2676.png-3.png',
 '/local1/rwadhawan7/ai2d/data/new_data/images/2676.png-1.png',
 '/local1/rwadhawan7/ai2d/data/new_data/images/2676.png-0.png',
 '/local1/rwadhawan7/ai2d/data/new_data/images/2676.png-2.png',
 '/local1/rwadhawan7/ai2d/data/new_data/images/4719.png-0.png',
 '/local1/rwadhawan7/ai2d/data/new_data/

In [373]:
(metadata_dict["test_split_new_images"][5].split('.')[0] + "." + metadata_dict["test_split_new_images"][5].split('.')[1]).split('/')[-1]

'1422.png-1'

In [379]:
processor.image_processor.is_vqa = False

In [380]:
processor

Pix2StructProcessor:
- image_processor: Pix2StructImageProcessor {
  "do_convert_rgb": true,
  "do_normalize": true,
  "hello": true,
  "image_processor_type": "Pix2StructImageProcessor",
  "is_vqa": false,
  "max_patches": 2048,
  "patch_size": {
    "height": 16,
    "width": 16
  },
  "processor_class": "Pix2StructProcessor"
}

- tokenizer: T5TokenizerFast(name_or_path='google/pix2struct-ai2d-base', vocab_size=50344, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<pad>', 'additional_special_tokens': ['<extra_id_0>', '<extra_id_1>', '<extra_id_2>', '<extra_id_3>', '<extra_id_4>', '<extra_id_5>', '<extra_id_6>', '<extra_id_7>', '<extra_id_8>', '<extra_id_9>', '<extra_id_10>', '<extra_id_11>', '<extra_id_12>', '<extra_id_13>', '<extra_id_14>', '<extra_id_15>', '<extra_id_16>', '<extra_id_17>', '<extra_id_18>', '<extra_id_19>', '<extra_id_20>', '<extra