In [136]:
import os
import numpy as np
import torch
import os
import re
import json
import argparse
import pandas as pd
import random
from transformers import T5Tokenizer, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, T5ForConditionalGeneration
from rich.table import Column, Table
from rich import box
from rich.console import Console
console = Console(record=True)
from torch import cuda
import nltk
import evaluate
import pdfkit
from pdf2image import convert_from_path
import glob, sys, fitz
from torch.utils.data import Dataset, DataLoader
from collections import namedtuple

from PIL import Image, ImageDraw, ImageOps, ImageFont
import requests
from transformers import AutoProcessor, Pix2StructForConditionalGeneration, Pix2StructConfig

import textwrap
import math
import string
import pickle
import matplotlib.pyplot as plt
from rouge import Rouge
from typing import Any, Callable, Iterable, List, Optional

# AI2D

In [37]:
def read_json_file(filepath):
    with open(filepath) as json_file:
        data = json.load(json_file)
    return data


In [38]:
# problem_list = read_json_file(os.path.join(os.getcwd(), "data", "scienceqa", "problems.json"))

In [115]:
def load_pickle_dataset(save_dir, source=""):
    pickle_filename = os.path.join(save_dir, f"{source}.pkl")
    with open(pickle_filename, 'rb') as f:
        pickle_data = pickle.load(f)
    return pickle_data

In [39]:
def render_text_on_bounding_box(
    text: str,
    bounding_box: Iterable[Iterable[int]],
    image: Image.Image,
    font_path: str):
    
    """Render text on top of a specific bounding box."""
    draw = ImageDraw.Draw(image)
    (x0, y0), (x1, y1) = bounding_box
    
    draw.rectangle(xy=[(x0, y0), (x1, y1)], fill=(255, 255, 255, 255))
    
    fontsize = 1
    def _can_increment_font(ratio=0.95):
        next_font = ImageFont.truetype(
            font_path, encoding="UTF-8", size=fontsize + 1)
        width, height = next_font.getsize(text)
        return width < ratio * (x1 - x0) and height < ratio * (y1 - y0)

    while _can_increment_font():
        fontsize += 1
    font = ImageFont.truetype(font_path, encoding="UTF-8", size=fontsize)

    draw.text(
        xy=((x0 + x1)/2, (y0 + y1)/2),
        text=text,
        font=font,
        fill="black",
        anchor="mm"
    )
    

In [141]:
def render_text(text: str,
                text_size: int = 36,
                text_color: str = "black",
                background_color: str = "white",
                left_padding: int = 5,
                right_padding: int = 5,
                top_padding: int = 5,
                bottom_padding: int = 5,
                font_path: str = "") -> Image.Image:

    """Render text."""
    # Add new lines so that each line is no more than 80 characters.
    wrapper = textwrap.TextWrapper(width=80)
    lines = wrapper.wrap(text=text)
    wrapped_text = "\n".join(lines)

    font = ImageFont.truetype(font_path, encoding="UTF-8", size=text_size)

    # Use a temporary canvas to determine the width and height in pixels when
    # rendering the text.
    temp_draw = ImageDraw.Draw(Image.new("RGB", (1, 1), background_color))
    _, _, text_width, text_height = temp_draw.textbbox((0, 0), wrapped_text, font)

    # Create the actual image with a bit of padding around the text.
    image_width = text_width + left_padding + right_padding
    image_height = text_height + top_padding + bottom_padding
    image = Image.new("RGB", (image_width, image_height), background_color)
    draw = ImageDraw.Draw(image)
    draw.text(
      xy=(left_padding, top_padding),
      text=wrapped_text,
      fill=text_color,
      font=font)
    
    return image

In [142]:
def render_header(image: Image.Image, header: str, font_path: str) -> Image.Image:
    """Renders a header on a PIL image and returns a new PIL image."""
    header_image = render_text(header, font_path=font_path)
    new_width = max(header_image.width, image.width)

    new_height = int(image.height *  (new_width / image.width))
    new_header_height = int(
        header_image.height * (new_width / header_image.width))

    new_image = Image.new(
        "RGB",
        (new_width, new_height + new_header_height),
        "white")
    new_image.paste(header_image.resize((new_width, new_header_height)), (0, 0))
    new_image.paste(image.resize((new_width, new_height)), (0, new_header_height))

    return new_image

In [143]:
def save_dataset(dataset, save_dir="", filename=""):
    pickle_filename = os.path.join(save_dir, filename)
    with open(pickle_filename, 'wb') as f:
        pickle.dump(dataset, f, protocol=pickle.HIGHEST_PROTOCOL)

In [144]:
def convert_one_question_AI2D(input_path: str, data_dir: str, font_path: str, skip_image_gen: bool):
  
    """Convert example."""
    samples_dict = {}
    
    data = read_json_file(os.path.join(data_dir, input_path)) # till ai2d folder + question path
    
    qid = -1
    if not data["questions"]:
        return samples_dict, idx, data["imageName"]
    
    annotation = read_json_file(os.path.join(data_dir, "annotations", f"{data['imageName']}.json"))

    if skip_image_gen:
        image_filepath = os.path.join(data_dir, "images", data["imageName"])
        image = Image.open(image_filepath)
        image_with_placeholders = image.copy()
    
        for v in annotation["text"].values():
            render_text_on_bounding_box(
                text=v["replacementText"],
                bounding_box=v["rectangle"],
                image=image_with_placeholders,
                font_path = font_path)

    for k, v in data["questions"].items():
        
        samples_dict[v["questionId"]] = {}
        # The `image_id` field is only used to ensure correct splitting of the data.
        options = " ".join(
            f"({string.ascii_lowercase[i]}) {a}"
            for i, a in enumerate(v["answerTexts"])
        )
        
        if skip_image_gen:
            image_with_header = render_header(
                image=image_with_placeholders if v["abcLabel"] else image,
                header=f"{k} {options}",
                font_path = font_path
            )

            # save new image
            save_dir = os.path.join(data_dir, "new_data","images")
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)

            image_with_header.save(os.path.join(save_dir, f"{v['questionId']}.png"))

        # get output for this sample
        parse = v["answerTexts"][v["correctAnswer"]]
        
        
        # update sample dict with info and meta
        samples_dict[idx]["src_image_name"] = data["imageName"]
        samples_dict[idx]["raw_output"] = parse
        samples_dict[idx]["header_text"] = f"{k} {options}"
        samples_dict[idx]["abcLabel"] = v["abcLabel"]
    
    
    return samples_dict, qid

In [145]:
def convert_AI2D(data_dir: str, font_path: str, skip_image_gen:bool=False):
    
    # list question folder files
    # create sample dict
    samples_dict = {}
    missing_question_list = []
    
    for file in os.listdir(os.path.join(data_dir,"questions")):
        filepath = os.path.join("questions", file)
        one_question_sample_dict, qid = convert_one_question_AI2D(filepath, data_dir, font_path, skip_image_gen)
        samples_dict.update(one_question_sample_dict)
        
        if qid!=-1:
            missing_question_list.append(qid)
        
    # save sample dict
    save_dataset(
        samples_dict,
        save_dir = os.path.join(data_dir,"new_data"),
        filename = "samples_dict.pkl"
    )
    
    save_dataset(
        samples_dict,
        save_dir = os.path.join(data_dir,"new_data"),
        filename = "missing_question_list.pkl"
    )
    return samples_dict, missing_question_list


In [146]:
data_dir = "../ai2d"
font_path = "../ai2d/arial.ttf"

In [None]:
sample_dict, missing_question_list = convert_AI2D(data_dir, font_path, skip_image_gen=True)

In [130]:
test_split_src_image = list(pd.read_csv("../ai2d/ai2d_test_ids.csv").to_numpy().reshape(-1))

In [138]:
samples_dict = load_pickle_dataset("../ai2d/new_data","samples_dict")

In [133]:
train_split_src_image = []
for i in range(1, 4908):
    if i not in test_split_src_image:
        train_split_src_image.append(i)

In [134]:
print(len(train_split_src_image))
print(len(test_split_src_image))

3927
981


In [None]:
test_split_new_images = []
train_split_new_images = []
image_dir = "../ai2d/new_data_images"
for file in os.listdir(image_dir):
    if int(file.split('.')[0]) in test_split_src_image:
        test_split_new_images.append(os.path.join(image_dir, file))
    
    else:
        train_split_new_images.append(os.path.join(image_dir, file))

In [86]:
processor = AutoProcessor.from_pretrained("google/pix2struct-ai2d-base")
model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-ai2d-base")

In [105]:
test_image = Image.open("../ai2d/new_data/images/7.png-0.png")
# test_image = Image.open("../ai2d/0.png-0.png")
question_text = ""

In [106]:
inputs = processor(images=test_image, text=question_text, return_tensors="pt")

In [107]:
generated_ids = model.generate(**inputs, max_new_tokens=50)

In [108]:
print(generated_ids)

tensor([[    0,   666,  1585, 48156,  1885,     1]])


In [113]:
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

In [114]:
print(generated_text)

Utricularia


In [None]:
def evaluate_model_AI2D(image_list, data_dict, model, processor):
    question_text = ""
    for image_file in image_list:
        test_image = Image.open(image_file)
        inputs = processor(images=test_image, text=question_text, return_tensors="pt")
        generated_ids = model.generate(**inputs, max_new_tokens=50)
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        

In [None]:
def process_AI2D():
    pass

In [None]:
def get_train_test_split():
    pass

In [None]:
def evaluate_AI2D():
    pass