# Test Pix2Struct model on Synthetic Bootstrap dataset (mini version)

## Setup Envirnoment

In [2]:
pip install transformers==4.36.2


Defaulting to user installation because normal site-packages is not writeable
Collecting tokenizers<0.19,>=0.14 (from transformers==4.36.2)
  Using cached tokenizers-0.15.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.6 MB)
Installing collected packages: tokenizers
  Attempting uninstall: tokenizers
    Found existing installation: tokenizers 0.21.0
    Uninstalling tokenizers-0.21.0:
      Successfully uninstalled tokenizers-0.21.0
Successfully installed tokenizers-0.15.2

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.1.2[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49m/opt/software-current/2023.06/x86_64/generic/software/Python/3.11.3-GCCcore-12.3.0/bin/python -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [8]:
# from pathlib import Path

# # Directory where the HTML files are located
# html_dir = Path("data/WebSight_test_data/images")

# # Select all *.html files that are named with 4-digit or more numbers
# html_files = sorted(html_dir.glob("[0-9][0-9][0-9][0-9]*.png"))

# # Rename to 5-digit zero-padded format
# for file in html_files:
#     number_part = file.stem  # e.g., "1000"
#     if number_part.isdigit():
#         new_name = f"{int(number_part):05}.png"
#         new_path = html_dir / new_name
#         file.rename(new_path)
#         print(f"Renamed {file.name} → {new_name}")


## Import necessary libraries

In [1]:
import os
import zipfile
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import re
from transformers import Pix2StructForConditionalGeneration, AutoProcessor
import torch
from torch.nn import functional as F
from pathlib import Path
from nltk import edit_distance
import numpy as np
from tqdm import tqdm
from nltk.translate.bleu_score import corpus_bleu, sentence_bleu, SmoothingFunction
from torch.utils.data import random_split
import random

## Define variables and parameters

In [2]:
# FOLDER_CHECKPOINTS = 'checkpoints'
DATASET_NAME = 'WebSight/'
# ZIP_NAME = DATASET_NAME + '.zip'
DESTINATION_FOLDER= 'data/'
DATASET_FOLDER = DESTINATION_FOLDER + DATASET_NAME
HTML_FILES_FOLDER = DATASET_FOLDER + "html/"
OUTPUT_FOLDER = 'experiment/'

EXPERIMENT_NAME = "experiment"

MAX_SENTENCE_LEN = 4096

CHUNK_LENGTH = 1024
CONTEXT_OVERLAP_LENGTH = 256

MAX_PATCHES = 1024

DEBUG = False
VERBOSE = True

BATCH_SIZE = 10

TRAIN_SET_PERCENTAGE = 0.88
VALID_SET_PERCENTAGE = 0.02 # Use 20 for validation
# TEST_SET_PERCENTAGE is 1 - TRAIN_SET_PERCENTAGE - VALID_SET_PERCENTAGE # Use 100 for test

RANDOM_SEED = 123

LOAD_FROM_CHECKPOINT = True
LAST_CHECKPOINT_NAME = "model/SynthBootstrap_epoch[34]_bleu[0.89].pth"

# WebSight_epoch[29]_bleu[0.67].pth

In [3]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
MAX_N_CHUNKS_PER_SENTENCE = 1 + (MAX_SENTENCE_LEN - CHUNK_LENGTH) // (CHUNK_LENGTH - CONTEXT_OVERLAP_LENGTH)
print("MAX_N_CHUNKS_PER_SENTENCE", MAX_N_CHUNKS_PER_SENTENCE)

MAX_N_CHUNKS_PER_SENTENCE 5


## Load Model and Processor

In [5]:
repo_id = "google/pix2struct-base"

processor = AutoProcessor.from_pretrained(repo_id)
model = Pix2StructForConditionalGeneration.from_pretrained(repo_id, is_encoder_decoder=True)

## Create Dataset class

### Preprocessing functions

In [6]:
def round_floats_in_text(text, precision=0):
    # match float numbers with 2 or more decimal places in the text
    pattern = r"\b\d+\.\d{2,}\b"

    def replace(match):
        float_number = float(match.group())
        return f"{float_number:.{precision}f}"

    text = re.sub(pattern, replace, text)
    return text

In [7]:
def remove_html_comments(text):
    # match html comments
    pattern = r"<!--.*?-->"

    text = re.sub(pattern, '', text, flags=re.DOTALL)
    return text

In [8]:
def preprocess_html_file(html_text):
    text_cleaned = html_text.replace('\n', ' ')
    text_cleaned_without_multiple_spaces = re.sub(r'\s+', ' ', text_cleaned)
    text_without_comments = remove_html_comments(text_cleaned_without_multiple_spaces)
    text_without_long_floats = round_floats_in_text(text_without_comments)
    return text_without_long_floats

### Find max sentence length and new unknown tokens

In [9]:
# Find max length
all_paths = os.listdir(HTML_FILES_FOLDER)

In [10]:
print(HTML_FILES_FOLDER)

data/WebSight/html/


In [11]:
len(all_paths)

1011

In [12]:
all_paths = all_paths[:10000]
extra_paths = all_paths[1000:]

In [13]:
extra_paths

['00999.html',
 '01000.html',
 '01001.html',
 '01002.html',
 '01003.html',
 '01004.html',
 '01005.html',
 '01006.html',
 '01007.html',
 '01008.html',
 '01009.html']

In [14]:
# # Find max length
# max_length = 0

# # Read text files and add new tokens to dictionary
# tokens_to_add = set()

# for html_file_path in all_paths:
#     file_path = os.path.join(HTML_FILES_FOLDER, html_file_path)
    
#     if os.path.isdir(file_path):
#         continue  # Skip directories
    
#     with open(file_path, "r") as reader:
#         splitted_text = processor.tokenizer(preprocess_html_file(reader.read())).tokens()
#         tokens_to_add = tokens_to_add.union(set(splitted_text))


    
#     # with open(HTML_FILES_FOLDER + html_file_path, "r") as reader:


#     #     splitted_text = processor.tokenizer(preprocess_html_file(reader.read())).tokens()
#     #     tokens_to_add = tokens_to_add.union(set(splitted_text))

#     # Check if the current sentence has the largest number of tokens
#     if len(splitted_text) > max_length:
#         max_length = len(splitted_text)

# print(f"Max sentence length = {max_length}")

# newly_added_num = processor.tokenizer.add_tokens(list(tokens_to_add))
# print(f"Number of new tokens = {newly_added_num}")

# # Resize the model's token embeddings if there are new tokens
# if newly_added_num > 0:
#     model.decoder.resize_token_embeddings(len(processor.tokenizer))


import chardet

# Find max length
max_length = 0
tokens_to_add = set()

for html_file_path in all_paths:
    file_path = os.path.join(HTML_FILES_FOLDER, html_file_path)
    
    if os.path.isdir(file_path):
        continue  # Skip directories
    
    # Detect encoding first
    with open(file_path, 'rb') as raw_reader:
        raw_data = raw_reader.read()
        encoding = chardet.detect(raw_data)['encoding']
    
    if encoding is None:
        print(f"⚠️ Skipping file due to unknown encoding: {file_path}")
        continue

    try:
        decoded_text = raw_data.decode(encoding)
    except Exception as e:
        print(f"⚠️ Skipping file {file_path} due to decoding error: {e}")
        continue

    try:
        splitted_text = processor.tokenizer(preprocess_html_file(decoded_text)).tokens()
        tokens_to_add.update(splitted_text)

        if len(splitted_text) > max_length:
            max_length = len(splitted_text)
    except Exception as e:
        print(f"⚠️ Tokenization error in {file_path}: {e}")
        continue

print(f"Max sentence length = {max_length}")

newly_added_num = processor.tokenizer.add_tokens(list(tokens_to_add))
print(f"Number of new tokens = {newly_added_num}")

if newly_added_num > 0:
    model.decoder.resize_token_embeddings(len(processor.tokenizer))


The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


Max sentence length = 1493
Number of new tokens = 4525


The new lm_head weights will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


### Split files into training - validation - test sets

In [15]:
# random.seed(RANDOM_SEED)

# # Use the same seed, so that parts remain the same
# random.shuffle(all_paths)

train_len = int(TRAIN_SET_PERCENTAGE * len(all_paths))
valid_len = int(VALID_SET_PERCENTAGE * len(all_paths))

train_paths = all_paths[:train_len]
valid_paths = all_paths[train_len:train_len+valid_len]
test_paths = all_paths[train_len+valid_len:] + extra_paths

print(f"TRAIN_SET size = {len(train_paths)}")
print(f"VALID_SET size = {len(valid_paths)}")
print(f"TEST_SET size = {len(test_paths)}")

TRAIN_SET size = 889
VALID_SET size = 20
TEST_SET size = 113


In [16]:
class SythBootstrapDataset(Dataset):
    def __init__(self, root_dir, transform, text_files_paths):

        self.root_dir = root_dir
        self.transform = transform
        self.text_files_paths = text_files_paths

        self.max_patches = MAX_PATCHES
        self.max_length = MAX_SENTENCE_LEN
        self.ignore_id = -100

        self.encodings = []

        for text_file in tqdm(text_files_paths):
            image_file = text_file.replace('.html', '.png')

            # Directly process the text files, and save them in the ram
            # Do the same also for images, if there is enough space in memory
            text_file_path = os.path.join(root_dir + "html/", text_file)
            image_file_path = os.path.join(root_dir + "images/", image_file)
            # print(image_file_path)
            # Load image
            image = Image.open(image_file_path).convert('RGB')

            if DEBUG:
                image.show()

            if self.transform:
                image = self.transform(image)

            encoding = processor(images=image, max_patches=self.max_patches, return_tensors="pt")
            encoding = {k:v.squeeze() for k,v in encoding.items()}

            # Load text
            with open(text_file_path, 'r') as f:
                text = f.read()
                text_cleaned = preprocess_html_file(text)

            if DEBUG:
              print("text:")
              print(text)
              print("\n\n\ntext_cleaned:")
              print(text_cleaned)

            input_ids = processor.tokenizer(
                text_cleaned,
                max_length=self.max_length,
                padding="max_length",
                truncation=True,
                return_tensors="pt",
            ).input_ids

            labels = input_ids.squeeze().clone()
            labels[labels == processor.tokenizer.pad_token_id] = self.ignore_id  # model doesn't need to predict pad token

            encoding["labels"] = labels.to(torch.int32)

            # For each sample save directly the encoding of both text and image
            self.encodings.append(encoding)

    def __len__(self):
        return len(self.encodings)

    def __getitem__(self, idx):
        return self.encodings[idx], self.text_files_paths[idx].replace(".html", "")

In [17]:
# Transformations for the image
transform = transforms.Compose([
    transforms.ToTensor(),  # convert PIL Image to PyTorch Tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # normalize for pretrained models
])

# Instantiate the CustomDataset
test_dataset = SythBootstrapDataset(DATASET_FOLDER, transform, test_paths)

# Use DataLoader for batching and shuffling
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

100%|██████████| 113/113 [00:13<00:00,  8.17it/s]


In [18]:
print(f"test_dataloader size = {len(test_dataloader)}")

test_dataloader size = 12


In [19]:
batch = next(iter(test_dataloader))

In [20]:
encoding, text_file_paths = batch

In [21]:
print(len(encoding))

3


In [22]:
print(text_file_paths)

('00908', '00909', '00910', '00911', '00912', '00913', '00914', '00915', '00916', '00917')


### Main Testing function

In [23]:
START_TOKEN_ID = PAD_TOKEN_ID = processor.tokenizer.pad_token_id

In [24]:
def testing_loop(testing_dataloader, model, processor, config, description):
    model.eval()
    bleu_scores = []
    ed_scores = []

    with torch.no_grad():
        test_loop = tqdm(enumerate(testing_dataloader), total=len(testing_dataloader), desc=description)
        for i, batch in test_loop:
            encoding, text_file_paths = batch
            encoding = move_to_device(encoding)
            labels, flattened_patches, attention_mask = encoding["labels"], encoding["flattened_patches"], encoding["attention_mask"]

            # Initialize total_outputs with zeros
            total_outputs = None
            context_from_last = None

            # Initialize a mask to track which sentences are finished
            finished_sentences_mask = torch.zeros(flattened_patches.size(0), dtype=torch.bool, device=flattened_patches.device)

            for iteration in range(MAX_N_CHUNKS_PER_SENTENCE):

                generate_args = {
                    "flattened_patches": flattened_patches[~finished_sentences_mask],
                    "attention_mask": attention_mask[~finished_sentences_mask],
                    "max_new_tokens": CHUNK_LENGTH - (CONTEXT_OVERLAP_LENGTH if iteration else 0),
                }

                if iteration and context_from_last is not None:
                    generate_args["decoder_input_ids"] = context_from_last[~finished_sentences_mask]

                outputs = model.generate(**generate_args)

                # Remove context overlap only from the second iteration onwards
                new_chunks = outputs if iteration == 0 else outputs[:, CONTEXT_OVERLAP_LENGTH:]

                if iteration == 0:
                    total_outputs = new_chunks
                else:
                    # Update total_outputs by concatenating new chunks
                    new_chunks_with_padding_chunks = torch.full((flattened_patches.shape[0], new_chunks.shape[1]), PAD_TOKEN_ID, dtype=new_chunks.dtype, device=new_chunks.device)
                    new_chunks_with_padding_chunks[~finished_sentences_mask] = new_chunks
                    total_outputs = torch.cat((total_outputs, new_chunks_with_padding_chunks), dim=1)

                # Update the finished_sentences_mask
                finished_sentences_mask[~finished_sentences_mask] |= (outputs == processor.tokenizer.eos_token_id).any(dim=1)

                # If all sentences are finished, exit the loop
                if finished_sentences_mask.all():
                    break

                if outputs.shape[1] < CHUNK_LENGTH:
                    print("ERROR: !! should have already exited because all sentences reached the end!!")

                # -1 because it will put in front a START_TOKEN automatically
                context_from_last = total_outputs[:, -(CONTEXT_OVERLAP_LENGTH-1):]

            predictions = processor.tokenizer.batch_decode(total_outputs, skip_special_tokens=True)

            labels[labels == -100] = 0
            answers = processor.tokenizer.batch_decode(labels, skip_special_tokens=True)

            for pred, answer, text_file_path in zip(predictions, answers, text_file_paths):
                with open(f"{OUTPUT_FOLDER}/{text_file_path}_pred.txt", "w") as f:
                    print(pred, file=f)

                with open(f"{OUTPUT_FOLDER}/{text_file_path}_answer.txt", "w") as f:
                    print(answer, file=f)
    return

In [25]:
config = {
          "verbose": VERBOSE,
}

In [26]:
def validate_config(config):
    # Check required keys
    required_keys = [
        "verbose"
    ]
    for key in required_keys:
        if key not in config:
            raise ValueError(f"Key '{key}' must be present in the configuration.")

    # Check that values are in expected ranges
    if not isinstance(config["verbose"], bool):
        raise ValueError("verbose must be a boolean value.")

In [27]:
validate_config(config)
print(config)

{'verbose': True}


### Utility functions

In [28]:
def move_to_device(data):
    if isinstance(data, (list,tuple)):
        return [move_to_device(x) for x in data]
    elif isinstance(data, dict):
        return {k: move_to_device(v) for k, v in data.items()}
    elif isinstance(data, torch.Tensor):
        return data.to(DEVICE)
    else:
        return data

## Test the model

In [31]:
def test_model(config, processor, model):
    print("Loading model from checkpoint: ", LAST_CHECKPOINT_NAME)
    
    checkpoint = torch.load(LAST_CHECKPOINT_NAME)
    model.resize_token_embeddings(50244)  ##
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(DEVICE)
    testing_loop(test_dataloader, model, processor, config, f"Test loop")

In [32]:
test_model(config, processor, model)

Loading model from checkpoint:  model/SynthBootstrap_epoch[34]_bleu[0.89].pth


Test loop: 100%|██████████| 12/12 [30:16<00:00, 151.39s/it]
