<a href="https://colab.research.google.com/github/phaethonp/we-ai/blob/main/pic2struct.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Pix2Struct is an AI model that converts images to text, particularly focusing on understanding and interpreting visually-situated language. It learns from web screenshots, creating simplified HTML versions.

Features of Pix2Struct:

1. Pretrained on web screenshots: Uses the rich, diverse data available on the web for pretraining.

2. Fine-tuned for specific tasks: Proven results on a variety of tasks, including image captioning, visual question answering, etc.

3. Flexible language and vision input integration: Can overlay language prompts on the input image.

4. Variable-resolution input representation: Capable of handling inputs of varying resolutions.

5. Single model, multiple domains: Effective on tasks across different domains like documents, illustrations, user interfaces, and natural images.

6. Contributed by open-source community: Maintained and improved by a global community of developers and researchers.

In [None]:
!pip install git+https://github.com/huggingface/transformers.git
!pip install requests

In [None]:
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
from matplotlib import pyplot as plt
import functools
from google.colab import files
import PIL


In [None]:
# Upload file here
def upload_files():
  uploaded = files.upload()
  images = []
  for k, v in uploaded.items():
      open(k, 'wb').write(v)
      images.append(PIL.Image.open(k))
  return images
images = upload_files()

In [None]:
model_urls = {
    "textcaps": "google/pix2struct-textcaps-large", # Finetuned on TextCaps
    "screen2words": "google/pix2struct-screen2words-large", # Finetuned on Screen2Words
    "widgetcaption": "google/pix2struct-widget-captioning-large", # Finetuned on Widget Captioning (captioning a UI component on a screen)
    "infographics": "google/pix2struct-infographics-vqa-large", # Infographics
    "docvqa": "google/pix2struct-docvqa-large", # Visual question answering
    "ai2d": "google/pix2struct-ai2d-large", # Scienfic diagram
}

models = {}
def get_model(model_name):
    if model_name not in models:
        print(f"Loading {model_name} from {model_urls[model_name]}")
        model = Pix2StructForConditionalGeneration.from_pretrained(model_urls[model_name]).to("cuda")
        processor = Pix2StructProcessor.from_pretrained(model_urls[model_name])
        models[model_name] = (model, processor)
    return models[model_name]

def run_model(model_name, text=None):
    text = text or "where is the main button on this page?"
    model_url = model_urls[model_name]
    model, processor = get_model(model_name)
    if processor.image_processor.is_vqa:
        print(f"Adding prompt for VQA model: '{text}'")
        inputs = processor(images[0], return_tensors="pt", text=text).to("cuda")
    else:
        inputs = processor(images[0], return_tensors="pt").to("cuda")
    predictions = model.generate(**inputs)
    print(f"Name: '{model_name}'")
    plt.imshow(images[0])
    plt.axis("off")
    plt.show()
    print(f"Output: '{processor.decode(predictions[0], skip_special_tokens=True)}'")
    

In [None]:
run_model("screen2words", text="what does this screen say?")

In [None]:
run_model("widgetcaption", text="what are the buttons on this page?")

In [None]:
# If you run out of space, you might need to clear your RAM. I should probably switch the cache to some kind of LRU cache.
del models
import gc
gc.collect()